middleware_test.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. package jwt
  2. import (
  3. "context"
  4. "sync"
  5. "testing"
  6. "time"
  7. "crypto/subtle"
  8. jwt "github.com/dgrijalva/jwt-go"
  9. "github.com/go-kit/kit/endpoint"
  10. )
  11. type customClaims struct {
  12. MyProperty string `json:"my_property"`
  13. jwt.StandardClaims
  14. }
  15. func (c customClaims) VerifyMyProperty(p string) bool {
  16. return subtle.ConstantTimeCompare([]byte(c.MyProperty), []byte(p)) != 0
  17. }
  18. var (
  19. kid = "kid"
  20. key = []byte("test_signing_key")
  21. myProperty = "some value"
  22. method = jwt.SigningMethodHS256
  23. invalidMethod = jwt.SigningMethodRS256
  24. mapClaims = jwt.MapClaims{"user": "go-kit"}
  25. standardClaims = jwt.StandardClaims{Audience: "go-kit"}
  26. myCustomClaims = customClaims{MyProperty: myProperty, StandardClaims: standardClaims}
  27. // Signed tokens generated at https://jwt.io/
  28. signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
  29. standardSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJnby1raXQifQ.L5ypIJjCOOv3jJ8G5SelaHvR04UJuxmcBN5QW3m_aoY"
  30. customSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJteV9wcm9wZXJ0eSI6InNvbWUgdmFsdWUiLCJhdWQiOiJnby1raXQifQ.s8F-IDrV4WPJUsqr7qfDi-3GRlcKR0SRnkTeUT_U-i0"
  31. invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA"
  32. malformedKey = "malformed.jwt.token"
  33. )
  34. func signingValidator(t *testing.T, signer endpoint.Endpoint, expectedKey string) {
  35. ctx, err := signer(context.Background(), struct{}{})
  36. if err != nil {
  37. t.Fatalf("Signer returned error: %s", err)
  38. }
  39. token, ok := ctx.(context.Context).Value(JWTTokenContextKey).(string)
  40. if !ok {
  41. t.Fatal("Token did not exist in context")
  42. }
  43. if token != expectedKey {
  44. t.Fatalf("JWT tokens did not match: expecting %s got %s", expectedKey, token)
  45. }
  46. }
  47. func TestNewSigner(t *testing.T) {
  48. e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }
  49. signer := NewSigner(kid, key, method, mapClaims)(e)
  50. signingValidator(t, signer, signedKey)
  51. signer = NewSigner(kid, key, method, standardClaims)(e)
  52. signingValidator(t, signer, standardSignedKey)
  53. signer = NewSigner(kid, key, method, myCustomClaims)(e)
  54. signingValidator(t, signer, customSignedKey)
  55. }
  56. func TestJWTParser(t *testing.T) {
  57. e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }
  58. keys := func(token *jwt.Token) (interface{}, error) {
  59. return key, nil
  60. }
  61. parser := NewParser(keys, method, MapClaimsFactory)(e)
  62. // No Token is passed into the parser
  63. _, err := parser(context.Background(), struct{}{})
  64. if err == nil {
  65. t.Error("Parser should have returned an error")
  66. }
  67. if err != ErrTokenContextMissing {
  68. t.Errorf("unexpected error returned, expected: %s got: %s", ErrTokenContextMissing, err)
  69. }
  70. // Invalid Token is passed into the parser
  71. ctx := context.WithValue(context.Background(), JWTTokenContextKey, invalidKey)
  72. _, err = parser(ctx, struct{}{})
  73. if err == nil {
  74. t.Error("Parser should have returned an error")
  75. }
  76. // Invalid Method is used in the parser
  77. badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e)
  78. ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
  79. _, err = badParser(ctx, struct{}{})
  80. if err == nil {
  81. t.Error("Parser should have returned an error")
  82. }
  83. if err != ErrUnexpectedSigningMethod {
  84. t.Errorf("unexpected error returned, expected: %s got: %s", ErrUnexpectedSigningMethod, err)
  85. }
  86. // Invalid key is used in the parser
  87. invalidKeys := func(token *jwt.Token) (interface{}, error) {
  88. return []byte("bad"), nil
  89. }
  90. badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e)
  91. ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
  92. _, err = badParser(ctx, struct{}{})
  93. if err == nil {
  94. t.Error("Parser should have returned an error")
  95. }
  96. // Correct token is passed into the parser
  97. ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
  98. ctx1, err := parser(ctx, struct{}{})
  99. if err != nil {
  100. t.Fatalf("Parser returned error: %s", err)
  101. }
  102. cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(jwt.MapClaims)
  103. if !ok {
  104. t.Fatal("Claims were not passed into context correctly")
  105. }
  106. if cl["user"] != mapClaims["user"] {
  107. t.Fatalf("JWT Claims.user did not match: expecting %s got %s", mapClaims["user"], cl["user"])
  108. }
  109. // Test for malformed token error response
  110. parser = NewParser(keys, method, StandardClaimsFactory)(e)
  111. ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey)
  112. ctx1, err = parser(ctx, struct{}{})
  113. if want, have := ErrTokenMalformed, err; want != have {
  114. t.Fatalf("Expected %+v, got %+v", want, have)
  115. }
  116. // Test for expired token error response
  117. parser = NewParser(keys, method, StandardClaimsFactory)(e)
  118. expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100})
  119. token, err := expired.SignedString(key)
  120. if err != nil {
  121. t.Fatalf("Unable to Sign Token: %+v", err)
  122. }
  123. ctx = context.WithValue(context.Background(), JWTTokenContextKey, token)
  124. ctx1, err = parser(ctx, struct{}{})
  125. if want, have := ErrTokenExpired, err; want != have {
  126. t.Fatalf("Expected %+v, got %+v", want, have)
  127. }
  128. // Test for not activated token error response
  129. parser = NewParser(keys, method, StandardClaimsFactory)(e)
  130. notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100})
  131. token, err = notactive.SignedString(key)
  132. if err != nil {
  133. t.Fatalf("Unable to Sign Token: %+v", err)
  134. }
  135. ctx = context.WithValue(context.Background(), JWTTokenContextKey, token)
  136. ctx1, err = parser(ctx, struct{}{})
  137. if want, have := ErrTokenNotActive, err; want != have {
  138. t.Fatalf("Expected %+v, got %+v", want, have)
  139. }
  140. // test valid standard claims token
  141. parser = NewParser(keys, method, StandardClaimsFactory)(e)
  142. ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey)
  143. ctx1, err = parser(ctx, struct{}{})
  144. if err != nil {
  145. t.Fatalf("Parser returned error: %s", err)
  146. }
  147. stdCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*jwt.StandardClaims)
  148. if !ok {
  149. t.Fatal("Claims were not passed into context correctly")
  150. }
  151. if !stdCl.VerifyAudience("go-kit", true) {
  152. t.Fatalf("JWT jwt.StandardClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, stdCl.Audience)
  153. }
  154. // test valid customized claims token
  155. parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e)
  156. ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey)
  157. ctx1, err = parser(ctx, struct{}{})
  158. if err != nil {
  159. t.Fatalf("Parser returned error: %s", err)
  160. }
  161. custCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*customClaims)
  162. if !ok {
  163. t.Fatal("Claims were not passed into context correctly")
  164. }
  165. if !custCl.VerifyAudience("go-kit", true) {
  166. t.Fatalf("JWT customClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, custCl.Audience)
  167. }
  168. if !custCl.VerifyMyProperty(myProperty) {
  169. t.Fatalf("JWT customClaims.MyProperty did not match: expecting %s got %s", myProperty, custCl.MyProperty)
  170. }
  171. }
  172. func TestIssue562(t *testing.T) {
  173. var (
  174. kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil }
  175. e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop)
  176. key = JWTTokenContextKey
  177. val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
  178. ctx = context.WithValue(context.Background(), key, val)
  179. )
  180. wg := sync.WaitGroup{}
  181. for i := 0; i < 100; i++ {
  182. wg.Add(1)
  183. go func() {
  184. defer wg.Done()
  185. e(ctx, struct{}{}) // fatal error: concurrent map read and map write
  186. }()
  187. }
  188. wg.Wait()
  189. }