Przeglądaj źródła

Support reloading preferred_ranges (#1043)

Nate Brown 1 rok temu
rodzic
commit
a390125935
11 zmienionych plików z 110 dodań i 84 usunięć
  1. 1 1
      connection_manager.go
  2. 8 3
      connection_manager_test.go
  3. 2 2
      control.go
  4. 3 1
      control_test.go
  5. 1 1
      handshake_ix.go
  6. 5 5
      handshake_manager.go
  7. 3 1
      handshake_manager_test.go
  8. 52 19
      hostmap.go
  9. 33 4
      hostmap_test.go
  10. 1 46
      main.go
  11. 1 1
      ssh.go

+ 1 - 1
connection_manager.go

@@ -457,7 +457,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 	}
 
 	if n.punchy.GetTargetEverything() {
-		hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) {
+		hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) {
 			n.metricsTxPunchy.Inc(1)
 			n.intf.outside.WriteTo([]byte{1}, addr)
 		})

+ 8 - 3
connection_manager_test.go

@@ -43,7 +43,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	preferredRanges := []*net.IPNet{localrange}
 
 	// Very incomplete mock objects
-	hostMap := NewHostMap(l, vpncidr, preferredRanges)
+	hostMap := newHostMap(l, vpncidr)
+	hostMap.preferredRanges.Store(&preferredRanges)
+
 	cs := &CertState{
 		RawCertificate:      []byte{},
 		PrivateKey:          []byte{},
@@ -123,7 +125,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	preferredRanges := []*net.IPNet{localrange}
 
 	// Very incomplete mock objects
-	hostMap := NewHostMap(l, vpncidr, preferredRanges)
+	hostMap := newHostMap(l, vpncidr)
+	hostMap.preferredRanges.Store(&preferredRanges)
+
 	cs := &CertState{
 		RawCertificate:      []byte{},
 		PrivateKey:          []byte{},
@@ -210,7 +214,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	preferredRanges := []*net.IPNet{localrange}
-	hostMap := NewHostMap(l, vpncidr, preferredRanges)
+	hostMap := newHostMap(l, vpncidr)
+	hostMap.preferredRanges.Store(&preferredRanges)
 
 	// Generate keys for CA and peer's cert.
 	pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)

+ 2 - 2
control.go

@@ -145,7 +145,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
 		return nil
 	}
 
-	ch := copyHostInfo(h, c.f.hostMap.preferredRanges)
+	ch := copyHostInfo(h, c.f.hostMap.GetPreferredRanges())
 	return &ch
 }
 
@@ -157,7 +157,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control
 	}
 
 	hostInfo.SetRemote(addr.Copy())
-	ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges)
+	ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
 	return &ch
 }
 

+ 3 - 1
control_test.go

@@ -18,7 +18,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	l := test.NewLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
-	hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0))
+	hm := newHostMap(l, &net.IPNet{})
+	hm.preferredRanges.Store(&[]*net.IPNet{})
+
 	remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
 	remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
 	ipNet := net.IPNet{

+ 1 - 1
handshake_ix.go

@@ -406,7 +406,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 			hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
 
 			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
-				WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
+				WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
 				Info("Blocked addresses for handshakes")
 
 			// Swap the packet store to benefit the original intended recipient

+ 5 - 5
handshake_manager.go

@@ -181,7 +181,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 	hostinfo := hh.hostinfo
 	// If we are out of time, clean up
 	if hh.counter >= hm.config.retries {
-		hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)).
+		hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())).
 			WithField("initiatorIndex", hh.hostinfo.localIndexId).
 			WithField("remoteIndex", hh.hostinfo.remoteIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@@ -211,7 +211,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 		hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
 	}
 
-	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)
+	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
 	remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
 
 	// We only care about a lighthouse trigger if we have new remotes to send to.
@@ -235,7 +235,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 
 	// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
 	var sentTo []*udp.Addr
-	hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
+	hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) {
 		hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
 		err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
 		if err != nil {
@@ -362,7 +362,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
 		hm.mainHostMap.RUnlock()
 		// Do not attempt promotion if you are a lighthouse
 		if !hm.lightHouse.amLighthouse {
-			h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f)
+			h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f)
 		}
 		return h, true
 	}
