3
0
Эх сурвалжийг харах

Lighthouse reload support (#649)

Co-authored-by: John Maguire <[email protected]>
Nate Brown 3 жил өмнө
parent
commit
312a01dc09
13 өөрчлөгдсөн 468 нэмэгдсэн , 219 устгасан
  1. 26 0
      config/config.go
  2. 3 3
      connection_manager_test.go
  3. 2 1
      control.go
  4. 1 1
      handshake.go
  5. 2 2
      handshake_ix.go
  6. 11 2
      handshake_manager_test.go
  7. 1 1
      inside.go
  8. 236 60
      lighthouse.go
  9. 65 30
      lighthouse_test.go
  10. 10 87
      main.go
  11. 1 1
      outside.go
  12. 72 17
      punchy.go
  13. 38 14
      punchy_test.go

+ 26 - 0
config/config.go

@@ -11,6 +11,7 @@ import (
 	"sort"
 	"strconv"
 	"strings"
+	"sync"
 	"syscall"
 	"time"
 
@@ -26,6 +27,7 @@ type C struct {
 	oldSettings map[interface{}]interface{}
 	callbacks   []func(*C)
 	l           *logrus.Logger
+	reloadLock  sync.Mutex
 }
 
 func NewC(l *logrus.Logger) *C {
@@ -133,6 +135,9 @@ func (c *C) CatchHUP(ctx context.Context) {
 }
 
 func (c *C) ReloadConfig() {
+	c.reloadLock.Lock()
+	defer c.reloadLock.Unlock()
+
 	c.oldSettings = make(map[interface{}]interface{})
 	for k, v := range c.Settings {
 		c.oldSettings[k] = v
@@ -149,6 +154,27 @@ func (c *C) ReloadConfig() {
 	}
 }
 
+func (c *C) ReloadConfigString(raw string) error {
+	c.reloadLock.Lock()
+	defer c.reloadLock.Unlock()
+
+	c.oldSettings = make(map[interface{}]interface{})
+	for k, v := range c.Settings {
+		c.oldSettings[k] = v
+	}
+
+	err := c.LoadString(raw)
+	if err != nil {
+		return err
+	}
+
+	for _, v := range c.callbacks {
+		v(c)
+	}
+
+	return nil
+}
+
 // GetString will get the string for k or return the default d if not found or invalid
 func (c *C) GetString(k, d string) string {
 	r := c.Get(k)

+ 3 - 3
connection_manager_test.go

@@ -35,7 +35,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		rawCertificateNoKey: []byte{},
 	}
 
-	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
+	lh := &LightHouse{l: l, atomicStaticList: make(map[iputil.VpnIp]struct{}), atomicLighthouses: make(map[iputil.VpnIp]struct{})}
 	ifce := &Interface{
 		hostMap:          hostMap,
 		inside:           &test.NoopTun{},
@@ -104,7 +104,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		rawCertificateNoKey: []byte{},
 	}
 
-	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
+	lh := &LightHouse{l: l, atomicStaticList: make(map[iputil.VpnIp]struct{}), atomicLighthouses: make(map[iputil.VpnIp]struct{})}
 	ifce := &Interface{
 		hostMap:          hostMap,
 		inside:           &test.NoopTun{},
@@ -213,7 +213,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 		rawCertificateNoKey: []byte{},
 	}
 
-	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
+	lh := &LightHouse{l: l, atomicStaticList: make(map[iputil.VpnIp]struct{}), atomicLighthouses: make(map[iputil.VpnIp]struct{})}
 	ifce := &Interface{
 		hostMap:           hostMap,
 		inside:            &test.NoopTun{},

+ 2 - 1
control.go

@@ -160,9 +160,10 @@ func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
 func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	//TODO: this is probably better as a function in ConnectionManager or HostMap directly
 	c.f.hostMap.Lock()
+	lighthouses := c.f.lightHouse.GetLighthouses()
 	for _, h := range c.f.hostMap.Hosts {
 		if excludeLighthouses {
-			if _, ok := c.f.lightHouse.lighthouses[h.vpnIp]; ok {
+			if _, ok := lighthouses[h.vpnIp]; ok {
 				continue
 			}
 		}

+ 1 - 1
handshake.go

@@ -7,7 +7,7 @@ import (
 
 func HandleIncomingHandshake(f *Interface, addr *udp.Addr, packet []byte, h *header.H, hostinfo *HostInfo) {
 	// First remote allow list check before we know the vpnIp
-	if !f.lightHouse.remoteAllowList.AllowUnknownVpnIp(addr.IP) {
+	if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
 		f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 		return
 	}

+ 2 - 2
handshake_ix.go

@@ -114,7 +114,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, packet []byte, h *header.H)
 		return
 	}
 
-	if !f.lightHouse.remoteAllowList.Allow(vpnIp, addr.IP) {
+	if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) {
 		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 		return
 	}
@@ -321,7 +321,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, hostinfo *HostInfo, packet
 	hostinfo.Lock()
 	defer hostinfo.Unlock()
 
-	if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) {
+	if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
 		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 		return false
 	}

+ 11 - 2
handshake_manager_test.go

@@ -21,8 +21,12 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
+	lh := &LightHouse{
+		atomicStaticList:  make(map[iputil.VpnIp]struct{}),
+		atomicLighthouses: make(map[iputil.VpnIp]struct{}),
+	}
 
-	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udp.Conn{}, defaultHandshakeConfig)
+	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig)
 
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now, mw)
@@ -74,7 +78,12 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
-	lh := &LightHouse{addrMap: make(map[iputil.VpnIp]*RemoteList), l: l}
+	lh := &LightHouse{
+		addrMap:           make(map[iputil.VpnIp]*RemoteList),
+		l:                 l,
+		atomicStaticList:  make(map[iputil.VpnIp]struct{}),
+		atomicLighthouses: make(map[iputil.VpnIp]struct{}),
+	}
 
 	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig)
 

+ 1 - 1
inside.go

@@ -110,7 +110,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
 
 		// If this is a static host, we don't need to wait for the HostQueryReply
 		// We can trigger the handshake right now
-		if _, ok := f.lightHouse.staticList[vpnIp]; ok {
+		if _, ok := f.lightHouse.GetStaticHostList()[vpnIp]; ok {
 			select {
 			case f.handshakeManager.trigger <- vpnIp:
 			default:

+ 236 - 60
lighthouse.go

@@ -7,14 +7,18 @@ import (
 	"fmt"
 	"net"
 	"sync"
+	"sync/atomic"
 	"time"
+	"unsafe"
 
 	"github.com/golang/protobuf/proto"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
+	"github.com/slackhq/nebula/util"
 )
 
 //TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake?
@@ -28,7 +32,9 @@ type LightHouse struct {
 	amLighthouse bool
 	myVpnIp      iputil.VpnIp
 	myVpnZeros   iputil.VpnIp
+	myVpnNet     *net.IPNet
 	punchConn    *udp.Conn
+	punchy       *Punchy
 
 	// Local cache of answers from light houses
 	// map of vpn Ip to answers
@@ -39,80 +45,240 @@ type LightHouse struct {
 	// respond with.
 	// - When we are not a lighthouse, this filters which addresses we accept
 	// from lighthouses.
-	remoteAllowList *RemoteAllowList
+	atomicRemoteAllowList *RemoteAllowList
 
 	// filters local addresses that we advertise to lighthouses
-	localAllowList *LocalAllowList
+	atomicLocalAllowList *LocalAllowList
 
 	// used to trigger the HandshakeManager when we receive HostQueryReply
 	handshakeTrigger chan<- iputil.VpnIp
 
-	// staticList exists to avoid having a bool in each addrMap entry
+	// atomicStaticList exists to avoid having a bool in each addrMap entry
 	// since static should be rare
-	staticList  map[iputil.VpnIp]struct{}
-	lighthouses map[iputil.VpnIp]struct{}
-	interval    int
-	nebulaPort  uint32 // 32 bits because protobuf does not have a uint16
-	punchBack   bool
-	punchDelay  time.Duration
+	atomicStaticList  map[iputil.VpnIp]struct{}
+	atomicLighthouses map[iputil.VpnIp]struct{}
+
+	atomicInterval  int64
+	updateCancel    context.CancelFunc
+	updateParentCtx context.Context
+	updateUdp       udp.EncWriter
+	nebulaPort      uint32 // 32 bits because protobuf does not have a uint16
 
 	metrics           *MessageMetrics
 	metricHolepunchTx metrics.Counter
 	l                 *logrus.Logger
 }
 
-func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []iputil.VpnIp, interval int, nebulaPort uint32, pc *udp.Conn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
-	ones, _ := myVpnIpNet.Mask.Size()
+// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
+// addrMap should be nil unless this is during a config reload
+func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
+	amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
+	nebulaPort := uint32(c.GetInt("listen.port", 0))
+	if amLighthouse && nebulaPort == 0 {
+		return nil, util.NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
+	}
+
+	ones, _ := myVpnNet.Mask.Size()
 	h := LightHouse{
-		amLighthouse: amLighthouse,
-		myVpnIp:      iputil.Ip2VpnIp(myVpnIpNet.IP),
-		myVpnZeros:   iputil.VpnIp(32 - ones),
-		addrMap:      make(map[iputil.VpnIp]*RemoteList),
-		nebulaPort:   nebulaPort,
-		lighthouses:  make(map[iputil.VpnIp]struct{}),
-		staticList:   make(map[iputil.VpnIp]struct{}),
-		interval:     interval,
-		punchConn:    pc,
-		punchBack:    punchBack,
-		punchDelay:   punchDelay,
-		l:            l,
-	}
-
-	if metricsEnabled {
+		amLighthouse:      amLighthouse,
+		myVpnIp:           iputil.Ip2VpnIp(myVpnNet.IP),
+		myVpnZeros:        iputil.VpnIp(32 - ones),
+		myVpnNet:          myVpnNet,
+		addrMap:           make(map[iputil.VpnIp]*RemoteList),
+		nebulaPort:        nebulaPort,
+		atomicLighthouses: make(map[iputil.VpnIp]struct{}),
+		atomicStaticList:  make(map[iputil.VpnIp]struct{}),
+		punchConn:         pc,
+		punchy:            p,
+		l:                 l,
+	}
+
+	if c.GetBool("stats.lighthouse_metrics", false) {
 		h.metrics = newLighthouseMetrics()
-
 		h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
 	} else {
 		h.metricHolepunchTx = metrics.NilCounter{}
 	}
 
-	for _, ip := range ips {
-		h.lighthouses[ip] = struct{}{}
+	err := h.reload(c, true)
+	if err != nil {
+		return nil, err
 	}
 
-	return &h
+	c.RegisterReloadCallback(func(c *config.C) {
+		err := h.reload(c, false)
+		switch v := err.(type) {
+		case util.ContextualError:
+			v.Log(l)
+		case error:
+			l.WithError(err).Error("failed to reload lighthouse")
+		}
+	})
+
+	return &h, nil
 }
 
-func (lh *LightHouse) SetRemoteAllowList(allowList *RemoteAllowList) {
-	lh.Lock()
-	defer lh.Unlock()
+func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} {
+	return *(*map[iputil.VpnIp]struct{})(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicStaticList))))
+}
 
-	lh.remoteAllowList = allowList
+func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} {
+	return *(*map[iputil.VpnIp]struct{})(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLighthouses))))
 }
 
