1 module lmpl4d.unpacker;
2
3 import lmpl4d.common;
4
5 struct Unpacker(Stream = const(ubyte)[]) if(isInputBuffer!(Stream, ubyte))
6 {
7 Stream buf;
8 size_t pos;
9 this(Stream stream) { buf = stream; }
10
11 version(betterC){
12 void rollback(size_t size, string expected, Format actual = Format.NONE) {
13 pos -= size + 1;
14 }
15
16 void check(size_t size = 1) {}
17 } else {
18 T unpack(T)()
19 if (is(Unqual!T == enum) || isPointer!T || isSomeChar!T || isNumeric!T || is(Unqual!T == bool))
20 {
21 static if (is(Unqual!T == enum))
22 return cast(T)unpack!(OriginalType!T);
23 else static if (isPointer!T) {
24 T val;
25 if (unpackNil(val))
26 return val;
27 throw new UnpackException("Can't deserialize a pointer that is not null");
28 } else static if (is(Unqual!T == char))
29 return cast(T)unpack!ubyte;
30 else static if (is(Unqual!T == wchar))
31 return cast(T)unpack!ushort;
32 else static if (is(Unqual!T == dchar))
33 return cast(T)unpack!uint;
34 else static if (isNumeric!T || is(Unqual!T == bool)) {
35 check();
36 int header = read();
37 static if (isIntegral!T) {
38 if (header <= 0x7f)
39 return cast(T)header;
40 }
41 switch (header) {
42 static if (is(Unqual!T == bool)) {
43 case Format.TRUE:
44 return true;
45 case Format.FALSE:
46 return false;
47 } else static if (isIntegral!T) {
48 case Format.UINT8:
49 check(ubyte.sizeof);
50 return read();
51 case Format.UINT16:
52 check(ushort.sizeof);
53 auto val = load!ushort(read(ushort.sizeof));
54 if (val > T.max)
55 rollback(ushort.sizeof, T.stringof, Format.UINT16);
56 return cast(T)val;
57 case Format.UINT32:
58 check(uint.sizeof);
59 auto val = load!uint(read(uint.sizeof));
60 if (val > T.max)
61 rollback(uint.sizeof, T.stringof, Format.UINT32);
62 return cast(T)val;
63 case Format.UINT64:
64 check(ulong.sizeof);
65 auto val = load!ulong(read(ulong.sizeof));
66 if (val > T.max)
67 rollback(ulong.sizeof, T.stringof, Format.UINT64);
68 return cast(T)val;
69 case Format.INT8:
70 check(byte.sizeof);
71 return cast(T)read();
72 case Format.INT16:
73 check(short.sizeof);
74 auto val = load!short(read(short.sizeof));
75 if (val < T.min || T.max < val)
76 rollback(short.sizeof, T.stringof, Format.INT16);
77 return cast(T)val;
78 case Format.INT32:
79 check(int.sizeof);
80 auto val = load!int(read(int.sizeof));
81 if (val < T.min || T.max < val)
82 rollback(int.sizeof, T.stringof, Format.INT32);
83 return cast(T)val;
84 case Format.INT64:
85 check(long.sizeof);
86 auto val = load!long(read(long.sizeof));
87 if (val < T.min || T.max < val)
88 rollback(long.sizeof, T.stringof, Format.INT64);
89 return cast(T)val;
90 } else static if (isFloatingPoint!T) {
91 case Format.FLOAT:
92 _f val;
93 check(uint.sizeof);
94 val.i = load!uint(read(uint.sizeof));
95 return val.f;
96 case Format.DOUBLE:
97 // check precision loss
98 static if (is(Unqual!T == float))
99 rollback(0, T.stringof, Format.DOUBLE);
100
101 _d val;
102 check(ulong.sizeof);
103 val.i = load!ulong(read(ulong.sizeof));
104 return val.f;
105 case Format.REAL:
106 static if (!EnableReal) {
107 rollback(0, "real is disabled", Format.REAL);
108 break;
109 }
110 else
111 {
112 // check precision loss
113 static if (is(Unqual!T == float) || is(Unqual!T == double))
114 rollback(0, T.stringof, Format.REAL);
115 check(real.sizeof);
116 version (NonX86)
117 {
118 CustomFloat!80 tmp;
119
120 const frac = load!ulong (read(ulong.sizeof));
121 const exp = load!ushort(read(ushort.sizeof));
122
123 tmp.significand = frac;
124 tmp.exponent = exp & 0x7fff;
125 tmp.sign = (exp & 0x8000) != 0;
126
127 // NOTE: tmp.get!real is inf on non-x86 when deserialized value is larger than double.max.
128 return tmp.get!real;
129 }
130 else
131 {
132 _r tmp;
133
134 tmp.fraction = load!(typeof(tmp.fraction))(read(tmp.fraction.sizeof));
135 tmp.exponent = load!(typeof(tmp.exponent))(read(tmp.exponent.sizeof));
136
137 return tmp.f;
138 }
139 }
140 }
141 default: break;
142 }
143 rollback(0, T.stringof, cast(Format)header);
144 }
145 assert(0, "Unsupported type");
146 }
147
148 void rollback(size_t size, string expected, Format actual = Format.NONE) {
149 import std.conv : text;
150 pos -= size + 1;
151 throw new MessagePackException(text("Attempt to unpack with non-compatible type: ",
152 actual ? text("expected = ", expected, ", got = ", actual) : expected));
153 }
154 void check(size_t size = 1) {
155 if(!canRead(size)) throw new UnpackException("Insufficient buffer");
156 }
157 }
158
159 T unpack(T)(T defValue) nothrow
160 if (is(Unqual!T == enum) || isPointer!T || isTuple!T || isSomeChar!T || isNumeric!T || is(Unqual!T == bool))
161 {
162 static if (is(Unqual!T == enum))
163 return cast(T)unpack(cast(OriginalType!T)defValue);
164 else static if (isPointer!T) {
165 T val;
166 return unpackNil(val) ? val : defValue;
167 } else static if (isTuple!T) {
168 T val;
169 unpackArray!(T.Types)(val.field);
170 return val;
171 } else static if (is(Unqual!T == char))
172 return cast(T)unpack(cast(ubyte)defValue);
173 else static if (is(Unqual!T == wchar))
174 return cast(T)unpack(cast(ushort)defValue);
175 else static if (is(Unqual!T == dchar))
176 return cast(T)unpack(cast(uint)defValue);
177 else static if (isNumeric!T || is(Unqual!T == bool)) {
178 if(!canRead) return defValue;
179 int header = read();
180 static if (isIntegral!T) {
181 if (header <= 0x7f)
182 return cast(T)header;
183 }
184 switch (header) {
185 static if (is(Unqual!T == bool)) {
186 case Format.TRUE:
187 return true;
188 case Format.FALSE:
189 return false;
190 } else static if (isIntegral!T) {
191 case Format.UINT8:
192 if(!canRead(ubyte.sizeof)) return defValue;
193 return read();
194 case Format.UINT16:
195 if(!canRead(ushort.sizeof)) return defValue;
196 auto val = load!ushort(read(ushort.sizeof));
197 if (val > T.max)
198 return defValue;
199 return cast(T)val;
200 case Format.UINT32:
201 if(!canRead(uint.sizeof)) return defValue;
202 auto val = load!uint(read(uint.sizeof));
203 if (val > T.max)
204 return defValue;
205 return cast(T)val;
206 case Format.UINT64:
207 if(!canRead(ulong.sizeof)) return defValue;
208 auto val = load!ulong(read(ulong.sizeof));
209 if (val > T.max)
210 return defValue;
211 return cast(T)val;
212 case Format.INT8:
213 if(!canRead(byte.sizeof)) return defValue;
214 return cast(T)read();
215 case Format.INT16:
216 if(!canRead(short.sizeof)) return defValue;
217 auto val = load!short(read(short.sizeof));
218 if (val < T.min || T.max < val)
219 return defValue;
220 return cast(T)val;
221 case Format.INT32:
222 if(!canRead(int.sizeof)) return defValue;
223 auto val = load!int(read(int.sizeof));
224 if (val < T.min || T.max < val)
225 return defValue;
226 return cast(T)val;
227 case Format.INT64:
228 if(!canRead(long.sizeof)) return defValue;
229 auto val = load!long(read(long.sizeof));
230 if (val < T.min || T.max < val)
231 return defValue;
232 return cast(T)val;
233 } else static if (isFloatingPoint!T) {
234 case Format.FLOAT:
235 _f val;
236 if(!canRead(uint.sizeof)) return defValue;
237 val.i = load!uint(read(uint.sizeof));
238 return val.f;
239 case Format.DOUBLE:
240 // check precision loss
241 static if (is(Unqual!T == float))
242 return defValue;
243 else {
244 _d val = void;
245 if(!canRead(ulong.sizeof)) return defValue;
246 val.i = load!ulong(read(ulong.sizeof));
247 return val.f;
248 }
249 case Format.REAL:
250 static if (!EnableReal) return defValue;
251 else
252 {
253 // check precision loss
254 static if (is(Unqual!T == float) || is(Unqual!T == double))
255 return defValue;
256 if(!canRead(real.sizeof)) return defValue;
257 version (NonX86)
258 {
259 CustomFloat!80 tmp;
260
261 const frac = load!ulong (read(ulong.sizeof));
262 const exp = load!ushort(read(ushort.sizeof));
263
264 tmp.significand = frac;
265 tmp.exponent = exp & 0x7fff;
266 tmp.sign = (exp & 0x8000) != 0;
267
268 // NOTE: tmp.get!real is inf on non-x86 when deserialized value is larger than double.max.
269 return tmp.get!real;
270 }
271 else
272 {
273 _r tmp = void;
274
275 tmp.fraction = load!(typeof(tmp.fraction))(read(tmp.fraction.sizeof));
276 tmp.exponent = load!(typeof(tmp.exponent))(read(tmp.exponent.sizeof));
277
278 return tmp.f;
279 }
280 }
281 }
282 default: return defValue;
283 }
284 }
285 }
286
287 /// ditto
288 ref typeof(this) unpack(Types...)(ref Types objects) if (Types.length > 1)
289 {
290 foreach (i, T; Types)
291 objects[i] = unpack!T;
292 return this;
293 }
294
295 T unpack(T)() if (isSomeArray!T)
296 {
297 if (checkNil()) {
298 static if (isStaticArray!T) {
299 pos++;
300 rollback(0, "static array", Format.NIL);
301 }
302 else {
303 T array = void;
304 unpackNil(array);
305 return array;
306 }
307 }
308 alias U = typeof(T.init[0]);
309 enum RawBytes = isRawByte!U;
310 static if (RawBytes)
311 auto length = beginRaw();
312 else
313 auto length = beginArray();
314 version(betterC) {} else {
315 static if (__traits(compiles, buf.length))
316 if (pos + length > buf.length) {
317 import std.conv: text;
318 throw new MessagePackException(text("Invalid array size in byte stream: Length (", length,
319 ") is larger than internal buffer size (", buf.length, ")"));
320 }
321 }
322 static if (isStaticArray!T)
323 T array = void;
324 else {
325 import std.array;
326 T array = uninitializedArray!T(length);
327 }
328 if (length == 0)
329 return array;
330 static if (RawBytes) {
331 auto offset = calculateSize!(true)(length);
332 check(length + offset);
333 static if (isStaticArray!T)
334 array = (cast(U[])read(length))[0 .. T.length];
335 else
336 array = cast(T)read(length);
337 } else
338 foreach (ref a; array)
339 a = unpack!U;
340 return array;
341 }
342
343 bool unpack(T)(ref T array) nothrow if (isSomeArray!T)
344 {
345 import std.array;
346
347 alias U = typeof(T.init[0]);
348 const spos = pos;
349 if (checkNil()) {
350 static if (isStaticArray!T)
351 return false;
352 else
353 return unpackNil(array);
354 }
355 if (!canRead)
356 return false;
357
358 enum RawBytes = isRawByte!U;
359 static if (RawBytes)
360 auto length = beginRaw();
361 else
362 auto length = beginArray();
363 version(betterC) {} else {
364 static if (__traits(compiles, buf.length))
365 if (length > buf.length) {
366 pos = spos;
367 return false;
368 }
369 }
370 if (length == 0)
371 return true;
372 static if (!isStaticArray!T)
373 if (array.length != length)
374 array = uninitializedArray!T(length);
375 static if (RawBytes) {
376 auto offset = calculateSize!(true)(length);
377 if(!canRead(length + offset)) {
378 pos = spos;
379 return false;
380 }
381 static if (isStaticArray!T)
382 array = (cast(U[])read(length))[0 .. T.length];
383 else
384 array = cast(T)read(length);
385 } else
386 foreach (ref a; array)
387 a = unpack!U;
388 return true;
389 }
390
391 /// ditto
392 T unpack(T)() if (isAssociativeArray!T)
393 {
394 alias K = typeof(T.init.keys[0]),
395 V = typeof(T.init.values[0]);
396 T array;
397
398 if (unpackNil(array))
399 return array;
400
401 auto length = beginMap();
402 if (length == 0)
403 return array;
404
405 foreach (i; 0..length) {
406 K k = unpack!K;
407 array[k] = unpack!V;
408 }
409
410 return array;
411 }
412
413 /**
414 * Deserializes the container object and assigns to each argument.
415 *
416 * These methods check the length. Do rollback if
417 * the length of arguments is different from length of deserialized object.
418 *
419 * In unpackMap, the number of arguments must be even.
420 *
421 * Params:
422 * objects = the references of object to assign.
423 *
424 * Returns: true if succeed
425 */
426 bool unpackArray(Types...)(ref Types objects) nothrow if (Types.length > 1)
427 {
428 auto length = beginArray();
429 const spos = pos;
430 if (length != Types.length) {
431 //the number of deserialized objects is mismatched
432 pos = spos;
433 return false;
434 }
435
436 foreach (i, T; Types)
437 try {
438 objects[i] = unpack!T;
439 } catch (Exception e) {
440 pos = spos;
441 return false;
442 }
443 // unpack(objects); // slow :(
444
445 return true;
446 }
447
448 /// ditto
449 bool unpackMap(Types...)(ref Types objects) nothrow if (Types.length > 1)
450 {
451 static assert(Types.length % 2 == 0, "The number of arguments must be even");
452
453 auto length = beginMap();
454 const spos = pos;
455 if (length != Types.length >> 1) {
456 // the number of deserialized objects is mismatched
457 pos = spos;
458 return false;
459 }
460
461 foreach (i, T; Types)
462 try {
463 objects[i] = unpack!T;
464 } catch (Exception e) {
465 pos = spos;
466 return false;
467 }
468
469 return this;
470 }
471
472 /// ditto
473 bool unpackAA(K, V)(K[V] array) nothrow
474 {
475 if (unpackNil(array))
476 return true;
477
478 auto length = beginMap();
479 if (length == 0)
480 return true;
481
482 foreach (i; 0..length) {
483 try {
484 K k = unpack!K;
485 array[k] = unpack!V;
486 } catch (Exception e) {
487 return false;
488 }
489 }
490
491 return true;
492 }
493
494 size_t begin(int s1 = 0xa0, int s2 = 0xbf, Format f = Format.ARRAY16)() nothrow
495 {
496 enum Raw = s1 == 0xa0 && s2 == 0xbf;
497 if(!canRead) return 0;
498 int header = read();
499
500 if (s1 <= header && header <= s2)
501 return header & (s2 - s1);
502 switch (header) {
503 static if(Raw) {
504 case Format.BIN8, Format.STR8:
505 if(!canRead(ubyte.sizeof)) return 0;
506 return read();
507 case Format.BIN16, Format.RAW16:
508 } else {
509 case f:
510 }
511 if(!canRead(ushort.sizeof)) return 0;
512 return load!ushort(read(ushort.sizeof));
513 static if(Raw) {
514 case Format.BIN32, Format.RAW32:
515 } else {
516 case cast(Format)(f + 1):
517 }
518 if(!canRead(uint.sizeof)) return 0;
519 return load!uint(read(uint.sizeof));
520 case Format.NIL:
521 return 0;
522 default:
523 pos--;
524 import std.conv: text;
525 assert(0, text("Attempt to unpack with non-compatible type: expected = ",
526 f.stringof, ", got = ", header));
527 }
528 }
529
530 /*
531 * Deserializes type-information of raw type.
532 */
533 alias beginRaw = begin!();
534 /**
535 * Deserializes the type-information of container.
536 *
537 * These methods don't deserialize contents.
538 * You need to call unpack method to deserialize contents at your own risk.
539 *
540 * Returns:
541 * the container size.
542 */
543 alias beginArray = begin!(0x90, 0x9f);
544
545 /// ditto
546 alias beginMap = begin!(0x80, 0x8f, Format.MAP16);
547
548 version(NoPackingStruct) {}
549 else {
550 T unpack(T)() if (is(Unqual!T == struct)) {
551 T val;
552 if (auto len = beginArray()) {
553 if (len != NumOfSerializingMembers!T)
554 rollback(calculateSize(len), "the number of struct fields is mismatched");
555
556 foreach (i, ref member; val.tupleof)
557 static if (isPackedField!(T.tupleof[i]))
558 member = unpack!(typeof(member));
559 }
560 return val;
561 }
562
563 bool unpackObj(T)(ref T obj) if (is(Unqual!T == struct)) {
564 auto len = beginArray();
565 if (len == 0)
566 return true;
567 if (len != NumOfSerializingMembers!T) {
568 pos -= calculateSize(len) + 1;
569 return false; // the number of struct fields is mismatched
570 }
571
572 foreach (i, ref member; obj.tupleof)
573 static if (isPackedField!(T.tupleof[i]))
574 member = unpack!(typeof(member));
575 return true;
576 }
577 }
578
579 nothrow:
580
581 /**
582 * Unpacks an EXT value into $(D type) and $(D data).
583 * Returns: true if succeed
584 */
585 bool unpackExt(T)(ref byte type, ref T data) if(isOutputBuffer!(T, ubyte))
586 {
587 if(!canRead) return false;
588 int header = read();
589 import std.conv : text;
590
591 uint len = void;
592 uint rollbackLen = 1;
593 if (header >= Format.EXT && header <= Format.EXT + 4) {
594 // Fixed
595 len = 1 << (header - Format.EXT);
596 } else
597 // Dynamic length
598 switch (header)
599 {
600 case Format.EXT8:
601 if(!canRead(1 + 1)) {
602 pos--;
603 return false;
604 }
605 len = read();
606 rollbackLen++;
607 break;
608 case Format.EXT16:
609 if(!canRead(2 + 1)) {
610 pos--;
611 return false;
612 }
613 len = load!ushort(read(2));
614 rollbackLen += 2;
615 break;
616 case Format.EXT32:
617 if(!canRead(4 + 1)) {
618 pos--;
619 return false;
620 }
621 len = load!uint(read(4));
622 rollbackLen += 4;
623 break;
624 default:
625 pos--;
626 return false;
627 }
628
629 if(!canRead(len + 1)) {
630 pos -= rollbackLen;
631 return false;
632 }
633
634 // Read type
635 type = read();
636 // Read and check data
637 AOutputBuf!T(data) ~= read(len);
638 return true;
639 }
640
641 /*
642 * Reads value from buffer and advances offset.
643 */
644 ubyte read()
645 {
646 return buf[pos++];
647 }
648
649 auto read(size_t size)
650 {
651 auto result = buf[pos..pos+size];
652 pos += size;
653 return result;
654 }
655
656 /*
657 * Reading test.
658 *
659 * Params:
660 * size = the size to read.
661 */
662 bool canRead(size_t size = 1) const
663 {
664 static if (__traits(compiles, buf.length))
665 return pos + size <= buf.length;
666 else
667 return true;
668 }
669
670 /*
671 * Next object is nil?
672 *
673 * Returns:
674 * true if next object is nil.
675 */
676 bool checkNil()
677 {
678 return canRead && buf[pos] == Format.NIL;
679 }
680
681 /*
682 * Deserializes nil object and assigns to $(D_PARAM value).
683 *
684 * Params:
685 * value = the reference of value to assign.
686 *
687 * Returns:
688 * true if next object is nil.
689 */
690 bool unpackNil(T)(ref T value)
691 {
692 if(!canRead)
693 return false;
694
695 if (buf[pos] == Format.NIL) {
696 value = null;
697 pos++;
698 return true;
699 }
700 return false;
701 }
702
703 /*
704 * Loads $(D_PARAM T) type value from $(D_PARAM buffer).
705 *
706 * Params:
707 * buffer = the serialized contents.
708 *
709 * Returns:
710 * the Endian-converted value.
711 */
712 package T load(T)(in ubyte[] buf)
713 {
714 static if (isIntegral!T && T.sizeof == 2)
715 enum bit = 16;
716 else static if (isIntegral!T && T.sizeof == 4)
717 enum bit = 32;
718 else static if (isIntegral!T && T.sizeof == 8)
719 enum bit = 64;
720 else static assert(0, "Unsupported type");
721 return convertEndianTo!bit(*cast(const T*)buf.ptr);
722 }
723 }
724
725 version(unittest)
726 import
727 lmpl4d.packer,
728 std.exception;
729
730 unittest
731 {
732 { // unique
733 mixin DefinePacker;
734
735 auto test = tuple(true, false);
736
737 packer.pack(test);
738
739 mixin TestUnpacker;
740 }
741 { // uint *
742 mixin DefinePacker;
743
744 auto test = tuple(ubyte.max, ushort.max, uint.max, ulong.max);
745 packer.pack(test);
746
747 mixin TestUnpacker;
748 }
749 { // int *
750 mixin DefinePacker;
751
752 auto test = tuple(byte.min, short.min, int.min, long.min);
753
754 packer.pack(test);
755
756 mixin TestUnpacker;
757 }
758 }
759 unittest
760 {
761 { // floating point
762 mixin DefinePacker;
763
764 static if (real.sizeof == double.sizeof || !EnableReal)
765 {
766 alias R = double;
767 }
768 else
769 {
770 alias R = real;
771 }
772 Tuple!(float, double, R) test = tuple(float.min_normal, double.max, cast(real)R.min_normal);
773
774 packer.pack(test);
775
776 mixin TestUnpacker;
777 }
778 { // enum
779 enum : float { D = 0.5 }
780 enum E : ulong { U = 100 }
781
782 mixin DefinePacker;
783
784 float f = D, resultF;
785 E e = E.U, resultE;
786
787 packer.pack(D, e);
788
789 mixin TestUnpacker;
790
791 unpacker.unpack(resultF, resultE);
792 assert(f == resultF);
793 assert(e == resultE);
794 }
795 }
796
797 version(NoPackingStruct) {}
798 else unittest
799 {
800 struct Test
801 {
802 string f1;
803 @nonPacked int f2;
804 }
805
806 mixin DefinePacker;
807
808 Test s = Test("foo", 10), r;
809
810 auto buf = packer.pack(s).buf[];
811 auto unpacker = Unpacker!()(buf);
812 r = unpacker.unpack!Test;
813 assert(s.f1 == r.f1);
814 assert(s.f2 != r.f2);
815 assert(r.f2 == int.init);
816
817 auto arr2 = Array!ubyte();
818 auto packer2 = Packer!(Array!ubyte)(arr2);
819 assert(packer2.pack(Test.init).buf.length < buf.length);
820 }
821
822 unittest
823 {
824 import std.conv : text;
825 { // container
826 mixin DefinePacker;
827
828 Tuple!(ulong[], int[uint], string, bool[2], char[2]) test =
829 tuple([1UL, 2], [3U:4, 5:6, 7:8], "MessagePack", [true, false], "D!");
830
831 packer.pack(test);
832
833 mixin TestUnpacker;
834 }
835 { // ext
836
837 // Try a variety of lengths, making sure to hit all the fixexts
838 foreach (L; AliasSeq!(1, 2, 3, 4, 5, 8, 9, 16, 32, 512, 2^^16))
839 {
840 mixin DefinePacker;
841
842 auto data = new ubyte[L];
843 data.fillData;
844 packer.packExt(7, data);
845
846 auto unpacker = Unpacker!()(packer.buf[]);
847 byte type;
848 ubyte[] deserializedData;
849
850 assert(unpacker.unpackExt(type, deserializedData));
851 assert(type == 7, text("type: ", type));
852 assert(data == deserializedData, text(data, "\nExpected: ", deserializedData));
853 }
854 }
855 }