3
0

handshake_manager.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. package nebula
  2. import (
  3. "bytes"
  4. "crypto/rand"
  5. "encoding/binary"
  6. "errors"
  7. "net"
  8. "time"
  9. "github.com/sirupsen/logrus"
  10. )
  11. const (
  12. // Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries
  13. // With 100ms interval and 20 retries is 23.5 seconds
  14. DefaultHandshakeTryInterval = time.Millisecond * 100
  15. DefaultHandshakeRetries = 20
  16. // DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
  17. DefaultHandshakeWaitRotation = 5
  18. DefaultHandshakeTriggerBuffer = 64
  19. )
  20. var (
  21. defaultHandshakeConfig = HandshakeConfig{
  22. tryInterval: DefaultHandshakeTryInterval,
  23. retries: DefaultHandshakeRetries,
  24. waitRotation: DefaultHandshakeWaitRotation,
  25. triggerBuffer: DefaultHandshakeTriggerBuffer,
  26. }
  27. )
  28. type HandshakeConfig struct {
  29. tryInterval time.Duration
  30. retries int
  31. waitRotation int
  32. triggerBuffer int
  33. messageMetrics *MessageMetrics
  34. }
  35. type HandshakeManager struct {
  36. pendingHostMap *HostMap
  37. mainHostMap *HostMap
  38. lightHouse *LightHouse
  39. outside *udpConn
  40. config HandshakeConfig
  41. // can be used to trigger outbound handshake for the given vpnIP
  42. trigger chan uint32
  43. OutboundHandshakeTimer *SystemTimerWheel
  44. InboundHandshakeTimer *SystemTimerWheel
  45. messageMetrics *MessageMetrics
  46. l *logrus.Logger
  47. }
  48. func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
  49. return &HandshakeManager{
  50. pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
  51. mainHostMap: mainHostMap,
  52. lightHouse: lightHouse,
  53. outside: outside,
  54. config: config,
  55. trigger: make(chan uint32, config.triggerBuffer),
  56. OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
  57. InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
  58. messageMetrics: config.messageMetrics,
  59. l: l,
  60. }
  61. }
  62. func (c *HandshakeManager) Run(f EncWriter) {
  63. clockSource := time.Tick(c.config.tryInterval)
  64. for {
  65. select {
  66. case vpnIP := <-c.trigger:
  67. c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
  68. c.handleOutbound(vpnIP, f, true)
  69. case now := <-clockSource:
  70. c.NextOutboundHandshakeTimerTick(now, f)
  71. c.NextInboundHandshakeTimerTick(now)
  72. }
  73. }
  74. }
  75. func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
  76. c.OutboundHandshakeTimer.advance(now)
  77. for {
  78. ep := c.OutboundHandshakeTimer.Purge()
  79. if ep == nil {
  80. break
  81. }
  82. vpnIP := ep.(uint32)
  83. c.handleOutbound(vpnIP, f, false)
  84. }
  85. }
  86. func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) {
  87. hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
  88. if err != nil {
  89. return
  90. }
  91. hostinfo.Lock()
  92. defer hostinfo.Unlock()
  93. // If we haven't finished the handshake and we haven't hit max retries, query
  94. // lighthouse and then send the handshake packet again.
  95. if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete {
  96. if hostinfo.remote == nil {
  97. // We continue to query the lighthouse because hosts may
  98. // come online during handshake retries. If the query
  99. // succeeds (no error), add the lighthouse info to hostinfo
  100. ips := c.lightHouse.QueryCache(vpnIP)
  101. // If we have no responses yet, or only one IP (the host hadn't
  102. // finished reporting its own IPs yet), then send another query to
  103. // the LH.
  104. if len(ips) <= 1 {
  105. ips, err = c.lightHouse.Query(vpnIP, f)
  106. }
  107. if err == nil {
  108. for _, ip := range ips {
  109. hostinfo.AddRemote(ip)
  110. }
  111. hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
  112. }
  113. } else if lighthouseTriggered {
  114. // We were triggered by a lighthouse HostQueryReply packet, but
  115. // we have already picked a remote for this host (this can happen
  116. // if we are configured with multiple lighthouses). So we can skip
  117. // this trigger and let the timerwheel handle the rest of the
  118. // process
  119. return
  120. }
  121. hostinfo.HandshakeCounter++
  122. // We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
  123. // all the others until we can stand up a connection.
  124. if hostinfo.HandshakeCounter > c.config.waitRotation {
  125. hostinfo.rotateRemote()
  126. }
  127. // Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
  128. if hostinfo.HandshakeReady && hostinfo.remote != nil {
  129. c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
  130. err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
  131. if err != nil {
  132. hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
  133. WithField("initiatorIndex", hostinfo.localIndexId).
  134. WithField("remoteIndex", hostinfo.remoteIndexId).
  135. WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
  136. WithError(err).Error("Failed to send handshake message")
  137. } else {
  138. //TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
  139. // keep the real packet struct around for logging purposes
  140. hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
  141. WithField("initiatorIndex", hostinfo.localIndexId).
  142. WithField("remoteIndex", hostinfo.remoteIndexId).
  143. WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
  144. Info("Handshake message sent")
  145. }
  146. }
  147. // Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
  148. if !lighthouseTriggered {
  149. //l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
  150. c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
  151. }
  152. } else {
  153. c.pendingHostMap.DeleteHostInfo(hostinfo)
  154. }
  155. }
  156. func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) {
  157. c.InboundHandshakeTimer.advance(now)
  158. for {
  159. ep := c.InboundHandshakeTimer.Purge()
  160. if ep == nil {
  161. break
  162. }
  163. index := ep.(uint32)
  164. c.pendingHostMap.DeleteIndex(index)
  165. }
  166. }
  167. func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
  168. hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
  169. // We lock here and use an array to insert items to prevent locking the
  170. // main receive thread for very long by waiting to add items to the pending map
  171. c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
  172. return hostinfo
  173. }
  174. var (
  175. ErrExistingHostInfo = errors.New("existing hostinfo")
  176. ErrAlreadySeen = errors.New("already seen")
  177. ErrLocalIndexCollision = errors.New("local index collision")
  178. )
  179. // CheckAndComplete checks for any conflicts in the main and pending hostmap
  180. // before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
  181. // ErrAlreadySeen if we already have an entry in the hostmap that has seen the
  182. // exact same handshake packet
  183. //
  184. // ErrExistingHostInfo if we already have an entry in the hostmap for this
  185. // VpnIP and overwrite was false.
  186. //
  187. // ErrLocalIndexCollision if we already have an entry in the main or pending
  188. // hostmap for the hostinfo.localIndexId.
  189. func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
  190. c.pendingHostMap.RLock()
  191. defer c.pendingHostMap.RUnlock()
  192. c.mainHostMap.Lock()
  193. defer c.mainHostMap.Unlock()
  194. existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
  195. if found && existingHostInfo != nil {
  196. if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
  197. return existingHostInfo, ErrAlreadySeen
  198. }
  199. if !overwrite {
  200. return existingHostInfo, ErrExistingHostInfo
  201. }
  202. }
  203. existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
  204. if found {
  205. // We have a collision, but for a different hostinfo
  206. return existingIndex, ErrLocalIndexCollision
  207. }
  208. existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
  209. if found && existingIndex != hostinfo {
  210. // We have a collision, but for a different hostinfo
  211. return existingIndex, ErrLocalIndexCollision
  212. }
  213. existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
  214. if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
  215. // We have a collision, but this can happen since we can't control
  216. // the remote ID. Just log about the situation as a note.
  217. hostinfo.logger(c.l).
  218. WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
  219. Info("New host shadows existing host remoteIndex")
  220. }
  221. if existingHostInfo != nil {
  222. // We are going to overwrite this entry, so remove the old references
  223. delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
  224. delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
  225. delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
  226. }
  227. c.mainHostMap.addHostInfo(hostinfo, f)
  228. return existingHostInfo, nil
  229. }
  230. // Complete is a simpler version of CheckAndComplete when we already know we
  231. // won't have a localIndexId collision because we already have an entry in the
  232. // pendingHostMap
  233. func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
  234. c.mainHostMap.Lock()
  235. defer c.mainHostMap.Unlock()
  236. existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
  237. if found && existingHostInfo != nil {
  238. // We are going to overwrite this entry, so remove the old references
  239. delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
  240. delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
  241. delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
  242. }
  243. existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
  244. if found && existingRemoteIndex != nil {
  245. // We have a collision, but this can happen since we can't control
  246. // the remote ID. Just log about the situation as a note.
  247. hostinfo.logger(c.l).
  248. WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
  249. Info("New host shadows existing host remoteIndex")
  250. }
  251. c.mainHostMap.addHostInfo(hostinfo, f)
  252. }
  253. // AddIndexHostInfo generates a unique localIndexId for this HostInfo
  254. // and adds it to the pendingHostMap. Will error if we are unable to generate
  255. // a unique localIndexId
  256. func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
  257. c.pendingHostMap.Lock()
  258. defer c.pendingHostMap.Unlock()
  259. c.mainHostMap.RLock()
  260. defer c.mainHostMap.RUnlock()
  261. for i := 0; i < 32; i++ {
  262. index, err := generateIndex(c.l)
  263. if err != nil {
  264. return err
  265. }
  266. _, inPending := c.pendingHostMap.Indexes[index]
  267. _, inMain := c.mainHostMap.Indexes[index]
  268. if !inMain && !inPending {
  269. h.localIndexId = index
  270. c.pendingHostMap.Indexes[index] = h
  271. return nil
  272. }
  273. }
  274. return errors.New("failed to generate unique localIndexId")
  275. }
  276. func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
  277. c.pendingHostMap.addRemoteIndexHostInfo(index, h)
  278. }
  279. func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
  280. //l.Debugln("Deleting pending hostinfo :", hostinfo)
  281. c.pendingHostMap.DeleteHostInfo(hostinfo)
  282. }
  283. func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) {
  284. return c.pendingHostMap.QueryIndex(index)
  285. }
  286. func (c *HandshakeManager) EmitStats() {
  287. c.pendingHostMap.EmitStats("pending")
  288. c.mainHostMap.EmitStats("main")
  289. }
  290. // Utility functions below
  291. func generateIndex(l *logrus.Logger) (uint32, error) {
  292. b := make([]byte, 4)
  293. // Let zero mean we don't know the ID, so don't generate zero
  294. var index uint32
  295. for index == 0 {
  296. _, err := rand.Read(b)
  297. if err != nil {
  298. l.Errorln(err)
  299. return 0, err
  300. }
  301. index = binary.BigEndian.Uint32(b)
  302. }
  303. if l.Level >= logrus.DebugLevel {
  304. l.WithField("index", index).
  305. Debug("Generated index")
  306. }
  307. return index, nil
  308. }