-func (lh *LightHouse) SetLocalAllowList(allowList *LocalAllowList) {
-	lh.Lock()
-	defer lh.Unlock()
+func (lh *LightHouse) GetRemoteAllowList() *RemoteAllowList {
+	return (*RemoteAllowList)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRemoteAllowList))))
+}
+
+func (lh *LightHouse) GetLocalAllowList() *LocalAllowList {
+	return (*LocalAllowList)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLocalAllowList))))
+}
+
+func (lh *LightHouse) GetUpdateInterval() int64 {
+	return atomic.LoadInt64(&lh.atomicInterval)
+}
+
+func (lh *LightHouse) reload(c *config.C, initial bool) error {
+	if initial || c.HasChanged("lighthouse.interval") {
+		atomic.StoreInt64(&lh.atomicInterval, int64(c.GetInt("lighthouse.interval", 10)))
+
+		if !initial {
+			lh.l.Infof("lighthouse.interval changed to %v", lh.atomicInterval)
+
+			if lh.updateCancel != nil {
+				// May not always have a running routine
+				lh.updateCancel()
+			}
+
+			lh.LhUpdateWorker(lh.updateParentCtx, lh.updateUdp)
+		}
+	}
+
+	if initial || c.HasChanged("lighthouse.remote_allow_list") || c.HasChanged("lighthouse.remote_allow_ranges") {
+		ral, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
+		if err != nil {
+			return util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
+		}
+
+		atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicRemoteAllowList)), unsafe.Pointer(ral))
+		if !initial {
+			//TODO: a diff will be annoyingly difficult
+			lh.l.Info("lighthouse.remote_allow_list and/or lighthouse.remote_allow_ranges has changed")
+		}
+	}
+
+	if initial || c.HasChanged("lighthouse.local_allow_list") {
+		lal, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
+		if err != nil {
+			return util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
+		}
+
+		atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLocalAllowList)), unsafe.Pointer(lal))
+		if !initial {
+			//TODO: a diff will be annoyingly difficult
+			lh.l.Info("lighthouse.local_allow_list has changed")
+		}
+	}
+
+	//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
+	if initial || c.HasChanged("static_host_map") {
+		staticList := make(map[iputil.VpnIp]struct{})
+		err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
+		if err != nil {
+			return err
+		}
+
+		atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicStaticList)), unsafe.Pointer(&staticList))
+		if !initial {
+			//TODO: we should remove any remote list entries for static hosts that were removed/modified?
+			lh.l.Info("static_host_map has changed")
+		}
+
+	}
+
+	if initial || c.HasChanged("lighthouse.hosts") {
+		lhMap := make(map[iputil.VpnIp]struct{})
+		err := lh.parseLighthouses(c, lh.myVpnNet, lhMap)
+		if err != nil {
+			return err
+		}
+
+		atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&lh.atomicLighthouses)), unsafe.Pointer(&lhMap))
+		if !initial {
+			//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
+			lh.l.Info("lighthouse.hosts has changed")
+		}
+	}
+
+	return nil
+}
+
+func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap map[iputil.VpnIp]struct{}) error {
+	lhs := c.GetStringSlice("lighthouse.hosts", []string{})
+	if lh.amLighthouse && len(lhs) != 0 {
+		lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
+	}
+
+	for i, host := range lhs {
+		ip := net.ParseIP(host)
+		if ip == nil {
+			return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
+		}
+		if !tunCidr.Contains(ip) {
+			return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
+		}
+		lhMap[iputil.Ip2VpnIp(ip)] = struct{}{}
+	}
+
+	if !lh.amLighthouse && len(lhMap) == 0 {
+		lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
+	}
 
-	lh.localAllowList = allowList
+	staticList := lh.GetStaticHostList()
+	for lhIP, _ := range lhMap {
+		if _, ok := staticList[lhIP]; !ok {
+			return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhIP)
+		}
+	}
+
+	return nil
 }
 
