handshake_manager.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. package nebula
  2. import (
  3. "crypto/rand"
  4. "encoding/binary"
  5. "fmt"
  6. "net"
  7. "time"
  8. "github.com/sirupsen/logrus"
  9. )
  10. const (
  11. // Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries
  12. // With 100ms interval and 20 retries is 23.5 seconds
  13. HandshakeTryInterval = time.Millisecond * 100
  14. HandshakeRetries = 20
  15. // HandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
  16. HandshakeWaitRotation = 5
  17. )
  18. type HandshakeManager struct {
  19. pendingHostMap *HostMap
  20. mainHostMap *HostMap
  21. lightHouse *LightHouse
  22. outside *udpConn
  23. OutboundHandshakeTimer *SystemTimerWheel
  24. InboundHandshakeTimer *SystemTimerWheel
  25. }
  26. func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn) *HandshakeManager {
  27. return &HandshakeManager{
  28. pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges),
  29. mainHostMap: mainHostMap,
  30. lightHouse: lightHouse,
  31. outside: outside,
  32. OutboundHandshakeTimer: NewSystemTimerWheel(HandshakeTryInterval, HandshakeTryInterval*HandshakeRetries),
  33. InboundHandshakeTimer: NewSystemTimerWheel(HandshakeTryInterval, HandshakeTryInterval*HandshakeRetries),
  34. }
  35. }
  36. func (c *HandshakeManager) Run(f EncWriter) {
  37. clockSource := time.Tick(HandshakeTryInterval)
  38. for now := range clockSource {
  39. c.NextOutboundHandshakeTimerTick(now, f)
  40. c.NextInboundHandshakeTimerTick(now)
  41. }
  42. }
  43. func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
  44. c.OutboundHandshakeTimer.advance(now)
  45. for {
  46. ep := c.OutboundHandshakeTimer.Purge()
  47. if ep == nil {
  48. break
  49. }
  50. vpnIP := ep.(uint32)
  51. index, err := c.pendingHostMap.GetIndexByVpnIP(vpnIP)
  52. if err != nil {
  53. continue
  54. }
  55. hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
  56. if err != nil {
  57. continue
  58. }
  59. // If we haven't finished the handshake and we haven't hit max retries, query
  60. // lighthouse and then send the handshake packet again.
  61. if hostinfo.HandshakeCounter < HandshakeRetries && !hostinfo.HandshakeComplete {
  62. if hostinfo.remote == nil {
  63. // We continue to query the lighthouse because hosts may
  64. // come online during handshake retries. If the query
  65. // succeeds (no error), add the lighthouse info to hostinfo
  66. ips, err := c.lightHouse.Query(vpnIP, f)
  67. if err == nil {
  68. for _, ip := range ips {
  69. hostinfo.AddRemote(ip)
  70. }
  71. hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
  72. }
  73. }
  74. hostinfo.HandshakeCounter++
  75. // We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
  76. // all the others until we can stand up a connection.
  77. if hostinfo.HandshakeCounter > HandshakeWaitRotation {
  78. hostinfo.rotateRemote()
  79. }
  80. // Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
  81. if hostinfo.HandshakeReady && hostinfo.remote != nil {
  82. err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
  83. if err != nil {
  84. l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", hostinfo.remote).
  85. WithField("initiatorIndex", hostinfo.localIndexId).
  86. WithField("remoteIndex", hostinfo.remoteIndexId).
  87. WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
  88. WithError(err).Error("Failed to send handshake message")
  89. } else {
  90. //TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
  91. // keep the real packet struct around for logging purposes
  92. l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", hostinfo.remote).
  93. WithField("initiatorIndex", hostinfo.localIndexId).
  94. WithField("remoteIndex", hostinfo.remoteIndexId).
  95. WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
  96. Info("Handshake message sent")
  97. }
  98. }
  99. // Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
  100. //l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
  101. c.OutboundHandshakeTimer.Add(vpnIP, HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
  102. } else {
  103. c.pendingHostMap.DeleteVpnIP(vpnIP)
  104. c.pendingHostMap.DeleteIndex(index)
  105. }
  106. }
  107. }
  108. func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) {
  109. c.InboundHandshakeTimer.advance(now)
  110. for {
  111. ep := c.InboundHandshakeTimer.Purge()
  112. if ep == nil {
  113. break
  114. }
  115. index := ep.(uint32)
  116. vpnIP, err := c.pendingHostMap.GetVpnIPByIndex(index)
  117. if err != nil {
  118. continue
  119. }
  120. c.pendingHostMap.DeleteIndex(index)
  121. c.pendingHostMap.DeleteVpnIP(vpnIP)
  122. }
  123. }
  124. func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
  125. hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
  126. // We lock here and use an array to insert items to prevent locking the
  127. // main receive thread for very long by waiting to add items to the pending map
  128. c.OutboundHandshakeTimer.Add(vpnIP, HandshakeTryInterval)
  129. return hostinfo
  130. }
  131. func (c *HandshakeManager) DeleteVpnIP(vpnIP uint32) {
  132. //l.Debugln("Deleting pending vpn ip :", IntIp(vpnIP))
  133. c.pendingHostMap.DeleteVpnIP(vpnIP)
  134. }
  135. func (c *HandshakeManager) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) {
  136. hostinfo, err := c.pendingHostMap.AddIndex(index, ci)
  137. if err != nil {
  138. return nil, fmt.Errorf("Issue adding index: %d", index)
  139. }
  140. //c.mainHostMap.AddIndexHostInfo(index, hostinfo)
  141. c.InboundHandshakeTimer.Add(index, time.Second*10)
  142. return hostinfo, nil
  143. }
  144. func (c *HandshakeManager) AddIndexHostInfo(index uint32, h *HostInfo) {
  145. c.pendingHostMap.AddIndexHostInfo(index, h)
  146. }
  147. func (c *HandshakeManager) DeleteIndex(index uint32) {
  148. //l.Debugln("Deleting pending index :", index)
  149. c.pendingHostMap.DeleteIndex(index)
  150. }
  151. func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) {
  152. return c.pendingHostMap.QueryIndex(index)
  153. }
  154. func (c *HandshakeManager) EmitStats() {
  155. c.pendingHostMap.EmitStats("pending")
  156. c.mainHostMap.EmitStats("main")
  157. }
  158. // Utility functions below
  159. func generateIndex() (uint32, error) {
  160. b := make([]byte, 4)
  161. _, err := rand.Read(b)
  162. if err != nil {
  163. l.Errorln(err)
  164. return 0, err
  165. }
  166. index := binary.BigEndian.Uint32(b)
  167. if l.Level >= logrus.DebugLevel {
  168. l.WithField("index", index).
  169. Debug("Generated index")
  170. }
  171. return index, nil
  172. }