123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- package util
- import (
- "errors"
- "reflect"
- "strconv"
- "unicode"
- )
- const (
- AMF3_UNDEFINED = iota
- AMF3_NULL
- AMF3_FALSE
- AMF3_TRUE
- AMF3_INTEGER
- AMF3_DOUBLE
- AMF3_STRING
- AMF3_XML_DOC
- AMF3_DATE
- AMF3_ARRAY
- AMF3_OBJECT
- AMF3_XML
- AMF3_BYTE_ARRAY
- AMF3_VECTOR_INT
- AMF3_VECTOR_UINT
- AMF3_VECTOR_DOUBLE
- AMF3_VECTOR_OBJECT
- AMF3_DICTIONARY
- )
- type AMF3 struct {
- AMF
- scEnc map[string]int
- scDec []string
- ocEnc map[uintptr]int
- ocDec []any
- reservStruct bool
- }
- func (amf *AMF3) readString() (string, error) {
- index, err := amf.readU29()
- if err != nil {
- return "", err
- }
- ret := ""
- if (index & 0x01) == 0 {
- ret = amf.scDec[int(index>>1)]
- } else {
- index >>= 1
- ret = string(amf.ReadN(int(index)))
- }
- if ret != "" {
- amf.scDec = append(amf.scDec, ret)
- }
- return ret, nil
- }
- func (amf *AMF3) Unmarshal() (obj any, err error) {
- defer func() {
- if e := recover(); e != nil {
- err = errors.New("amf3 unmarshal error")
- }
- }()
- switch amf.ReadByte() {
- case AMF3_NULL:
- return nil, nil
- case AMF3_FALSE:
- return false, nil
- case AMF3_TRUE:
- return true, nil
- case AMF3_INTEGER:
- return amf.readU29()
- case AMF3_DOUBLE:
- return amf.ReadFloat64(), nil
- case AMF3_STRING:
- return amf.readString()
- case AMF3_OBJECT:
- index, err := amf.readU29()
- if err != nil {
- return nil, err
- }
- if (index & 0x01) == 0 {
- return amf.ocDec[int(index>>1)], nil
- }
- if index != 0x0b {
- return nil, errors.New("invalid object type")
- }
- if amf.ReadByte() != 0x01 {
- return nil, errors.New("type object not allowed")
- }
- ret := make(map[string]any)
- amf.ocDec = append(amf.ocDec, ret)
- for {
- key, err := amf.readString()
- if err != nil {
- return nil, err
- }
- if key == "" {
- break
- }
- ret[key], err = amf.Unmarshal()
- if err != nil {
- return nil, err
- }
- }
- return ret, nil
- }
- return nil, errors.New("amf3 unmarshal error")
- }
- func (amf *AMF3) writeString(s string) error {
- index, ok := amf.scEnc[s]
- if ok {
- amf.writeU29(uint32(index << 1))
- return nil
- }
- err := amf.writeU29(uint32((len(s) << 1) | 0x01))
- if err != nil {
- return err
- }
- if s != "" {
- amf.scEnc[s] = len(amf.scEnc)
- }
- amf.WriteString(s)
- return nil
- }
- func (amf *AMF3) readU29() (uint32, error) {
- var ret uint32 = 0
- for i := 0; i < 4; i++ {
- b := amf.ReadByte()
- if i != 3 {
- ret = (ret << 7) | uint32(b&0x7f)
- if (b & 0x80) == 0 {
- break
- }
- } else {
- ret = (ret << 8) | uint32(b)
- }
- }
- return ret, nil
- }
- func (amf *AMF3) writeU29(value uint32) error {
- switch {
- case value < 0x80:
- amf.WriteByte(byte(value))
- case value < 0x4000:
- amf.Write([]byte{byte((value >> 7) | 0x80), byte(value & 0x7f)})
- case value < 0x200000:
- amf.Write([]byte{byte((value >> 14) | 0x80), byte((value >> 7) | 0x80), byte(value & 0x7f)})
- case value < 0x20000000:
- amf.Write([]byte{byte((value >> 22) | 0x80), byte((value >> 15) | 0x80), byte((value >> 7) | 0x80), byte(value & 0xff)})
- default:
- return errors.New("u29 over flow")
- }
- return nil
- }
- func (amf *AMF3) Marshals(v ...any) []byte {
- for _, vv := range v {
- amf.Marshal(vv)
- }
- return amf.Buffer
- }
- func MarshalAMF3s(v ...any) []byte {
- var amf AMF3
- amf.ocEnc = make(map[uintptr]int)
- amf.scEnc = make(map[string]int)
- return amf.Marshals(v...)
- }
- func (amf *AMF3) Marshal(v any) []byte {
- if v == nil {
- amf.WriteByte(AMF3_NULL)
- return amf.Buffer
- }
- switch vv := v.(type) {
- case string:
- amf.WriteByte(AMF3_STRING)
- amf.writeString(vv)
- case bool:
- if vv {
- amf.WriteByte(AMF3_TRUE)
- } else {
- amf.WriteByte(AMF3_FALSE)
- }
- case int, int8, int16, int32, int64:
- var value int64
- reflect.ValueOf(&value).Elem().Set(reflect.ValueOf(vv).Convert(reflect.TypeOf(value)))
- if value < -0xfffffff {
- if value > -0x7fffffff {
- return amf.Marshal(float64(value))
- }
- return amf.Marshal(strconv.FormatInt(value, 10))
- }
- amf.WriteByte(AMF3_INTEGER)
- amf.writeU29(uint32(value))
- case uint, uint8, uint16, uint32, uint64:
- var value uint64
- reflect.ValueOf(&value).Elem().Set(reflect.ValueOf(vv).Convert(reflect.TypeOf(value)))
- if value >= 0x20000000 {
- if value <= 0xffffffff {
- return amf.Marshal(float64(value))
- }
- return amf.Marshal(strconv.FormatUint(value, 10))
- }
- amf.WriteByte(AMF3_INTEGER)
- amf.writeU29(uint32(value))
- case float32:
- amf.Marshal(float64(vv))
- case float64:
- amf.WriteByte(AMF3_DOUBLE)
- amf.WriteFloat64(vv)
- case map[string]any:
- amf.WriteByte(AMF3_OBJECT)
- index, ok := amf.ocEnc[reflect.ValueOf(vv).Pointer()]
- if ok {
- index <<= 1
- amf.writeU29(uint32(index << 1))
- return nil
- }
- amf.WriteByte(0x0b)
- err := amf.writeString("")
- if err != nil {
- return nil
- }
- for k, v := range vv {
- err = amf.writeString(k)
- if err != nil {
- return nil
- }
- amf.Marshal(v)
- }
- amf.writeString("")
- default:
- v := reflect.ValueOf(vv)
- if !v.IsValid() {
- amf.WriteByte(AMF3_NULL)
- return amf.Buffer
- }
- switch v.Kind() {
- case reflect.Ptr:
- if v.IsNil() {
- amf.WriteByte(AMF3_NULL)
- return amf.Buffer
- }
- vv := reflect.Indirect(v)
- if vv.Kind() == reflect.Struct {
- amf.WriteByte(AMF3_OBJECT)
- index, ok := amf.ocEnc[v.Pointer()]
- if ok {
- index <<= 1
- amf.writeU29(uint32(index << 1))
- return nil
- }
- amf.WriteByte(0x0b)
- err := amf.writeString("")
- if err != nil {
- return nil
- }
- t := vv.Type()
- for i := 0; i < t.NumField(); i++ {
- f := t.Field(i)
- key := amf.getFieldName(f)
- if key == "" {
- continue
- }
- err = amf.writeString(key)
- if err != nil {
- return nil
- }
- fv := v.FieldByName(f.Name)
- if fv.Kind() == reflect.Struct {
- fv = fv.Addr()
- }
- amf.Marshal(fv.Interface())
- }
- amf.writeString("")
- }
- }
- }
- return amf.Buffer
- }
- func (amf *AMF3) getFieldName(f reflect.StructField) string {
- chars := []rune(f.Name)
- if unicode.IsLower(chars[0]) {
- return ""
- }
- name := f.Tag.Get("amf.name")
- if name != "" {
- return name
- }
- if !amf.reservStruct {
- chars[0] = unicode.ToLower(chars[0])
- return string(chars)
- }
- return f.Name
- }
|