-func (lh *LightHouse) ValidateLHStaticEntries() error {
-	for lhIP, _ := range lh.lighthouses {
-		if _, ok := lh.staticList[lhIP]; !ok {
-			return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", lhIP)
+func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error {
+	shm := c.GetMap("static_host_map", map[interface{}]interface{}{})
+	i := 0
+
+	for k, v := range shm {
+		rip := net.ParseIP(fmt.Sprintf("%v", k))
+		if rip == nil {
+			return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, nil)
+		}
+
+		if !tunCidr.Contains(rip) {
+			return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": rip, "network": tunCidr.String(), "entry": i + 1}, nil)
 		}
+
+		vpnIp := iputil.Ip2VpnIp(rip)
+		vals, ok := v.([]interface{})
+		if ok {
+			for _, v := range vals {
+				ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
+				if err != nil {
+					return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
+				}
+				lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList)
+			}
+
+		} else {
+			ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
+			if err != nil {
+				return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
+			}
+			lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList)
+		}
+		i++
 	}
+
 	return nil
 }
 
@@ -146,10 +312,11 @@ func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) {
 		return
 	}
 
-	lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
+	lighthouses := lh.GetLighthouses()
+	lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses)))
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
-	for n := range lh.lighthouses {
+	for n := range lighthouses {
 		f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
 	}
 }
