timeout_test.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. package nebula
  2. import (
  3. "testing"
  4. "time"
  5. "github.com/slackhq/nebula/firewall"
  6. "github.com/stretchr/testify/assert"
  7. )
  8. func TestNewTimerWheel(t *testing.T) {
  9. // Make sure we get an object we expect
  10. tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
  11. assert.Equal(t, 12, tw.wheelLen)
  12. assert.Equal(t, 0, tw.current)
  13. assert.Nil(t, tw.lastTick)
  14. assert.Equal(t, time.Second*1, tw.tickDuration)
  15. assert.Equal(t, time.Second*10, tw.wheelDuration)
  16. assert.Len(t, tw.wheel, 12)
  17. // Assert the math is correct
  18. tw = NewTimerWheel[firewall.Packet](time.Second*3, time.Second*10)
  19. assert.Equal(t, 5, tw.wheelLen)
  20. tw = NewTimerWheel[firewall.Packet](time.Second*120, time.Minute*10)
  21. assert.Equal(t, 7, tw.wheelLen)
  22. // Test empty purge of non nil items
  23. i, ok := tw.Purge()
  24. assert.Equal(t, firewall.Packet{}, i)
  25. assert.False(t, ok)
  26. // Test empty purges of nil items
  27. tw2 := NewTimerWheel[*int](time.Second, time.Second*10)
  28. i2, ok := tw2.Purge()
  29. assert.Nil(t, i2)
  30. assert.False(t, ok)
  31. }
  32. func TestTimerWheel_findWheel(t *testing.T) {
  33. tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
  34. assert.Len(t, tw.wheel, 12)
  35. // Current + tick + 1 since we don't know how far into current we are
  36. assert.Equal(t, 2, tw.findWheel(time.Second*1))
  37. // Scale up to min duration
  38. assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
  39. // Make sure we hit that last index
  40. assert.Equal(t, 11, tw.findWheel(time.Second*10))
  41. // Scale down to max duration
  42. assert.Equal(t, 11, tw.findWheel(time.Second*11))
  43. tw.current = 1
  44. // Make sure we account for the current position properly
  45. assert.Equal(t, 3, tw.findWheel(time.Second*1))
  46. assert.Equal(t, 0, tw.findWheel(time.Second*10))
  47. }
  48. func TestTimerWheel_Add(t *testing.T) {
  49. tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
  50. fp1 := firewall.Packet{}
  51. tw.Add(fp1, time.Second*1)
  52. // Make sure we set head and tail properly
  53. assert.NotNil(t, tw.wheel[2])
  54. assert.Equal(t, fp1, tw.wheel[2].Head.Item)
  55. assert.Nil(t, tw.wheel[2].Head.Next)
  56. assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
  57. assert.Nil(t, tw.wheel[2].Tail.Next)
  58. // Make sure we only modify head
  59. fp2 := firewall.Packet{}
  60. tw.Add(fp2, time.Second*1)
  61. assert.Equal(t, fp2, tw.wheel[2].Head.Item)
  62. assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
  63. assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
  64. assert.Nil(t, tw.wheel[2].Tail.Next)
  65. // Make sure we use free'd items first
  66. tw.itemCache = &TimeoutItem[firewall.Packet]{}
  67. tw.itemsCached = 1
  68. tw.Add(fp2, time.Second*1)
  69. assert.Nil(t, tw.itemCache)
  70. assert.Equal(t, 0, tw.itemsCached)
  71. // Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel
  72. for min := time.Duration(1); min < 100; min++ {
  73. for max := min; max < 100; max++ {
  74. tw = NewTimerWheel[firewall.Packet](min, max)
  75. for current := 0; current < tw.wheelLen; current++ {
  76. tw.current = current
  77. for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ {
  78. tick := tw.findWheel(timeout)
  79. if tick >= tw.wheelLen {
  80. t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick)
  81. }
  82. }
  83. }
  84. }
  85. }
  86. }
  87. func TestTimerWheel_Purge(t *testing.T) {
  88. // First advance should set the lastTick and do nothing else
  89. tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
  90. assert.Nil(t, tw.lastTick)
  91. tw.Advance(time.Now())
  92. assert.NotNil(t, tw.lastTick)
  93. assert.Equal(t, 0, tw.current)
  94. fps := []firewall.Packet{
  95. {LocalIP: 1},
  96. {LocalIP: 2},
  97. {LocalIP: 3},
  98. {LocalIP: 4},
  99. }
  100. tw.Add(fps[0], time.Second*1)
  101. tw.Add(fps[1], time.Second*1)
  102. tw.Add(fps[2], time.Second*2)
  103. tw.Add(fps[3], time.Second*2)
  104. ta := time.Now().Add(time.Second * 3)
  105. lastTick := *tw.lastTick
  106. tw.Advance(ta)
  107. assert.Equal(t, 3, tw.current)
  108. assert.True(t, tw.lastTick.After(lastTick))
  109. // Make sure we get all 4 packets back
  110. for i := 0; i < 4; i++ {
  111. p, has := tw.Purge()
  112. assert.True(t, has)
  113. assert.Equal(t, fps[i], p)
  114. }
  115. // Make sure there aren't any leftover
  116. _, ok := tw.Purge()
  117. assert.False(t, ok)
  118. assert.Nil(t, tw.expired.Head)
  119. assert.Nil(t, tw.expired.Tail)
  120. // Make sure we cached the free'd items
  121. assert.Equal(t, 4, tw.itemsCached)
  122. ci := tw.itemCache
  123. for i := 0; i < 4; i++ {
  124. assert.NotNil(t, ci)
  125. ci = ci.Next
  126. }
  127. assert.Nil(t, ci)
  128. // Let's make sure we roll over properly
  129. ta = ta.Add(time.Second * 5)
  130. tw.Advance(ta)
  131. assert.Equal(t, 8, tw.current)
  132. ta = ta.Add(time.Second * 2)
  133. tw.Advance(ta)
  134. assert.Equal(t, 10, tw.current)
  135. ta = ta.Add(time.Second * 1)
  136. tw.Advance(ta)
  137. assert.Equal(t, 11, tw.current)
  138. ta = ta.Add(time.Second * 1)
  139. tw.Advance(ta)
  140. assert.Equal(t, 0, tw.current)
  141. }