@@ -599,7 +599,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
 }
 
 func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
-	return c.mainHostMap.preferredRanges
+	return c.mainHostMap.GetPreferredRanges()
 }
 
 func (c *HandshakeManager) ForEachVpnIp(f controlEach) {

+ 3 - 1
handshake_manager_test.go

@@ -19,7 +19,9 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
-	mainHM := NewHostMap(l, vpncidr, preferredRanges)
+	mainHM := newHostMap(l, vpncidr)
+	mainHM.preferredRanges.Store(&preferredRanges)
+
 	lh := newTestLighthouse()
 
 	cs := &CertState{

+ 52 - 19
hostmap.go

@@ -11,6 +11,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
@@ -57,9 +58,8 @@ type HostMap struct {
 	Relays          map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
 	RemoteIndexes   map[uint32]*HostInfo
 	Hosts           map[iputil.VpnIp]*HostInfo
-	preferredRanges []*net.IPNet
+	preferredRanges atomic.Pointer[[]*net.IPNet]
 	vpnCIDR         *net.IPNet
-	metricsEnabled  bool
 	l               *logrus.Logger
 }
 
@@ -254,21 +254,53 @@ type cachedPacketMetrics struct {
 	dropped metrics.Counter
 }
 
-func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
-	h := map[iputil.VpnIp]*HostInfo{}
-	i := map[uint32]*HostInfo{}
-	r := map[uint32]*HostInfo{}
-	relays := map[uint32]*HostInfo{}
-	m := HostMap{
-		Indexes:         i,
-		Relays:          relays,
-		RemoteIndexes:   r,
-		Hosts:           h,
-		preferredRanges: preferredRanges,
-		vpnCIDR:         vpnCIDR,
-		l:               l,
+func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap {
+	hm := newHostMap(l, vpnCIDR)
+
+	hm.reload(c, true)
+	c.RegisterReloadCallback(func(c *config.C) {
+		hm.reload(c, false)
+	})
+
+	l.WithField("network", hm.vpnCIDR.String()).
+		WithField("preferredRanges", hm.GetPreferredRanges()).
+		Info("Main HostMap created")
+
+	return hm
+}
+
+func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
+	return &HostMap{
+		Indexes:       map[uint32]*HostInfo{},
+		Relays:        map[uint32]*HostInfo{},
+		RemoteIndexes: map[uint32]*HostInfo{},
+		Hosts:         map[iputil.VpnIp]*HostInfo{},
+		vpnCIDR:       vpnCIDR,
+		l:             l,
+	}
+}
+
+func (hm *HostMap) reload(c *config.C, initial bool) {
+	if initial || c.HasChanged("preferred_ranges") {
+		var preferredRanges []*net.IPNet
+		rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
+
+		for _, rawPreferredRange := range rawPreferredRanges {
+			_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
+
+			if err != nil {
+				hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
+				continue
+			}
+
+			preferredRanges = append(preferredRanges, preferredRange)
+		}
+
+		oldRanges := hm.preferredRanges.Swap(&preferredRanges)
+		if !initial {
+			hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
+		}
 	}
-	return &m
 }
 
 // EmitStats reports host, index, and relay counts to the stats collection system
@@ -457,7 +489,7 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostI
 		hm.RUnlock()
 		// Do not attempt promotion if you are a lighthouse
 		if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
-			h.TryPromoteBest(hm.preferredRanges, promoteIfce)
+			h.TryPromoteBest(hm.GetPreferredRanges(), promoteIfce)
 		}
 		return h
 
@@ -504,7 +536,8 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 }
 
 func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
