connection_manager.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. package nebula
  2. import (
  3. "context"
  4. "sync"
  5. "time"
  6. "github.com/rcrowley/go-metrics"
  7. "github.com/sirupsen/logrus"
  8. "github.com/slackhq/nebula/header"
  9. "github.com/slackhq/nebula/udp"
  10. )
  11. type connectionManager struct {
  12. in map[uint32]struct{}
  13. inLock *sync.RWMutex
  14. out map[uint32]struct{}
  15. outLock *sync.RWMutex
  16. hostMap *HostMap
  17. trafficTimer *LockingTimerWheel[uint32]
  18. intf *Interface
  19. pendingDeletion map[uint32]struct{}
  20. punchy *Punchy
  21. checkInterval time.Duration
  22. pendingDeletionInterval time.Duration
  23. metricsTxPunchy metrics.Counter
  24. l *logrus.Logger
  25. }
  26. func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
  27. var max time.Duration
  28. if checkInterval < pendingDeletionInterval {
  29. max = pendingDeletionInterval
  30. } else {
  31. max = checkInterval
  32. }
  33. nc := &connectionManager{
  34. hostMap: intf.hostMap,
  35. in: make(map[uint32]struct{}),
  36. inLock: &sync.RWMutex{},
  37. out: make(map[uint32]struct{}),
  38. outLock: &sync.RWMutex{},
  39. trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
  40. intf: intf,
  41. pendingDeletion: make(map[uint32]struct{}),
  42. checkInterval: checkInterval,
  43. pendingDeletionInterval: pendingDeletionInterval,
  44. punchy: punchy,
  45. metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
  46. l: l,
  47. }
  48. nc.Start(ctx)
  49. return nc
  50. }
  51. func (n *connectionManager) In(localIndex uint32) {
  52. n.inLock.RLock()
  53. // If this already exists, return
  54. if _, ok := n.in[localIndex]; ok {
  55. n.inLock.RUnlock()
  56. return
  57. }
  58. n.inLock.RUnlock()
  59. n.inLock.Lock()
  60. n.in[localIndex] = struct{}{}
  61. n.inLock.Unlock()
  62. }
  63. func (n *connectionManager) Out(localIndex uint32) {
  64. n.outLock.RLock()
  65. // If this already exists, return
  66. if _, ok := n.out[localIndex]; ok {
  67. n.outLock.RUnlock()
  68. return
  69. }
  70. n.outLock.RUnlock()
  71. n.outLock.Lock()
  72. n.out[localIndex] = struct{}{}
  73. n.outLock.Unlock()
  74. }
  75. // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
  76. // resets the state for this local index
  77. func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
  78. n.inLock.Lock()
  79. n.outLock.Lock()
  80. _, in := n.in[localIndex]
  81. _, out := n.out[localIndex]
  82. delete(n.in, localIndex)
  83. delete(n.out, localIndex)
  84. n.inLock.Unlock()
  85. n.outLock.Unlock()
  86. return in, out
  87. }
  88. func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
  89. n.Out(localIndex)
  90. n.trafficTimer.Add(localIndex, n.checkInterval)
  91. }
  92. func (n *connectionManager) Start(ctx context.Context) {
  93. go n.Run(ctx)
  94. }
  95. func (n *connectionManager) Run(ctx context.Context) {
  96. //TODO: this tick should be based on the min wheel tick? Check firewall
  97. clockSource := time.NewTicker(500 * time.Millisecond)
  98. defer clockSource.Stop()
  99. p := []byte("")
  100. nb := make([]byte, 12, 12)
  101. out := make([]byte, mtu)
  102. for {
  103. select {
  104. case <-ctx.Done():
  105. return
  106. case now := <-clockSource.C:
  107. n.trafficTimer.Advance(now)
  108. for {
  109. localIndex, has := n.trafficTimer.Purge()
  110. if !has {
  111. break
  112. }
  113. n.doTrafficCheck(localIndex, p, nb, out, now)
  114. }
  115. }
  116. }
  117. }
  118. func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
  119. hostinfo, err := n.hostMap.QueryIndex(localIndex)
  120. if err != nil {
  121. n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
  122. delete(n.pendingDeletion, localIndex)
  123. return
  124. }
  125. if n.handleInvalidCertificate(now, hostinfo) {
  126. return
  127. }
  128. primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
  129. mainHostInfo := true
  130. if primary != nil && primary != hostinfo {
  131. mainHostInfo = false
  132. }
  133. // Check for traffic on this hostinfo
  134. inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
  135. // A hostinfo is determined alive if there is incoming traffic
  136. if inTraffic {
  137. if n.l.Level >= logrus.DebugLevel {
  138. hostinfo.logger(n.l).
  139. WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
  140. Debug("Tunnel status")
  141. }
  142. delete(n.pendingDeletion, hostinfo.localIndexId)
  143. if !mainHostInfo {
  144. if hostinfo.vpnIp > n.intf.myVpnIp {
  145. // We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
  146. // This the primary and prime the old primary hostinfo for testing
  147. n.hostMap.MakePrimary(hostinfo)
  148. }
  149. }
  150. n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
  151. if !outTraffic {
  152. // Send a punch packet to keep the NAT state alive
  153. n.sendPunch(hostinfo)
  154. }
  155. return
  156. }
  157. if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
  158. // We have already sent a test packet and nothing was returned, this hostinfo is dead
  159. hostinfo.logger(n.l).
  160. WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
  161. Info("Tunnel status")
  162. n.hostMap.DeleteHostInfo(hostinfo)
  163. delete(n.pendingDeletion, hostinfo.localIndexId)
  164. return
  165. }
  166. hostinfo.logger(n.l).
  167. WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
  168. Debug("Tunnel status")
  169. if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
  170. if !outTraffic {
  171. // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
  172. // Just maintain NAT state if configured to do so.
  173. n.sendPunch(hostinfo)
  174. n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
  175. return
  176. }
  177. if n.punchy.GetTargetEverything() {
  178. // This is similar to the old punchy behavior with a slight optimization.
  179. // We aren't receiving traffic but we are sending it, punch on all known
  180. // ips in case we need to re-prime NAT state
  181. n.sendPunch(hostinfo)
  182. }
  183. if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
  184. // We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel
  185. n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
  186. return
  187. }
  188. // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
  189. n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
  190. } else {
  191. hostinfo.logger(n.l).Debugf("Hostinfo sadness")
  192. }
  193. n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
  194. n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
  195. }
  196. // handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
  197. func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
  198. if !n.intf.disconnectInvalid {
  199. return false
  200. }
  201. remoteCert := hostinfo.GetCert()
  202. if remoteCert == nil {
  203. return false
  204. }
  205. valid, err := remoteCert.Verify(now, n.intf.caPool)
  206. if valid {
  207. return false
  208. }
  209. fingerprint, _ := remoteCert.Sha256Sum()
  210. hostinfo.logger(n.l).WithError(err).
  211. WithField("fingerprint", fingerprint).
  212. Info("Remote certificate is no longer valid, tearing down the tunnel")
  213. // Inform the remote and close the tunnel locally
  214. n.intf.sendCloseTunnel(hostinfo)
  215. n.intf.closeTunnel(hostinfo)
  216. delete(n.pendingDeletion, hostinfo.localIndexId)
  217. return true
  218. }
  219. func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
  220. if !n.punchy.GetPunch() {
  221. // Punching is disabled
  222. return
  223. }
  224. if n.punchy.GetTargetEverything() {
  225. hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) {
  226. n.metricsTxPunchy.Inc(1)
  227. n.intf.outside.WriteTo([]byte{1}, addr)
  228. })
  229. } else if hostinfo.remote != nil {
  230. n.metricsTxPunchy.Inc(1)
  231. n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
  232. }
  233. }