decode_map.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. package msgpack
  2. import (
  3. "errors"
  4. "fmt"
  5. "reflect"
  6. "github.com/vmihailenco/msgpack/codes"
  7. )
  8. const mapElemsAllocLimit = 1e4
  9. var mapStringStringPtrType = reflect.TypeOf((*map[string]string)(nil))
  10. var mapStringStringType = mapStringStringPtrType.Elem()
  11. var mapStringInterfacePtrType = reflect.TypeOf((*map[string]interface{})(nil))
  12. var mapStringInterfaceType = mapStringInterfacePtrType.Elem()
  13. var errInvalidCode = errors.New("invalid code")
  14. func decodeMapValue(d *Decoder, v reflect.Value) error {
  15. size, err := d.DecodeMapLen()
  16. if err != nil {
  17. return err
  18. }
  19. typ := v.Type()
  20. if size == -1 {
  21. v.Set(reflect.Zero(typ))
  22. return nil
  23. }
  24. if v.IsNil() {
  25. v.Set(reflect.MakeMap(typ))
  26. }
  27. if size == 0 {
  28. return nil
  29. }
  30. return decodeMapValueSize(d, v, size)
  31. }
  32. func decodeMapValueSize(d *Decoder, v reflect.Value, size int) error {
  33. typ := v.Type()
  34. keyType := typ.Key()
  35. valueType := typ.Elem()
  36. for i := 0; i < size; i++ {
  37. mk := reflect.New(keyType).Elem()
  38. if err := d.DecodeValue(mk); err != nil {
  39. return err
  40. }
  41. mv := reflect.New(valueType).Elem()
  42. if err := d.DecodeValue(mv); err != nil {
  43. return err
  44. }
  45. v.SetMapIndex(mk, mv)
  46. }
  47. return nil
  48. }
  49. // DecodeMapLen decodes map length. Length is -1 when map is nil.
  50. func (d *Decoder) DecodeMapLen() (int, error) {
  51. c, err := d.readCode()
  52. if err != nil {
  53. return 0, err
  54. }
  55. if codes.IsExt(c) {
  56. if err = d.skipExtHeader(c); err != nil {
  57. return 0, err
  58. }
  59. c, err = d.readCode()
  60. if err != nil {
  61. return 0, err
  62. }
  63. }
  64. return d.mapLen(c)
  65. }
  66. func (d *Decoder) mapLen(c codes.Code) (int, error) {
  67. size, err := d._mapLen(c)
  68. err = expandInvalidCodeMapLenError(c, err)
  69. return size, err
  70. }
  71. func (d *Decoder) _mapLen(c codes.Code) (int, error) {
  72. if c == codes.Nil {
  73. return -1, nil
  74. }
  75. if c >= codes.FixedMapLow && c <= codes.FixedMapHigh {
  76. return int(c & codes.FixedMapMask), nil
  77. }
  78. if c == codes.Map16 {
  79. size, err := d.uint16()
  80. return int(size), err
  81. }
  82. if c == codes.Map32 {
  83. size, err := d.uint32()
  84. return int(size), err
  85. }
  86. return 0, errInvalidCode
  87. }
  88. func expandInvalidCodeMapLenError(c codes.Code, err error) error {
  89. if err == errInvalidCode {
  90. return fmt.Errorf("msgpack: invalid code=%x decoding map length", c)
  91. }
  92. return err
  93. }
  94. func decodeMapStringStringValue(d *Decoder, v reflect.Value) error {
  95. mptr := v.Addr().Convert(mapStringStringPtrType).Interface().(*map[string]string)
  96. return d.decodeMapStringStringPtr(mptr)
  97. }
  98. func (d *Decoder) decodeMapStringStringPtr(ptr *map[string]string) error {
  99. size, err := d.DecodeMapLen()
  100. if err != nil {
  101. return err
  102. }
  103. if size == -1 {
  104. *ptr = nil
  105. return nil
  106. }
  107. m := *ptr
  108. if m == nil {
  109. *ptr = make(map[string]string, min(size, mapElemsAllocLimit))
  110. m = *ptr
  111. }
  112. for i := 0; i < size; i++ {
  113. mk, err := d.DecodeString()
  114. if err != nil {
  115. return err
  116. }
  117. mv, err := d.DecodeString()
  118. if err != nil {
  119. return err
  120. }
  121. m[mk] = mv
  122. }
  123. return nil
  124. }
  125. func decodeMapStringInterfaceValue(d *Decoder, v reflect.Value) error {
  126. ptr := v.Addr().Convert(mapStringInterfacePtrType).Interface().(*map[string]interface{})
  127. return d.decodeMapStringInterfacePtr(ptr)
  128. }
  129. func (d *Decoder) decodeMapStringInterfacePtr(ptr *map[string]interface{}) error {
  130. n, err := d.DecodeMapLen()
  131. if err != nil {
  132. return err
  133. }
  134. if n == -1 {
  135. *ptr = nil
  136. return nil
  137. }
  138. m := *ptr
  139. if m == nil {
  140. *ptr = make(map[string]interface{}, min(n, mapElemsAllocLimit))
  141. m = *ptr
  142. }
  143. for i := 0; i < n; i++ {
  144. mk, err := d.DecodeString()
  145. if err != nil {
  146. return err
  147. }
  148. mv, err := d.decodeInterfaceCond()
  149. if err != nil {
  150. return err
  151. }
  152. m[mk] = mv
  153. }
  154. return nil
  155. }
  156. func (d *Decoder) DecodeMap() (interface{}, error) {
  157. if d.decodeMapFunc != nil {
  158. return d.decodeMapFunc(d)
  159. }
  160. size, err := d.DecodeMapLen()
  161. if err != nil {
  162. return nil, err
  163. }
  164. if size == -1 {
  165. return nil, nil
  166. }
  167. if size == 0 {
  168. return make(map[string]interface{}), nil
  169. }
  170. code, err := d.PeekCode()
  171. if err != nil {
  172. return nil, err
  173. }
  174. if codes.IsString(code) || codes.IsBin(code) {
  175. return d.decodeMapStringInterfaceSize(size)
  176. }
  177. key, err := d.decodeInterfaceCond()
  178. if err != nil {
  179. return nil, err
  180. }
  181. value, err := d.decodeInterfaceCond()
  182. if err != nil {
  183. return nil, err
  184. }
  185. keyType := reflect.TypeOf(key)
  186. valueType := reflect.TypeOf(value)
  187. mapType := reflect.MapOf(keyType, valueType)
  188. mapValue := reflect.MakeMap(mapType)
  189. mapValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(value))
  190. size--
  191. err = decodeMapValueSize(d, mapValue, size)
  192. if err != nil {
  193. return nil, err
  194. }
  195. return mapValue.Interface(), nil
  196. }
  197. func (d *Decoder) decodeMapStringInterfaceSize(size int) (map[string]interface{}, error) {
  198. m := make(map[string]interface{}, min(size, mapElemsAllocLimit))
  199. for i := 0; i < size; i++ {
  200. mk, err := d.DecodeString()
  201. if err != nil {
  202. return nil, err
  203. }
  204. mv, err := d.decodeInterfaceCond()
  205. if err != nil {
  206. return nil, err
  207. }
  208. m[mk] = mv
  209. }
  210. return m, nil
  211. }
  212. func (d *Decoder) skipMap(c codes.Code) error {
  213. n, err := d.mapLen(c)
  214. if err != nil {
  215. return err
  216. }
  217. for i := 0; i < n; i++ {
  218. if err := d.Skip(); err != nil {
  219. return err
  220. }
  221. if err := d.Skip(); err != nil {
  222. return err
  223. }
  224. }
  225. return nil
  226. }
  227. func decodeStructValue(d *Decoder, v reflect.Value) error {
  228. c, err := d.readCode()
  229. if err != nil {
  230. return err
  231. }
  232. var isArray bool
  233. n, err := d._mapLen(c)
  234. if err != nil {
  235. var err2 error
  236. n, err2 = d.arrayLen(c)
  237. if err2 != nil {
  238. return expandInvalidCodeMapLenError(c, err)
  239. }
  240. isArray = true
  241. }
  242. if n == -1 {
  243. if err = mustSet(v); err != nil {
  244. return err
  245. }
  246. v.Set(reflect.Zero(v.Type()))
  247. return nil
  248. }
  249. var fields *fields
  250. if d.useJSONTag {
  251. fields = jsonStructs.Fields(v.Type())
  252. } else {
  253. fields = structs.Fields(v.Type())
  254. }
  255. if isArray {
  256. for i, f := range fields.List {
  257. if i >= n {
  258. break
  259. }
  260. if err := f.DecodeValue(d, v); err != nil {
  261. return err
  262. }
  263. }
  264. // Skip extra values.
  265. for i := len(fields.List); i < n; i++ {
  266. if err := d.Skip(); err != nil {
  267. return err
  268. }
  269. }
  270. return nil
  271. }
  272. for i := 0; i < n; i++ {
  273. name, err := d.DecodeString()
  274. if err != nil {
  275. return err
  276. }
  277. if f := fields.Table[name]; f != nil {
  278. if err := f.DecodeValue(d, v); err != nil {
  279. return err
  280. }
  281. } else {
  282. if err := d.Skip(); err != nil {
  283. return err
  284. }
  285. }
  286. }
  287. return nil
  288. }