control_test.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package nebula
  2. import (
  3. "net"
  4. "net/netip"
  5. "reflect"
  6. "testing"
  7. "time"
  8. "github.com/sirupsen/logrus"
  9. "github.com/slackhq/nebula/cert"
  10. "github.com/slackhq/nebula/test"
  11. "github.com/stretchr/testify/assert"
  12. )
  13. func TestControl_GetHostInfoByVpnIp(t *testing.T) {
  14. l := test.NewLogger()
  15. // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
  16. // To properly ensure we are not exposing core memory to the caller
  17. hm := newHostMap(l, netip.Prefix{})
  18. hm.preferredRanges.Store(&[]netip.Prefix{})
  19. remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
  20. remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444")
  21. ipNet := net.IPNet{
  22. IP: remote1.Addr().AsSlice(),
  23. Mask: net.IPMask{255, 255, 255, 0},
  24. }
  25. ipNet2 := net.IPNet{
  26. IP: remote2.Addr().AsSlice(),
  27. Mask: net.IPMask{255, 255, 255, 0},
  28. }
  29. crt := &cert.NebulaCertificate{
  30. Details: cert.NebulaCertificateDetails{
  31. Name: "test",
  32. Ips: []*net.IPNet{&ipNet},
  33. Subnets: []*net.IPNet{},
  34. Groups: []string{"default-group"},
  35. NotBefore: time.Unix(1, 0),
  36. NotAfter: time.Unix(2, 0),
  37. PublicKey: []byte{5, 6, 7, 8},
  38. IsCA: false,
  39. Issuer: "the-issuer",
  40. InvertedGroups: map[string]struct{}{"default-group": {}},
  41. },
  42. Signature: []byte{1, 2, 1, 2, 1, 3},
  43. }
  44. remotes := NewRemoteList(nil)
  45. remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
  46. remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
  47. vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
  48. assert.True(t, ok)
  49. hm.unlockedAddHostInfo(&HostInfo{
  50. remote: remote1,
  51. remotes: remotes,
  52. ConnectionState: &ConnectionState{
  53. peerCert: crt,
  54. },
  55. remoteIndexId: 200,
  56. localIndexId: 201,
  57. vpnIp: vpnIp,
  58. relayState: RelayState{
  59. relays: map[netip.Addr]struct{}{},
  60. relayForByIp: map[netip.Addr]*Relay{},
  61. relayForByIdx: map[uint32]*Relay{},
  62. },
  63. }, &Interface{})
  64. vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP)
  65. assert.True(t, ok)
  66. hm.unlockedAddHostInfo(&HostInfo{
  67. remote: remote1,
  68. remotes: remotes,
  69. ConnectionState: &ConnectionState{
  70. peerCert: nil,
  71. },
  72. remoteIndexId: 200,
  73. localIndexId: 201,
  74. vpnIp: vpnIp2,
  75. relayState: RelayState{
  76. relays: map[netip.Addr]struct{}{},
  77. relayForByIp: map[netip.Addr]*Relay{},
  78. relayForByIdx: map[uint32]*Relay{},
  79. },
  80. }, &Interface{})
  81. c := Control{
  82. f: &Interface{
  83. hostMap: hm,
  84. },
  85. l: logrus.New(),
  86. }
  87. thi := c.GetHostInfoByVpnIp(vpnIp, false)
  88. expectedInfo := ControlHostInfo{
  89. VpnIp: vpnIp,
  90. LocalIndex: 201,
  91. RemoteIndex: 200,
  92. RemoteAddrs: []netip.AddrPort{remote2, remote1},
  93. Cert: crt.Copy(),
  94. MessageCounter: 0,
  95. CurrentRemote: remote1,
  96. CurrentRelaysToMe: []netip.Addr{},
  97. CurrentRelaysThroughMe: []netip.Addr{},
  98. }
  99. // Make sure we don't have any unexpected fields
  100. assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
  101. assert.EqualValues(t, &expectedInfo, thi)
  102. //TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here
  103. //test.AssertDeepCopyEqual(t, &expectedInfo, thi)
  104. // Make sure we don't panic if the host info doesn't have a cert yet
  105. assert.NotPanics(t, func() {
  106. thi = c.GetHostInfoByVpnIp(vpnIp2, false)
  107. })
  108. }
  109. func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
  110. val := reflect.ValueOf(actualStruct).Elem()
  111. fields := make([]string, val.NumField())
  112. for i := 0; i < val.NumField(); i++ {
  113. fields[i] = val.Type().Field(i).Name
  114. }
  115. assert.Equal(t, expected, fields)
  116. }