timeout_test.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. package nebula
  2. import (
  3. "testing"
  4. "time"
  5. "github.com/slackhq/nebula/firewall"
  6. "github.com/stretchr/testify/assert"
  7. )
  8. func TestNewTimerWheel(t *testing.T) {
  9. // Make sure we get an object we expect
  10. tw := NewTimerWheel(time.Second, time.Second*10)
  11. assert.Equal(t, 12, tw.wheelLen)
  12. assert.Equal(t, 0, tw.current)
  13. assert.Nil(t, tw.lastTick)
  14. assert.Equal(t, time.Second*1, tw.tickDuration)
  15. assert.Equal(t, time.Second*10, tw.wheelDuration)
  16. assert.Len(t, tw.wheel, 12)
  17. // Assert the math is correct
  18. tw = NewTimerWheel(time.Second*3, time.Second*10)
  19. assert.Equal(t, 5, tw.wheelLen)
  20. tw = NewTimerWheel(time.Second*120, time.Minute*10)
  21. assert.Equal(t, 7, tw.wheelLen)
  22. }
  23. func TestTimerWheel_findWheel(t *testing.T) {
  24. tw := NewTimerWheel(time.Second, time.Second*10)
  25. assert.Len(t, tw.wheel, 12)
  26. // Current + tick + 1 since we don't know how far into current we are
  27. assert.Equal(t, 2, tw.findWheel(time.Second*1))
  28. // Scale up to min duration
  29. assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
  30. // Make sure we hit that last index
  31. assert.Equal(t, 11, tw.findWheel(time.Second*10))
  32. // Scale down to max duration
  33. assert.Equal(t, 11, tw.findWheel(time.Second*11))
  34. tw.current = 1
  35. // Make sure we account for the current position properly
  36. assert.Equal(t, 3, tw.findWheel(time.Second*1))
  37. assert.Equal(t, 0, tw.findWheel(time.Second*10))
  38. }
  39. func TestTimerWheel_Add(t *testing.T) {
  40. tw := NewTimerWheel(time.Second, time.Second*10)
  41. fp1 := firewall.Packet{}
  42. tw.Add(fp1, time.Second*1)
  43. // Make sure we set head and tail properly
  44. assert.NotNil(t, tw.wheel[2])
  45. assert.Equal(t, fp1, tw.wheel[2].Head.Packet)
  46. assert.Nil(t, tw.wheel[2].Head.Next)
  47. assert.Equal(t, fp1, tw.wheel[2].Tail.Packet)
  48. assert.Nil(t, tw.wheel[2].Tail.Next)
  49. // Make sure we only modify head
  50. fp2 := firewall.Packet{}
  51. tw.Add(fp2, time.Second*1)
  52. assert.Equal(t, fp2, tw.wheel[2].Head.Packet)
  53. assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet)
  54. assert.Equal(t, fp1, tw.wheel[2].Tail.Packet)
  55. assert.Nil(t, tw.wheel[2].Tail.Next)
  56. // Make sure we use free'd items first
  57. tw.itemCache = &TimeoutItem{}
  58. tw.itemsCached = 1
  59. tw.Add(fp2, time.Second*1)
  60. assert.Nil(t, tw.itemCache)
  61. assert.Equal(t, 0, tw.itemsCached)
  62. // Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel
  63. for min := time.Duration(1); min < 100; min++ {
  64. for max := min; max < 100; max++ {
  65. tw = NewTimerWheel(min, max)
  66. for current := 0; current < tw.wheelLen; current++ {
  67. tw.current = current
  68. for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ {
  69. tick := tw.findWheel(timeout)
  70. if tick >= tw.wheelLen {
  71. t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick)
  72. }
  73. }
  74. }
  75. }
  76. }
  77. }
  78. func TestTimerWheel_Purge(t *testing.T) {
  79. // First advance should set the lastTick and do nothing else
  80. tw := NewTimerWheel(time.Second, time.Second*10)
  81. assert.Nil(t, tw.lastTick)
  82. tw.advance(time.Now())
  83. assert.NotNil(t, tw.lastTick)
  84. assert.Equal(t, 0, tw.current)
  85. fps := []firewall.Packet{
  86. {LocalIP: 1},
  87. {LocalIP: 2},
  88. {LocalIP: 3},
  89. {LocalIP: 4},
  90. }
  91. tw.Add(fps[0], time.Second*1)
  92. tw.Add(fps[1], time.Second*1)
  93. tw.Add(fps[2], time.Second*2)
  94. tw.Add(fps[3], time.Second*2)
  95. ta := time.Now().Add(time.Second * 3)
  96. lastTick := *tw.lastTick
  97. tw.advance(ta)
  98. assert.Equal(t, 3, tw.current)
  99. assert.True(t, tw.lastTick.After(lastTick))
  100. // Make sure we get all 4 packets back
  101. for i := 0; i < 4; i++ {
  102. p, has := tw.Purge()
  103. assert.True(t, has)
  104. assert.Equal(t, fps[i], p)
  105. }
  106. // Make sure there aren't any leftover
  107. _, ok := tw.Purge()
  108. assert.False(t, ok)
  109. assert.Nil(t, tw.expired.Head)
  110. assert.Nil(t, tw.expired.Tail)
  111. // Make sure we cached the free'd items
  112. assert.Equal(t, 4, tw.itemsCached)
  113. ci := tw.itemCache
  114. for i := 0; i < 4; i++ {
  115. assert.NotNil(t, ci)
  116. ci = ci.Next
  117. }
  118. assert.Nil(t, ci)
  119. // Lets make sure we roll over properly
  120. ta = ta.Add(time.Second * 5)
  121. tw.advance(ta)
  122. assert.Equal(t, 8, tw.current)
  123. ta = ta.Add(time.Second * 2)
  124. tw.advance(ta)
  125. assert.Equal(t, 10, tw.current)
  126. ta = ta.Add(time.Second * 1)
  127. tw.advance(ta)
  128. assert.Equal(t, 11, tw.current)
  129. ta = ta.Add(time.Second * 1)
  130. tw.advance(ta)
  131. assert.Equal(t, 0, tw.current)
  132. }