connection_manager.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. package nebula
  2. import (
  3. "context"
  4. "sync"
  5. "time"
  6. "github.com/sirupsen/logrus"
  7. "github.com/slackhq/nebula/header"
  8. )
  9. // TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
  10. // and something like every 10 packets we could lock, send 10, then unlock for a moment
  11. type connectionManager struct {
  12. hostMap *HostMap
  13. in map[uint32]struct{}
  14. inLock *sync.RWMutex
  15. out map[uint32]struct{}
  16. outLock *sync.RWMutex
  17. TrafficTimer *LockingTimerWheel[uint32]
  18. intf *Interface
  19. pendingDeletion map[uint32]int
  20. pendingDeletionLock *sync.RWMutex
  21. pendingDeletionTimer *LockingTimerWheel[uint32]
  22. checkInterval int
  23. pendingDeletionInterval int
  24. l *logrus.Logger
  25. // I wanted to call one matLock
  26. }
  27. func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
  28. nc := &connectionManager{
  29. hostMap: intf.hostMap,
  30. in: make(map[uint32]struct{}),
  31. inLock: &sync.RWMutex{},
  32. out: make(map[uint32]struct{}),
  33. outLock: &sync.RWMutex{},
  34. TrafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60),
  35. intf: intf,
  36. pendingDeletion: make(map[uint32]int),
  37. pendingDeletionLock: &sync.RWMutex{},
  38. pendingDeletionTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60),
  39. checkInterval: checkInterval,
  40. pendingDeletionInterval: pendingDeletionInterval,
  41. l: l,
  42. }
  43. nc.Start(ctx)
  44. return nc
  45. }
  46. func (n *connectionManager) In(localIndex uint32) {
  47. n.inLock.RLock()
  48. // If this already exists, return
  49. if _, ok := n.in[localIndex]; ok {
  50. n.inLock.RUnlock()
  51. return
  52. }
  53. n.inLock.RUnlock()
  54. n.inLock.Lock()
  55. n.in[localIndex] = struct{}{}
  56. n.inLock.Unlock()
  57. }
  58. func (n *connectionManager) Out(localIndex uint32) {
  59. n.outLock.RLock()
  60. // If this already exists, return
  61. if _, ok := n.out[localIndex]; ok {
  62. n.outLock.RUnlock()
  63. return
  64. }
  65. n.outLock.RUnlock()
  66. n.outLock.Lock()
  67. // double check since we dropped the lock temporarily
  68. if _, ok := n.out[localIndex]; ok {
  69. n.outLock.Unlock()
  70. return
  71. }
  72. n.out[localIndex] = struct{}{}
  73. n.AddTrafficWatch(localIndex, n.checkInterval)
  74. n.outLock.Unlock()
  75. }
  76. func (n *connectionManager) CheckIn(localIndex uint32) bool {
  77. n.inLock.RLock()
  78. if _, ok := n.in[localIndex]; ok {
  79. n.inLock.RUnlock()
  80. return true
  81. }
  82. n.inLock.RUnlock()
  83. return false
  84. }
  85. func (n *connectionManager) ClearLocalIndex(localIndex uint32) {
  86. n.inLock.Lock()
  87. n.outLock.Lock()
  88. delete(n.in, localIndex)
  89. delete(n.out, localIndex)
  90. n.inLock.Unlock()
  91. n.outLock.Unlock()
  92. }
  93. func (n *connectionManager) ClearPendingDeletion(localIndex uint32) {
  94. n.pendingDeletionLock.Lock()
  95. delete(n.pendingDeletion, localIndex)
  96. n.pendingDeletionLock.Unlock()
  97. }
  98. func (n *connectionManager) AddPendingDeletion(localIndex uint32) {
  99. n.pendingDeletionLock.Lock()
  100. if _, ok := n.pendingDeletion[localIndex]; ok {
  101. n.pendingDeletion[localIndex] += 1
  102. } else {
  103. n.pendingDeletion[localIndex] = 0
  104. }
  105. n.pendingDeletionTimer.Add(localIndex, time.Second*time.Duration(n.pendingDeletionInterval))
  106. n.pendingDeletionLock.Unlock()
  107. }
  108. func (n *connectionManager) checkPendingDeletion(localIndex uint32) bool {
  109. n.pendingDeletionLock.RLock()
  110. if _, ok := n.pendingDeletion[localIndex]; ok {
  111. n.pendingDeletionLock.RUnlock()
  112. return true
  113. }
  114. n.pendingDeletionLock.RUnlock()
  115. return false
  116. }
  117. func (n *connectionManager) AddTrafficWatch(localIndex uint32, seconds int) {
  118. n.TrafficTimer.Add(localIndex, time.Second*time.Duration(seconds))
  119. }
  120. func (n *connectionManager) Start(ctx context.Context) {
  121. go n.Run(ctx)
  122. }
  123. func (n *connectionManager) Run(ctx context.Context) {
  124. clockSource := time.NewTicker(500 * time.Millisecond)
  125. defer clockSource.Stop()
  126. p := []byte("")
  127. nb := make([]byte, 12, 12)
  128. out := make([]byte, mtu)
  129. for {
  130. select {
  131. case <-ctx.Done():
  132. return
  133. case now := <-clockSource.C:
  134. n.HandleMonitorTick(now, p, nb, out)
  135. n.HandleDeletionTick(now)
  136. }
  137. }
  138. }
  139. func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) {
  140. n.TrafficTimer.Advance(now)
  141. for {
  142. localIndex, has := n.TrafficTimer.Purge()
  143. if !has {
  144. break
  145. }
  146. // Check for traffic coming back in from this host.
  147. traf := n.CheckIn(localIndex)
  148. hostinfo, err := n.hostMap.QueryIndex(localIndex)
  149. if err != nil {
  150. n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
  151. n.ClearLocalIndex(localIndex)
  152. n.ClearPendingDeletion(localIndex)
  153. continue
  154. }
  155. if n.handleInvalidCertificate(now, hostinfo) {
  156. continue
  157. }
  158. // If we saw an incoming packets from this ip and peer's certificate is not
  159. // expired, just ignore.
  160. if traf {
  161. if n.l.Level >= logrus.DebugLevel {
  162. hostinfo.logger(n.l).
  163. WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
  164. Debug("Tunnel status")
  165. }
  166. n.ClearLocalIndex(localIndex)
  167. n.ClearPendingDeletion(localIndex)
  168. continue
  169. }
  170. hostinfo.logger(n.l).
  171. WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
  172. Debug("Tunnel status")
  173. if hostinfo != nil && hostinfo.ConnectionState != nil {
  174. // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
  175. n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
  176. } else {
  177. hostinfo.logger(n.l).Debugf("Hostinfo sadness")
  178. }
  179. n.AddPendingDeletion(localIndex)
  180. }
  181. }
  182. func (n *connectionManager) HandleDeletionTick(now time.Time) {
  183. n.pendingDeletionTimer.Advance(now)
  184. for {
  185. localIndex, has := n.pendingDeletionTimer.Purge()
  186. if !has {
  187. break
  188. }
  189. hostinfo, err := n.hostMap.QueryIndex(localIndex)
  190. if err != nil {
  191. n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
  192. n.ClearLocalIndex(localIndex)
  193. n.ClearPendingDeletion(localIndex)
  194. continue
  195. }
  196. if n.handleInvalidCertificate(now, hostinfo) {
  197. continue
  198. }
  199. // If we saw an incoming packets from this ip and peer's certificate is not
  200. // expired, just ignore.
  201. traf := n.CheckIn(localIndex)
  202. if traf {
  203. hostinfo.logger(n.l).
  204. WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
  205. Debug("Tunnel status")
  206. n.ClearLocalIndex(localIndex)
  207. n.ClearPendingDeletion(localIndex)
  208. continue
  209. }
  210. // If it comes around on deletion wheel and hasn't resolved itself, delete
  211. if n.checkPendingDeletion(localIndex) {
  212. cn := ""
  213. if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
  214. cn = hostinfo.ConnectionState.peerCert.Details.Name
  215. }
  216. hostinfo.logger(n.l).
  217. WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
  218. WithField("certName", cn).
  219. Info("Tunnel status")
  220. n.hostMap.DeleteHostInfo(hostinfo)
  221. }
  222. n.ClearLocalIndex(localIndex)
  223. n.ClearPendingDeletion(localIndex)
  224. }
  225. }
  226. // handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
  227. func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
  228. if !n.intf.disconnectInvalid {
  229. return false
  230. }
  231. remoteCert := hostinfo.GetCert()
  232. if remoteCert == nil {
  233. return false
  234. }
  235. valid, err := remoteCert.Verify(now, n.intf.caPool)
  236. if valid {
  237. return false
  238. }
  239. fingerprint, _ := remoteCert.Sha256Sum()
  240. hostinfo.logger(n.l).WithError(err).
  241. WithField("fingerprint", fingerprint).
  242. Info("Remote certificate is no longer valid, tearing down the tunnel")
  243. // Inform the remote and close the tunnel locally
  244. n.intf.sendCloseTunnel(hostinfo)
  245. n.intf.closeTunnel(hostinfo)
  246. n.ClearLocalIndex(hostinfo.localIndexId)
  247. n.ClearPendingDeletion(hostinfo.localIndexId)
  248. return true
  249. }