3
0

control_test.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. package nebula
  2. import (
  3. "net"
  4. "reflect"
  5. "testing"
  6. "time"
  7. "github.com/sirupsen/logrus"
  8. "github.com/slackhq/nebula/cert"
  9. "github.com/slackhq/nebula/util"
  10. "github.com/stretchr/testify/assert"
  11. )
  12. func TestControl_GetHostInfoByVpnIP(t *testing.T) {
  13. // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
  14. // To properly ensure we are not exposing core memory to the caller
  15. hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
  16. remote1 := NewUDPAddr(100, 4444)
  17. remote2 := NewUDPAddr(101, 4444)
  18. ipNet := net.IPNet{
  19. IP: net.IPv4(1, 2, 3, 4),
  20. Mask: net.IPMask{255, 255, 255, 0},
  21. }
  22. ipNet2 := net.IPNet{
  23. IP: net.IPv4(1, 2, 3, 5),
  24. Mask: net.IPMask{255, 255, 255, 0},
  25. }
  26. crt := &cert.NebulaCertificate{
  27. Details: cert.NebulaCertificateDetails{
  28. Name: "test",
  29. Ips: []*net.IPNet{&ipNet},
  30. Subnets: []*net.IPNet{},
  31. Groups: []string{"default-group"},
  32. NotBefore: time.Unix(1, 0),
  33. NotAfter: time.Unix(2, 0),
  34. PublicKey: []byte{5, 6, 7, 8},
  35. IsCA: false,
  36. Issuer: "the-issuer",
  37. InvertedGroups: map[string]struct{}{"default-group": {}},
  38. },
  39. Signature: []byte{1, 2, 1, 2, 1, 3},
  40. }
  41. remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)}
  42. hm.Add(ip2int(ipNet.IP), &HostInfo{
  43. remote: remote1,
  44. Remotes: remotes,
  45. ConnectionState: &ConnectionState{
  46. peerCert: crt,
  47. },
  48. remoteIndexId: 200,
  49. localIndexId: 201,
  50. hostId: ip2int(ipNet.IP),
  51. })
  52. hm.Add(ip2int(ipNet2.IP), &HostInfo{
  53. remote: remote1,
  54. Remotes: remotes,
  55. ConnectionState: &ConnectionState{
  56. peerCert: nil,
  57. },
  58. remoteIndexId: 200,
  59. localIndexId: 201,
  60. hostId: ip2int(ipNet2.IP),
  61. })
  62. c := Control{
  63. f: &Interface{
  64. hostMap: hm,
  65. },
  66. l: logrus.New(),
  67. }
  68. thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
  69. expectedInfo := ControlHostInfo{
  70. VpnIP: net.IPv4(1, 2, 3, 4).To4(),
  71. LocalIndex: 201,
  72. RemoteIndex: 200,
  73. RemoteAddrs: []udpAddr{*remote1, *remote2},
  74. CachedPackets: 0,
  75. Cert: crt.Copy(),
  76. MessageCounter: 0,
  77. CurrentRemote: *NewUDPAddr(100, 4444),
  78. }
  79. // Make sure we don't have any unexpected fields
  80. assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
  81. util.AssertDeepCopyEqual(t, &expectedInfo, thi)
  82. // Make sure we don't panic if the host info doesn't have a cert yet
  83. assert.NotPanics(t, func() {
  84. thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
  85. })
  86. }
  87. func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
  88. val := reflect.ValueOf(actualStruct).Elem()
  89. fields := make([]string, val.NumField())
  90. for i := 0; i < val.NumField(); i++ {
  91. fields[i] = val.Type().Field(i).Name
  92. }
  93. assert.Equal(t, expected, fields)
  94. }