fields.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package gstruct
  2. import (
  3. "errors"
  4. "fmt"
  5. "reflect"
  6. "runtime/debug"
  7. "strings"
  8. "github.com/onsi/gomega/format"
  9. errorsutil "github.com/onsi/gomega/gstruct/errors"
  10. "github.com/onsi/gomega/types"
  11. )
  12. //MatchAllFields succeeds if every field of a struct matches the field matcher associated with
  13. //it, and every element matcher is matched.
  14. // actual := struct{
  15. // A int
  16. // B []bool
  17. // C string
  18. // }{
  19. // A: 5,
  20. // B: []bool{true, false},
  21. // C: "foo",
  22. // }
  23. //
  24. // Expect(actual).To(MatchAllFields(Fields{
  25. // "A": Equal(5),
  26. // "B": ConsistOf(true, false),
  27. // "C": Equal("foo"),
  28. // }))
  29. func MatchAllFields(fields Fields) types.GomegaMatcher {
  30. return &FieldsMatcher{
  31. Fields: fields,
  32. }
  33. }
  34. //MatchFields succeeds if each element of a struct matches the field matcher associated with
  35. //it. It can ignore extra fields and/or missing fields.
  36. // actual := struct{
  37. // A int
  38. // B []bool
  39. // C string
  40. // }{
  41. // A: 5,
  42. // B: []bool{true, false},
  43. // C: "foo",
  44. // }
  45. //
  46. // Expect(actual).To(MatchFields(IgnoreExtras, Fields{
  47. // "A": Equal(5),
  48. // "B": ConsistOf(true, false),
  49. // }))
  50. // Expect(actual).To(MatchFields(IgnoreMissing, Fields{
  51. // "A": Equal(5),
  52. // "B": ConsistOf(true, false),
  53. // "C": Equal("foo"),
  54. // "D": Equal("extra"),
  55. // }))
  56. func MatchFields(options Options, fields Fields) types.GomegaMatcher {
  57. return &FieldsMatcher{
  58. Fields: fields,
  59. IgnoreExtras: options&IgnoreExtras != 0,
  60. IgnoreMissing: options&IgnoreMissing != 0,
  61. }
  62. }
  63. type FieldsMatcher struct {
  64. // Matchers for each field.
  65. Fields Fields
  66. // Whether to ignore extra elements or consider it an error.
  67. IgnoreExtras bool
  68. // Whether to ignore missing elements or consider it an error.
  69. IgnoreMissing bool
  70. // State.
  71. failures []error
  72. }
  73. // Field name to matcher.
  74. type Fields map[string]types.GomegaMatcher
  75. func (m *FieldsMatcher) Match(actual interface{}) (success bool, err error) {
  76. if reflect.TypeOf(actual).Kind() != reflect.Struct {
  77. return false, fmt.Errorf("%v is type %T, expected struct", actual, actual)
  78. }
  79. m.failures = m.matchFields(actual)
  80. if len(m.failures) > 0 {
  81. return false, nil
  82. }
  83. return true, nil
  84. }
  85. func (m *FieldsMatcher) matchFields(actual interface{}) (errs []error) {
  86. val := reflect.ValueOf(actual)
  87. typ := val.Type()
  88. fields := map[string]bool{}
  89. for i := 0; i < val.NumField(); i++ {
  90. fieldName := typ.Field(i).Name
  91. fields[fieldName] = true
  92. err := func() (err error) {
  93. // This test relies heavily on reflect, which tends to panic.
  94. // Recover here to provide more useful error messages in that case.
  95. defer func() {
  96. if r := recover(); r != nil {
  97. err = fmt.Errorf("panic checking %+v: %v\n%s", actual, r, debug.Stack())
  98. }
  99. }()
  100. matcher, expected := m.Fields[fieldName]
  101. if !expected {
  102. if !m.IgnoreExtras {
  103. return fmt.Errorf("unexpected field %s: %+v", fieldName, actual)
  104. }
  105. return nil
  106. }
  107. field := val.Field(i).Interface()
  108. match, err := matcher.Match(field)
  109. if err != nil {
  110. return err
  111. } else if !match {
  112. if nesting, ok := matcher.(errorsutil.NestingMatcher); ok {
  113. return errorsutil.AggregateError(nesting.Failures())
  114. }
  115. return errors.New(matcher.FailureMessage(field))
  116. }
  117. return nil
  118. }()
  119. if err != nil {
  120. errs = append(errs, errorsutil.Nest("."+fieldName, err))
  121. }
  122. }
  123. for field := range m.Fields {
  124. if !fields[field] && !m.IgnoreMissing {
  125. errs = append(errs, fmt.Errorf("missing expected field %s", field))
  126. }
  127. }
  128. return errs
  129. }
  130. func (m *FieldsMatcher) FailureMessage(actual interface{}) (message string) {
  131. failures := make([]string, len(m.failures))
  132. for i := range m.failures {
  133. failures[i] = m.failures[i].Error()
  134. }
  135. return format.Message(reflect.TypeOf(actual).Name(),
  136. fmt.Sprintf("to match fields: {\n%v\n}\n", strings.Join(failures, "\n")))
  137. }
  138. func (m *FieldsMatcher) NegatedFailureMessage(actual interface{}) (message string) {
  139. return format.Message(actual, "not to match fields")
  140. }
  141. func (m *FieldsMatcher) Failures() []error {
  142. return m.failures
  143. }