connection_manager_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. package nebula
  2. import (
  3. "context"
  4. "crypto/ed25519"
  5. "crypto/rand"
  6. "net/netip"
  7. "testing"
  8. "time"
  9. "github.com/flynn/noise"
  10. "github.com/slackhq/nebula/cert"
  11. "github.com/slackhq/nebula/config"
  12. "github.com/slackhq/nebula/test"
  13. "github.com/slackhq/nebula/udp"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/stretchr/testify/require"
  16. )
  17. func newTestLighthouse() *LightHouse {
  18. lh := &LightHouse{
  19. l: test.NewLogger(),
  20. addrMap: map[netip.Addr]*RemoteList{},
  21. queryChan: make(chan netip.Addr, 10),
  22. }
  23. lighthouses := map[netip.Addr]struct{}{}
  24. staticList := map[netip.Addr]struct{}{}
  25. lh.lighthouses.Store(&lighthouses)
  26. lh.staticList.Store(&staticList)
  27. return lh
  28. }
  29. func Test_NewConnectionManagerTest(t *testing.T) {
  30. l := test.NewLogger()
  31. //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
  32. localrange := netip.MustParsePrefix("10.1.1.1/24")
  33. vpnIp := netip.MustParseAddr("172.1.1.2")
  34. preferredRanges := []netip.Prefix{localrange}
  35. // Very incomplete mock objects
  36. hostMap := newHostMap(l)
  37. hostMap.preferredRanges.Store(&preferredRanges)
  38. cs := &CertState{
  39. initiatingVersion: cert.Version1,
  40. privateKey: []byte{},
  41. v1Cert: &dummyCert{version: cert.Version1},
  42. v1HandshakeBytes: []byte{},
  43. }
  44. lh := newTestLighthouse()
  45. ifce := &Interface{
  46. hostMap: hostMap,
  47. inside: &test.NoopTun{},
  48. outside: &udp.NoopConn{},
  49. firewall: &Firewall{},
  50. lightHouse: lh,
  51. pki: &PKI{},
  52. handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
  53. l: l,
  54. }
  55. ifce.pki.cs.Store(cs)
  56. // Create manager
  57. ctx, cancel := context.WithCancel(context.Background())
  58. defer cancel()
  59. punchy := NewPunchyFromConfig(l, config.NewC(l))
  60. nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
  61. p := []byte("")
  62. nb := make([]byte, 12)
  63. out := make([]byte, mtu)
  64. // Add an ip we have established a connection w/ to hostmap
  65. hostinfo := &HostInfo{
  66. vpnAddrs: []netip.Addr{vpnIp},
  67. localIndexId: 1099,
  68. remoteIndexId: 9901,
  69. }
  70. hostinfo.ConnectionState = &ConnectionState{
  71. myCert: &dummyCert{version: cert.Version1},
  72. H: &noise.HandshakeState{},
  73. }
  74. nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
  75. // We saw traffic out to vpnIp
  76. nc.Out(hostinfo.localIndexId)
  77. nc.In(hostinfo.localIndexId)
  78. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  79. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
  80. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  81. assert.Contains(t, nc.out, hostinfo.localIndexId)
  82. // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
  83. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  84. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  85. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  86. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  87. // Do another traffic check tick, this host should be pending deletion now
  88. nc.Out(hostinfo.localIndexId)
  89. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  90. assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
  91. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  92. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  93. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  94. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
  95. // Do a final traffic check tick, the host should now be removed
  96. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  97. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  98. assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
  99. assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  100. }
  101. func Test_NewConnectionManagerTest2(t *testing.T) {
  102. l := test.NewLogger()
  103. //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
  104. localrange := netip.MustParsePrefix("10.1.1.1/24")
  105. vpnIp := netip.MustParseAddr("172.1.1.2")
  106. preferredRanges := []netip.Prefix{localrange}
  107. // Very incomplete mock objects
  108. hostMap := newHostMap(l)
  109. hostMap.preferredRanges.Store(&preferredRanges)
  110. cs := &CertState{
  111. initiatingVersion: cert.Version1,
  112. privateKey: []byte{},
  113. v1Cert: &dummyCert{version: cert.Version1},
  114. v1HandshakeBytes: []byte{},
  115. }
  116. lh := newTestLighthouse()
  117. ifce := &Interface{
  118. hostMap: hostMap,
  119. inside: &test.NoopTun{},
  120. outside: &udp.NoopConn{},
  121. firewall: &Firewall{},
  122. lightHouse: lh,
  123. pki: &PKI{},
  124. handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
  125. l: l,
  126. }
  127. ifce.pki.cs.Store(cs)
  128. // Create manager
  129. ctx, cancel := context.WithCancel(context.Background())
  130. defer cancel()
  131. punchy := NewPunchyFromConfig(l, config.NewC(l))
  132. nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
  133. p := []byte("")
  134. nb := make([]byte, 12)
  135. out := make([]byte, mtu)
  136. // Add an ip we have established a connection w/ to hostmap
  137. hostinfo := &HostInfo{
  138. vpnAddrs: []netip.Addr{vpnIp},
  139. localIndexId: 1099,
  140. remoteIndexId: 9901,
  141. }
  142. hostinfo.ConnectionState = &ConnectionState{
  143. myCert: &dummyCert{version: cert.Version1},
  144. H: &noise.HandshakeState{},
  145. }
  146. nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
  147. // We saw traffic out to vpnIp
  148. nc.Out(hostinfo.localIndexId)
  149. nc.In(hostinfo.localIndexId)
  150. assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
  151. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
  152. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  153. // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
  154. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  155. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  156. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  157. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  158. // Do another traffic check tick, this host should be pending deletion now
  159. nc.Out(hostinfo.localIndexId)
  160. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  161. assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
  162. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  163. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  164. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  165. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
  166. // We saw traffic, should no longer be pending deletion
  167. nc.In(hostinfo.localIndexId)
  168. nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
  169. assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
  170. assert.NotContains(t, nc.out, hostinfo.localIndexId)
  171. assert.NotContains(t, nc.in, hostinfo.localIndexId)
  172. assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
  173. assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
  174. }
  175. // Check if we can disconnect the peer.
  176. // Validate if the peer's certificate is invalid (expired, etc.)
  177. // Disconnect only if disconnectInvalid: true is set.
  178. func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
  179. now := time.Now()
  180. l := test.NewLogger()
  181. vpncidr := netip.MustParsePrefix("172.1.1.1/24")
  182. localrange := netip.MustParsePrefix("10.1.1.1/24")
  183. vpnIp := netip.MustParseAddr("172.1.1.2")
  184. preferredRanges := []netip.Prefix{localrange}
  185. hostMap := newHostMap(l)
  186. hostMap.preferredRanges.Store(&preferredRanges)
  187. // Generate keys for CA and peer's cert.
  188. pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
  189. tbs := &cert.TBSCertificate{
  190. Version: 1,
  191. Name: "ca",
  192. IsCA: true,
  193. NotBefore: now,
  194. NotAfter: now.Add(1 * time.Hour),
  195. PublicKey: pubCA,
  196. }
  197. caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
  198. require.NoError(t, err)
  199. ncp := cert.NewCAPool()
  200. require.NoError(t, ncp.AddCA(caCert))
  201. pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
  202. tbs = &cert.TBSCertificate{
  203. Version: 1,
  204. Name: "host",
  205. Networks: []netip.Prefix{vpncidr},
  206. NotBefore: now,
  207. NotAfter: now.Add(60 * time.Second),
  208. PublicKey: pubCrt,
  209. }
  210. peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
  211. require.NoError(t, err)
  212. cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
  213. require.NoError(t, err)
  214. cs := &CertState{
  215. privateKey: []byte{},
  216. v1Cert: &dummyCert{},
  217. v1HandshakeBytes: []byte{},
  218. }
  219. lh := newTestLighthouse()
  220. ifce := &Interface{
  221. hostMap: hostMap,
  222. inside: &test.NoopTun{},
  223. outside: &udp.NoopConn{},
  224. firewall: &Firewall{},
  225. lightHouse: lh,
  226. handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
  227. l: l,
  228. pki: &PKI{},
  229. }
  230. ifce.pki.cs.Store(cs)
  231. ifce.pki.caPool.Store(ncp)
  232. ifce.disconnectInvalid.Store(true)
  233. // Create manager
  234. ctx, cancel := context.WithCancel(context.Background())
  235. defer cancel()
  236. punchy := NewPunchyFromConfig(l, config.NewC(l))
  237. nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
  238. ifce.connectionManager = nc
  239. hostinfo := &HostInfo{
  240. vpnAddrs: []netip.Addr{vpnIp},
  241. ConnectionState: &ConnectionState{
  242. myCert: &dummyCert{},
  243. peerCert: cachedPeerCert,
  244. H: &noise.HandshakeState{},
  245. },
  246. }
  247. nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
  248. // Move ahead 45s.
  249. // Check if to disconnect with invalid certificate.
  250. // Should be alive.
  251. nextTick := now.Add(45 * time.Second)
  252. invalid := nc.isInvalidCertificate(nextTick, hostinfo)
  253. assert.False(t, invalid)
  254. // Move ahead 61s.
  255. // Check if to disconnect with invalid certificate.
  256. // Should be disconnected.
  257. nextTick = now.Add(61 * time.Second)
  258. invalid = nc.isInvalidCertificate(nextTick, hostinfo)
  259. assert.True(t, invalid)
  260. }
  261. type dummyCert struct {
  262. version cert.Version
  263. curve cert.Curve
  264. groups []string
  265. isCa bool
  266. issuer string
  267. name string
  268. networks []netip.Prefix
  269. notAfter time.Time
  270. notBefore time.Time
  271. publicKey []byte
  272. signature []byte
  273. unsafeNetworks []netip.Prefix
  274. }
  275. func (d *dummyCert) Version() cert.Version {
  276. return d.version
  277. }
  278. func (d *dummyCert) Curve() cert.Curve {
  279. return d.curve
  280. }
  281. func (d *dummyCert) Groups() []string {
  282. return d.groups
  283. }
  284. func (d *dummyCert) IsCA() bool {
  285. return d.isCa
  286. }
  287. func (d *dummyCert) Issuer() string {
  288. return d.issuer
  289. }
  290. func (d *dummyCert) Name() string {
  291. return d.name
  292. }
  293. func (d *dummyCert) Networks() []netip.Prefix {
  294. return d.networks
  295. }
  296. func (d *dummyCert) NotAfter() time.Time {
  297. return d.notAfter
  298. }
  299. func (d *dummyCert) NotBefore() time.Time {
  300. return d.notBefore
  301. }
  302. func (d *dummyCert) PublicKey() []byte {
  303. return d.publicKey
  304. }
  305. func (d *dummyCert) Signature() []byte {
  306. return d.signature
  307. }
  308. func (d *dummyCert) UnsafeNetworks() []netip.Prefix {
  309. return d.unsafeNetworks
  310. }
  311. func (d *dummyCert) MarshalForHandshakes() ([]byte, error) {
  312. return nil, nil
  313. }
  314. func (d *dummyCert) Sign(curve cert.Curve, key []byte) error {
  315. return nil
  316. }
  317. func (d *dummyCert) CheckSignature(key []byte) bool {
  318. return true
  319. }
  320. func (d *dummyCert) Expired(t time.Time) bool {
  321. return false
  322. }
  323. func (d *dummyCert) CheckRootConstraints(signer cert.Certificate) error {
  324. return nil
  325. }
  326. func (d *dummyCert) VerifyPrivateKey(curve cert.Curve, key []byte) error {
  327. return nil
  328. }
  329. func (d *dummyCert) String() string {
  330. return ""
  331. }
  332. func (d *dummyCert) Marshal() ([]byte, error) {
  333. return nil, nil
  334. }
  335. func (d *dummyCert) MarshalPEM() ([]byte, error) {
  336. return nil, nil
  337. }
  338. func (d *dummyCert) Fingerprint() (string, error) {
  339. return "", nil
  340. }
  341. func (d *dummyCert) MarshalJSON() ([]byte, error) {
  342. return nil, nil
  343. }
  344. func (d *dummyCert) Copy() cert.Certificate {
  345. return d
  346. }