bucket_test.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. // Copyright 2015 Google Inc. All rights reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package trafficshape
  15. import (
  16. "errors"
  17. "runtime"
  18. "sync/atomic"
  19. "testing"
  20. "time"
  21. )
  22. func TestBucket(t *testing.T) {
  23. t.Parallel()
  24. b := NewBucket(10, 10*time.Millisecond)
  25. defer b.Close()
  26. if got, want := b.Capacity(), int64(10); got != want {
  27. t.Fatalf("b.Capacity(): got %d, want %d", got, want)
  28. }
  29. n, err := b.Fill(func(remaining int64) (int64, error) {
  30. if want := int64(10); remaining != want {
  31. t.Errorf("remaining: got %d, want %d", remaining, want)
  32. }
  33. return 5, nil
  34. })
  35. if err != nil {
  36. t.Fatalf("Fill(): got %v, want no error", err)
  37. }
  38. if got, want := n, int64(5); got != want {
  39. t.Fatalf("n: got %d, want %d", got, want)
  40. }
  41. n, err = b.Fill(func(remaining int64) (int64, error) {
  42. if want := int64(5); remaining != want {
  43. t.Errorf("remaining: got %d, want %d", remaining, want)
  44. }
  45. return 5, nil
  46. })
  47. if err != nil {
  48. t.Fatalf("Fill(): got %v, want no error", err)
  49. }
  50. if got, want := n, int64(5); got != want {
  51. t.Fatalf("n: got %d, want %d", got, want)
  52. }
  53. n, err = b.Fill(func(remaining int64) (int64, error) {
  54. t.Fatal("Fill: executed func when full, want skipped")
  55. return 0, nil
  56. })
  57. if err != nil {
  58. t.Fatalf("Fill(): got %v, want no error", err)
  59. }
  60. // Wait for the bucket to drain.
  61. for {
  62. if atomic.LoadInt64(&b.fill) == 0 {
  63. break
  64. }
  65. // Allow for a goroutine switch, required for GOMAXPROCS = 1.
  66. runtime.Gosched()
  67. }
  68. wanterr := errors.New("fill function error")
  69. n, err = b.Fill(func(remaining int64) (int64, error) {
  70. if want := int64(10); remaining != want {
  71. t.Errorf("remaining: got %d, want %d", remaining, want)
  72. }
  73. return 0, wanterr
  74. })
  75. if err != wanterr {
  76. t.Fatalf("Fill(): got %v, want %v", err, wanterr)
  77. }
  78. if got, want := n, int64(0); got != want {
  79. t.Fatalf("n: got %d, want %d", got, want)
  80. }
  81. }
  82. func TestBucketClosed(t *testing.T) {
  83. t.Parallel()
  84. b := NewBucket(0, time.Millisecond)
  85. b.Close()
  86. if _, err := b.Fill(nil); err != errFillClosedBucket {
  87. t.Errorf("Fill(): got %v, want errFillClosedBucket", err)
  88. }
  89. if _, err := b.FillThrottle(nil); err != errFillClosedBucket {
  90. t.Errorf("FillThrottle(): got %v, want errFillClosedBucket", err)
  91. }
  92. }
  93. func TestBucketOverflow(t *testing.T) {
  94. t.Parallel()
  95. b := NewBucket(10, 10*time.Millisecond)
  96. defer b.Close()
  97. n, err := b.Fill(func(remaining int64) (int64, error) {
  98. return 11, nil
  99. })
  100. if err != nil {
  101. t.Fatalf("Fill(): got %v, want no error", err)
  102. }
  103. n, err = b.Fill(func(int64) (int64, error) {
  104. t.Fatal("Fill: executed func when full, want skipped")
  105. return 0, nil
  106. })
  107. if err != ErrBucketOverflow {
  108. t.Fatalf("Fill(): got %v, want ErrBucketOverflow", err)
  109. }
  110. if got, want := n, int64(0); got != want {
  111. t.Fatalf("n: got %d, want %d", got, want)
  112. }
  113. }
  114. func TestBucketThrottle(t *testing.T) {
  115. t.Parallel()
  116. b := NewBucket(50, 50*time.Millisecond)
  117. defer b.Close()
  118. closec := make(chan struct{})
  119. errc := make(chan error, 1)
  120. fill := func() {
  121. for {
  122. select {
  123. case <-closec:
  124. return
  125. default:
  126. if _, err := b.FillThrottle(func(remaining int64) (int64, error) {
  127. if remaining < 10 {
  128. return remaining, nil
  129. }
  130. return 10, nil
  131. }); err != nil {
  132. select {
  133. case errc <- err:
  134. default:
  135. }
  136. }
  137. }
  138. }
  139. }
  140. for i := 0; i < 5; i++ {
  141. go fill()
  142. }
  143. time.Sleep(time.Second)
  144. close(closec)
  145. select {
  146. case err := <-errc:
  147. t.Fatalf("FillThrottle: got %v, want no error", err)
  148. default:
  149. }
  150. }
  151. func TestBucketFillThrottleCloseBeforeTick(t *testing.T) {
  152. t.Parallel()
  153. b := NewBucket(0, time.Minute)
  154. time.AfterFunc(time.Second, func() { b.Close() })
  155. if _, err := b.FillThrottle(func(int64) (int64, error) {
  156. t.Fatal("FillThrottle(): executed func after close, want skipped")
  157. return 0, nil
  158. }); err != errFillClosedBucket {
  159. t.Errorf("b.FillThrottle(): got nil, want errFillClosedBucket")
  160. }
  161. }