@@ -197,7 +364,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (in
 func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
 	// First we check the static mapping
 	// and do nothing if it is there
-	if _, ok := lh.staticList[vpnIp]; ok {
+	if _, ok := lh.GetStaticHostList()[vpnIp]; ok {
 		return
 	}
 	lh.Lock()
@@ -211,10 +378,11 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
 	lh.Unlock()
 }
 
-// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
+// addStaticRemote adds a static host entry for vpnIp as ourselves as the owner
 // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
 // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
-func (lh *LightHouse) AddStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr) {
+//NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
+func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) {
 	lh.Lock()
 	am := lh.unlockedGetRemoteList(vpnIp)
 	am.Lock()
@@ -236,8 +404,8 @@ func (lh *LightHouse) AddStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr) {
 		am.unlockedPrependV6(lh.myVpnIp, to)
 	}
 
-	// Mark it as static
-	lh.staticList[vpnIp] = struct{}{}
+	// Mark it as static in the caller provided map
+	staticList[vpnIp] = struct{}{}
 }
 
 // unlockedGetRemoteList assumes you have the lh lock
@@ -252,7 +420,7 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
 
 // unlockedShouldAddV4 checks if to is allowed by our allow list
 func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
-	allow := lh.remoteAllowList.AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
+	allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
 	if lh.l.Level >= logrus.TraceLevel {
 		lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
 	}
