control_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. package nebula
  2. import (
  3. "net"
  4. "reflect"
  5. "testing"
  6. "time"
  7. "github.com/sirupsen/logrus"
  8. "github.com/slackhq/nebula/cert"
  9. "github.com/slackhq/nebula/util"
  10. "github.com/stretchr/testify/assert"
  11. )
  12. func TestControl_GetHostInfoByVpnIP(t *testing.T) {
  13. l := NewTestLogger()
  14. // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
  15. // To properly ensure we are not exposing core memory to the caller
  16. hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
  17. remote1 := NewUDPAddr(int2ip(100), 4444)
  18. remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
  19. ipNet := net.IPNet{
  20. IP: net.IPv4(1, 2, 3, 4),
  21. Mask: net.IPMask{255, 255, 255, 0},
  22. }
  23. ipNet2 := net.IPNet{
  24. IP: net.ParseIP("1:2:3:4:5:6:7:8"),
  25. Mask: net.IPMask{255, 255, 255, 0},
  26. }
  27. crt := &cert.NebulaCertificate{
  28. Details: cert.NebulaCertificateDetails{
  29. Name: "test",
  30. Ips: []*net.IPNet{&ipNet},
  31. Subnets: []*net.IPNet{},
  32. Groups: []string{"default-group"},
  33. NotBefore: time.Unix(1, 0),
  34. NotAfter: time.Unix(2, 0),
  35. PublicKey: []byte{5, 6, 7, 8},
  36. IsCA: false,
  37. Issuer: "the-issuer",
  38. InvertedGroups: map[string]struct{}{"default-group": {}},
  39. },
  40. Signature: []byte{1, 2, 1, 2, 1, 3},
  41. }
  42. remotes := NewRemoteList()
  43. remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
  44. remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
  45. hm.Add(ip2int(ipNet.IP), &HostInfo{
  46. remote: remote1,
  47. remotes: remotes,
  48. ConnectionState: &ConnectionState{
  49. peerCert: crt,
  50. },
  51. remoteIndexId: 200,
  52. localIndexId: 201,
  53. hostId: ip2int(ipNet.IP),
  54. })
  55. hm.Add(ip2int(ipNet2.IP), &HostInfo{
  56. remote: remote1,
  57. remotes: remotes,
  58. ConnectionState: &ConnectionState{
  59. peerCert: nil,
  60. },
  61. remoteIndexId: 200,
  62. localIndexId: 201,
  63. hostId: ip2int(ipNet2.IP),
  64. })
  65. c := Control{
  66. f: &Interface{
  67. hostMap: hm,
  68. },
  69. l: logrus.New(),
  70. }
  71. thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
  72. expectedInfo := ControlHostInfo{
  73. VpnIP: net.IPv4(1, 2, 3, 4).To4(),
  74. LocalIndex: 201,
  75. RemoteIndex: 200,
  76. RemoteAddrs: []*udpAddr{remote2, remote1},
  77. CachedPackets: 0,
  78. Cert: crt.Copy(),
  79. MessageCounter: 0,
  80. CurrentRemote: NewUDPAddr(int2ip(100), 4444),
  81. }
  82. // Make sure we don't have any unexpected fields
  83. assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
  84. util.AssertDeepCopyEqual(t, &expectedInfo, thi)
  85. // Make sure we don't panic if the host info doesn't have a cert yet
  86. assert.NotPanics(t, func() {
  87. thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
  88. })
  89. }
  90. func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
  91. val := reflect.ValueOf(actualStruct).Elem()
  92. fields := make([]string, val.NumField())
  93. for i := 0; i < val.NumField(); i++ {
  94. fields[i] = val.Type().Field(i).Name
  95. }
  96. assert.Equal(t, expected, fields)
  97. }