handshake_manager_test.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. package nebula
  2. import (
  3. "net/netip"
  4. "testing"
  5. "time"
  6. "github.com/slackhq/nebula/cert"
  7. "github.com/slackhq/nebula/header"
  8. "github.com/slackhq/nebula/test"
  9. "github.com/slackhq/nebula/udp"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/stretchr/testify/require"
  12. )
  13. func Test_NewHandshakeManagerVpnIp(t *testing.T) {
  14. l := test.NewLogger()
  15. localrange := netip.MustParsePrefix("10.1.1.1/24")
  16. ip := netip.MustParseAddr("172.1.1.2")
  17. preferredRanges := []netip.Prefix{localrange}
  18. mainHM := newHostMap(l)
  19. mainHM.preferredRanges.Store(&preferredRanges)
  20. lh := newTestLighthouse()
  21. psk, err := NewPsk(PskAccepting, nil)
  22. require.NoError(t, err)
  23. cs := &CertState{
  24. defaultVersion: cert.Version1,
  25. privateKey: []byte{},
  26. v1Cert: &dummyCert{version: cert.Version1},
  27. v1HandshakeBytes: []byte{},
  28. psk: psk,
  29. }
  30. blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
  31. blah.f = &Interface{handshakeManager: blah, pki: &PKI{}, l: l}
  32. blah.f.pki.cs.Store(cs)
  33. now := time.Now()
  34. blah.NextOutboundHandshakeTimerTick(now)
  35. i := blah.StartHandshake(ip, nil)
  36. i2 := blah.StartHandshake(ip, nil)
  37. assert.Same(t, i, i2)
  38. i.remotes = NewRemoteList([]netip.Addr{}, nil)
  39. // Adding something to pending should not affect the main hostmap
  40. assert.Empty(t, mainHM.Hosts)
  41. // Confirm they are in the pending index list
  42. assert.Contains(t, blah.vpnIps, ip)
  43. // Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right
  44. for i := 1; i <= DefaultHandshakeRetries+1; i++ {
  45. now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval)
  46. blah.NextOutboundHandshakeTimerTick(now)
  47. }
  48. // Confirm they are still in the pending index list
  49. assert.Contains(t, blah.vpnIps, ip)
  50. // Tick 1 more time, a minute will certainly flush it out
  51. blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute))
  52. // Confirm they have been removed
  53. assert.NotContains(t, blah.vpnIps, ip)
  54. }
  55. func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
  56. for _, i := range tw.t.wheel {
  57. n := i.Head
  58. for n != nil {
  59. c++
  60. n = n.Next
  61. }
  62. }
  63. return c
  64. }
  65. type mockEncWriter struct {
  66. }
  67. func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
  68. return
  69. }
  70. func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
  71. return
  72. }
  73. func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
  74. return
  75. }
  76. func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
  77. func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
  78. return nil
  79. }
  80. func (mw *mockEncWriter) GetCertState() *CertState {
  81. return &CertState{defaultVersion: cert.Version2}
  82. }