connection_manager.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. package nebula
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "net/netip"
  7. "sync"
  8. "time"
  9. "github.com/rcrowley/go-metrics"
  10. "github.com/sirupsen/logrus"
  11. "github.com/slackhq/nebula/cert"
  12. "github.com/slackhq/nebula/header"
  13. )
  14. type trafficDecision int
  15. const (
  16. doNothing trafficDecision = 0
  17. deleteTunnel trafficDecision = 1 // delete the hostinfo on our side, do not notify the remote
  18. closeTunnel trafficDecision = 2 // delete the hostinfo and notify the remote
  19. swapPrimary trafficDecision = 3
  20. migrateRelays trafficDecision = 4
  21. tryRehandshake trafficDecision = 5
  22. sendTestPacket trafficDecision = 6
  23. )
  24. type connectionManager struct {
  25. in map[uint32]struct{}
  26. inLock *sync.RWMutex
  27. out map[uint32]struct{}
  28. outLock *sync.RWMutex
  29. // relayUsed holds which relay localIndexs are in use
  30. relayUsed map[uint32]struct{}
  31. relayUsedLock *sync.RWMutex
  32. hostMap *HostMap
  33. trafficTimer *LockingTimerWheel[uint32]
  34. intf *Interface
  35. pendingDeletion map[uint32]struct{}
  36. punchy *Punchy
  37. checkInterval time.Duration
  38. pendingDeletionInterval time.Duration
  39. metricsTxPunchy metrics.Counter
  40. l *logrus.Logger
  41. }
  42. func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
  43. var max time.Duration
  44. if checkInterval < pendingDeletionInterval {
  45. max = pendingDeletionInterval
  46. } else {
  47. max = checkInterval
  48. }
  49. nc := &connectionManager{
  50. hostMap: intf.hostMap,
  51. in: make(map[uint32]struct{}),
  52. inLock: &sync.RWMutex{},
  53. out: make(map[uint32]struct{}),
  54. outLock: &sync.RWMutex{},
  55. relayUsed: make(map[uint32]struct{}),
  56. relayUsedLock: &sync.RWMutex{},
  57. trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
  58. intf: intf,
  59. pendingDeletion: make(map[uint32]struct{}),
  60. checkInterval: checkInterval,
  61. pendingDeletionInterval: pendingDeletionInterval,
  62. punchy: punchy,
  63. metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
  64. l: l,
  65. }
  66. nc.Start(ctx)
  67. return nc
  68. }
  69. func (n *connectionManager) In(localIndex uint32) {
  70. n.inLock.RLock()
  71. // If this already exists, return
  72. if _, ok := n.in[localIndex]; ok {
  73. n.inLock.RUnlock()
  74. return
  75. }
  76. n.inLock.RUnlock()
  77. n.inLock.Lock()
  78. n.in[localIndex] = struct{}{}
  79. n.inLock.Unlock()
  80. }
  81. func (n *connectionManager) Out(localIndex uint32) {
  82. n.outLock.RLock()
  83. // If this already exists, return
  84. if _, ok := n.out[localIndex]; ok {
  85. n.outLock.RUnlock()
  86. return
  87. }
  88. n.outLock.RUnlock()
  89. n.outLock.Lock()
  90. n.out[localIndex] = struct{}{}
  91. n.outLock.Unlock()
  92. }
  93. func (n *connectionManager) RelayUsed(localIndex uint32) {
  94. n.relayUsedLock.RLock()
  95. // If this already exists, return
  96. if _, ok := n.relayUsed[localIndex]; ok {
  97. n.relayUsedLock.RUnlock()
  98. return
  99. }
  100. n.relayUsedLock.RUnlock()
  101. n.relayUsedLock.Lock()
  102. n.relayUsed[localIndex] = struct{}{}
  103. n.relayUsedLock.Unlock()
  104. }
  105. // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
  106. // resets the state for this local index
  107. func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
  108. n.inLock.Lock()
  109. n.outLock.Lock()
  110. _, in := n.in[localIndex]
  111. _, out := n.out[localIndex]
  112. delete(n.in, localIndex)
  113. delete(n.out, localIndex)
  114. n.inLock.Unlock()
  115. n.outLock.Unlock()
  116. return in, out
  117. }
  118. func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
  119. // Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
  120. n.outLock.Lock()
  121. if _, ok := n.out[localIndex]; ok {
  122. n.outLock.Unlock()
  123. return
  124. }
  125. n.out[localIndex] = struct{}{}
  126. n.trafficTimer.Add(localIndex, n.checkInterval)
  127. n.outLock.Unlock()
  128. }
  129. func (n *connectionManager) Start(ctx context.Context) {
  130. go n.Run(ctx)
  131. }
  132. func (n *connectionManager) Run(ctx context.Context) {
  133. //TODO: this tick should be based on the min wheel tick? Check firewall
  134. clockSource := time.NewTicker(500 * time.Millisecond)
  135. defer clockSource.Stop()
  136. p := []byte("")
  137. nb := make([]byte, 12, 12)
  138. out := make([]byte, mtu)
  139. for {
  140. select {
  141. case <-ctx.Done():
  142. return
  143. case now := <-clockSource.C:
  144. n.trafficTimer.Advance(now)
  145. for {
  146. localIndex, has := n.trafficTimer.Purge()
  147. if !has {
  148. break
  149. }
  150. n.doTrafficCheck(localIndex, p, nb, out, now)
  151. }
  152. }
  153. }
  154. }
  155. func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
  156. decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
  157. switch decision {
  158. case deleteTunnel:
  159. if n.hostMap.DeleteHostInfo(hostinfo) {
  160. // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
  161. n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
  162. }
  163. case closeTunnel:
  164. n.intf.sendCloseTunnel(hostinfo)
  165. n.intf.closeTunnel(hostinfo)
  166. case swapPrimary:
  167. n.swapPrimary(hostinfo, primary)
  168. case migrateRelays:
  169. n.migrateRelayUsed(hostinfo, primary)
  170. case tryRehandshake:
  171. n.tryRehandshake(hostinfo)
  172. case sendTestPacket:
  173. n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
  174. }
  175. n.resetRelayTrafficCheck(hostinfo)
  176. }
  177. func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
  178. if hostinfo != nil {
  179. n.relayUsedLock.Lock()
  180. defer n.relayUsedLock.Unlock()
  181. // No need to migrate any relays, delete usage info now.
  182. for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
  183. delete(n.relayUsed, idx)
  184. }
  185. }
  186. }
  187. func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
  188. relayFor := oldhostinfo.relayState.CopyAllRelayFor()
  189. for _, r := range relayFor {
  190. existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
  191. var index uint32
  192. var relayFrom netip.Addr
  193. var relayTo netip.Addr
  194. switch {
  195. case ok && existing.State == Established:
  196. // This relay already exists in newhostinfo, then do nothing.
  197. continue
  198. case ok && existing.State == Requested:
  199. // The relay exists in a Requested state; re-send the request
  200. index = existing.LocalIndex
  201. switch r.Type {
  202. case TerminalType:
  203. relayFrom = n.intf.myVpnNet.Addr()
  204. relayTo = existing.PeerIp
  205. case ForwardingType:
  206. relayFrom = existing.PeerIp
  207. relayTo = newhostinfo.vpnIp
  208. default:
  209. // should never happen
  210. }
  211. case !ok:
  212. n.relayUsedLock.RLock()
  213. if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
  214. // The relay hasn't been used; don't migrate it.
  215. n.relayUsedLock.RUnlock()
  216. continue
  217. }
  218. n.relayUsedLock.RUnlock()
  219. // The relay doesn't exist at all; create some relay state and send the request.
  220. var err error
  221. index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
  222. if err != nil {
  223. n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
  224. continue
  225. }
  226. switch r.Type {
  227. case TerminalType:
  228. relayFrom = n.intf.myVpnNet.Addr()
  229. relayTo = r.PeerIp
  230. case ForwardingType:
  231. relayFrom = r.PeerIp
  232. relayTo = newhostinfo.vpnIp
  233. default:
  234. // should never happen
  235. }
  236. }
  237. //TODO: IPV6-WORK
  238. relayFromB := relayFrom.As4()
  239. relayToB := relayTo.As4()
  240. // Send a CreateRelayRequest to the peer.
  241. req := NebulaControl{
  242. Type: NebulaControl_CreateRelayRequest,
  243. InitiatorRelayIndex: index,
  244. RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]),
  245. RelayToIp: binary.BigEndian.Uint32(relayToB[:]),
  246. }
  247. msg, err := req.Marshal()
  248. if err != nil {
  249. n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
  250. } else {
  251. n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
  252. n.l.WithFields(logrus.Fields{
  253. "relayFrom": req.RelayFromIp,
  254. "relayTo": req.RelayToIp,
  255. "initiatorRelayIndex": req.InitiatorRelayIndex,
  256. "responderRelayIndex": req.ResponderRelayIndex,
  257. "vpnIp": newhostinfo.vpnIp}).
  258. Info("send CreateRelayRequest")
  259. }
  260. }
  261. }
  262. func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
  263. n.hostMap.RLock()
  264. defer n.hostMap.RUnlock()
  265. hostinfo := n.hostMap.Indexes[localIndex]
  266. if hostinfo == nil {
  267. n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
  268. delete(n.pendingDeletion, localIndex)
  269. return doNothing, nil, nil
  270. }
  271. if n.isInvalidCertificate(now, hostinfo) {
  272. delete(n.pendingDeletion, hostinfo.localIndexId)
  273. return closeTunnel, hostinfo, nil
  274. }
  275. primary := n.hostMap.Hosts[hostinfo.vpnIp]
  276. mainHostInfo := true
  277. if primary != nil && primary != hostinfo {
  278. mainHostInfo = false
  279. }
  280. // Check for traffic on this hostinfo
  281. inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
  282. // A hostinfo is determined alive if there is incoming traffic
  283. if inTraffic {
  284. decision := doNothing
  285. if n.l.Level >= logrus.DebugLevel {
  286. hostinfo.logger(n.l).
  287. WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
  288. Debug("Tunnel status")
  289. }
  290. delete(n.pendingDeletion, hostinfo.localIndexId)
  291. if mainHostInfo {
  292. decision = tryRehandshake
  293. } else {
  294. if n.shouldSwapPrimary(hostinfo, primary) {
  295. decision = swapPrimary
  296. } else {
  297. // migrate the relays to the primary, if in use.
  298. decision = migrateRelays
  299. }
  300. }
  301. n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
  302. if !outTraffic {
  303. // Send a punch packet to keep the NAT state alive
  304. n.sendPunch(hostinfo)
  305. }
  306. return decision, hostinfo, primary
  307. }
  308. if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
  309. // We have already sent a test packet and nothing was returned, this hostinfo is dead
  310. hostinfo.logger(n.l).
  311. WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
  312. Info("Tunnel status")
  313. delete(n.pendingDeletion, hostinfo.localIndexId)
  314. return deleteTunnel, hostinfo, nil
  315. }
  316. decision := doNothing
  317. if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
  318. if !outTraffic {
  319. // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
  320. // Just maintain NAT state if configured to do so.
  321. n.sendPunch(hostinfo)
  322. n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
  323. return doNothing, nil, nil
  324. }
  325. if n.punchy.GetTargetEverything() {
  326. // This is similar to the old punchy behavior with a slight optimization.
  327. // We aren't receiving traffic but we are sending it, punch on all known
  328. // ips in case we need to re-prime NAT state
  329. n.sendPunch(hostinfo)
  330. }
  331. if n.l.Level >= logrus.DebugLevel {
  332. hostinfo.logger(n.l).
  333. WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
  334. Debug("Tunnel status")
  335. }
  336. // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
  337. decision = sendTestPacket
  338. } else {
  339. if n.l.Level >= logrus.DebugLevel {
  340. hostinfo.logger(n.l).Debugf("Hostinfo sadness")
  341. }
  342. }
  343. n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
  344. n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
  345. return decision, hostinfo, nil
  346. }
  347. func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
  348. // The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
  349. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
  350. // Let's sort this out.
  351. if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
  352. // Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
  353. // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
  354. // The remotes vpn ip is lower than mine. I will not flip.
  355. return false
  356. }
  357. certState := n.intf.pki.GetCertState()
  358. return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
  359. }
  360. func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
  361. n.hostMap.Lock()
  362. // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
  363. if n.hostMap.Hosts[current.vpnIp] == primary {
  364. n.hostMap.unlockedMakePrimary(current)
  365. }
  366. n.hostMap.Unlock()
  367. }
  368. // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
  369. // the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
  370. // check and return true.
  371. func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
  372. remoteCert := hostinfo.GetCert()
  373. if remoteCert == nil {
  374. return false
  375. }
  376. valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
  377. if valid {
  378. return false
  379. }
  380. if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
  381. // Block listed certificates should always be disconnected
  382. return false
  383. }
  384. fingerprint, _ := remoteCert.Sha256Sum()
  385. hostinfo.logger(n.l).WithError(err).
  386. WithField("fingerprint", fingerprint).
  387. Info("Remote certificate is no longer valid, tearing down the tunnel")
  388. return true
  389. }
  390. func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
  391. if !n.punchy.GetPunch() {
  392. // Punching is disabled
  393. return
  394. }
  395. if n.punchy.GetTargetEverything() {
  396. hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
  397. n.metricsTxPunchy.Inc(1)
  398. n.intf.outside.WriteTo([]byte{1}, addr)
  399. })
  400. } else if hostinfo.remote.IsValid() {
  401. n.metricsTxPunchy.Inc(1)
  402. n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
  403. }
  404. }
  405. func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
  406. certState := n.intf.pki.GetCertState()
  407. if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
  408. return
  409. }
  410. n.l.WithField("vpnIp", hostinfo.vpnIp).
  411. WithField("reason", "local certificate is not current").
  412. Info("Re-handshaking with remote")
  413. n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
  414. }