@@ -266,7 +434,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bo
 
 // unlockedShouldAddV6 checks if to is allowed by our allow list
 func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool {
-	allow := lh.remoteAllowList.AllowIpV6(vpnIp, to.Hi, to.Lo)
+	allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, to.Hi, to.Lo)
 	if lh.l.Level >= logrus.TraceLevel {
 		lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
 	}
@@ -287,7 +455,7 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
 }
 
 func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool {
-	if _, ok := lh.lighthouses[vpnIp]; ok {
+	if _, ok := lh.GetLighthouses()[vpnIp]; ok {
 		return true
 	}
 	return false
@@ -329,18 +497,24 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
 }
 
 func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
-	if lh.amLighthouse || lh.interval == 0 {
+	lh.updateParentCtx = ctx
+	lh.updateUdp = f
+
+	interval := lh.GetUpdateInterval()
+	if lh.amLighthouse || interval == 0 {
 		return
 	}
 
-	clockSource := time.NewTicker(time.Second * time.Duration(lh.interval))
+	clockSource := time.NewTicker(time.Second * time.Duration(interval))
+	updateCtx, cancel := context.WithCancel(ctx)
+	lh.updateCancel = cancel
 	defer clockSource.Stop()
 
 	for {
 		lh.SendUpdate(f)
 
 		select {
-		case <-ctx.Done():
+		case <-updateCtx.Done():
 			return
 		case <-clockSource.C:
 			continue
@@ -352,7 +526,8 @@ func (lh *LightHouse) SendUpdate(f udp.EncWriter) {
 	var v4 []*Ip4AndPort
 	var v6 []*Ip6AndPort
 
-	for _, e := range *localIps(lh.l, lh.localAllowList) {
+	lal := lh.GetLocalAllowList()
+	for _, e := range *localIps(lh.l, lal) {
 		if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) {
 			continue
 		}
@@ -373,7 +548,8 @@ func (lh *LightHouse) SendUpdate(f udp.EncWriter) {
 		},
 	}
 
-	lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lh.lighthouses)))
+	lighthouses := lh.GetLighthouses()
+	lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lighthouses)))
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
 
@@ -383,7 +559,7 @@ func (lh *LightHouse) SendUpdate(f udp.EncWriter) {
 		return
 	}
 
-	for vpnIp := range lh.lighthouses {
+	for vpnIp := range lighthouses {
 		f.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out)
 	}
 }
@@ -609,7 +785,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
 		}
 
 		go func() {
-			time.Sleep(lhh.lh.punchDelay)
+			time.Sleep(lhh.lh.punchy.GetDelay())
 			lhh.lh.metricHolepunchTx.Inc(1)
 			lhh.lh.punchConn.WriteTo(empty, vpnPeer)
 		}()
@@ -631,7 +807,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
 	// This sends a nebula test packet to the host trying to contact us. In the case
 	// of a double nat or other difficult scenario, this may help establish
 	// a tunnel.
-	if lhh.lh.punchBack {
+	if lhh.lh.punchy.GetRespond() {
 		queryVpnIp := iputil.VpnIp(n.Details.VpnIp)
 		go func() {
 			time.Sleep(time.Second * 5)

+ 65 - 30
lighthouse_test.go

@@ -6,6 +6,7 @@ import (
 	"testing"
 
 	"github.com/golang/protobuf/proto"
+	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
@@ -47,33 +48,32 @@ func TestNewLhQuery(t *testing.T) {
 
 func Test_lhStaticMapping(t *testing.T) {
 	l := test.NewLogger()
+	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
 	lh1 := "10.128.0.2"
-	lh1IP := net.ParseIP(lh1)
 
-	udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
-
-	meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false)
-	meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242)))
-	err := meh.ValidateLHStaticEntries()
+	c := config.NewC(l)
+	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
+	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
+	_, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
 	assert.Nil(t, err)
 
 	lh2 := "10.128.0.3"
-	lh2IP := net.ParseIP(lh2)
-
-	meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP), iputil.Ip2VpnIp(lh2IP)}, 10, 10003, udpServer, false, 1, false)
-	meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242)))
-	err = meh.ValidateLHStaticEntries()
-	assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
+	c = config.NewC(l)
+	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
+	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
+	_, err = NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
+	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	l := test.NewLogger()
-	lh1 := "10.128.0.2"
-	lh1IP := net.ParseIP(lh1)
+	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
 
