123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339 |
- package msgpack
- import (
- "errors"
- "fmt"
- "reflect"
- "github.com/vmihailenco/msgpack/codes"
- )
- const mapElemsAllocLimit = 1e4
- var mapStringStringPtrType = reflect.TypeOf((*map[string]string)(nil))
- var mapStringStringType = mapStringStringPtrType.Elem()
- var mapStringInterfacePtrType = reflect.TypeOf((*map[string]interface{})(nil))
- var mapStringInterfaceType = mapStringInterfacePtrType.Elem()
- var errInvalidCode = errors.New("invalid code")
- func decodeMapValue(d *Decoder, v reflect.Value) error {
- size, err := d.DecodeMapLen()
- if err != nil {
- return err
- }
- typ := v.Type()
- if size == -1 {
- v.Set(reflect.Zero(typ))
- return nil
- }
- if v.IsNil() {
- v.Set(reflect.MakeMap(typ))
- }
- if size == 0 {
- return nil
- }
- return decodeMapValueSize(d, v, size)
- }
- func decodeMapValueSize(d *Decoder, v reflect.Value, size int) error {
- typ := v.Type()
- keyType := typ.Key()
- valueType := typ.Elem()
- for i := 0; i < size; i++ {
- mk := reflect.New(keyType).Elem()
- if err := d.DecodeValue(mk); err != nil {
- return err
- }
- mv := reflect.New(valueType).Elem()
- if err := d.DecodeValue(mv); err != nil {
- return err
- }
- v.SetMapIndex(mk, mv)
- }
- return nil
- }
- // DecodeMapLen decodes map length. Length is -1 when map is nil.
- func (d *Decoder) DecodeMapLen() (int, error) {
- c, err := d.readCode()
- if err != nil {
- return 0, err
- }
- if codes.IsExt(c) {
- if err = d.skipExtHeader(c); err != nil {
- return 0, err
- }
- c, err = d.readCode()
- if err != nil {
- return 0, err
- }
- }
- return d.mapLen(c)
- }
- func (d *Decoder) mapLen(c codes.Code) (int, error) {
- size, err := d._mapLen(c)
- err = expandInvalidCodeMapLenError(c, err)
- return size, err
- }
- func (d *Decoder) _mapLen(c codes.Code) (int, error) {
- if c == codes.Nil {
- return -1, nil
- }
- if c >= codes.FixedMapLow && c <= codes.FixedMapHigh {
- return int(c & codes.FixedMapMask), nil
- }
- if c == codes.Map16 {
- size, err := d.uint16()
- return int(size), err
- }
- if c == codes.Map32 {
- size, err := d.uint32()
- return int(size), err
- }
- return 0, errInvalidCode
- }
- func expandInvalidCodeMapLenError(c codes.Code, err error) error {
- if err == errInvalidCode {
- return fmt.Errorf("msgpack: invalid code=%x decoding map length", c)
- }
- return err
- }
- func decodeMapStringStringValue(d *Decoder, v reflect.Value) error {
- mptr := v.Addr().Convert(mapStringStringPtrType).Interface().(*map[string]string)
- return d.decodeMapStringStringPtr(mptr)
- }
- func (d *Decoder) decodeMapStringStringPtr(ptr *map[string]string) error {
- size, err := d.DecodeMapLen()
- if err != nil {
- return err
- }
- if size == -1 {
- *ptr = nil
- return nil
- }
- m := *ptr
- if m == nil {
- *ptr = make(map[string]string, min(size, mapElemsAllocLimit))
- m = *ptr
- }
- for i := 0; i < size; i++ {
- mk, err := d.DecodeString()
- if err != nil {
- return err
- }
- mv, err := d.DecodeString()
- if err != nil {
- return err
- }
- m[mk] = mv
- }
- return nil
- }
- func decodeMapStringInterfaceValue(d *Decoder, v reflect.Value) error {
- ptr := v.Addr().Convert(mapStringInterfacePtrType).Interface().(*map[string]interface{})
- return d.decodeMapStringInterfacePtr(ptr)
- }
- func (d *Decoder) decodeMapStringInterfacePtr(ptr *map[string]interface{}) error {
- n, err := d.DecodeMapLen()
- if err != nil {
- return err
- }
- if n == -1 {
- *ptr = nil
- return nil
- }
- m := *ptr
- if m == nil {
- *ptr = make(map[string]interface{}, min(n, mapElemsAllocLimit))
- m = *ptr
- }
- for i := 0; i < n; i++ {
- mk, err := d.DecodeString()
- if err != nil {
- return err
- }
- mv, err := d.decodeInterfaceCond()
- if err != nil {
- return err
- }
- m[mk] = mv
- }
- return nil
- }
- func (d *Decoder) DecodeMap() (interface{}, error) {
- if d.decodeMapFunc != nil {
- return d.decodeMapFunc(d)
- }
- size, err := d.DecodeMapLen()
- if err != nil {
- return nil, err
- }
- if size == -1 {
- return nil, nil
- }
- if size == 0 {
- return make(map[string]interface{}), nil
- }
- code, err := d.PeekCode()
- if err != nil {
- return nil, err
- }
- if codes.IsString(code) || codes.IsBin(code) {
- return d.decodeMapStringInterfaceSize(size)
- }
- key, err := d.decodeInterfaceCond()
- if err != nil {
- return nil, err
- }
- value, err := d.decodeInterfaceCond()
- if err != nil {
- return nil, err
- }
- keyType := reflect.TypeOf(key)
- valueType := reflect.TypeOf(value)
- mapType := reflect.MapOf(keyType, valueType)
- mapValue := reflect.MakeMap(mapType)
- mapValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(value))
- size--
- err = decodeMapValueSize(d, mapValue, size)
- if err != nil {
- return nil, err
- }
- return mapValue.Interface(), nil
- }
- func (d *Decoder) decodeMapStringInterfaceSize(size int) (map[string]interface{}, error) {
- m := make(map[string]interface{}, min(size, mapElemsAllocLimit))
- for i := 0; i < size; i++ {
- mk, err := d.DecodeString()
- if err != nil {
- return nil, err
- }
- mv, err := d.decodeInterfaceCond()
- if err != nil {
- return nil, err
- }
- m[mk] = mv
- }
- return m, nil
- }
- func (d *Decoder) skipMap(c codes.Code) error {
- n, err := d.mapLen(c)
- if err != nil {
- return err
- }
- for i := 0; i < n; i++ {
- if err := d.Skip(); err != nil {
- return err
- }
- if err := d.Skip(); err != nil {
- return err
- }
- }
- return nil
- }
- func decodeStructValue(d *Decoder, v reflect.Value) error {
- c, err := d.readCode()
- if err != nil {
- return err
- }
- var isArray bool
- n, err := d._mapLen(c)
- if err != nil {
- var err2 error
- n, err2 = d.arrayLen(c)
- if err2 != nil {
- return expandInvalidCodeMapLenError(c, err)
- }
- isArray = true
- }
- if n == -1 {
- if err = mustSet(v); err != nil {
- return err
- }
- v.Set(reflect.Zero(v.Type()))
- return nil
- }
- var fields *fields
- if d.useJSONTag {
- fields = jsonStructs.Fields(v.Type())
- } else {
- fields = structs.Fields(v.Type())
- }
- if isArray {
- for i, f := range fields.List {
- if i >= n {
- break
- }
- if err := f.DecodeValue(d, v); err != nil {
- return err
- }
- }
- // Skip extra values.
- for i := len(fields.List); i < n; i++ {
- if err := d.Skip(); err != nil {
- return err
- }
- }
- return nil
- }
- for i := 0; i < n; i++ {
- name, err := d.DecodeString()
- if err != nil {
- return err
- }
- if f := fields.Table[name]; f != nil {
- if err := f.DecodeValue(d, v); err != nil {
- return err
- }
- } else {
- if err := d.Skip(); err != nil {
- return err
- }
- }
- }
- return nil
- }
|