handshake_manager.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. package nebula
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/rand"
  6. "encoding/binary"
  7. "errors"
  8. "net"
  9. "time"
  10. "github.com/rcrowley/go-metrics"
  11. "github.com/sirupsen/logrus"
  12. "github.com/slackhq/nebula/header"
  13. "github.com/slackhq/nebula/iputil"
  14. "github.com/slackhq/nebula/udp"
  15. )
  16. const (
  17. DefaultHandshakeTryInterval = time.Millisecond * 100
  18. DefaultHandshakeRetries = 10
  19. DefaultHandshakeTriggerBuffer = 64
  20. DefaultUseRelays = true
  21. )
  22. var (
  23. defaultHandshakeConfig = HandshakeConfig{
  24. tryInterval: DefaultHandshakeTryInterval,
  25. retries: DefaultHandshakeRetries,
  26. triggerBuffer: DefaultHandshakeTriggerBuffer,
  27. useRelays: DefaultUseRelays,
  28. }
  29. )
  30. type HandshakeConfig struct {
  31. tryInterval time.Duration
  32. retries int
  33. triggerBuffer int
  34. useRelays bool
  35. messageMetrics *MessageMetrics
  36. }
  37. type HandshakeManager struct {
  38. pendingHostMap *HostMap
  39. mainHostMap *HostMap
  40. lightHouse *LightHouse
  41. outside *udp.Conn
  42. config HandshakeConfig
  43. OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp]
  44. messageMetrics *MessageMetrics
  45. metricInitiated metrics.Counter
  46. metricTimedOut metrics.Counter
  47. l *logrus.Logger
  48. // can be used to trigger outbound handshake for the given vpnIp
  49. trigger chan iputil.VpnIp
  50. }
  51. func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udp.Conn, config HandshakeConfig) *HandshakeManager {
  52. return &HandshakeManager{
  53. pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
  54. mainHostMap: mainHostMap,
  55. lightHouse: lightHouse,
  56. outside: outside,
  57. config: config,
  58. trigger: make(chan iputil.VpnIp, config.triggerBuffer),
  59. OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
  60. messageMetrics: config.messageMetrics,
  61. metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
  62. metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
  63. l: l,
  64. }
  65. }
  66. func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
  67. clockSource := time.NewTicker(c.config.tryInterval)
  68. defer clockSource.Stop()
  69. for {
  70. select {
  71. case <-ctx.Done():
  72. return
  73. case vpnIP := <-c.trigger:
  74. c.handleOutbound(vpnIP, f, true)
  75. case now := <-clockSource.C:
  76. c.NextOutboundHandshakeTimerTick(now, f)
  77. }
  78. }
  79. }
  80. func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
  81. c.OutboundHandshakeTimer.Advance(now)
  82. for {
  83. vpnIp, has := c.OutboundHandshakeTimer.Purge()
  84. if !has {
  85. break
  86. }
  87. c.handleOutbound(vpnIp, f, false)
  88. }
  89. }
  90. func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) {
  91. hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
  92. if err != nil {
  93. return
  94. }
  95. hostinfo.Lock()
  96. defer hostinfo.Unlock()
  97. // We may have raced to completion but now that we have a lock we should ensure we have not yet completed.
  98. if hostinfo.HandshakeComplete {
  99. // Ensure we don't exist in the pending hostmap anymore since we have completed
  100. c.pendingHostMap.DeleteHostInfo(hostinfo)
  101. return
  102. }
  103. // Check if we have a handshake packet to transmit yet
  104. if !hostinfo.HandshakeReady {
  105. // There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
  106. // Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
  107. c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
  108. return
  109. }
  110. // If we are out of time, clean up
  111. if hostinfo.HandshakeCounter >= c.config.retries {
  112. hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)).
  113. WithField("initiatorIndex", hostinfo.localIndexId).
  114. WithField("remoteIndex", hostinfo.remoteIndexId).
  115. WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
  116. WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()).
  117. Info("Handshake timed out")
  118. c.metricTimedOut.Inc(1)
  119. c.pendingHostMap.DeleteHostInfo(hostinfo)
  120. return
  121. }
  122. // Get a remotes object if we don't already have one.
  123. // This is mainly to protect us as this should never be the case
  124. // NB ^ This comment doesn't jive. It's how the thing gets initialized.
  125. // It's the common path. Should it update every time, in case a future LH query/queries give us more info?
  126. if hostinfo.remotes == nil {
  127. hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
  128. }
  129. remotes := hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)
  130. remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes)
  131. // We only care about a lighthouse trigger if we have new remotes to send to.
  132. // This is a very specific optimization for a fast lighthouse reply.
  133. if lighthouseTriggered && !remotesHaveChanged {
  134. // If we didn't return here a lighthouse could cause us to aggressively send handshakes
  135. return
  136. }
  137. hostinfo.HandshakeLastRemotes = remotes
  138. // TODO: this will generate a load of queries for hosts with only 1 ip
  139. // (such as ones registered to the lighthouse with only a private IP)
  140. // So we only do it one time after attempting 5 handshakes already.
  141. if len(remotes) <= 1 && hostinfo.HandshakeCounter == 5 {
  142. // If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
  143. // Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
  144. // the learned public ip for them. Query again to short circuit the promotion counter
  145. c.lightHouse.QueryServer(vpnIp, f)
  146. }
  147. // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
  148. var sentTo []*udp.Addr
  149. hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
  150. c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
  151. err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
  152. if err != nil {
  153. hostinfo.logger(c.l).WithField("udpAddr", addr).
  154. WithField("initiatorIndex", hostinfo.localIndexId).
  155. WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
  156. WithError(err).Error("Failed to send handshake message")
  157. } else {
  158. sentTo = append(sentTo, addr)
  159. }
  160. })
  161. // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
  162. // so only log when the list of remotes has changed
  163. if remotesHaveChanged {
  164. hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
  165. WithField("initiatorIndex", hostinfo.localIndexId).
  166. WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
  167. Info("Handshake message sent")
  168. } else if c.l.IsLevelEnabled(logrus.DebugLevel) {
  169. hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
  170. WithField("initiatorIndex", hostinfo.localIndexId).
  171. WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
  172. Debug("Handshake message sent")
  173. }
  174. if c.config.useRelays && len(hostinfo.remotes.relays) > 0 {
  175. hostinfo.logger(c.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
  176. // Send a RelayRequest to all known Relay IP's
  177. for _, relay := range hostinfo.remotes.relays {
  178. // Don't relay to myself, and don't relay through the host I'm trying to connect to
  179. if *relay == vpnIp || *relay == c.lightHouse.myVpnIp {
  180. continue
  181. }
  182. relayHostInfo, err := c.mainHostMap.QueryVpnIp(*relay)
  183. if err != nil || relayHostInfo.remote == nil {
  184. hostinfo.logger(c.l).WithError(err).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
  185. f.Handshake(*relay)
  186. continue
  187. }
  188. // Check the relay HostInfo to see if we already established a relay through it
  189. if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok {
  190. switch existingRelay.State {
  191. case Established:
  192. hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Send handshake via relay")
  193. f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
  194. case Requested:
  195. hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
  196. // Re-send the CreateRelay request, in case the previous one was lost.
  197. m := NebulaControl{
  198. Type: NebulaControl_CreateRelayRequest,
  199. InitiatorRelayIndex: existingRelay.LocalIndex,
  200. RelayFromIp: uint32(c.lightHouse.myVpnIp),
  201. RelayToIp: uint32(vpnIp),
  202. }
  203. msg, err := m.Marshal()
  204. if err != nil {
  205. hostinfo.logger(c.l).
  206. WithError(err).
  207. Error("Failed to marshal Control message to create relay")
  208. } else {
  209. // This must send over the hostinfo, not over hm.Hosts[ip]
  210. f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
  211. c.l.WithFields(logrus.Fields{
  212. "relayFrom": c.lightHouse.myVpnIp,
  213. "relayTo": vpnIp,
  214. "initiatorRelayIndex": existingRelay.LocalIndex,
  215. "relay": *relay}).
  216. Info("send CreateRelayRequest")
  217. }
  218. default:
  219. hostinfo.logger(c.l).
  220. WithField("vpnIp", vpnIp).
  221. WithField("state", existingRelay.State).
  222. WithField("relay", relayHostInfo.vpnIp).
  223. Errorf("Relay unexpected state")
  224. }
  225. } else {
  226. // No relays exist or requested yet.
  227. if relayHostInfo.remote != nil {
  228. idx, err := AddRelay(c.l, relayHostInfo, c.mainHostMap, vpnIp, nil, TerminalType, Requested)
  229. if err != nil {
  230. hostinfo.logger(c.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
  231. }
  232. m := NebulaControl{
  233. Type: NebulaControl_CreateRelayRequest,
  234. InitiatorRelayIndex: idx,
  235. RelayFromIp: uint32(c.lightHouse.myVpnIp),
  236. RelayToIp: uint32(vpnIp),
  237. }
  238. msg, err := m.Marshal()
  239. if err != nil {
  240. hostinfo.logger(c.l).
  241. WithError(err).
  242. Error("Failed to marshal Control message to create relay")
  243. } else {
  244. f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
  245. c.l.WithFields(logrus.Fields{
  246. "relayFrom": c.lightHouse.myVpnIp,
  247. "relayTo": vpnIp,
  248. "initiatorRelayIndex": idx,
  249. "relay": *relay}).
  250. Info("send CreateRelayRequest")
  251. }
  252. }
  253. }
  254. }
  255. }
  256. // Increment the counter to increase our delay, linear backoff
  257. hostinfo.HandshakeCounter++
  258. // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
  259. if !lighthouseTriggered {
  260. c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
  261. }
  262. }
  263. func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
  264. hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init)
  265. if created {
  266. c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
  267. c.metricInitiated.Inc(1)
  268. }
  269. return hostinfo
  270. }
  271. var (
  272. ErrExistingHostInfo = errors.New("existing hostinfo")
  273. ErrAlreadySeen = errors.New("already seen")
  274. ErrLocalIndexCollision = errors.New("local index collision")
  275. )
  276. // CheckAndComplete checks for any conflicts in the main and pending hostmap
  277. // before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
  278. //
  279. // ErrAlreadySeen if we already have an entry in the hostmap that has seen the
  280. // exact same handshake packet
  281. //
  282. // ErrExistingHostInfo if we already have an entry in the hostmap for this
  283. // VpnIp and the new handshake was older than the one we currently have
  284. //
  285. // ErrLocalIndexCollision if we already have an entry in the main or pending
  286. // hostmap for the hostinfo.localIndexId.
  287. func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
  288. c.pendingHostMap.Lock()
  289. defer c.pendingHostMap.Unlock()
  290. c.mainHostMap.Lock()
  291. defer c.mainHostMap.Unlock()
  292. // Check if we already have a tunnel with this vpn ip
  293. existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
  294. if found && existingHostInfo != nil {
  295. testHostInfo := existingHostInfo
  296. for testHostInfo != nil {
  297. // Is it just a delayed handshake packet?
  298. if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], testHostInfo.HandshakePacket[handshakePacket]) {
  299. return testHostInfo, ErrAlreadySeen
  300. }
  301. testHostInfo = testHostInfo.next
  302. }
  303. // Is this a newer handshake?
  304. if existingHostInfo.lastHandshakeTime >= hostinfo.lastHandshakeTime && !existingHostInfo.ConnectionState.initiator {
  305. return existingHostInfo, ErrExistingHostInfo
  306. }
  307. existingHostInfo.logger(c.l).Info("Taking new handshake")
  308. }
  309. existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
  310. if found {
  311. // We have a collision, but for a different hostinfo
  312. return existingIndex, ErrLocalIndexCollision
  313. }
  314. existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
  315. if found && existingIndex != hostinfo {
  316. // We have a collision, but for a different hostinfo
  317. return existingIndex, ErrLocalIndexCollision
  318. }
  319. existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
  320. if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
  321. // We have a collision, but this can happen since we can't control
  322. // the remote ID. Just log about the situation as a note.
  323. hostinfo.logger(c.l).
  324. WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
  325. Info("New host shadows existing host remoteIndex")
  326. }
  327. c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
  328. return existingHostInfo, nil
  329. }
  330. // Complete is a simpler version of CheckAndComplete when we already know we
  331. // won't have a localIndexId collision because we already have an entry in the
  332. // pendingHostMap. An existing hostinfo is returned if there was one.
  333. func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
  334. c.pendingHostMap.Lock()
  335. defer c.pendingHostMap.Unlock()
  336. c.mainHostMap.Lock()
  337. defer c.mainHostMap.Unlock()
  338. existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
  339. if found && existingRemoteIndex != nil {
  340. // We have a collision, but this can happen since we can't control
  341. // the remote ID. Just log about the situation as a note.
  342. hostinfo.logger(c.l).
  343. WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
  344. Info("New host shadows existing host remoteIndex")
  345. }
  346. // We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
  347. c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
  348. c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
  349. }
  350. // AddIndexHostInfo generates a unique localIndexId for this HostInfo
  351. // and adds it to the pendingHostMap. Will error if we are unable to generate
  352. // a unique localIndexId
  353. func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
  354. c.pendingHostMap.Lock()
  355. defer c.pendingHostMap.Unlock()
  356. c.mainHostMap.RLock()
  357. defer c.mainHostMap.RUnlock()
  358. for i := 0; i < 32; i++ {
  359. index, err := generateIndex(c.l)
  360. if err != nil {
  361. return err
  362. }
  363. _, inPending := c.pendingHostMap.Indexes[index]
  364. _, inMain := c.mainHostMap.Indexes[index]
  365. if !inMain && !inPending {
  366. h.localIndexId = index
  367. c.pendingHostMap.Indexes[index] = h
  368. return nil
  369. }
  370. }
  371. return errors.New("failed to generate unique localIndexId")
  372. }
  373. func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
  374. c.pendingHostMap.addRemoteIndexHostInfo(index, h)
  375. }
  376. func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
  377. //l.Debugln("Deleting pending hostinfo :", hostinfo)
  378. c.pendingHostMap.DeleteHostInfo(hostinfo)
  379. }
  380. func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) {
  381. return c.pendingHostMap.QueryIndex(index)
  382. }
  383. func (c *HandshakeManager) EmitStats() {
  384. c.pendingHostMap.EmitStats("pending")
  385. c.mainHostMap.EmitStats("main")
  386. }
  387. // Utility functions below
  388. func generateIndex(l *logrus.Logger) (uint32, error) {
  389. b := make([]byte, 4)
  390. // Let zero mean we don't know the ID, so don't generate zero
  391. var index uint32
  392. for index == 0 {
  393. _, err := rand.Read(b)
  394. if err != nil {
  395. l.Errorln(err)
  396. return 0, err
  397. }
  398. index = binary.BigEndian.Uint32(b)
  399. }
  400. if l.Level >= logrus.DebugLevel {
  401. l.WithField("index", index).
  402. Debug("Generated index")
  403. }
  404. return index, nil
  405. }
  406. func hsTimeout(tries int, interval time.Duration) time.Duration {
  407. return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval)))
  408. }