2
0

control_test.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. package nebula
  2. import (
  3. "net"
  4. "reflect"
  5. "testing"
  6. "time"
  7. "github.com/sirupsen/logrus"
  8. "github.com/slackhq/nebula/cert"
  9. "github.com/slackhq/nebula/iputil"
  10. "github.com/slackhq/nebula/test"
  11. "github.com/slackhq/nebula/udp"
  12. "github.com/stretchr/testify/assert"
  13. )
  14. func TestControl_GetHostInfoByVpnIp(t *testing.T) {
  15. l := test.NewLogger()
  16. // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
  17. // To properly ensure we are not exposing core memory to the caller
  18. hm := newHostMap(l, &net.IPNet{})
  19. hm.preferredRanges.Store(&[]*net.IPNet{})
  20. remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
  21. remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
  22. ipNet := net.IPNet{
  23. IP: net.IPv4(1, 2, 3, 4),
  24. Mask: net.IPMask{255, 255, 255, 0},
  25. }
  26. ipNet2 := net.IPNet{
  27. IP: net.ParseIP("1:2:3:4:5:6:7:8"),
  28. Mask: net.IPMask{255, 255, 255, 0},
  29. }
  30. crt := &cert.NebulaCertificate{
  31. Details: cert.NebulaCertificateDetails{
  32. Name: "test",
  33. Ips: []*net.IPNet{&ipNet},
  34. Subnets: []*net.IPNet{},
  35. Groups: []string{"default-group"},
  36. NotBefore: time.Unix(1, 0),
  37. NotAfter: time.Unix(2, 0),
  38. PublicKey: []byte{5, 6, 7, 8},
  39. IsCA: false,
  40. Issuer: "the-issuer",
  41. InvertedGroups: map[string]struct{}{"default-group": {}},
  42. },
  43. Signature: []byte{1, 2, 1, 2, 1, 3},
  44. }
  45. remotes := NewRemoteList(nil)
  46. remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
  47. remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
  48. hm.unlockedAddHostInfo(&HostInfo{
  49. remote: remote1,
  50. remotes: remotes,
  51. ConnectionState: &ConnectionState{
  52. peerCert: crt,
  53. },
  54. remoteIndexId: 200,
  55. localIndexId: 201,
  56. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  57. relayState: RelayState{
  58. relays: map[iputil.VpnIp]struct{}{},
  59. relayForByIp: map[iputil.VpnIp]*Relay{},
  60. relayForByIdx: map[uint32]*Relay{},
  61. },
  62. }, &Interface{})
  63. hm.unlockedAddHostInfo(&HostInfo{
  64. remote: remote1,
  65. remotes: remotes,
  66. ConnectionState: &ConnectionState{
  67. peerCert: nil,
  68. },
  69. remoteIndexId: 200,
  70. localIndexId: 201,
  71. vpnIp: iputil.Ip2VpnIp(ipNet2.IP),
  72. relayState: RelayState{
  73. relays: map[iputil.VpnIp]struct{}{},
  74. relayForByIp: map[iputil.VpnIp]*Relay{},
  75. relayForByIdx: map[uint32]*Relay{},
  76. },
  77. }, &Interface{})
  78. c := Control{
  79. f: &Interface{
  80. hostMap: hm,
  81. },
  82. l: logrus.New(),
  83. }
  84. thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false)
  85. expectedInfo := ControlHostInfo{
  86. VpnIp: net.IPv4(1, 2, 3, 4).To4(),
  87. LocalIndex: 201,
  88. RemoteIndex: 200,
  89. RemoteAddrs: []*udp.Addr{remote2, remote1},
  90. Cert: crt.Copy(),
  91. MessageCounter: 0,
  92. CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
  93. CurrentRelaysToMe: []iputil.VpnIp{},
  94. CurrentRelaysThroughMe: []iputil.VpnIp{},
  95. }
  96. // Make sure we don't have any unexpected fields
  97. assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
  98. test.AssertDeepCopyEqual(t, &expectedInfo, thi)
  99. // Make sure we don't panic if the host info doesn't have a cert yet
  100. assert.NotPanics(t, func() {
  101. thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false)
  102. })
  103. }
  104. func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
  105. val := reflect.ValueOf(actualStruct).Elem()
  106. fields := make([]string, val.NumField())
  107. for i := 0; i < val.NumField(); i++ {
  108. fields[i] = val.Type().Field(i).Name
  109. }
  110. assert.Equal(t, expected, fields)
  111. }