control_test.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package nebula
  2. import (
  3. "net"
  4. "net/netip"
  5. "reflect"
  6. "testing"
  7. "github.com/sirupsen/logrus"
  8. "github.com/slackhq/nebula/cert"
  9. "github.com/slackhq/nebula/test"
  10. "github.com/stretchr/testify/assert"
  11. )
  12. func TestControl_GetHostInfoByVpnIp(t *testing.T) {
  13. //TODO: CERT-V2 with multiple certificate versions we have a problem with this test
  14. // Some certs versions have different characteristics and each version implements their own Copy() func
  15. // which means this is not a good place to test for exposing memory
  16. l := test.NewLogger()
  17. // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
  18. // To properly ensure we are not exposing core memory to the caller
  19. hm := newHostMap(l)
  20. hm.preferredRanges.Store(&[]netip.Prefix{})
  21. remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
  22. remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444")
  23. ipNet := net.IPNet{
  24. IP: remote1.Addr().AsSlice(),
  25. Mask: net.IPMask{255, 255, 255, 0},
  26. }
  27. ipNet2 := net.IPNet{
  28. IP: remote2.Addr().AsSlice(),
  29. Mask: net.IPMask{255, 255, 255, 0},
  30. }
  31. remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
  32. remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port()))
  33. remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port()))
  34. vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
  35. assert.True(t, ok)
  36. crt := &dummyCert{}
  37. hm.unlockedAddHostInfo(&HostInfo{
  38. remote: remote1,
  39. remotes: remotes,
  40. ConnectionState: &ConnectionState{
  41. peerCert: &cert.CachedCertificate{Certificate: crt},
  42. },
  43. remoteIndexId: 200,
  44. localIndexId: 201,
  45. vpnAddrs: []netip.Addr{vpnIp},
  46. relayState: RelayState{
  47. relays: nil,
  48. relayForByAddr: map[netip.Addr]*Relay{},
  49. relayForByIdx: map[uint32]*Relay{},
  50. },
  51. }, &Interface{})
  52. vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP)
  53. assert.True(t, ok)
  54. hm.unlockedAddHostInfo(&HostInfo{
  55. remote: remote1,
  56. remotes: remotes,
  57. ConnectionState: &ConnectionState{
  58. peerCert: nil,
  59. },
  60. remoteIndexId: 200,
  61. localIndexId: 201,
  62. vpnAddrs: []netip.Addr{vpnIp2},
  63. relayState: RelayState{
  64. relays: nil,
  65. relayForByAddr: map[netip.Addr]*Relay{},
  66. relayForByIdx: map[uint32]*Relay{},
  67. },
  68. }, &Interface{})
  69. c := Control{
  70. f: &Interface{
  71. hostMap: hm,
  72. },
  73. l: logrus.New(),
  74. }
  75. thi := c.GetHostInfoByVpnAddr(vpnIp, false)
  76. expectedInfo := ControlHostInfo{
  77. VpnAddrs: []netip.Addr{vpnIp},
  78. LocalIndex: 201,
  79. RemoteIndex: 200,
  80. RemoteAddrs: []netip.AddrPort{remote2, remote1},
  81. Cert: crt.Copy(),
  82. MessageCounter: 0,
  83. CurrentRemote: remote1,
  84. CurrentRelaysToMe: []netip.Addr{},
  85. CurrentRelaysThroughMe: []netip.Addr{},
  86. }
  87. // Make sure we don't have any unexpected fields
  88. assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
  89. assert.Equal(t, &expectedInfo, thi)
  90. test.AssertDeepCopyEqual(t, &expectedInfo, thi)
  91. // Make sure we don't panic if the host info doesn't have a cert yet
  92. assert.NotPanics(t, func() {
  93. thi = c.GetHostInfoByVpnAddr(vpnIp2, false)
  94. })
  95. }
  96. func assertFields(t *testing.T, expected []string, actualStruct any) {
  97. val := reflect.ValueOf(actualStruct).Elem()
  98. fields := make([]string, val.NumField())
  99. for i := 0; i < val.NumField(); i++ {
  100. fields[i] = val.Type().Field(i).Name
  101. }
  102. assert.Equal(t, expected, fields)
  103. }