timeout_test.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package nebula
  2. import (
  3. "net/netip"
  4. "testing"
  5. "time"
  6. "github.com/slackhq/nebula/firewall"
  7. "github.com/stretchr/testify/assert"
  8. )
  9. func TestNewTimerWheel(t *testing.T) {
  10. // Make sure we get an object we expect
  11. tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
  12. assert.Equal(t, 12, tw.wheelLen)
  13. assert.Equal(t, 0, tw.current)
  14. assert.Nil(t, tw.lastTick)
  15. assert.Equal(t, time.Second*1, tw.tickDuration)
  16. assert.Equal(t, time.Second*10, tw.wheelDuration)
  17. assert.Len(t, tw.wheel, 12)
  18. // Assert the math is correct
  19. tw = NewTimerWheel[firewall.Packet](time.Second*3, time.Second*10)
  20. assert.Equal(t, 5, tw.wheelLen)
  21. tw = NewTimerWheel[firewall.Packet](time.Second*120, time.Minute*10)
  22. assert.Equal(t, 7, tw.wheelLen)
  23. // Test empty purge of non nil items
  24. i, ok := tw.Purge()
  25. assert.Equal(t, firewall.Packet{}, i)
  26. assert.False(t, ok)
  27. // Test empty purges of nil items
  28. tw2 := NewTimerWheel[*int](time.Second, time.Second*10)
  29. i2, ok := tw2.Purge()
  30. assert.Nil(t, i2)
  31. assert.False(t, ok)
  32. }
  33. func TestTimerWheel_findWheel(t *testing.T) {
  34. tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
  35. assert.Len(t, tw.wheel, 12)
  36. // Current + tick + 1 since we don't know how far into current we are
  37. assert.Equal(t, 2, tw.findWheel(time.Second*1))
  38. // Scale up to min duration
  39. assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
  40. // Make sure we hit that last index
  41. assert.Equal(t, 11, tw.findWheel(time.Second*10))
  42. // Scale down to max duration
  43. assert.Equal(t, 11, tw.findWheel(time.Second*11))
  44. tw.current = 1
  45. // Make sure we account for the current position properly
  46. assert.Equal(t, 3, tw.findWheel(time.Second*1))
  47. assert.Equal(t, 0, tw.findWheel(time.Second*10))
  48. }
  49. func TestTimerWheel_Add(t *testing.T) {
  50. tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
  51. fp1 := firewall.Packet{}
  52. tw.Add(fp1, time.Second*1)
  53. // Make sure we set head and tail properly
  54. assert.NotNil(t, tw.wheel[2])
  55. assert.Equal(t, fp1, tw.wheel[2].Head.Item)
  56. assert.Nil(t, tw.wheel[2].Head.Next)
  57. assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
  58. assert.Nil(t, tw.wheel[2].Tail.Next)
  59. // Make sure we only modify head
  60. fp2 := firewall.Packet{}
  61. tw.Add(fp2, time.Second*1)
  62. assert.Equal(t, fp2, tw.wheel[2].Head.Item)
  63. assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
  64. assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
  65. assert.Nil(t, tw.wheel[2].Tail.Next)
  66. // Make sure we use free'd items first
  67. tw.itemCache = &TimeoutItem[firewall.Packet]{}
  68. tw.itemsCached = 1
  69. tw.Add(fp2, time.Second*1)
  70. assert.Nil(t, tw.itemCache)
  71. assert.Equal(t, 0, tw.itemsCached)
  72. // Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel
  73. for min := time.Duration(1); min < 100; min++ {
  74. for max := min; max < 100; max++ {
  75. tw = NewTimerWheel[firewall.Packet](min, max)
  76. for current := 0; current < tw.wheelLen; current++ {
  77. tw.current = current
  78. for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ {
  79. tick := tw.findWheel(timeout)
  80. if tick >= tw.wheelLen {
  81. t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick)
  82. }
  83. }
  84. }
  85. }
  86. }
  87. }
  88. func TestTimerWheel_Purge(t *testing.T) {
  89. // First advance should set the lastTick and do nothing else
  90. tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
  91. assert.Nil(t, tw.lastTick)
  92. tw.Advance(time.Now())
  93. assert.NotNil(t, tw.lastTick)
  94. assert.Equal(t, 0, tw.current)
  95. fps := []firewall.Packet{
  96. {LocalAddr: netip.MustParseAddr("0.0.0.1")},
  97. {LocalAddr: netip.MustParseAddr("0.0.0.2")},
  98. {LocalAddr: netip.MustParseAddr("0.0.0.3")},
  99. {LocalAddr: netip.MustParseAddr("0.0.0.4")},
  100. }
  101. tw.Add(fps[0], time.Second*1)
  102. tw.Add(fps[1], time.Second*1)
  103. tw.Add(fps[2], time.Second*2)
  104. tw.Add(fps[3], time.Second*2)
  105. ta := time.Now().Add(time.Second * 3)
  106. lastTick := *tw.lastTick
  107. tw.Advance(ta)
  108. assert.Equal(t, 3, tw.current)
  109. assert.True(t, tw.lastTick.After(lastTick))
  110. // Make sure we get all 4 packets back
  111. for i := 0; i < 4; i++ {
  112. p, has := tw.Purge()
  113. assert.True(t, has)
  114. assert.Equal(t, fps[i], p)
  115. }
  116. // Make sure there aren't any leftover
  117. _, ok := tw.Purge()
  118. assert.False(t, ok)
  119. assert.Nil(t, tw.expired.Head)
  120. assert.Nil(t, tw.expired.Tail)
  121. // Make sure we cached the free'd items
  122. assert.Equal(t, 4, tw.itemsCached)
  123. ci := tw.itemCache
  124. for i := 0; i < 4; i++ {
  125. assert.NotNil(t, ci)
  126. ci = ci.Next
  127. }
  128. assert.Nil(t, ci)
  129. // Let's make sure we roll over properly
  130. ta = ta.Add(time.Second * 5)
  131. tw.Advance(ta)
  132. assert.Equal(t, 8, tw.current)
  133. ta = ta.Add(time.Second * 2)
  134. tw.Advance(ta)
  135. assert.Equal(t, 10, tw.current)
  136. ta = ta.Add(time.Second * 1)
  137. tw.Advance(ta)
  138. assert.Equal(t, 11, tw.current)
  139. ta = ta.Add(time.Second * 1)
  140. tw.Advance(ta)
  141. assert.Equal(t, 0, tw.current)
  142. }