-	return hm.preferredRanges
+	//NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
+	return *hm.preferredRanges.Load()
 }
 
 func (hm *HostMap) ForEachVpnIp(f controlEach) {
@@ -596,7 +629,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
 	// NOTE: We do this loop here instead of calling `isPreferred` in
 	// remote_list.go so that we only have to loop over preferredRanges once.
 	newIsPreferred := false
-	for _, l := range hm.preferredRanges {
+	for _, l := range hm.GetPreferredRanges() {
 		// return early if we are already on a preferred remote
 		if l.Contains(currentRemote.IP) {
 			return false

+ 33 - 4
hostmap_test.go

@@ -4,19 +4,19 @@ import (
 	"net"
 	"testing"
 
+	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestHostMap_MakePrimary(t *testing.T) {
 	l := test.NewLogger()
-	hm := NewHostMap(
+	hm := newHostMap(
 		l,
 		&net.IPNet{
 			IP:   net.IP{10, 0, 0, 1},
 			Mask: net.IPMask{255, 255, 255, 0},
 		},
-		[]*net.IPNet{},
 	)
 
 	f := &Interface{}
@@ -91,13 +91,12 @@ func TestHostMap_MakePrimary(t *testing.T) {
 
 func TestHostMap_DeleteHostInfo(t *testing.T) {
 	l := test.NewLogger()
-	hm := NewHostMap(
+	hm := newHostMap(
 		l,
 		&net.IPNet{
 			IP:   net.IP{10, 0, 0, 1},
 			Mask: net.IPMask{255, 255, 255, 0},
 		},
-		[]*net.IPNet{},
 	)
 
 	f := &Interface{}
@@ -205,3 +204,33 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	prim = hm.QueryVpnIp(1)
 	assert.Nil(t, prim)
 }
+
+func TestHostMap_reload(t *testing.T) {
+	l := test.NewLogger()
+	c := config.NewC(l)
+
+	hm := NewHostMapFromConfig(
+		l,
+		&net.IPNet{
+			IP:   net.IP{10, 0, 0, 1},
+			Mask: net.IPMask{255, 255, 255, 0},
+		},
+		c,
+	)
+
+	toS := func(ipn []*net.IPNet) []string {
+		var s []string
+		for _, n := range ipn {
+			s = append(s, n.String())
+		}
+		return s
+	}
+
+	assert.Empty(t, hm.GetPreferredRanges())
+
+	c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
+	assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
+
+	c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
+	assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
+}

+ 1 - 46
main.go

@@ -183,52 +183,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 	}
 
-	// Set up my internal host map
-	var preferredRanges []*net.IPNet
-	rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
-	// First, check if 'preferred_ranges' is set and fallback to 'local_range'
-	if len(rawPreferredRanges) > 0 {
-		for _, rawPreferredRange := range rawPreferredRanges {
-			_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
-			if err != nil {
-				return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err)
-			}
-			preferredRanges = append(preferredRanges, preferredRange)
-		}
-	}
-
-	// local_range was superseded by preferred_ranges. If it is still present,
-	// merge the local_range setting into preferred_ranges. We will probably
-	// deprecate local_range and remove in the future.
-	rawLocalRange := c.GetString("local_range", "")
-	if rawLocalRange != "" {
-		_, localRange, err := net.ParseCIDR(rawLocalRange)
-		if err != nil {
-			return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err)
-		}
-
-		// Check if the entry for local_range was already specified in
-		// preferred_ranges. Don't put it into the slice twice if so.
-		var found bool
-		for _, r := range preferredRanges {
-			if r.String() == localRange.String() {
-				found = true
-				break
-			}
-		}
-		if !found {
-			preferredRanges = append(preferredRanges, localRange)
-		}
-	}
-
-	hostMap := NewHostMap(l, tunCidr, preferredRanges)
-	hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
-
-	l.
-		WithField("network", hostMap.vpnCIDR.String()).
-		WithField("preferredRanges", hostMap.preferredRanges).
-		Info("Main HostMap created")
-
+	hostMap := NewHostMapFromConfig(l, tunCidr, c)
 	punchy := NewPunchyFromConfig(l, c)
 	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
 	if err != nil {

+ 1 - 1
ssh.go

@@ -939,7 +939,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		enc.SetIndent("", "    ")
 	}
 
-	return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges))
+	return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
 }
 
 func sshReload(c *config.C, w sshd.StringWriter) error {