handshake_manager.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. package nebula
  2. import (
  3. "bytes"
  4. "crypto/rand"
  5. "encoding/binary"
  6. "errors"
  7. "net"
  8. "time"
  9. "github.com/sirupsen/logrus"
  10. )
  11. const (
  12. // Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries
  13. // With 100ms interval and 20 retries is 23.5 seconds
  14. DefaultHandshakeTryInterval = time.Millisecond * 100
  15. DefaultHandshakeRetries = 20
  16. // DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
  17. DefaultHandshakeWaitRotation = 5
  18. DefaultHandshakeTriggerBuffer = 64
  19. )
  20. var (
  21. defaultHandshakeConfig = HandshakeConfig{
  22. tryInterval: DefaultHandshakeTryInterval,
  23. retries: DefaultHandshakeRetries,
  24. waitRotation: DefaultHandshakeWaitRotation,
  25. triggerBuffer: DefaultHandshakeTriggerBuffer,
  26. }
  27. )
  28. type HandshakeConfig struct {
  29. tryInterval time.Duration
  30. retries int
  31. waitRotation int
  32. triggerBuffer int
  33. messageMetrics *MessageMetrics
  34. }
  35. type HandshakeManager struct {
  36. pendingHostMap *HostMap
  37. mainHostMap *HostMap
  38. lightHouse *LightHouse
  39. outside *udpConn
  40. config HandshakeConfig
  41. // can be used to trigger outbound handshake for the given vpnIP
  42. trigger chan uint32
  43. OutboundHandshakeTimer *SystemTimerWheel
  44. InboundHandshakeTimer *SystemTimerWheel
  45. messageMetrics *MessageMetrics
  46. }
  47. func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
  48. return &HandshakeManager{
  49. pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges),
  50. mainHostMap: mainHostMap,
  51. lightHouse: lightHouse,
  52. outside: outside,
  53. config: config,
  54. trigger: make(chan uint32, config.triggerBuffer),
  55. OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
  56. InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
  57. messageMetrics: config.messageMetrics,
  58. }
  59. }
  60. func (c *HandshakeManager) Run(f EncWriter) {
  61. clockSource := time.Tick(c.config.tryInterval)
  62. for {
  63. select {
  64. case vpnIP := <-c.trigger:
  65. l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
  66. c.handleOutbound(vpnIP, f, true)
  67. case now := <-clockSource:
  68. c.NextOutboundHandshakeTimerTick(now, f)
  69. c.NextInboundHandshakeTimerTick(now)
  70. }
  71. }
  72. }
  73. func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
  74. c.OutboundHandshakeTimer.advance(now)
  75. for {
  76. ep := c.OutboundHandshakeTimer.Purge()
  77. if ep == nil {
  78. break
  79. }
  80. vpnIP := ep.(uint32)
  81. c.handleOutbound(vpnIP, f, false)
  82. }
  83. }
  84. func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) {
  85. hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
  86. if err != nil {
  87. return
  88. }
  89. hostinfo.Lock()
  90. defer hostinfo.Unlock()
  91. // If we haven't finished the handshake and we haven't hit max retries, query
  92. // lighthouse and then send the handshake packet again.
  93. if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete {
  94. if hostinfo.remote == nil {
  95. // We continue to query the lighthouse because hosts may
  96. // come online during handshake retries. If the query
  97. // succeeds (no error), add the lighthouse info to hostinfo
  98. ips := c.lightHouse.QueryCache(vpnIP)
  99. // If we have no responses yet, or only one IP (the host hadn't
  100. // finished reporting its own IPs yet), then send another query to
  101. // the LH.
  102. if len(ips) <= 1 {
  103. ips, err = c.lightHouse.Query(vpnIP, f)
  104. }
  105. if err == nil {
  106. for _, ip := range ips {
  107. hostinfo.AddRemote(ip)
  108. }
  109. hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
  110. }
  111. } else if lighthouseTriggered {
  112. // We were triggered by a lighthouse HostQueryReply packet, but
  113. // we have already picked a remote for this host (this can happen
  114. // if we are configured with multiple lighthouses). So we can skip
  115. // this trigger and let the timerwheel handle the rest of the
  116. // process
  117. return
  118. }
  119. hostinfo.HandshakeCounter++
  120. // We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
  121. // all the others until we can stand up a connection.
  122. if hostinfo.HandshakeCounter > c.config.waitRotation {
  123. hostinfo.rotateRemote()
  124. }
  125. // Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
  126. if hostinfo.HandshakeReady && hostinfo.remote != nil {
  127. c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
  128. err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
  129. if err != nil {
  130. hostinfo.logger().WithField("udpAddr", hostinfo.remote).
  131. WithField("initiatorIndex", hostinfo.localIndexId).
  132. WithField("remoteIndex", hostinfo.remoteIndexId).
  133. WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
  134. WithError(err).Error("Failed to send handshake message")
  135. } else {
  136. //TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
  137. // keep the real packet struct around for logging purposes
  138. hostinfo.logger().WithField("udpAddr", hostinfo.remote).
  139. WithField("initiatorIndex", hostinfo.localIndexId).
  140. WithField("remoteIndex", hostinfo.remoteIndexId).
  141. WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
  142. Info("Handshake message sent")
  143. }
  144. }
  145. // Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
  146. if !lighthouseTriggered {
  147. //l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
  148. c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
  149. }
  150. } else {
  151. c.pendingHostMap.DeleteHostInfo(hostinfo)
  152. }
  153. }
  154. func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) {
  155. c.InboundHandshakeTimer.advance(now)
  156. for {
  157. ep := c.InboundHandshakeTimer.Purge()
  158. if ep == nil {
  159. break
  160. }
  161. index := ep.(uint32)
  162. c.pendingHostMap.DeleteIndex(index)
  163. }
  164. }
  165. func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
  166. hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
  167. // We lock here and use an array to insert items to prevent locking the
  168. // main receive thread for very long by waiting to add items to the pending map
  169. c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
  170. return hostinfo
  171. }
  172. var (
  173. ErrExistingHostInfo = errors.New("existing hostinfo")
  174. ErrAlreadySeen = errors.New("already seen")
  175. ErrLocalIndexCollision = errors.New("local index collision")
  176. )
  177. // CheckAndComplete checks for any conflicts in the main and pending hostmap
  178. // before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
  179. // ErrAlreadySeen if we already have an entry in the hostmap that has seen the
  180. // exact same handshake packet
  181. //
  182. // ErrExistingHostInfo if we already have an entry in the hostmap for this
  183. // VpnIP and overwrite was false.
  184. //
  185. // ErrLocalIndexCollision if we already have an entry in the main or pending
  186. // hostmap for the hostinfo.localIndexId.
  187. func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
  188. c.pendingHostMap.RLock()
  189. defer c.pendingHostMap.RUnlock()
  190. c.mainHostMap.Lock()
  191. defer c.mainHostMap.Unlock()
  192. existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
  193. if found && existingHostInfo != nil {
  194. if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
  195. return existingHostInfo, ErrAlreadySeen
  196. }
  197. if !overwrite {
  198. return existingHostInfo, ErrExistingHostInfo
  199. }
  200. }
  201. existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
  202. if found {
  203. // We have a collision, but for a different hostinfo
  204. return existingIndex, ErrLocalIndexCollision
  205. }
  206. existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
  207. if found && existingIndex != hostinfo {
  208. // We have a collision, but for a different hostinfo
  209. return existingIndex, ErrLocalIndexCollision
  210. }
  211. existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
  212. if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
  213. // We have a collision, but this can happen since we can't control
  214. // the remote ID. Just log about the situation as a note.
  215. hostinfo.logger().
  216. WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
  217. Info("New host shadows existing host remoteIndex")
  218. }
  219. if existingHostInfo != nil {
  220. // We are going to overwrite this entry, so remove the old references
  221. delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
  222. delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
  223. delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
  224. }
  225. c.mainHostMap.addHostInfo(hostinfo, f)
  226. return existingHostInfo, nil
  227. }
  228. // Complete is a simpler version of CheckAndComplete when we already know we
  229. // won't have a localIndexId collision because we already have an entry in the
  230. // pendingHostMap
  231. func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
  232. c.mainHostMap.Lock()
  233. defer c.mainHostMap.Unlock()
  234. existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
  235. if found && existingHostInfo != nil {
  236. // We are going to overwrite this entry, so remove the old references
  237. delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
  238. delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
  239. delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
  240. }
  241. existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
  242. if found && existingRemoteIndex != nil {
  243. // We have a collision, but this can happen since we can't control
  244. // the remote ID. Just log about the situation as a note.
  245. hostinfo.logger().
  246. WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
  247. Info("New host shadows existing host remoteIndex")
  248. }
  249. c.mainHostMap.addHostInfo(hostinfo, f)
  250. }
  251. // AddIndexHostInfo generates a unique localIndexId for this HostInfo
  252. // and adds it to the pendingHostMap. Will error if we are unable to generate
  253. // a unique localIndexId
  254. func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
  255. c.pendingHostMap.Lock()
  256. defer c.pendingHostMap.Unlock()
  257. c.mainHostMap.RLock()
  258. defer c.mainHostMap.RUnlock()
  259. for i := 0; i < 32; i++ {
  260. index, err := generateIndex()
  261. if err != nil {
  262. return err
  263. }
  264. _, inPending := c.pendingHostMap.Indexes[index]
  265. _, inMain := c.mainHostMap.Indexes[index]
  266. if !inMain && !inPending {
  267. h.localIndexId = index
  268. c.pendingHostMap.Indexes[index] = h
  269. return nil
  270. }
  271. }
  272. return errors.New("failed to generate unique localIndexId")
  273. }
  274. func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
  275. c.pendingHostMap.addRemoteIndexHostInfo(index, h)
  276. }
  277. func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
  278. //l.Debugln("Deleting pending hostinfo :", hostinfo)
  279. c.pendingHostMap.DeleteHostInfo(hostinfo)
  280. }
  281. func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) {
  282. return c.pendingHostMap.QueryIndex(index)
  283. }
  284. func (c *HandshakeManager) EmitStats() {
  285. c.pendingHostMap.EmitStats("pending")
  286. c.mainHostMap.EmitStats("main")
  287. }
  288. // Utility functions below
  289. func generateIndex() (uint32, error) {
  290. b := make([]byte, 4)
  291. // Let zero mean we don't know the ID, so don't generate zero
  292. var index uint32
  293. for index == 0 {
  294. _, err := rand.Read(b)
  295. if err != nil {
  296. l.Errorln(err)
  297. return 0, err
  298. }
  299. index = binary.BigEndian.Uint32(b)
  300. }
  301. if l.Level >= logrus.DebugLevel {
  302. l.WithField("index", index).
  303. Debug("Generated index")
  304. }
  305. return index, nil
  306. }