handshake_manager_test.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package nebula
  2. import (
  3. "net/netip"
  4. "testing"
  5. "time"
  6. "github.com/slackhq/nebula/cert"
  7. "github.com/slackhq/nebula/header"
  8. "github.com/slackhq/nebula/test"
  9. "github.com/slackhq/nebula/udp"
  10. "github.com/stretchr/testify/assert"
  11. )
  12. func Test_NewHandshakeManagerVpnIp(t *testing.T) {
  13. l := test.NewLogger()
  14. localrange := netip.MustParsePrefix("10.1.1.1/24")
  15. ip := netip.MustParseAddr("172.1.1.2")
  16. preferredRanges := []netip.Prefix{localrange}
  17. mainHM := newHostMap(l)
  18. mainHM.preferredRanges.Store(&preferredRanges)
  19. lh := newTestLighthouse()
  20. cs := &CertState{
  21. initiatingVersion: cert.Version1,
  22. privateKey: []byte{},
  23. v1Cert: &dummyCert{version: cert.Version1},
  24. v1HandshakeBytes: []byte{},
  25. }
  26. blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
  27. blah.f = &Interface{handshakeManager: blah, pki: &PKI{}, l: l}
  28. blah.f.pki.cs.Store(cs)
  29. now := time.Now()
  30. blah.NextOutboundHandshakeTimerTick(now)
  31. i := blah.StartHandshake(ip, nil)
  32. i2 := blah.StartHandshake(ip, nil)
  33. assert.Same(t, i, i2)
  34. i.remotes = NewRemoteList([]netip.Addr{}, nil)
  35. // Adding something to pending should not affect the main hostmap
  36. assert.Empty(t, mainHM.Hosts)
  37. // Confirm they are in the pending index list
  38. assert.Contains(t, blah.vpnIps, ip)
  39. // Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right
  40. for i := 1; i <= DefaultHandshakeRetries+1; i++ {
  41. now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval)
  42. blah.NextOutboundHandshakeTimerTick(now)
  43. }
  44. // Confirm they are still in the pending index list
  45. assert.Contains(t, blah.vpnIps, ip)
  46. // Tick 1 more time, a minute will certainly flush it out
  47. blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute))
  48. // Confirm they have been removed
  49. assert.NotContains(t, blah.vpnIps, ip)
  50. }
  51. func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
  52. for _, i := range tw.t.wheel {
  53. n := i.Head
  54. for n != nil {
  55. c++
  56. n = n.Next
  57. }
  58. }
  59. return c
  60. }
  61. type mockEncWriter struct {
  62. }
  63. func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
  64. return
  65. }
  66. func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
  67. return
  68. }
  69. func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
  70. return
  71. }
  72. func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
  73. func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
  74. return nil
  75. }
  76. func (mw *mockEncWriter) GetCertState() *CertState {
  77. return &CertState{initiatingVersion: cert.Version2}
  78. }