timeout_test.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package nebula
  2. import (
  3. "testing"
  4. "time"
  5. "github.com/stretchr/testify/assert"
  6. )
  7. func TestNewTimerWheel(t *testing.T) {
  8. // Make sure we get an object we expect
  9. tw := NewTimerWheel(time.Second, time.Second*10)
  10. assert.Equal(t, 11, tw.wheelLen)
  11. assert.Equal(t, 0, tw.current)
  12. assert.Nil(t, tw.lastTick)
  13. assert.Equal(t, time.Second*1, tw.tickDuration)
  14. assert.Equal(t, time.Second*10, tw.wheelDuration)
  15. assert.Len(t, tw.wheel, 11)
  16. // Assert the math is correct
  17. tw = NewTimerWheel(time.Second*3, time.Second*10)
  18. assert.Equal(t, 4, tw.wheelLen)
  19. tw = NewTimerWheel(time.Second*120, time.Minute*10)
  20. assert.Equal(t, 6, tw.wheelLen)
  21. }
  22. func TestTimerWheel_findWheel(t *testing.T) {
  23. tw := NewTimerWheel(time.Second, time.Second*10)
  24. assert.Len(t, tw.wheel, 11)
  25. // Current + tick + 1 since we don't know how far into current we are
  26. assert.Equal(t, 2, tw.findWheel(time.Second*1))
  27. // Scale up to min duration
  28. assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
  29. // Make sure we hit that last index
  30. assert.Equal(t, 0, tw.findWheel(time.Second*10))
  31. // Scale down to max duration
  32. assert.Equal(t, 0, tw.findWheel(time.Second*11))
  33. tw.current = 1
  34. // Make sure we account for the current position properly
  35. assert.Equal(t, 3, tw.findWheel(time.Second*1))
  36. assert.Equal(t, 1, tw.findWheel(time.Second*10))
  37. }
  38. func TestTimerWheel_Add(t *testing.T) {
  39. tw := NewTimerWheel(time.Second, time.Second*10)
  40. fp1 := FirewallPacket{}
  41. tw.Add(fp1, time.Second*1)
  42. // Make sure we set head and tail properly
  43. assert.NotNil(t, tw.wheel[2])
  44. assert.Equal(t, fp1, tw.wheel[2].Head.Packet)
  45. assert.Nil(t, tw.wheel[2].Head.Next)
  46. assert.Equal(t, fp1, tw.wheel[2].Tail.Packet)
  47. assert.Nil(t, tw.wheel[2].Tail.Next)
  48. // Make sure we only modify head
  49. fp2 := FirewallPacket{}
  50. tw.Add(fp2, time.Second*1)
  51. assert.Equal(t, fp2, tw.wheel[2].Head.Packet)
  52. assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet)
  53. assert.Equal(t, fp1, tw.wheel[2].Tail.Packet)
  54. assert.Nil(t, tw.wheel[2].Tail.Next)
  55. // Make sure we use free'd items first
  56. tw.itemCache = &TimeoutItem{}
  57. tw.itemsCached = 1
  58. tw.Add(fp2, time.Second*1)
  59. assert.Nil(t, tw.itemCache)
  60. assert.Equal(t, 0, tw.itemsCached)
  61. }
  62. func TestTimerWheel_Purge(t *testing.T) {
  63. // First advance should set the lastTick and do nothing else
  64. tw := NewTimerWheel(time.Second, time.Second*10)
  65. assert.Nil(t, tw.lastTick)
  66. tw.advance(time.Now())
  67. assert.NotNil(t, tw.lastTick)
  68. assert.Equal(t, 0, tw.current)
  69. fps := []FirewallPacket{
  70. {LocalIP: 1},
  71. {LocalIP: 2},
  72. {LocalIP: 3},
  73. {LocalIP: 4},
  74. }
  75. tw.Add(fps[0], time.Second*1)
  76. tw.Add(fps[1], time.Second*1)
  77. tw.Add(fps[2], time.Second*2)
  78. tw.Add(fps[3], time.Second*2)
  79. ta := time.Now().Add(time.Second * 3)
  80. lastTick := *tw.lastTick
  81. tw.advance(ta)
  82. assert.Equal(t, 3, tw.current)
  83. assert.True(t, tw.lastTick.After(lastTick))
  84. // Make sure we get all 4 packets back
  85. for i := 0; i < 4; i++ {
  86. p, has := tw.Purge()
  87. assert.True(t, has)
  88. assert.Equal(t, fps[i], p)
  89. }
  90. // Make sure there aren't any leftover
  91. _, ok := tw.Purge()
  92. assert.False(t, ok)
  93. assert.Nil(t, tw.expired.Head)
  94. assert.Nil(t, tw.expired.Tail)
  95. // Make sure we cached the free'd items
  96. assert.Equal(t, 4, tw.itemsCached)
  97. ci := tw.itemCache
  98. for i := 0; i < 4; i++ {
  99. assert.NotNil(t, ci)
  100. ci = ci.Next
  101. }
  102. assert.Nil(t, ci)
  103. // Lets make sure we roll over properly
  104. ta = ta.Add(time.Second * 5)
  105. tw.advance(ta)
  106. assert.Equal(t, 8, tw.current)
  107. ta = ta.Add(time.Second * 2)
  108. tw.advance(ta)
  109. assert.Equal(t, 10, tw.current)
  110. ta = ta.Add(time.Second * 1)
  111. tw.advance(ta)
  112. assert.Equal(t, 0, tw.current)
  113. }