handshake_manager_test.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. package nebula
  2. import (
  3. "net"
  4. "testing"
  5. "time"
  6. "github.com/stretchr/testify/assert"
  7. )
  8. //var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
  9. var ips []uint32
  10. func Test_NewHandshakeManagerIndex(t *testing.T) {
  11. _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
  12. _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
  13. _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
  14. ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
  15. preferredRanges := []*net.IPNet{localrange}
  16. mainHM := NewHostMap("test", vpncidr, preferredRanges)
  17. blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
  18. now := time.Now()
  19. blah.NextInboundHandshakeTimerTick(now)
  20. var indexes = make([]uint32, 4)
  21. var hostinfo = make([]*HostInfo, len(indexes))
  22. for i := range indexes {
  23. hostinfo[i] = &HostInfo{ConnectionState: &ConnectionState{}}
  24. }
  25. // Add four indexes
  26. for i := range indexes {
  27. err := blah.AddIndexHostInfo(hostinfo[i])
  28. assert.NoError(t, err)
  29. indexes[i] = hostinfo[i].localIndexId
  30. blah.InboundHandshakeTimer.Add(indexes[i], time.Second*10)
  31. }
  32. // Confirm they are in the pending index list
  33. for _, v := range indexes {
  34. assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
  35. }
  36. // Adding something to pending should not affect the main hostmap
  37. assert.Len(t, mainHM.Indexes, 0)
  38. // Jump ahead 8 seconds
  39. for i := 1; i <= DefaultHandshakeRetries; i++ {
  40. next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
  41. blah.NextInboundHandshakeTimerTick(next_tick)
  42. }
  43. // Confirm they are still in the pending index list
  44. for _, v := range indexes {
  45. assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
  46. }
  47. // Jump ahead 4 more seconds
  48. next_tick := now.Add(12 * time.Second)
  49. blah.NextInboundHandshakeTimerTick(next_tick)
  50. // Confirm they have been removed
  51. for _, v := range indexes {
  52. assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(v))
  53. }
  54. }
  55. func Test_NewHandshakeManagerVpnIP(t *testing.T) {
  56. _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
  57. _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
  58. _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
  59. ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
  60. preferredRanges := []*net.IPNet{localrange}
  61. mw := &mockEncWriter{}
  62. mainHM := NewHostMap("test", vpncidr, preferredRanges)
  63. blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
  64. now := time.Now()
  65. blah.NextOutboundHandshakeTimerTick(now, mw)
  66. // Add four "IPs" - which are just uint32s
  67. for _, v := range ips {
  68. blah.AddVpnIP(v)
  69. }
  70. // Adding something to pending should not affect the main hostmap
  71. assert.Len(t, mainHM.Hosts, 0)
  72. // Confirm they are in the pending index list
  73. for _, v := range ips {
  74. assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
  75. }
  76. // Jump ahead `HandshakeRetries` ticks
  77. cumulative := time.Duration(0)
  78. for i := 0; i <= DefaultHandshakeRetries+1; i++ {
  79. cumulative += time.Duration(i)*DefaultHandshakeTryInterval + 1
  80. next_tick := now.Add(cumulative)
  81. //l.Infoln(next_tick)
  82. blah.NextOutboundHandshakeTimerTick(next_tick, mw)
  83. }
  84. // Confirm they are still in the pending index list
  85. for _, v := range ips {
  86. assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
  87. }
  88. // Jump ahead 1 more second
  89. cumulative += time.Duration(DefaultHandshakeRetries+1) * DefaultHandshakeTryInterval
  90. next_tick := now.Add(cumulative)
  91. //l.Infoln(next_tick)
  92. blah.NextOutboundHandshakeTimerTick(next_tick, mw)
  93. // Confirm they have been removed
  94. for _, v := range ips {
  95. assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(v))
  96. }
  97. }
  98. func Test_NewHandshakeManagerTrigger(t *testing.T) {
  99. _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
  100. _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
  101. _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
  102. ip := ip2int(net.ParseIP("172.1.1.2"))
  103. preferredRanges := []*net.IPNet{localrange}
  104. mw := &mockEncWriter{}
  105. mainHM := NewHostMap("test", vpncidr, preferredRanges)
  106. lh := &LightHouse{}
  107. blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
  108. now := time.Now()
  109. blah.NextOutboundHandshakeTimerTick(now, mw)
  110. assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
  111. blah.AddVpnIP(ip)
  112. assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
  113. // Trigger the same method the channel will
  114. blah.handleOutbound(ip, mw, true)
  115. // Make sure the trigger doesn't schedule another timer entry
  116. assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
  117. hi := blah.pendingHostMap.Hosts[ip]
  118. assert.Nil(t, hi.remote)
  119. lh.addrMap = map[uint32][]udpAddr{
  120. ip: {*NewUDPAddrFromString("10.1.1.1:4242")},
  121. }
  122. // This should trigger the hostmap to populate the hostinfo
  123. blah.handleOutbound(ip, mw, true)
  124. assert.NotNil(t, hi.remote)
  125. assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
  126. }
  127. func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
  128. for _, i := range tw.wheel {
  129. n := i.Head
  130. for n != nil {
  131. c++
  132. n = n.Next
  133. }
  134. }
  135. return c
  136. }
  137. func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
  138. _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
  139. _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
  140. _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
  141. vpnIP = ip2int(net.ParseIP("172.1.1.2"))
  142. preferredRanges := []*net.IPNet{localrange}
  143. mw := &mockEncWriter{}
  144. mainHM := NewHostMap("test", vpncidr, preferredRanges)
  145. blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
  146. now := time.Now()
  147. blah.NextOutboundHandshakeTimerTick(now, mw)
  148. hostinfo := blah.AddVpnIP(vpnIP)
  149. // Pretned we have an index too
  150. err := blah.AddIndexHostInfo(hostinfo)
  151. assert.NoError(t, err)
  152. blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
  153. assert.NotZero(t, hostinfo.localIndexId)
  154. assert.Contains(t, blah.pendingHostMap.Indexes, hostinfo.localIndexId)
  155. // Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending
  156. // but not main hostmap
  157. cumulative := time.Duration(0)
  158. for i := 1; i <= DefaultHandshakeRetries+2; i++ {
  159. cumulative += DefaultHandshakeTryInterval * time.Duration(i)
  160. next_tick := now.Add(cumulative)
  161. blah.NextOutboundHandshakeTimerTick(next_tick, mw)
  162. }
  163. /*
  164. for i := 0; i <= HandshakeRetries+1; i++ {
  165. next_tick := now.Add(cumulative)
  166. //l.Infoln(next_tick)
  167. blah.NextOutboundHandshakeTimerTick(next_tick)
  168. }
  169. */
  170. /*
  171. for i := 0; i <= HandshakeRetries+1; i++ {
  172. next_tick := now.Add(time.Duration(i) * time.Second)
  173. blah.NextOutboundHandshakeTimerTick(next_tick)
  174. }
  175. */
  176. /*
  177. cumulative += HandshakeTryInterval*time.Duration(HandshakeRetries) + 3
  178. next_tick := now.Add(cumulative)
  179. l.Infoln(cumulative, next_tick)
  180. blah.NextOutboundHandshakeTimerTick(next_tick)
  181. */
  182. assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(vpnIP))
  183. assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
  184. }
  185. func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
  186. _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
  187. _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
  188. _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
  189. preferredRanges := []*net.IPNet{localrange}
  190. mainHM := NewHostMap("test", vpncidr, preferredRanges)
  191. blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
  192. now := time.Now()
  193. blah.NextInboundHandshakeTimerTick(now)
  194. hostinfo := &HostInfo{ConnectionState: &ConnectionState{}}
  195. err := blah.AddIndexHostInfo(hostinfo)
  196. assert.NoError(t, err)
  197. blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
  198. // Pretned we have an index too
  199. blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo)
  200. assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010))
  201. for i := 1; i <= DefaultHandshakeRetries+2; i++ {
  202. next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
  203. blah.NextInboundHandshakeTimerTick(next_tick)
  204. }
  205. next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3)
  206. blah.NextInboundHandshakeTimerTick(next_tick)
  207. assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
  208. assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(hostinfo.localIndexId))
  209. }
  210. type mockEncWriter struct {
  211. }
  212. func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
  213. return
  214. }
  215. func (mw *mockEncWriter) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
  216. return
  217. }