Quellcode durchsuchen

store lighthouses as a slice (#1473)

* store lighthouses as a slice. If you have fewer than 16 lighthouses (and fewer than 16 vpnaddrs on a host, I guess), it's faster
Jack Doan vor 1 Tag
Ursprung
Commit
8196c22b5a
2 geänderte Dateien mit 31 neuen und 25 gelöschten Zeilen
  1. 1 1
      connection_manager_test.go
  2. 30 24
      lighthouse.go

+ 1 - 1
connection_manager_test.go

@@ -22,7 +22,7 @@ func newTestLighthouse() *LightHouse {
 		addrMap:   map[netip.Addr]*RemoteList{},
 		queryChan: make(chan netip.Addr, 10),
 	}
-	lighthouses := map[netip.Addr]struct{}{}
+	lighthouses := []netip.Addr{}
 	staticList := map[netip.Addr]struct{}{}
 
 	lh.lighthouses.Store(&lighthouses)

+ 30 - 24
lighthouse.go

@@ -57,7 +57,7 @@ type LightHouse struct {
 	// staticList exists to avoid having a bool in each addrMap entry
 	// since static should be rare
 	staticList  atomic.Pointer[map[netip.Addr]struct{}]
-	lighthouses atomic.Pointer[map[netip.Addr]struct{}]
+	lighthouses atomic.Pointer[[]netip.Addr]
 
 	interval     atomic.Int64
 	updateCancel context.CancelFunc
@@ -108,7 +108,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
 		queryChan:          make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
 		l:                  l,
 	}
-	lighthouses := make(map[netip.Addr]struct{})
+	lighthouses := make([]netip.Addr, 0)
 	h.lighthouses.Store(&lighthouses)
 	staticList := make(map[netip.Addr]struct{})
 	h.staticList.Store(&staticList)
@@ -144,7 +144,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
 	return *lh.staticList.Load()
 }
 
-func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
+func (lh *LightHouse) GetLighthouses() []netip.Addr {
 	return *lh.lighthouses.Load()
 }
 
@@ -307,13 +307,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	}
 
 	if initial || c.HasChanged("lighthouse.hosts") {
-		lhMap := make(map[netip.Addr]struct{})
-		err := lh.parseLighthouses(c, lhMap)
+		lhList, err := lh.parseLighthouses(c)
 		if err != nil {
 			return err
 		}
 
-		lh.lighthouses.Store(&lhMap)
+		lh.lighthouses.Store(&lhList)
 		if !initial {
 			//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
 			lh.l.Info("lighthouse.hosts has changed")
@@ -347,36 +346,37 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
+func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
 	lhs := c.GetStringSlice("lighthouse.hosts", []string{})
 	if lh.amLighthouse && len(lhs) != 0 {
 		lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
 	}
+	out := make([]netip.Addr, len(lhs))
 
 	for i, host := range lhs {
 		addr, err := netip.ParseAddr(host)
 		if err != nil {
-			return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
+			return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
 		}
 
 		if !lh.myVpnNetworksTable.Contains(addr) {
-			return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
+			return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
 		}
-		lhMap[addr] = struct{}{}
+		out[i] = addr
 	}
 
-	if !lh.amLighthouse && len(lhMap) == 0 {
+	if !lh.amLighthouse && len(out) == 0 {
 		lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
 	}
 
 	staticList := lh.GetStaticHostList()
-	for lhAddr, _ := range lhMap {
-		if _, ok := staticList[lhAddr]; !ok {
-			return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
+	for i := range out {
+		if _, ok := staticList[out[i]]; !ok {
+			return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
 		}
 	}
 
-	return nil
+	return out, nil
 }
 
 func getStaticMapCadence(c *config.C) (time.Duration, error) {
@@ -711,15 +711,22 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
 }
 
 func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
-	_, ok := lh.GetLighthouses()[vpnAddr]
-	return ok
+	l := lh.GetLighthouses()
+	for i := range l {
+		if l[i] == vpnAddr {
+			return true
+		}
+	}
+	return false
 }
 
-func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
+func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
 	l := lh.GetLighthouses()
-	for _, a := range vpnAddr {
-		if _, ok := l[a]; ok {
-			return true
+	for i := range vpnAddrs {
+		for j := range l {
+			if l[j] == vpnAddrs[i] {
+				return true
+			}
 		}
 	}
 	return false
@@ -761,7 +768,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
 	queried := 0
 	lighthouses := lh.GetLighthouses()
 
-	for lhVpnAddr := range lighthouses {
+	for _, lhVpnAddr := range lighthouses {
 		hi := lh.ifce.GetHostInfo(lhVpnAddr)
 		if hi != nil {
 			v = hi.ConnectionState.myCert.Version()
@@ -879,7 +886,7 @@ func (lh *LightHouse) SendUpdate() {
 	updated := 0
 	lighthouses := lh.GetLighthouses()
 
-	for lhVpnAddr := range lighthouses {
+	for _, lhVpnAddr := range lighthouses {
 		var v cert.Version
 		hi := lh.ifce.GetHostInfo(lhVpnAddr)
 		if hi != nil {
@@ -1289,7 +1296,6 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
 func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
 	//It's possible the lighthouse is communicating with us using a non primary vpn addr,
 	//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
-	//maybe one day we'll have a better idea, if it matters.
 	if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
 		return
 	}