-	udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
-
-	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false)
+	c := config.NewC(l)
+	lh, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
+	if !assert.NoError(b, err) {
+		b.Fatal()
+	}
 
 	hAddr := udp.NewAddrFromString("4.5.6.7:12345")
 	hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
@@ -160,8 +160,11 @@ func TestLighthouse_Memory(t *testing.T) {
 	theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
 	theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3"))
 
-	udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
-	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []iputil.VpnIp{}, 10, 10003, udpServer, false, 1, false)
+	c := config.NewC(l)
+	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
+	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
+	lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+	assert.NoError(t, err)
 	lhh := lh.NewRequestHandler()
 
 	// Test that my first update responds with just that
@@ -179,9 +182,16 @@ func TestLighthouse_Memory(t *testing.T) {
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
 
-	// Update a different host
+	// Update a different host and ask about it
 	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
+	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
+	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
+
+	// Have both hosts ask about the other
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
+	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+
+	r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
 	// Make sure we didn't get changed
@@ -224,6 +234,18 @@ func TestLighthouse_Memory(t *testing.T) {
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
 }
 
+func TestLighthouse_reload(t *testing.T) {
+	l := test.NewLogger()
+	c := config.NewC(l)
+	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
+	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
+	lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+	assert.NoError(t, err)
+
+	c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}}
+	lh.reload(c, false)
+}
+
 func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {
 	req := &NebulaMeta{
 		Type: NebulaMeta_HostQuery,
@@ -237,7 +259,10 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh
 		panic(err)
 	}
 
-	w := &testEncWriter{}
+	filter := NebulaMeta_HostQueryReply
+	w := &testEncWriter{
+		metaFilter: &filter,
+	}
 	lhh.HandleRequest(fromAddr, myVpnIp, b, w)
 	return w.lastReply
 }
@@ -344,18 +369,22 @@ type testLhReply struct {
 }
 
 type testEncWriter struct {
-	lastReply testLhReply
+	lastReply  testLhReply
+	metaFilter *NebulaMeta_MessageType
 }
 
 func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
-	tw.lastReply = testLhReply{
-		nebType:    t,
-		nebSubType: st,
-		vpnIp:      vpnIp,
-		msg:        &NebulaMeta{},
+	msg := &NebulaMeta{}
+	err := proto.Unmarshal(p, msg)
+	if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
+		tw.lastReply = testLhReply{
+			nebType:    t,
+			nebSubType: st,
+			vpnIp:      vpnIp,
+			msg:        msg,
+		}
 	}
 
-	err := proto.Unmarshal(p, tw.lastReply.msg)
 	if err != nil {
 		panic(err)
 	}
