cors_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. package cors
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "regexp"
  6. "strings"
  7. "testing"
  8. )
  9. var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  10. w.Write([]byte("bar"))
  11. })
  12. var allHeaders = []string{
  13. "Vary",
  14. "Access-Control-Allow-Origin",
  15. "Access-Control-Allow-Methods",
  16. "Access-Control-Allow-Headers",
  17. "Access-Control-Allow-Credentials",
  18. "Access-Control-Max-Age",
  19. "Access-Control-Expose-Headers",
  20. }
  21. func assertHeaders(t *testing.T, resHeaders http.Header, expHeaders map[string]string) {
  22. for _, name := range allHeaders {
  23. got := strings.Join(resHeaders[name], ", ")
  24. want := expHeaders[name]
  25. if got != want {
  26. t.Errorf("Response header %q = %q, want %q", name, got, want)
  27. }
  28. }
  29. }
  30. func assertResponse(t *testing.T, res *httptest.ResponseRecorder, responseCode int) {
  31. if responseCode != res.Code {
  32. t.Errorf("assertResponse: expected response code to be %d but got %d. ", responseCode, res.Code)
  33. }
  34. }
  35. func TestSpec(t *testing.T) {
  36. cases := []struct {
  37. name string
  38. options Options
  39. method string
  40. reqHeaders map[string]string
  41. resHeaders map[string]string
  42. }{
  43. {
  44. "NoConfig",
  45. Options{
  46. // Intentionally left blank.
  47. },
  48. "GET",
  49. map[string]string{},
  50. map[string]string{
  51. "Vary": "Origin",
  52. },
  53. },
  54. {
  55. "MatchAllOrigin",
  56. Options{
  57. AllowedOrigins: []string{"*"},
  58. },
  59. "GET",
  60. map[string]string{
  61. "Origin": "http://foobar.com",
  62. },
  63. map[string]string{
  64. "Vary": "Origin",
  65. "Access-Control-Allow-Origin": "*",
  66. },
  67. },
  68. {
  69. "MatchAllOriginWithCredentials",
  70. Options{
  71. AllowedOrigins: []string{"*"},
  72. AllowCredentials: true,
  73. },
  74. "GET",
  75. map[string]string{
  76. "Origin": "http://foobar.com",
  77. },
  78. map[string]string{
  79. "Vary": "Origin",
  80. "Access-Control-Allow-Origin": "*",
  81. "Access-Control-Allow-Credentials": "true",
  82. },
  83. },
  84. {
  85. "AllowedOrigin",
  86. Options{
  87. AllowedOrigins: []string{"http://foobar.com"},
  88. },
  89. "GET",
  90. map[string]string{
  91. "Origin": "http://foobar.com",
  92. },
  93. map[string]string{
  94. "Vary": "Origin",
  95. "Access-Control-Allow-Origin": "http://foobar.com",
  96. },
  97. },
  98. {
  99. "WildcardOrigin",
  100. Options{
  101. AllowedOrigins: []string{"http://*.bar.com"},
  102. },
  103. "GET",
  104. map[string]string{
  105. "Origin": "http://foo.bar.com",
  106. },
  107. map[string]string{
  108. "Vary": "Origin",
  109. "Access-Control-Allow-Origin": "http://foo.bar.com",
  110. },
  111. },
  112. {
  113. "DisallowedOrigin",
  114. Options{
  115. AllowedOrigins: []string{"http://foobar.com"},
  116. },
  117. "GET",
  118. map[string]string{
  119. "Origin": "http://barbaz.com",
  120. },
  121. map[string]string{
  122. "Vary": "Origin",
  123. },
  124. },
  125. {
  126. "DisallowedWildcardOrigin",
  127. Options{
  128. AllowedOrigins: []string{"http://*.bar.com"},
  129. },
  130. "GET",
  131. map[string]string{
  132. "Origin": "http://foo.baz.com",
  133. },
  134. map[string]string{
  135. "Vary": "Origin",
  136. },
  137. },
  138. {
  139. "AllowedOriginFuncMatch",
  140. Options{
  141. AllowOriginFunc: func(o string) bool {
  142. return regexp.MustCompile("^http://foo").MatchString(o)
  143. },
  144. },
  145. "GET",
  146. map[string]string{
  147. "Origin": "http://foobar.com",
  148. },
  149. map[string]string{
  150. "Vary": "Origin",
  151. "Access-Control-Allow-Origin": "http://foobar.com",
  152. },
  153. },
  154. {
  155. "AllowOriginRequestFuncMatch",
  156. Options{
  157. AllowOriginRequestFunc: func(r *http.Request, o string) bool {
  158. return regexp.MustCompile("^http://foo").MatchString(o) && r.Header.Get("Authorization") == "secret"
  159. },
  160. },
  161. "GET",
  162. map[string]string{
  163. "Origin": "http://foobar.com",
  164. "Authorization": "secret",
  165. },
  166. map[string]string{
  167. "Vary": "Origin",
  168. "Access-Control-Allow-Origin": "http://foobar.com",
  169. },
  170. },
  171. {
  172. "AllowOriginRequestFuncNotMatch",
  173. Options{
  174. AllowOriginRequestFunc: func(r *http.Request, o string) bool {
  175. return regexp.MustCompile("^http://foo").MatchString(o) && r.Header.Get("Authorization") == "secret"
  176. },
  177. },
  178. "GET",
  179. map[string]string{
  180. "Origin": "http://foobar.com",
  181. "Authorization": "not-secret",
  182. },
  183. map[string]string{
  184. "Vary": "Origin",
  185. },
  186. },
  187. {
  188. "MaxAge",
  189. Options{
  190. AllowedOrigins: []string{"http://example.com/"},
  191. AllowedMethods: []string{"GET"},
  192. MaxAge: 10,
  193. },
  194. "OPTIONS",
  195. map[string]string{
  196. "Origin": "http://example.com/",
  197. "Access-Control-Request-Method": "GET",
  198. },
  199. map[string]string{
  200. "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
  201. "Access-Control-Allow-Origin": "http://example.com/",
  202. "Access-Control-Allow-Methods": "GET",
  203. "Access-Control-Max-Age": "10",
  204. },
  205. },
  206. {
  207. "AllowedMethod",
  208. Options{
  209. AllowedOrigins: []string{"http://foobar.com"},
  210. AllowedMethods: []string{"PUT", "DELETE"},
  211. },
  212. "OPTIONS",
  213. map[string]string{
  214. "Origin": "http://foobar.com",
  215. "Access-Control-Request-Method": "PUT",
  216. },
  217. map[string]string{
  218. "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
  219. "Access-Control-Allow-Origin": "http://foobar.com",
  220. "Access-Control-Allow-Methods": "PUT",
  221. },
  222. },
  223. {
  224. "DisallowedMethod",
  225. Options{
  226. AllowedOrigins: []string{"http://foobar.com"},
  227. AllowedMethods: []string{"PUT", "DELETE"},
  228. },
  229. "OPTIONS",
  230. map[string]string{
  231. "Origin": "http://foobar.com",
  232. "Access-Control-Request-Method": "PATCH",
  233. },
  234. map[string]string{
  235. "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
  236. },
  237. },
  238. {
  239. "AllowedHeaders",
  240. Options{
  241. AllowedOrigins: []string{"http://foobar.com"},
  242. AllowedHeaders: []string{"X-Header-1", "x-header-2"},
  243. },
  244. "OPTIONS",
  245. map[string]string{
  246. "Origin": "http://foobar.com",
  247. "Access-Control-Request-Method": "GET",
  248. "Access-Control-Request-Headers": "X-Header-2, X-HEADER-1",
  249. },
  250. map[string]string{
  251. "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
  252. "Access-Control-Allow-Origin": "http://foobar.com",
  253. "Access-Control-Allow-Methods": "GET",
  254. "Access-Control-Allow-Headers": "X-Header-2, X-Header-1",
  255. },
  256. },
  257. {
  258. "DefaultAllowedHeaders",
  259. Options{
  260. AllowedOrigins: []string{"http://foobar.com"},
  261. AllowedHeaders: []string{},
  262. },
  263. "OPTIONS",
  264. map[string]string{
  265. "Origin": "http://foobar.com",
  266. "Access-Control-Request-Method": "GET",
  267. "Access-Control-Request-Headers": "X-Requested-With",
  268. },
  269. map[string]string{
  270. "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
  271. "Access-Control-Allow-Origin": "http://foobar.com",
  272. "Access-Control-Allow-Methods": "GET",
  273. "Access-Control-Allow-Headers": "X-Requested-With",
  274. },
  275. },
  276. {
  277. "AllowedWildcardHeader",
  278. Options{
  279. AllowedOrigins: []string{"http://foobar.com"},
  280. AllowedHeaders: []string{"*"},
  281. },
  282. "OPTIONS",
  283. map[string]string{
  284. "Origin": "http://foobar.com",
  285. "Access-Control-Request-Method": "GET",
  286. "Access-Control-Request-Headers": "X-Header-2, X-HEADER-1",
  287. },
  288. map[string]string{
  289. "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
  290. "Access-Control-Allow-Origin": "http://foobar.com",
  291. "Access-Control-Allow-Methods": "GET",
  292. "Access-Control-Allow-Headers": "X-Header-2, X-Header-1",
  293. },
  294. },
  295. {
  296. "DisallowedHeader",
  297. Options{
  298. AllowedOrigins: []string{"http://foobar.com"},
  299. AllowedHeaders: []string{"X-Header-1", "x-header-2"},
  300. },
  301. "OPTIONS",
  302. map[string]string{
  303. "Origin": "http://foobar.com",
  304. "Access-Control-Request-Method": "GET",
  305. "Access-Control-Request-Headers": "X-Header-3, X-Header-1",
  306. },
  307. map[string]string{
  308. "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
  309. },
  310. },
  311. {
  312. "OriginHeader",
  313. Options{
  314. AllowedOrigins: []string{"http://foobar.com"},
  315. },
  316. "OPTIONS",
  317. map[string]string{
  318. "Origin": "http://foobar.com",
  319. "Access-Control-Request-Method": "GET",
  320. "Access-Control-Request-Headers": "origin",
  321. },
  322. map[string]string{
  323. "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
  324. "Access-Control-Allow-Origin": "http://foobar.com",
  325. "Access-Control-Allow-Methods": "GET",
  326. "Access-Control-Allow-Headers": "Origin",
  327. },
  328. },
  329. {
  330. "ExposedHeader",
  331. Options{
  332. AllowedOrigins: []string{"http://foobar.com"},
  333. ExposedHeaders: []string{"X-Header-1", "x-header-2"},
  334. },
  335. "GET",
  336. map[string]string{
  337. "Origin": "http://foobar.com",
  338. },
  339. map[string]string{
  340. "Vary": "Origin",
  341. "Access-Control-Allow-Origin": "http://foobar.com",
  342. "Access-Control-Expose-Headers": "X-Header-1, X-Header-2",
  343. },
  344. },
  345. {
  346. "AllowedCredentials",
  347. Options{
  348. AllowedOrigins: []string{"http://foobar.com"},
  349. AllowCredentials: true,
  350. },
  351. "OPTIONS",
  352. map[string]string{
  353. "Origin": "http://foobar.com",
  354. "Access-Control-Request-Method": "GET",
  355. },
  356. map[string]string{
  357. "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
  358. "Access-Control-Allow-Origin": "http://foobar.com",
  359. "Access-Control-Allow-Methods": "GET",
  360. "Access-Control-Allow-Credentials": "true",
  361. },
  362. },
  363. {
  364. "OptionPassthrough",
  365. Options{
  366. OptionsPassthrough: true,
  367. },
  368. "OPTIONS",
  369. map[string]string{
  370. "Origin": "http://foobar.com",
  371. "Access-Control-Request-Method": "GET",
  372. },
  373. map[string]string{
  374. "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
  375. "Access-Control-Allow-Origin": "*",
  376. "Access-Control-Allow-Methods": "GET",
  377. },
  378. },
  379. {
  380. "NonPreflightOptions",
  381. Options{
  382. AllowedOrigins: []string{"http://foobar.com"},
  383. },
  384. "OPTIONS",
  385. map[string]string{
  386. "Origin": "http://foobar.com",
  387. },
  388. map[string]string{
  389. "Vary": "Origin",
  390. "Access-Control-Allow-Origin": "http://foobar.com",
  391. },
  392. },
  393. }
  394. for i := range cases {
  395. tc := cases[i]
  396. t.Run(tc.name, func(t *testing.T) {
  397. s := New(tc.options)
  398. req, _ := http.NewRequest(tc.method, "http://example.com/foo", nil)
  399. for name, value := range tc.reqHeaders {
  400. req.Header.Add(name, value)
  401. }
  402. t.Run("Handler", func(t *testing.T) {
  403. res := httptest.NewRecorder()
  404. s.Handler(testHandler).ServeHTTP(res, req)
  405. assertHeaders(t, res.Header(), tc.resHeaders)
  406. })
  407. t.Run("HandlerFunc", func(t *testing.T) {
  408. res := httptest.NewRecorder()
  409. s.HandlerFunc(res, req)
  410. assertHeaders(t, res.Header(), tc.resHeaders)
  411. })
  412. t.Run("Negroni", func(t *testing.T) {
  413. res := httptest.NewRecorder()
  414. s.ServeHTTP(res, req, testHandler)
  415. assertHeaders(t, res.Header(), tc.resHeaders)
  416. })
  417. })
  418. }
  419. }
  420. func TestDebug(t *testing.T) {
  421. s := New(Options{
  422. Debug: true,
  423. })
  424. if s.Log == nil {
  425. t.Error("Logger not created when debug=true")
  426. }
  427. }
  428. func TestDefault(t *testing.T) {
  429. s := Default()
  430. if s.Log != nil {
  431. t.Error("c.log should be nil when Default")
  432. }
  433. if !s.allowedOriginsAll {
  434. t.Error("c.allowedOriginsAll should be true when Default")
  435. }
  436. if s.allowedHeaders == nil {
  437. t.Error("c.allowedHeaders should be nil when Default")
  438. }
  439. if s.allowedMethods == nil {
  440. t.Error("c.allowedMethods should be nil when Default")
  441. }
  442. }
  443. func TestHandlePreflightInvalidOriginAbortion(t *testing.T) {
  444. s := New(Options{
  445. AllowedOrigins: []string{"http://foo.com"},
  446. })
  447. res := httptest.NewRecorder()
  448. req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil)
  449. req.Header.Add("Origin", "http://example.com/")
  450. s.handlePreflight(res, req)
  451. assertHeaders(t, res.Header(), map[string]string{
  452. "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
  453. })
  454. }
  455. func TestHandlePreflightNoOptionsAbortion(t *testing.T) {
  456. s := New(Options{
  457. // Intentionally left blank.
  458. })
  459. res := httptest.NewRecorder()
  460. req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
  461. s.handlePreflight(res, req)
  462. assertHeaders(t, res.Header(), map[string]string{})
  463. }
  464. func TestHandleActualRequestInvalidOriginAbortion(t *testing.T) {
  465. s := New(Options{
  466. AllowedOrigins: []string{"http://foo.com"},
  467. })
  468. res := httptest.NewRecorder()
  469. req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
  470. req.Header.Add("Origin", "http://example.com/")
  471. s.handleActualRequest(res, req)
  472. assertHeaders(t, res.Header(), map[string]string{
  473. "Vary": "Origin",
  474. })
  475. }
  476. func TestHandleActualRequestInvalidMethodAbortion(t *testing.T) {
  477. s := New(Options{
  478. AllowedMethods: []string{"POST"},
  479. AllowCredentials: true,
  480. })
  481. res := httptest.NewRecorder()
  482. req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
  483. req.Header.Add("Origin", "http://example.com/")
  484. s.handleActualRequest(res, req)
  485. assertHeaders(t, res.Header(), map[string]string{
  486. "Vary": "Origin",
  487. })
  488. }
  489. func TestIsMethodAllowedReturnsFalseWithNoMethods(t *testing.T) {
  490. s := New(Options{
  491. // Intentionally left blank.
  492. })
  493. s.allowedMethods = []string{}
  494. if s.isMethodAllowed("") {
  495. t.Error("IsMethodAllowed should return false when c.allowedMethods is nil.")
  496. }
  497. }
  498. func TestIsMethodAllowedReturnsTrueWithOptions(t *testing.T) {
  499. s := New(Options{
  500. // Intentionally left blank.
  501. })
  502. if !s.isMethodAllowed("OPTIONS") {
  503. t.Error("IsMethodAllowed should return true when c.allowedMethods is nil.")
  504. }
  505. }