helpers_test.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. // +build e2e_testing
  2. package e2e
  3. import (
  4. "crypto/rand"
  5. "encoding/binary"
  6. "fmt"
  7. "io"
  8. "net"
  9. "testing"
  10. "time"
  11. "github.com/google/gopacket"
  12. "github.com/google/gopacket/layers"
  13. "github.com/sirupsen/logrus"
  14. "github.com/slackhq/nebula"
  15. "github.com/slackhq/nebula/cert"
  16. "github.com/stretchr/testify/assert"
  17. "golang.org/x/crypto/curve25519"
  18. "golang.org/x/crypto/ed25519"
  19. "gopkg.in/yaml.v2"
  20. )
  21. type m map[string]interface{}
  22. // newSimpleServer creates a nebula instance with many assumptions
  23. func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, listenAddr *net.UDPAddr, vpnIp *net.IPNet) *nebula.Control {
  24. l := logrus.New()
  25. _, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIp, nil, []string{})
  26. caB, err := caCrt.MarshalToPEM()
  27. if err != nil {
  28. panic(err)
  29. }
  30. mc := m{
  31. "pki": m{
  32. "ca": string(caB),
  33. "cert": string(myPEM),
  34. "key": string(myPrivKey),
  35. },
  36. //"tun": m{"disabled": true},
  37. "firewall": m{
  38. "outbound": []m{{
  39. "proto": "any",
  40. "port": "any",
  41. "host": "any",
  42. }},
  43. "inbound": []m{{
  44. "proto": "any",
  45. "port": "any",
  46. "host": "any",
  47. }},
  48. },
  49. "listen": m{
  50. "host": listenAddr.IP.String(),
  51. "port": listenAddr.Port,
  52. },
  53. "logging": m{
  54. "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
  55. "level": "info",
  56. },
  57. }
  58. cb, err := yaml.Marshal(mc)
  59. if err != nil {
  60. panic(err)
  61. }
  62. config := nebula.NewConfig(l)
  63. config.LoadString(string(cb))
  64. control, err := nebula.Main(config, false, "e2e-test", l, nil)
  65. if err != nil {
  66. panic(err)
  67. }
  68. return control
  69. }
  70. // newTestCaCert will generate a CA cert
  71. func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
  72. pub, priv, err := ed25519.GenerateKey(rand.Reader)
  73. if before.IsZero() {
  74. before = time.Now().Add(time.Second * -60).Round(time.Second)
  75. }
  76. if after.IsZero() {
  77. after = time.Now().Add(time.Second * 60).Round(time.Second)
  78. }
  79. nc := &cert.NebulaCertificate{
  80. Details: cert.NebulaCertificateDetails{
  81. Name: "test ca",
  82. NotBefore: time.Unix(before.Unix(), 0),
  83. NotAfter: time.Unix(after.Unix(), 0),
  84. PublicKey: pub,
  85. IsCA: true,
  86. InvertedGroups: make(map[string]struct{}),
  87. },
  88. }
  89. if len(ips) > 0 {
  90. nc.Details.Ips = ips
  91. }
  92. if len(subnets) > 0 {
  93. nc.Details.Subnets = subnets
  94. }
  95. if len(groups) > 0 {
  96. nc.Details.Groups = groups
  97. }
  98. err = nc.Sign(priv)
  99. if err != nil {
  100. panic(err)
  101. }
  102. pem, err := nc.MarshalToPEM()
  103. if err != nil {
  104. panic(err)
  105. }
  106. return nc, pub, priv, pem
  107. }
  108. // newTestCert will generate a signed certificate with the provided details.
  109. // Expiry times are defaulted if you do not pass them in
  110. func newTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
  111. issuer, err := ca.Sha256Sum()
  112. if err != nil {
  113. panic(err)
  114. }
  115. if before.IsZero() {
  116. before = time.Now().Add(time.Second * -60).Round(time.Second)
  117. }
  118. if after.IsZero() {
  119. after = time.Now().Add(time.Second * 60).Round(time.Second)
  120. }
  121. pub, rawPriv := x25519Keypair()
  122. nc := &cert.NebulaCertificate{
  123. Details: cert.NebulaCertificateDetails{
  124. Name: name,
  125. Ips: []*net.IPNet{ip},
  126. Subnets: subnets,
  127. Groups: groups,
  128. NotBefore: time.Unix(before.Unix(), 0),
  129. NotAfter: time.Unix(after.Unix(), 0),
  130. PublicKey: pub,
  131. IsCA: false,
  132. Issuer: issuer,
  133. InvertedGroups: make(map[string]struct{}),
  134. },
  135. }
  136. err = nc.Sign(key)
  137. if err != nil {
  138. panic(err)
  139. }
  140. pem, err := nc.MarshalToPEM()
  141. if err != nil {
  142. panic(err)
  143. }
  144. return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
  145. }
  146. func x25519Keypair() ([]byte, []byte) {
  147. var pubkey, privkey [32]byte
  148. if _, err := io.ReadFull(rand.Reader, privkey[:]); err != nil {
  149. panic(err)
  150. }
  151. curve25519.ScalarBaseMult(&pubkey, &privkey)
  152. return pubkey[:], privkey[:]
  153. }
  154. func ip2int(ip []byte) uint32 {
  155. if len(ip) == 16 {
  156. return binary.BigEndian.Uint32(ip[12:16])
  157. }
  158. return binary.BigEndian.Uint32(ip)
  159. }
  160. func int2ip(nn uint32) net.IP {
  161. ip := make(net.IP, 4)
  162. binary.BigEndian.PutUint32(ip, nn)
  163. return ip
  164. }
  165. func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
  166. // Get both host infos
  167. hBinA := controlA.GetHostInfoByVpnIP(ip2int(vpnIpB), false)
  168. assert.NotNil(t, hBinA, "Host B was not found by vpnIP in controlA")
  169. hAinB := controlB.GetHostInfoByVpnIP(ip2int(vpnIpA), false)
  170. assert.NotNil(t, hAinB, "Host A was not found by vpnIP in controlB")
  171. // Check that both vpn and real addr are correct
  172. assert.Equal(t, vpnIpB, hBinA.VpnIP, "HostA VpnIp is wrong in controlB")
  173. assert.Equal(t, vpnIpA, hAinB.VpnIP, "HostB VpnIp is wrong in controlA")
  174. assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "HostA remote ip is wrong in controlB")
  175. assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "HostB remote ip is wrong in controlA")
  176. assert.Equal(t, uint16(addrA.Port), hBinA.CurrentRemote.Port, "HostA remote ip is wrong in controlB")
  177. assert.Equal(t, uint16(addrB.Port), hAinB.CurrentRemote.Port, "HostB remote ip is wrong in controlA")
  178. // Check that our indexes match
  179. assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
  180. assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
  181. //TODO: Would be nice to assert this memory
  182. //checkIndexes := func(name string, hm *HostMap, hi *HostInfo) {
  183. // hBbyIndex := hmA.Indexes[hBinA.localIndexId]
  184. // assert.NotNil(t, hBbyIndex, "Could not host info by local index in %s", name)
  185. // assert.Equal(t, &hBbyIndex, &hBinA, "%s Indexes map did not point to the right host info", name)
  186. //
  187. // //TODO: remote indexes are susceptible to collision
  188. // hBbyRemoteIndex := hmA.RemoteIndexes[hBinA.remoteIndexId]
  189. // assert.NotNil(t, hBbyIndex, "Could not host info by remote index in %s", name)
  190. // assert.Equal(t, &hBbyRemoteIndex, &hBinA, "%s RemoteIndexes did not point to the right host info", name)
  191. //}
  192. //
  193. //// Check hostmap indexes too
  194. //checkIndexes("hmA", hmA, hBinA)
  195. //checkIndexes("hmB", hmB, hAinB)
  196. }
  197. func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) {
  198. packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
  199. v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
  200. assert.NotNil(t, v4, "No ipv4 data found")
  201. assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect")
  202. assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect")
  203. udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
  204. assert.NotNil(t, udp, "No udp data found")
  205. assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect")
  206. assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect")
  207. data := packet.ApplicationLayer()
  208. assert.NotNil(t, data)
  209. assert.Equal(t, expected, data.Payload(), "Data was incorrect")
  210. }