connection_manager.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. package nebula
  2. import (
  3. "context"
  4. "sync"
  5. "time"
  6. "github.com/sirupsen/logrus"
  7. "github.com/slackhq/nebula/header"
  8. )
  9. // TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
  10. // and something like every 10 packets we could lock, send 10, then unlock for a moment
  11. type connectionManager struct {
  12. hostMap *HostMap
  13. in map[uint32]struct{}
  14. inLock *sync.RWMutex
  15. out map[uint32]struct{}
  16. outLock *sync.RWMutex
  17. TrafficTimer *LockingTimerWheel[uint32]
  18. intf *Interface
  19. pendingDeletion map[uint32]int
  20. pendingDeletionLock *sync.RWMutex
  21. pendingDeletionTimer *LockingTimerWheel[uint32]
  22. checkInterval int
  23. pendingDeletionInterval int
  24. l *logrus.Logger
  25. // I wanted to call one matLock
  26. }
  27. func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
  28. nc := &connectionManager{
  29. hostMap: intf.hostMap,
  30. in: make(map[uint32]struct{}),
  31. inLock: &sync.RWMutex{},
  32. out: make(map[uint32]struct{}),
  33. outLock: &sync.RWMutex{},
  34. TrafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60),
  35. intf: intf,
  36. pendingDeletion: make(map[uint32]int),
  37. pendingDeletionLock: &sync.RWMutex{},
  38. pendingDeletionTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60),
  39. checkInterval: checkInterval,
  40. pendingDeletionInterval: pendingDeletionInterval,
  41. l: l,
  42. }
  43. nc.Start(ctx)
  44. return nc
  45. }
  46. func (n *connectionManager) In(localIndex uint32) {
  47. n.inLock.RLock()
  48. // If this already exists, return
  49. if _, ok := n.in[localIndex]; ok {
  50. n.inLock.RUnlock()
  51. return
  52. }
  53. n.inLock.RUnlock()
  54. n.inLock.Lock()
  55. n.in[localIndex] = struct{}{}
  56. n.inLock.Unlock()
  57. }
  58. func (n *connectionManager) Out(localIndex uint32) {
  59. n.outLock.RLock()
  60. // If this already exists, return
  61. if _, ok := n.out[localIndex]; ok {
  62. n.outLock.RUnlock()
  63. return
  64. }
  65. n.outLock.RUnlock()
  66. n.outLock.Lock()
  67. // double check since we dropped the lock temporarily
  68. if _, ok := n.out[localIndex]; ok {
  69. n.outLock.Unlock()
  70. return
  71. }
  72. n.out[localIndex] = struct{}{}
  73. n.AddTrafficWatch(localIndex, n.checkInterval)
  74. n.outLock.Unlock()
  75. }
  76. func (n *connectionManager) CheckIn(localIndex uint32) bool {
  77. n.inLock.RLock()
  78. if _, ok := n.in[localIndex]; ok {
  79. n.inLock.RUnlock()
  80. return true
  81. }
  82. n.inLock.RUnlock()
  83. return false
  84. }
  85. func (n *connectionManager) ClearLocalIndex(localIndex uint32) {
  86. n.inLock.Lock()
  87. n.outLock.Lock()
  88. delete(n.in, localIndex)
  89. delete(n.out, localIndex)
  90. n.inLock.Unlock()
  91. n.outLock.Unlock()
  92. }
  93. func (n *connectionManager) ClearPendingDeletion(localIndex uint32) {
  94. n.pendingDeletionLock.Lock()
  95. delete(n.pendingDeletion, localIndex)
  96. n.pendingDeletionLock.Unlock()
  97. }
  98. func (n *connectionManager) AddPendingDeletion(localIndex uint32) {
  99. n.pendingDeletionLock.Lock()
  100. if _, ok := n.pendingDeletion[localIndex]; ok {
  101. n.pendingDeletion[localIndex] += 1
  102. } else {
  103. n.pendingDeletion[localIndex] = 0
  104. }
  105. n.pendingDeletionTimer.Add(localIndex, time.Second*time.Duration(n.pendingDeletionInterval))
  106. n.pendingDeletionLock.Unlock()
  107. }
  108. func (n *connectionManager) checkPendingDeletion(localIndex uint32) bool {
  109. n.pendingDeletionLock.RLock()
  110. if _, ok := n.pendingDeletion[localIndex]; ok {
  111. n.pendingDeletionLock.RUnlock()
  112. return true
  113. }
  114. n.pendingDeletionLock.RUnlock()
  115. return false
  116. }
  117. func (n *connectionManager) AddTrafficWatch(localIndex uint32, seconds int) {
  118. n.TrafficTimer.Add(localIndex, time.Second*time.Duration(seconds))
  119. }
  120. func (n *connectionManager) Start(ctx context.Context) {
  121. go n.Run(ctx)
  122. }
  123. func (n *connectionManager) Run(ctx context.Context) {
  124. clockSource := time.NewTicker(500 * time.Millisecond)
  125. defer clockSource.Stop()
  126. p := []byte("")
  127. nb := make([]byte, 12, 12)
  128. out := make([]byte, mtu)
  129. for {
  130. select {
  131. case <-ctx.Done():
  132. return
  133. case now := <-clockSource.C:
  134. n.HandleMonitorTick(now, p, nb, out)
  135. n.HandleDeletionTick(now)
  136. }
  137. }
  138. }
  139. func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) {
  140. n.TrafficTimer.Advance(now)
  141. for {
  142. localIndex, has := n.TrafficTimer.Purge()
  143. if !has {
  144. break
  145. }
  146. // Check for traffic coming back in from this host.
  147. traf := n.CheckIn(localIndex)
  148. hostinfo, err := n.hostMap.QueryIndex(localIndex)
  149. if err != nil {
  150. n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
  151. n.ClearLocalIndex(localIndex)
  152. n.ClearPendingDeletion(localIndex)
  153. continue
  154. }
  155. if n.handleInvalidCertificate(now, hostinfo) {
  156. continue
  157. }
  158. // Does the vpnIp point to this hostinfo or is it ancillary? If we have ancillary hostinfos then we need to
  159. // decide if this should be the main hostinfo if we are seeing traffic on it
  160. primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
  161. mainHostInfo := true
  162. if primary != nil && primary != hostinfo {
  163. mainHostInfo = false
  164. }
  165. // If we saw an incoming packets from this ip and peer's certificate is not
  166. // expired, just ignore.
  167. if traf {
  168. if n.l.Level >= logrus.DebugLevel {
  169. hostinfo.logger(n.l).
  170. WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
  171. Debug("Tunnel status")
  172. }
  173. n.ClearLocalIndex(localIndex)
  174. n.ClearPendingDeletion(localIndex)
  175. if !mainHostInfo {
  176. if hostinfo.vpnIp > n.intf.myVpnIp {
  177. // We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
  178. // This the primary and prime the old primary hostinfo for testing
  179. n.hostMap.MakePrimary(hostinfo)
  180. n.Out(primary.localIndexId)
  181. } else {
  182. // This hostinfo is still being used despite not being the primary hostinfo for this vpn ip
  183. // Keep tracking so that we can tear it down when it goes away
  184. n.Out(hostinfo.localIndexId)
  185. }
  186. }
  187. continue
  188. }
  189. hostinfo.logger(n.l).
  190. WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
  191. Debug("Tunnel status")
  192. if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
  193. // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
  194. n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
  195. } else {
  196. hostinfo.logger(n.l).Debugf("Hostinfo sadness")
  197. }
  198. n.AddPendingDeletion(localIndex)
  199. }
  200. }
  201. func (n *connectionManager) HandleDeletionTick(now time.Time) {
  202. n.pendingDeletionTimer.Advance(now)
  203. for {
  204. localIndex, has := n.pendingDeletionTimer.Purge()
  205. if !has {
  206. break
  207. }
  208. hostinfo, err := n.hostMap.QueryIndex(localIndex)
  209. if err != nil {
  210. n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
  211. n.ClearLocalIndex(localIndex)
  212. n.ClearPendingDeletion(localIndex)
  213. continue
  214. }
  215. if n.handleInvalidCertificate(now, hostinfo) {
  216. continue
  217. }
  218. // If we saw an incoming packets from this ip and peer's certificate is not
  219. // expired, just ignore.
  220. traf := n.CheckIn(localIndex)
  221. if traf {
  222. hostinfo.logger(n.l).
  223. WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
  224. Debug("Tunnel status")
  225. n.ClearLocalIndex(localIndex)
  226. n.ClearPendingDeletion(localIndex)
  227. continue
  228. }
  229. // If it comes around on deletion wheel and hasn't resolved itself, delete
  230. if n.checkPendingDeletion(localIndex) {
  231. cn := ""
  232. if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
  233. cn = hostinfo.ConnectionState.peerCert.Details.Name
  234. }
  235. hostinfo.logger(n.l).
  236. WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
  237. WithField("certName", cn).
  238. Info("Tunnel status")
  239. n.hostMap.DeleteHostInfo(hostinfo)
  240. }
  241. n.ClearLocalIndex(localIndex)
  242. n.ClearPendingDeletion(localIndex)
  243. }
  244. }
  245. // handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
  246. func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
  247. if !n.intf.disconnectInvalid {
  248. return false
  249. }
  250. remoteCert := hostinfo.GetCert()
  251. if remoteCert == nil {
  252. return false
  253. }
  254. valid, err := remoteCert.Verify(now, n.intf.caPool)
  255. if valid {
  256. return false
  257. }
  258. fingerprint, _ := remoteCert.Sha256Sum()
  259. hostinfo.logger(n.l).WithError(err).
  260. WithField("fingerprint", fingerprint).
  261. Info("Remote certificate is no longer valid, tearing down the tunnel")
  262. // Inform the remote and close the tunnel locally
  263. n.intf.sendCloseTunnel(hostinfo)
  264. n.intf.closeTunnel(hostinfo)
  265. n.ClearLocalIndex(hostinfo.localIndexId)
  266. n.ClearPendingDeletion(hostinfo.localIndexId)
  267. return true
  268. }