3
0

connection_manager_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. package nebula
  2. import (
  3. "context"
  4. "crypto/ed25519"
  5. "crypto/rand"
  6. "net/netip"
  7. "testing"
  8. "time"
  9. "github.com/flynn/noise"
  10. "github.com/slackhq/nebula/cert"
  11. "github.com/slackhq/nebula/config"
  12. "github.com/slackhq/nebula/test"
  13. "github.com/slackhq/nebula/udp"
  14. "github.com/stretchr/testify/assert"
  15. )
  16. func newTestLighthouse() *LightHouse {
  17. lh := &LightHouse{
  18. l: test.NewLogger(),
  19. addrMap: map[netip.Addr]*RemoteList{},
  20. queryChan: make(chan netip.Addr, 10),
  21. }
  22. lighthouses := map[netip.Addr]struct{}{}
  23. staticList := map[netip.Addr]struct{}{}
  24. lh.lighthouses.Store(&lighthouses)
  25. lh.staticList.Store(&staticList)
  26. return lh
  27. }
  28. func Test_NewConnectionManagerTest(t *testing.T) {
  29. l := test.NewLogger()
  30. //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
  31. localrange := netip.MustParsePrefix("10.1.1.1/24")
  32. vpnIp := netip.MustParseAddr("172.1.1.2")
  33. preferredRanges := []netip.Prefix{localrange}
  34. // Very incomplete mock objects
  35. hostMap := newHostMap(l)
  36. hostMap.preferredRanges.Store(&preferredRanges)
  37. cs := &CertState{
  38. defaultVersion: cert.Version1,
  39. privateKey: []byte{},
  40. v1Cert: &dummyCert{version: cert.Version1},
  41. v1HandshakeBytes: []byte{},
  42. }
  43. lh := newTestLighthouse()
  44. ifce := &Interface{
  45. hostMap: hostMap,
  46. inside: &test.NoopTun{},
  47. outside: &udp.NoopConn{},
  48. firewall: &Firewall{},
  49. lightHouse: lh,
  50. pki: &PKI{},
  51. handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
  52. l: l,
  53. }
  54. ifce.pki.cs.Store(cs)
  55. // Create manager
  56. ctx, cancel := context.WithCancel(context.Background())
  57. defer cancel()
  58. punchy := NewPunchyFromConfig(l, config.NewC(l))
  59. nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
  60. p := []byte("")
  61. nb := make([]byte, 12, 12)
  62. out := make([]byte, mtu)
  63. // Add an ip we have established a connection w/ to hostmap
  64. hostinfo := &HostInfo{
  65. vpnAddrs: []netip.Addr{vpnIp},
  66. localIndexId: 1099,
  67. remoteIndexId: 9901,
  68. }
  69. hostinfo.ConnectionState = &ConnectionState{
  70. myCert: &dummyCert{version: cert.Version1},
  71. H: &noise.HandshakeState{},
  72. }
  73. nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
  74. // We saw traffic out to vpnIp
  75. nc.Out(hostinfo.localIndexId)
  76. nc.In(hostinfo.localIndexId)
  77. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  78. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
  79. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  80. assert.Contains(t, nc.out, hostinfo.localIndexId)
  81. // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
  82. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  83. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  84. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  85. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  86. // Do another traffic check tick, this host should be pending deletion now
  87. nc.Out(hostinfo.localIndexId)
  88. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  89. assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
  90. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  91. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  92. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  93. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
  94. // Do a final traffic check tick, the host should now be removed
  95. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  96. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  97. assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
  98. assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  99. }
  100. func Test_NewConnectionManagerTest2(t *testing.T) {
  101. l := test.NewLogger()
  102. //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
  103. localrange := netip.MustParsePrefix("10.1.1.1/24")
  104. vpnIp := netip.MustParseAddr("172.1.1.2")
  105. preferredRanges := []netip.Prefix{localrange}
  106. // Very incomplete mock objects
  107. hostMap := newHostMap(l)
  108. hostMap.preferredRanges.Store(&preferredRanges)
  109. cs := &CertState{
  110. defaultVersion: cert.Version1,
  111. privateKey: []byte{},
  112. v1Cert: &dummyCert{version: cert.Version1},
  113. v1HandshakeBytes: []byte{},
  114. }
  115. lh := newTestLighthouse()
  116. ifce := &Interface{
  117. hostMap: hostMap,
  118. inside: &test.NoopTun{},
  119. outside: &udp.NoopConn{},
  120. firewall: &Firewall{},
  121. lightHouse: lh,
  122. pki: &PKI{},
  123. handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
  124. l: l,
  125. }
  126. ifce.pki.cs.Store(cs)
  127. // Create manager
  128. ctx, cancel := context.WithCancel(context.Background())
  129. defer cancel()
  130. punchy := NewPunchyFromConfig(l, config.NewC(l))
  131. nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
  132. p := []byte("")
  133. nb := make([]byte, 12, 12)
  134. out := make([]byte, mtu)
  135. // Add an ip we have established a connection w/ to hostmap
  136. hostinfo := &HostInfo{
  137. vpnAddrs: []netip.Addr{vpnIp},
  138. localIndexId: 1099,
  139. remoteIndexId: 9901,
  140. }
  141. hostinfo.ConnectionState = &ConnectionState{
  142. myCert: &dummyCert{version: cert.Version1},
  143. H: &noise.HandshakeState{},
  144. }
  145. nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
  146. // We saw traffic out to vpnIp
  147. nc.Out(hostinfo.localIndexId)
  148. nc.In(hostinfo.localIndexId)
  149. assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
  150. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
  151. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  152. // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
  153. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  154. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  155. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  156. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  157. // Do another traffic check tick, this host should be pending deletion now
  158. nc.Out(hostinfo.localIndexId)
  159. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  160. assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
  161. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  162. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  163. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  164. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
  165. // We saw traffic, should no longer be pending deletion
  166. nc.In(hostinfo.localIndexId)
  167. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  168. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  169. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  170. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  171. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  172. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
  173. }
  174. // Check if we can disconnect the peer.
  175. // Validate if the peer's certificate is invalid (expired, etc.)
  176. // Disconnect only if disconnectInvalid: true is set.
  177. func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
  178. now := time.Now()
  179. l := test.NewLogger()
  180. vpncidr := netip.MustParsePrefix("172.1.1.1/24")
  181. localrange := netip.MustParsePrefix("10.1.1.1/24")
  182. vpnIp := netip.MustParseAddr("172.1.1.2")
  183. preferredRanges := []netip.Prefix{localrange}
  184. hostMap := newHostMap(l)
  185. hostMap.preferredRanges.Store(&preferredRanges)
  186. // Generate keys for CA and peer's cert.
  187. pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
  188. tbs := &cert.TBSCertificate{
  189. Version: 1,
  190. Name: "ca",
  191. IsCA: true,
  192. NotBefore: now,
  193. NotAfter: now.Add(1 * time.Hour),
  194. PublicKey: pubCA,
  195. }
  196. caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
  197. assert.NoError(t, err)
  198. ncp := cert.NewCAPool()
  199. assert.NoError(t, ncp.AddCA(caCert))
  200. pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
  201. tbs = &cert.TBSCertificate{
  202. Version: 1,
  203. Name: "host",
  204. Networks: []netip.Prefix{vpncidr},
  205. NotBefore: now,
  206. NotAfter: now.Add(60 * time.Second),
  207. PublicKey: pubCrt,
  208. }
  209. peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
  210. assert.NoError(t, err)
  211. cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
  212. cs := &CertState{
  213. privateKey: []byte{},
  214. v1Cert: &dummyCert{},
  215. v1HandshakeBytes: []byte{},
  216. }
  217. lh := newTestLighthouse()
  218. ifce := &Interface{
  219. hostMap: hostMap,
  220. inside: &test.NoopTun{},
  221. outside: &udp.NoopConn{},
  222. firewall: &Firewall{},
  223. lightHouse: lh,
  224. handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
  225. l: l,
  226. pki: &PKI{},
  227. }
  228. ifce.pki.cs.Store(cs)
  229. ifce.pki.caPool.Store(ncp)
  230. ifce.disconnectInvalid.Store(true)
  231. // Create manager
  232. ctx, cancel := context.WithCancel(context.Background())
  233. defer cancel()
  234. punchy := NewPunchyFromConfig(l, config.NewC(l))
  235. nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
  236. ifce.connectionManager = nc
  237. hostinfo := &HostInfo{
  238. vpnAddrs: []netip.Addr{vpnIp},
  239. ConnectionState: &ConnectionState{
  240. myCert: &dummyCert{},
  241. peerCert: cachedPeerCert,
  242. H: &noise.HandshakeState{},
  243. },
  244. }
  245. nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
  246. // Move ahead 45s.
  247. // Check if to disconnect with invalid certificate.
  248. // Should be alive.
  249. nextTick := now.Add(45 * time.Second)
  250. invalid := nc.isInvalidCertificate(nextTick, hostinfo)
  251. assert.False(t, invalid)
  252. // Move ahead 61s.
  253. // Check if to disconnect with invalid certificate.
  254. // Should be disconnected.
  255. nextTick = now.Add(61 * time.Second)
  256. invalid = nc.isInvalidCertificate(nextTick, hostinfo)
  257. assert.True(t, invalid)
  258. }
  259. type dummyCert struct {
  260. version cert.Version
  261. curve cert.Curve
  262. groups []string
  263. isCa bool
  264. issuer string
  265. name string
  266. networks []netip.Prefix
  267. notAfter time.Time
  268. notBefore time.Time
  269. publicKey []byte
  270. signature []byte
  271. unsafeNetworks []netip.Prefix
  272. }
  273. func (d *dummyCert) Version() cert.Version {
  274. return d.version
  275. }
  276. func (d *dummyCert) Curve() cert.Curve {
  277. return d.curve
  278. }
  279. func (d *dummyCert) Groups() []string {
  280. return d.groups
  281. }
  282. func (d *dummyCert) IsCA() bool {
  283. return d.isCa
  284. }
  285. func (d *dummyCert) Issuer() string {
  286. return d.issuer
  287. }
  288. func (d *dummyCert) Name() string {
  289. return d.name
  290. }
  291. func (d *dummyCert) Networks() []netip.Prefix {
  292. return d.networks
  293. }
  294. func (d *dummyCert) NotAfter() time.Time {
  295. return d.notAfter
  296. }
  297. func (d *dummyCert) NotBefore() time.Time {
  298. return d.notBefore
  299. }
  300. func (d *dummyCert) PublicKey() []byte {
  301. return d.publicKey
  302. }
  303. func (d *dummyCert) Signature() []byte {
  304. return d.signature
  305. }
  306. func (d *dummyCert) UnsafeNetworks() []netip.Prefix {
  307. return d.unsafeNetworks
  308. }
  309. func (d *dummyCert) MarshalForHandshakes() ([]byte, error) {
  310. return nil, nil
  311. }
  312. func (d *dummyCert) Sign(curve cert.Curve, key []byte) error {
  313. return nil
  314. }
  315. func (d *dummyCert) CheckSignature(key []byte) bool {
  316. return true
  317. }
  318. func (d *dummyCert) Expired(t time.Time) bool {
  319. return false
  320. }
  321. func (d *dummyCert) CheckRootConstraints(signer cert.Certificate) error {
  322. return nil
  323. }
  324. func (d *dummyCert) VerifyPrivateKey(curve cert.Curve, key []byte) error {
  325. return nil
  326. }
  327. func (d *dummyCert) String() string {
  328. return ""
  329. }
  330. func (d *dummyCert) Marshal() ([]byte, error) {
  331. return nil, nil
  332. }
  333. func (d *dummyCert) MarshalPEM() ([]byte, error) {
  334. return nil, nil
  335. }
  336. func (d *dummyCert) Fingerprint() (string, error) {
  337. return "", nil
  338. }
  339. func (d *dummyCert) MarshalJSON() ([]byte, error) {
  340. return nil, nil
  341. }
  342. func (d *dummyCert) Copy() cert.Certificate {
  343. return d
  344. }