connection_manager.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. package nebula
  2. import (
  3. "bytes"
  4. "context"
  5. "sync"
  6. "time"
  7. "github.com/rcrowley/go-metrics"
  8. "github.com/sirupsen/logrus"
  9. "github.com/slackhq/nebula/header"
  10. "github.com/slackhq/nebula/iputil"
  11. "github.com/slackhq/nebula/udp"
  12. )
  13. type trafficDecision int
  14. const (
  15. doNothing trafficDecision = 0
  16. deleteTunnel trafficDecision = 1 // delete the hostinfo on our side, do not notify the remote
  17. closeTunnel trafficDecision = 2 // delete the hostinfo and notify the remote
  18. swapPrimary trafficDecision = 3
  19. migrateRelays trafficDecision = 4
  20. )
  21. type connectionManager struct {
  22. in map[uint32]struct{}
  23. inLock *sync.RWMutex
  24. out map[uint32]struct{}
  25. outLock *sync.RWMutex
  26. // relayUsed holds which relay localIndexs are in use
  27. relayUsed map[uint32]struct{}
  28. relayUsedLock *sync.RWMutex
  29. hostMap *HostMap
  30. trafficTimer *LockingTimerWheel[uint32]
  31. intf *Interface
  32. pendingDeletion map[uint32]struct{}
  33. punchy *Punchy
  34. checkInterval time.Duration
  35. pendingDeletionInterval time.Duration
  36. metricsTxPunchy metrics.Counter
  37. l *logrus.Logger
  38. }
  39. func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
  40. var max time.Duration
  41. if checkInterval < pendingDeletionInterval {
  42. max = pendingDeletionInterval
  43. } else {
  44. max = checkInterval
  45. }
  46. nc := &connectionManager{
  47. hostMap: intf.hostMap,
  48. in: make(map[uint32]struct{}),
  49. inLock: &sync.RWMutex{},
  50. out: make(map[uint32]struct{}),
  51. outLock: &sync.RWMutex{},
  52. relayUsed: make(map[uint32]struct{}),
  53. relayUsedLock: &sync.RWMutex{},
  54. trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
  55. intf: intf,
  56. pendingDeletion: make(map[uint32]struct{}),
  57. checkInterval: checkInterval,
  58. pendingDeletionInterval: pendingDeletionInterval,
  59. punchy: punchy,
  60. metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
  61. l: l,
  62. }
  63. nc.Start(ctx)
  64. return nc
  65. }
  66. func (n *connectionManager) In(localIndex uint32) {
  67. n.inLock.RLock()
  68. // If this already exists, return
  69. if _, ok := n.in[localIndex]; ok {
  70. n.inLock.RUnlock()
  71. return
  72. }
  73. n.inLock.RUnlock()
  74. n.inLock.Lock()
  75. n.in[localIndex] = struct{}{}
  76. n.inLock.Unlock()
  77. }
  78. func (n *connectionManager) Out(localIndex uint32) {
  79. n.outLock.RLock()
  80. // If this already exists, return
  81. if _, ok := n.out[localIndex]; ok {
  82. n.outLock.RUnlock()
  83. return
  84. }
  85. n.outLock.RUnlock()
  86. n.outLock.Lock()
  87. n.out[localIndex] = struct{}{}
  88. n.outLock.Unlock()
  89. }
  90. func (n *connectionManager) RelayUsed(localIndex uint32) {
  91. n.relayUsedLock.RLock()
  92. // If this already exists, return
  93. if _, ok := n.relayUsed[localIndex]; ok {
  94. n.relayUsedLock.RUnlock()
  95. return
  96. }
  97. n.relayUsedLock.RUnlock()
  98. n.relayUsedLock.Lock()
  99. n.relayUsed[localIndex] = struct{}{}
  100. n.relayUsedLock.Unlock()
  101. }
  102. // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
  103. // resets the state for this local index
  104. func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
  105. n.inLock.Lock()
  106. n.outLock.Lock()
  107. _, in := n.in[localIndex]
  108. _, out := n.out[localIndex]
  109. delete(n.in, localIndex)
  110. delete(n.out, localIndex)
  111. n.inLock.Unlock()
  112. n.outLock.Unlock()
  113. return in, out
  114. }
  115. func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
  116. // Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
  117. n.outLock.Lock()
  118. if _, ok := n.out[localIndex]; ok {
  119. n.outLock.Unlock()
  120. return
  121. }
  122. n.out[localIndex] = struct{}{}
  123. n.trafficTimer.Add(localIndex, n.checkInterval)
  124. n.outLock.Unlock()
  125. }
  126. func (n *connectionManager) Start(ctx context.Context) {
  127. go n.Run(ctx)
  128. }
  129. func (n *connectionManager) Run(ctx context.Context) {
  130. //TODO: this tick should be based on the min wheel tick? Check firewall
  131. clockSource := time.NewTicker(500 * time.Millisecond)
  132. defer clockSource.Stop()
  133. p := []byte("")
  134. nb := make([]byte, 12, 12)
  135. out := make([]byte, mtu)
  136. for {
  137. select {
  138. case <-ctx.Done():
  139. return
  140. case now := <-clockSource.C:
  141. n.trafficTimer.Advance(now)
  142. for {
  143. localIndex, has := n.trafficTimer.Purge()
  144. if !has {
  145. break
  146. }
  147. n.doTrafficCheck(localIndex, p, nb, out, now)
  148. }
  149. }
  150. }
  151. }
  152. func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
  153. decision, hostinfo, primary := n.makeTrafficDecision(localIndex, p, nb, out, now)
  154. switch decision {
  155. case deleteTunnel:
  156. if n.hostMap.DeleteHostInfo(hostinfo) {
  157. // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
  158. n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
  159. }
  160. case closeTunnel:
  161. n.intf.sendCloseTunnel(hostinfo)
  162. n.intf.closeTunnel(hostinfo)
  163. case swapPrimary:
  164. n.swapPrimary(hostinfo, primary)
  165. case migrateRelays:
  166. n.migrateRelayUsed(hostinfo, primary)
  167. }
  168. n.resetRelayTrafficCheck(hostinfo)
  169. }
  170. func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
  171. if hostinfo != nil {
  172. n.relayUsedLock.Lock()
  173. defer n.relayUsedLock.Unlock()
  174. // No need to migrate any relays, delete usage info now.
  175. for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
  176. delete(n.relayUsed, idx)
  177. }
  178. }
  179. }
  180. func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
  181. relayFor := oldhostinfo.relayState.CopyAllRelayFor()
  182. for _, r := range relayFor {
  183. existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
  184. var index uint32
  185. var relayFrom iputil.VpnIp
  186. var relayTo iputil.VpnIp
  187. switch {
  188. case ok && existing.State == Established:
  189. // This relay already exists in newhostinfo, then do nothing.
  190. continue
  191. case ok && existing.State == Requested:
  192. // The relay exists in a Requested state; re-send the request
  193. index = existing.LocalIndex
  194. switch r.Type {
  195. case TerminalType:
  196. relayFrom = newhostinfo.vpnIp
  197. relayTo = existing.PeerIp
  198. case ForwardingType:
  199. relayFrom = existing.PeerIp
  200. relayTo = newhostinfo.vpnIp
  201. default:
  202. // should never happen
  203. }
  204. case !ok:
  205. n.relayUsedLock.RLock()
  206. if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
  207. // The relay hasn't been used; don't migrate it.
  208. n.relayUsedLock.RUnlock()
  209. continue
  210. }
  211. n.relayUsedLock.RUnlock()
  212. // The relay doesn't exist at all; create some relay state and send the request.
  213. var err error
  214. index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
  215. if err != nil {
  216. n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
  217. continue
  218. }
  219. switch r.Type {
  220. case TerminalType:
  221. relayFrom = newhostinfo.vpnIp
  222. relayTo = r.PeerIp
  223. case ForwardingType:
  224. relayFrom = r.PeerIp
  225. relayTo = newhostinfo.vpnIp
  226. default:
  227. // should never happen
  228. }
  229. }
  230. // Send a CreateRelayRequest to the peer.
  231. req := NebulaControl{
  232. Type: NebulaControl_CreateRelayRequest,
  233. InitiatorRelayIndex: index,
  234. RelayFromIp: uint32(relayFrom),
  235. RelayToIp: uint32(relayTo),
  236. }
  237. msg, err := req.Marshal()
  238. if err != nil {
  239. n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
  240. } else {
  241. n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
  242. n.l.WithFields(logrus.Fields{
  243. "relayFrom": iputil.VpnIp(req.RelayFromIp),
  244. "relayTo": iputil.VpnIp(req.RelayToIp),
  245. "initiatorRelayIndex": req.InitiatorRelayIndex,
  246. "responderRelayIndex": req.ResponderRelayIndex,
  247. "vpnIp": newhostinfo.vpnIp}).
  248. Info("send CreateRelayRequest")
  249. }
  250. }
  251. }
  252. func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []byte, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
  253. n.hostMap.RLock()
  254. defer n.hostMap.RUnlock()
  255. hostinfo := n.hostMap.Indexes[localIndex]
  256. if hostinfo == nil {
  257. n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
  258. delete(n.pendingDeletion, localIndex)
  259. return doNothing, nil, nil
  260. }
  261. if n.isInvalidCertificate(now, hostinfo) {
  262. delete(n.pendingDeletion, hostinfo.localIndexId)
  263. return closeTunnel, hostinfo, nil
  264. }
  265. primary := n.hostMap.Hosts[hostinfo.vpnIp]
  266. mainHostInfo := true
  267. if primary != nil && primary != hostinfo {
  268. mainHostInfo = false
  269. }
  270. // Check for traffic on this hostinfo
  271. inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
  272. // A hostinfo is determined alive if there is incoming traffic
  273. if inTraffic {
  274. decision := doNothing
  275. if n.l.Level >= logrus.DebugLevel {
  276. hostinfo.logger(n.l).
  277. WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
  278. Debug("Tunnel status")
  279. }
  280. delete(n.pendingDeletion, hostinfo.localIndexId)
  281. if mainHostInfo {
  282. n.tryRehandshake(hostinfo)
  283. } else {
  284. if n.shouldSwapPrimary(hostinfo, primary) {
  285. decision = swapPrimary
  286. } else {
  287. // migrate the relays to the primary, if in use.
  288. decision = migrateRelays
  289. }
  290. }
  291. n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
  292. if !outTraffic {
  293. // Send a punch packet to keep the NAT state alive
  294. n.sendPunch(hostinfo)
  295. }
  296. return decision, hostinfo, primary
  297. }
  298. if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
  299. // We have already sent a test packet and nothing was returned, this hostinfo is dead
  300. hostinfo.logger(n.l).
  301. WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
  302. Info("Tunnel status")
  303. delete(n.pendingDeletion, hostinfo.localIndexId)
  304. return deleteTunnel, hostinfo, nil
  305. }
  306. if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
  307. if !outTraffic {
  308. // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
  309. // Just maintain NAT state if configured to do so.
  310. n.sendPunch(hostinfo)
  311. n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
  312. return doNothing, nil, nil
  313. }
  314. if n.punchy.GetTargetEverything() {
  315. // This is similar to the old punchy behavior with a slight optimization.
  316. // We aren't receiving traffic but we are sending it, punch on all known
  317. // ips in case we need to re-prime NAT state
  318. n.sendPunch(hostinfo)
  319. }
  320. if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
  321. // We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel
  322. n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
  323. return doNothing, nil, nil
  324. }
  325. if n.l.Level >= logrus.DebugLevel {
  326. hostinfo.logger(n.l).
  327. WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
  328. Debug("Tunnel status")
  329. }
  330. // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
  331. n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
  332. } else {
  333. if n.l.Level >= logrus.DebugLevel {
  334. hostinfo.logger(n.l).Debugf("Hostinfo sadness")
  335. }
  336. }
  337. n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
  338. n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
  339. return doNothing, nil, nil
  340. }
  341. func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
  342. // The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
  343. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
  344. // Let's sort this out.
  345. if current.vpnIp < n.intf.myVpnIp {
  346. // Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
  347. // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
  348. // The remotes vpn ip is lower than mine. I will not flip.
  349. return false
  350. }
  351. certState := n.intf.certState.Load()
  352. return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature)
  353. }
  354. func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
  355. n.hostMap.Lock()
  356. // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
  357. if n.hostMap.Hosts[current.vpnIp] == primary {
  358. n.hostMap.unlockedMakePrimary(current)
  359. }
  360. n.hostMap.Unlock()
  361. }
  362. // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
  363. // the certificate is no longer valid
  364. func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
  365. if !n.intf.disconnectInvalid {
  366. return false
  367. }
  368. remoteCert := hostinfo.GetCert()
  369. if remoteCert == nil {
  370. return false
  371. }
  372. valid, err := remoteCert.Verify(now, n.intf.caPool)
  373. if valid {
  374. return false
  375. }
  376. fingerprint, _ := remoteCert.Sha256Sum()
  377. hostinfo.logger(n.l).WithError(err).
  378. WithField("fingerprint", fingerprint).
  379. Info("Remote certificate is no longer valid, tearing down the tunnel")
  380. return true
  381. }
  382. func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
  383. if !n.punchy.GetPunch() {
  384. // Punching is disabled
  385. return
  386. }
  387. if n.punchy.GetTargetEverything() {
  388. hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) {
  389. n.metricsTxPunchy.Inc(1)
  390. n.intf.outside.WriteTo([]byte{1}, addr)
  391. })
  392. } else if hostinfo.remote != nil {
  393. n.metricsTxPunchy.Inc(1)
  394. n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
  395. }
  396. }
  397. func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
  398. certState := n.intf.certState.Load()
  399. if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) {
  400. return
  401. }
  402. n.l.WithField("vpnIp", hostinfo.vpnIp).
  403. WithField("reason", "local certificate is not current").
  404. Info("Re-handshaking with remote")
  405. //TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
  406. newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)
  407. if !newHostinfo.HandshakeReady {
  408. ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
  409. }
  410. //If this is a static host, we don't need to wait for the HostQueryReply
  411. //We can trigger the handshake right now
  412. if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
  413. select {
  414. case n.intf.handshakeManager.trigger <- hostinfo.vpnIp:
  415. default:
  416. }
  417. }
  418. }