浏览代码

Fix reconfig freeze attempting to send to an unbuffered, unread channel (#886)

* Fixes a reocnfig freeze where the reconfig attempts to send to an unbuffered channel with no readers.
Only create stop channel when a DNS goroutine is created, and only send when the channel exists.
Buffer to size 1 so that the stop message can be immediately sent even if the goroutine is busy doing DNS lookups.
brad-defined 2 年之前
父节点
当前提交
96f4dcaab8
共有 3 个文件被更改,包括 30 次插入9 次删除
  1. 12 0
      lighthouse.go
  2. 12 2
      lighthouse_test.go
  3. 6 7
      remote_list.go

+ 12 - 0
lighthouse.go

@@ -262,6 +262,18 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 
 
 	//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
 	//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
 	if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
 	if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
+		// Clean up. Entries still in the static_host_map will be re-built.
+		// Entries no longer present must have their (possible) background DNS goroutines stopped.
+		if existingStaticList := lh.staticList.Load(); existingStaticList != nil {
+			lh.RLock()
+			for staticVpnIp := range *existingStaticList {
+				if am, ok := lh.addrMap[staticVpnIp]; ok && am != nil {
+					am.hr.Cancel()
+				}
+			}
+			lh.RUnlock()
+		}
+		// Build a new list based on current config.
 		staticList := make(map[iputil.VpnIp]struct{})
 		staticList := make(map[iputil.VpnIp]struct{})
 		err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
 		err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
 		if err != nil {
 		if err != nil {

+ 12 - 2
lighthouse_test.go

@@ -12,6 +12,7 @@ import (
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
+	"gopkg.in/yaml.v2"
 )
 )
 
 
 //TODO: Add a test to ensure udpAddr is copied and not reused
 //TODO: Add a test to ensure udpAddr is copied and not reused
@@ -242,8 +243,17 @@ func TestLighthouse_reload(t *testing.T) {
 	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
 	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
-	c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}}
-	lh.reload(c, false)
+	nc := map[interface{}]interface{}{
+		"static_host_map": map[interface{}]interface{}{
+			"10.128.0.2": []interface{}{"1.1.1.1:4242"},
+		},
+	}
+	rc, err := yaml.Marshal(nc)
+	assert.NoError(t, err)
+	c.ReloadConfigString(string(rc))
+
+	err = lh.reload(c, false)
+	assert.NoError(t, err)
 }
 }
 
 
 func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {
 func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {

+ 6 - 7
remote_list.go

@@ -70,7 +70,7 @@ type hostnamesResults struct {
 	hostnames     []hostnamePort
 	hostnames     []hostnamePort
 	network       string
 	network       string
 	lookupTimeout time.Duration
 	lookupTimeout time.Duration
-	stop          chan struct{}
+	cancelFn      func()
 	l             *logrus.Logger
 	l             *logrus.Logger
 	ips           atomic.Pointer[map[netip.AddrPort]struct{}]
 	ips           atomic.Pointer[map[netip.AddrPort]struct{}]
 }
 }
@@ -80,7 +80,6 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
 		hostnames:     make([]hostnamePort, len(hostPorts)),
 		hostnames:     make([]hostnamePort, len(hostPorts)),
 		network:       network,
 		network:       network,
 		lookupTimeout: timeout,
 		lookupTimeout: timeout,
-		stop:          make(chan (struct{})),
 		l:             l,
 		l:             l,
 	}
 	}
 
 
@@ -115,6 +114,8 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
 
 
 	// Time for the DNS lookup goroutine
 	// Time for the DNS lookup goroutine
 	if performBackgroundLookup {
 	if performBackgroundLookup {
+		newCtx, cancel := context.WithCancel(ctx)
+		r.cancelFn = cancel
 		ticker := time.NewTicker(d)
 		ticker := time.NewTicker(d)
 		go func() {
 		go func() {
 			defer ticker.Stop()
 			defer ticker.Stop()
@@ -154,9 +155,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
 					onUpdate()
 					onUpdate()
 				}
 				}
 				select {
 				select {
-				case <-ctx.Done():
-					return
-				case <-r.stop:
+				case <-newCtx.Done():
 					return
 					return
 				case <-ticker.C:
 				case <-ticker.C:
 					continue
 					continue
@@ -169,8 +168,8 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
 }
 }
 
 
 func (hr *hostnamesResults) Cancel() {
 func (hr *hostnamesResults) Cancel() {
-	if hr != nil {
-		hr.stop <- struct{}{}
+	if hr != nil && hr.cancelFn != nil {
+		hr.cancelFn()
 	}
 	}
 }
 }