Procházet zdrojové kódy

Limit how often a busy tunnel can requery the lighthouse (#940)

Co-authored-by: Wade Simmons <[email protected]>
Nate Brown před 1 rokem
rodič
revize
223cc6e660
4 změnil soubory, kde provedl 62 přidání a 4 odebrání
  1. 10 0
      config/config.go
  2. 15 4
      hostmap.go
  3. 33 0
      interface.go
  4. 4 0
      main.go

+ 10 - 0
config/config.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"io/ioutil"
+	"math"
 	"os"
 	"os/signal"
 	"path/filepath"
@@ -236,6 +237,15 @@ func (c *C) GetInt(k string, d int) int {
 	return v
 }
 
+// GetUint32 will get the uint32 for k or return the default d if not found or invalid
+func (c *C) GetUint32(k string, d uint32) uint32 {
+	r := c.GetInt(k, int(d))
+	if uint64(r) > uint64(math.MaxUint32) {
+		return d
+	}
+	return uint32(r)
+}
+
 // GetBool will get the bool for k or return the default d if not found or invalid
 func (c *C) GetBool(k string, d bool) bool {
 	r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))

+ 15 - 4
hostmap.go

@@ -17,8 +17,9 @@ import (
 )
 
 // const ProbeLen = 100
-const PromoteEvery = 1000
-const ReQueryEvery = 5000
+const defaultPromoteEvery = 1000       // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
+const defaultReQueryEvery = 5000       // Count of packets sent before re-querying a hostinfo to the lighthouse
+const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
 const MaxRemotes = 10
 
 // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
@@ -215,6 +216,10 @@ type HostInfo struct {
 	remoteCidr           *cidr.Tree4
 	relayState           RelayState
 
+	// nextLHQuery is the earliest we can ask the lighthouse for new information.
+	// This is used to limit lighthouse re-queries in chatty clients
+	nextLHQuery atomic.Int64
+
 	// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
 	// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
 	// with a handshake
@@ -535,7 +540,7 @@ func (hm *HostMap) ForEachIndex(f controlEach) {
 // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
 func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
 	c := i.promoteCounter.Add(1)
-	if c%PromoteEvery == 0 {
+	if c%ifce.tryPromoteEvery.Load() == 0 {
 		// The lock here is currently protecting i.remote access
 		i.RLock()
 		remote := i.remote
@@ -563,7 +568,13 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 	}
 
 	// Re query our lighthouses for new remotes occasionally
-	if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
+	if c%ifce.reQueryEvery.Load() == 0 && ifce.lightHouse != nil {
+		now := time.Now().UnixNano()
+		if now < i.nextLHQuery.Load() {
+			return
+		}
+
+		i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
 		ifce.lightHouse.QueryServer(i.vpnIp, ifce)
 	}
 }

+ 33 - 0
interface.go

@@ -46,6 +46,10 @@ type InterfaceConfig struct {
 	relayManager            *relayManager
 	punchy                  *Punchy
 
+	tryPromoteEvery uint32
+	reQueryEvery    uint32
+	reQueryWait     time.Duration
+
 	ConntrackCacheTimeout time.Duration
 	l                     *logrus.Logger
 }
@@ -72,6 +76,10 @@ type Interface struct {
 	closed             atomic.Bool
 	relayManager       *relayManager
 
+	tryPromoteEvery atomic.Uint32
+	reQueryEvery    atomic.Uint32
+	reQueryWait     atomic.Int64
+
 	sendRecvErrorConfig sendRecvErrorConfig
 
 	// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
@@ -186,6 +194,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		l: c.l,
 	}
 
+	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
+	ifce.reQueryEvery.Store(c.reQueryEvery)
+	ifce.reQueryWait.Store(int64(c.reQueryWait))
+
 	ifce.certState.Store(c.certState)
 	ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
 
@@ -287,6 +299,7 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
 	c.RegisterReloadCallback(f.reloadCertKey)
 	c.RegisterReloadCallback(f.reloadFirewall)
 	c.RegisterReloadCallback(f.reloadSendRecvError)
+	c.RegisterReloadCallback(f.reloadMisc)
 	for _, udpConn := range f.writers {
 		c.RegisterReloadCallback(udpConn.ReloadConfig)
 	}
@@ -389,6 +402,26 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
 	}
 }
 
+func (f *Interface) reloadMisc(c *config.C) {
+	if c.HasChanged("counters.try_promote") {
+		n := c.GetUint32("counters.try_promote", defaultPromoteEvery)
+		f.tryPromoteEvery.Store(n)
+		f.l.Info("counters.try_promote has changed")
+	}
+
+	if c.HasChanged("counters.requery_every_packets") {
+		n := c.GetUint32("counters.requery_every_packets", defaultReQueryEvery)
+		f.reQueryEvery.Store(n)
+		f.l.Info("counters.requery_every_packets has changed")
+	}
+
+	if c.HasChanged("timers.requery_wait_duration") {
+		n := c.GetDuration("timers.requery_wait_duration", defaultReQueryWait)
+		f.reQueryWait.Store(int64(n))
+		f.l.Info("timers.requery_wait_duration has changed")
+	}
+}
+
 func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 	ticker := time.NewTicker(i)
 	defer ticker.Stop()

+ 4 - 0
main.go

@@ -261,6 +261,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	checkInterval := c.GetInt("timers.connection_alive_interval", 5)
 	pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
+
 	ifConfig := &InterfaceConfig{
 		HostMap:                 hostMap,
 		Inside:                  tun,
@@ -273,6 +274,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		lightHouse:              lightHouse,
 		checkInterval:           time.Second * time.Duration(checkInterval),
 		pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
+		tryPromoteEvery:         c.GetUint32("counters.try_promote", defaultPromoteEvery),
+		reQueryEvery:            c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
+		reQueryWait:             c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
 		DropLocalBroadcast:      c.GetBool("tun.drop_local_broadcast", false),
 		DropMulticast:           c.GetBool("tun.drop_multicast", false),
 		routines:                routines,