@@ -363,7 +392,10 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
 
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
 func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) {
-	assert.Len(t, have, len(want))
+	if !assert.Len(t, have, len(want)) {
+		return
+	}
+
 	for k, w := range want {
 		if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) {
 			assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
@@ -373,7 +405,10 @@ func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) {
 
 // assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
 func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) {
-	assert.Len(t, have, len(want))
+	if !assert.Len(t, have, len(want)) {
+		return
+	}
+
 	for k, w := range want {
 		if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
 			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have))

+ 10 - 87
main.go

@@ -3,13 +3,13 @@ package nebula
 import (
 	"context"
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"net"
 	"time"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/sshd"
 	"github.com/slackhq/nebula/udp"
@@ -218,95 +218,18 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		go hostMap.Promoter(config.GetInt("promoter.interval"))
 	*/
 
-	punchy := NewPunchyFromConfig(c)
-	if punchy.Punch && !configTest {
+	punchy := NewPunchyFromConfig(l, c)
+	if punchy.GetPunch() && !configTest {
 		l.Info("UDP hole punching enabled")
 		go hostMap.Punchy(ctx, udpConns[0])
 	}
 
-	amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
-
-	// fatal if am_lighthouse is enabled but we are using an ephemeral port
-	if amLighthouse && (c.GetInt("listen.port", 0) == 0) {
-		return nil, util.NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
-	}
-
-	// warn if am_lighthouse is enabled but upstream lighthouses exists
-	rawLighthouseHosts := c.GetStringSlice("lighthouse.hosts", []string{})
-	if amLighthouse && len(rawLighthouseHosts) != 0 {
-		l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
-	}
-
-	lighthouseHosts := make([]iputil.VpnIp, len(rawLighthouseHosts))
-	for i, host := range rawLighthouseHosts {
-		ip := net.ParseIP(host)
-		if ip == nil {
-			return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
-		}
-		if !tunCidr.Contains(ip) {
-			return nil, util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
-		}
-		lighthouseHosts[i] = iputil.Ip2VpnIp(ip)
-	}
-
-	if !amLighthouse && len(lighthouseHosts) == 0 {
-		l.Warn("No lighthouses.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
-	}
-
-	lightHouse := NewLightHouse(
-		l,
-		amLighthouse,
-		tunCidr,
-		lighthouseHosts,
-		//TODO: change to a duration
-		c.GetInt("lighthouse.interval", 10),
-		uint32(port),
-		udpConns[0],
-		punchy.Respond,
-		punchy.Delay,
-		c.GetBool("stats.lighthouse_metrics", false),
-	)
-
-	remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
-	if err != nil {
-		return nil, util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
-	}
-	lightHouse.SetRemoteAllowList(remoteAllowList)
-
-	localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
-	if err != nil {
-		return nil, util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
-	}
-	lightHouse.SetLocalAllowList(localAllowList)
-
-	//TODO: Move all of this inside functions in lighthouse.go
-	for k, v := range c.GetMap("static_host_map", map[interface{}]interface{}{}) {
-		ip := net.ParseIP(fmt.Sprintf("%v", k))
-		vpnIp := iputil.Ip2VpnIp(ip)
-		if !tunCidr.Contains(ip) {
-			return nil, util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
-		}
-		vals, ok := v.([]interface{})
-		if ok {
-			for _, v := range vals {
-				ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
-				if err != nil {
-					return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
-				}
-				lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
-			}
-		} else {
-			ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
-			if err != nil {
-				return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
-			}
-			lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
-		}
-	}
-
-	err = lightHouse.ValidateLHStaticEntries()
-	if err != nil {
-		l.WithError(err).Error("Lighthouse unreachable")
+	lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy)
+	switch {
+	case errors.As(err, &util.ContextualError{}):
+		return nil, err
+	case err != nil:
+		return nil, util.NewContextualError("Failed to initialize lighthouse handler", nil, err)
 	}
 
 	var messageMetrics *MessageMetrics
@@ -411,7 +334,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
 	var dnsStart func()
-	if amLighthouse && serveDns {
+	if lightHouse.amLighthouse && serveDns {
 		l.Debugln("Starting dns server")
 		dnsStart = dnsMain(l, hostMap, c)
 	}

+ 1 - 1
outside.go

@@ -157,7 +157,7 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
 
 func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) {
 	if !hostinfo.remote.Equals(addr) {
-		if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) {
+		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
 			hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
 			return
 		}

+ 72 - 17
punchy.go

@@ -1,34 +1,89 @@
 package nebula
 
 import (
+	"sync/atomic"
 	"time"
 
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 )
 
 type Punchy struct {
-	Punch   bool
-	Respond bool
-	Delay   time.Duration
+	atomicPunch   int32
+	atomicRespond int32
+	atomicDelay   time.Duration
+	l             *logrus.Logger
 }
 
-func NewPunchyFromConfig(c *config.C) *Punchy {
-	p := &Punchy{}
+func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy {
+	p := &Punchy{l: l}
 
-	if c.IsSet("punchy.punch") {
-		p.Punch = c.GetBool("punchy.punch", false)
-	} else {
-		// Deprecated fallback
-		p.Punch = c.GetBool("punchy", false)
+	p.reload(c, true)
+	c.RegisterReloadCallback(func(c *config.C) {
+		p.reload(c, false)
+	})
+
+	return p
+}
+
+func (p *Punchy) reload(c *config.C, initial bool) {
+	if initial {
+		var yes bool
+		if c.IsSet("punchy.punch") {
+			yes = c.GetBool("punchy.punch", false)
+		} else {
+			// Deprecated fallback
+			yes = c.GetBool("punchy", false)
+		}
+
+		if yes {
+			atomic.StoreInt32(&p.atomicPunch, 1)
+		} else {
+			atomic.StoreInt32(&p.atomicPunch, 0)
+		}
+
+	} else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
+		//TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here
+		p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.")
+	}
+
+	if initial || c.HasChanged("punchy.respond") || c.HasChanged("punch_back") {
+		var yes bool
+		if c.IsSet("punchy.respond") {
+			yes = c.GetBool("punchy.respond", false)
+		} else {
+			// Deprecated fallback
+			yes = c.GetBool("punch_back", false)
+		}
+
+		if yes {
+			atomic.StoreInt32(&p.atomicRespond, 1)
+		} else {
+			atomic.StoreInt32(&p.atomicRespond, 0)
+		}
+
+		if !initial {
+			p.l.Infof("punchy.respond changed to %v", p.GetRespond())
+		}
 	}
 
-	if c.IsSet("punchy.respond") {
-		p.Respond = c.GetBool("punchy.respond", false)
-	} else {
-		// Deprecated fallback
-		p.Respond = c.GetBool("punch_back", false)
+	//NOTE: this will not apply to any in progress operations, only the next one
+	if initial || c.HasChanged("punchy.delay") {
+		atomic.StoreInt64((*int64)(&p.atomicDelay), (int64)(c.GetDuration("punchy.delay", time.Second)))
+		if !initial {
+			p.l.Infof("punchy.delay changed to %s", p.GetDelay())
+		}
 	}
+}
 
-	p.Delay = c.GetDuration("punchy.delay", time.Second)
-	return p
+func (p *Punchy) GetPunch() bool {
+	return atomic.LoadInt32(&p.atomicPunch) == 1
+}
+
+func (p *Punchy) GetRespond() bool {
+	return atomic.LoadInt32(&p.atomicRespond) == 1
+}
+
+func (p *Punchy) GetDelay() time.Duration {
+	return (time.Duration)(atomic.LoadInt64((*int64)(&p.atomicDelay)))
 }

+ 38 - 14
punchy_test.go

@@ -14,34 +14,58 @@ func TestNewPunchyFromConfig(t *testing.T) {
 	c := config.NewC(l)
 
 	// Test defaults
-	p := NewPunchyFromConfig(c)
-	assert.Equal(t, false, p.Punch)
-	assert.Equal(t, false, p.Respond)
-	assert.Equal(t, time.Second, p.Delay)
+	p := NewPunchyFromConfig(l, c)
+	assert.Equal(t, false, p.GetPunch())
+	assert.Equal(t, false, p.GetRespond())
+	assert.Equal(t, time.Second, p.GetDelay())
 
 	// punchy deprecation
 	c.Settings["punchy"] = true
-	p = NewPunchyFromConfig(c)
-	assert.Equal(t, true, p.Punch)
+	p = NewPunchyFromConfig(l, c)
+	assert.Equal(t, true, p.GetPunch())
 
 	// punchy.punch
 	c.Settings["punchy"] = map[interface{}]interface{}{"punch": true}
-	p = NewPunchyFromConfig(c)
-	assert.Equal(t, true, p.Punch)
+	p = NewPunchyFromConfig(l, c)
+	assert.Equal(t, true, p.GetPunch())
 
 	// punch_back deprecation
 	c.Settings["punch_back"] = true
-	p = NewPunchyFromConfig(c)
-	assert.Equal(t, true, p.Respond)
+	p = NewPunchyFromConfig(l, c)
+	assert.Equal(t, true, p.GetRespond())
 
 	// punchy.respond
 	c.Settings["punchy"] = map[interface{}]interface{}{"respond": true}
 	c.Settings["punch_back"] = false
-	p = NewPunchyFromConfig(c)
-	assert.Equal(t, true, p.Respond)
+	p = NewPunchyFromConfig(l, c)
+	assert.Equal(t, true, p.GetRespond())
 
 	// punchy.delay
 	c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"}
-	p = NewPunchyFromConfig(c)
-	assert.Equal(t, time.Minute, p.Delay)
+	p = NewPunchyFromConfig(l, c)
+	assert.Equal(t, time.Minute, p.GetDelay())
+}
+
+func TestPunchy_reload(t *testing.T) {
+	l := test.NewLogger()
+	c := config.NewC(l)
+	delay, _ := time.ParseDuration("1m")
+	assert.NoError(t, c.LoadString(`
+punchy:
+  delay: 1m
+  respond: false
+`))
+	p := NewPunchyFromConfig(l, c)
+	assert.Equal(t, delay, p.GetDelay())
+	assert.Equal(t, false, p.GetRespond())
+
+	newDelay, _ := time.ParseDuration("10m")
+	assert.NoError(t, c.ReloadConfigString(`
+punchy:
+  delay: 10m
+  respond: true
+`))
+	p.reload(c, false)
+	assert.Equal(t, newDelay, p.GetDelay())
+	assert.Equal(t, true, p.GetRespond())
 }