Browse Source

Dns static lookerupper (#796)

* Support lighthouse DNS names, and regularly resolve the name in a background goroutine to discover DNS updates.
brad-defined 2 years ago
parent
commit
bd9cc01d62
7 changed files with 324 additions and 48 deletions
  1. 1 1
      control_test.go
  2. 1 1
      handshake_manager_test.go
  3. 144 31
      lighthouse.go
  4. 8 7
      lighthouse_test.go
  5. 1 1
      main.go
  6. 166 4
      remote_list.go
  7. 3 3
      remote_list_test.go

+ 1 - 1
control_test.go

@@ -47,7 +47,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		Signature: []byte{1, 2, 1, 2, 1, 3},
 	}
 
-	remotes := NewRemoteList()
+	remotes := NewRemoteList(nil)
 	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
 	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
 	hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{

+ 1 - 1
handshake_manager_test.go

@@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	assert.False(t, initCalled)
 	assert.Same(t, i, i2)
 
-	i.remotes = NewRemoteList()
+	i.remotes = NewRemoteList(nil)
 	i.HandshakeReady = true
 
 	// Adding something to pending should not affect the main hostmap

+ 144 - 31
lighthouse.go

@@ -6,6 +6,7 @@ import (
 	"errors"
 	"fmt"
 	"net"
+	"net/netip"
 	"sync"
 	"sync/atomic"
 	"time"
@@ -33,6 +34,7 @@ type netIpAndPort struct {
 type LightHouse struct {
 	//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
 	sync.RWMutex //Because we concurrently read and write to our maps
+	ctx          context.Context
 	amLighthouse bool
 	myVpnIp      iputil.VpnIp
 	myVpnZeros   iputil.VpnIp
@@ -82,7 +84,7 @@ type LightHouse struct {
 
 // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
 // addrMap should be nil unless this is during a config reload
-func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
+func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
 	amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
 	nebulaPort := uint32(c.GetInt("listen.port", 0))
 	if amLighthouse && nebulaPort == 0 {
@@ -100,6 +102,7 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet,
 
 	ones, _ := myVpnNet.Mask.Size()
 	h := LightHouse{
+		ctx:          ctx,
 		amLighthouse: amLighthouse,
 		myVpnIp:      iputil.Ip2VpnIp(myVpnNet.IP),
 		myVpnZeros:   iputil.VpnIp(32 - ones),
@@ -258,7 +261,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	}
 
 	//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
-	if initial || c.HasChanged("static_host_map") {
+	if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
 		staticList := make(map[iputil.VpnIp]struct{})
 		err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
 		if err != nil {
@@ -268,9 +271,19 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 		lh.staticList.Store(&staticList)
 		if !initial {
 			//TODO: we should remove any remote list entries for static hosts that were removed/modified?
-			lh.l.Info("static_host_map has changed")
+			if c.HasChanged("static_host_map") {
+				lh.l.Info("static_host_map has changed")
+			}
+			if c.HasChanged("static_map.cadence") {
+				lh.l.Info("static_map.cadence has changed")
+			}
+			if c.HasChanged("static_map.network") {
+				lh.l.Info("static_map.network has changed")
+			}
+			if c.HasChanged("static_map.lookup_timeout") {
+				lh.l.Info("static_map.lookup_timeout has changed")
+			}
 		}
-
 	}
 
 	if initial || c.HasChanged("lighthouse.hosts") {
@@ -344,7 +357,48 @@ func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap ma
 	return nil
 }
 
+func getStaticMapCadence(c *config.C) (time.Duration, error) {
+	cadence := c.GetString("static_map.cadence", "30s")
+	d, err := time.ParseDuration(cadence)
+	if err != nil {
+		return 0, err
+	}
+	return d, nil
+}
+
+func getStaticMapLookupTimeout(c *config.C) (time.Duration, error) {
+	lookupTimeout := c.GetString("static_map.lookup_timeout", "250ms")
+	d, err := time.ParseDuration(lookupTimeout)
+	if err != nil {
+		return 0, err
+	}
+	return d, nil
+}
+
+func getStaticMapNetwork(c *config.C) (string, error) {
+	network := c.GetString("static_map.network", "ip4")
+	if network != "ip" && network != "ip4" && network != "ip6" {
+		return "", fmt.Errorf("static_map.network must be one of ip, ip4, or ip6")
+	}
+	return network, nil
+}
+
 func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error {
+	d, err := getStaticMapCadence(c)
+	if err != nil {
+		return err
+	}
+
+	network, err := getStaticMapNetwork(c)
+	if err != nil {
+		return err
+	}
+
+	lookup_timeout, err := getStaticMapLookupTimeout(c)
+	if err != nil {
+		return err
+	}
+
 	shm := c.GetMap("static_host_map", map[interface{}]interface{}{})
 	i := 0
 
@@ -360,21 +414,17 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 
 		vpnIp := iputil.Ip2VpnIp(rip)
 		vals, ok := v.([]interface{})
-		if ok {
-			for _, v := range vals {
-				ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
-				if err != nil {
-					return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
-				}
-				lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList)
-			}
+		if !ok {
+			vals = []interface{}{v}
+		}
+		remoteAddrs := []string{}
+		for _, v := range vals {
+			remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v))
+		}
 
-		} else {
-			ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
-			if err != nil {
-				return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
-			}
-			lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList)
+		err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList)
+		if err != nil {
+			return err
 		}
 		i++
 	}
@@ -482,30 +532,47 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
 // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
 // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
 // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
-func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) {
+func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error {
 	lh.Lock()
 	am := lh.unlockedGetRemoteList(vpnIp)
 	am.Lock()
 	defer am.Unlock()
+	ctx := lh.ctx
 	lh.Unlock()
 
-	if ipv4 := toAddr.IP.To4(); ipv4 != nil {
-		to := NewIp4AndPort(ipv4, uint32(toAddr.Port))
-		if !lh.unlockedShouldAddV4(vpnIp, to) {
-			return
-		}
-		am.unlockedPrependV4(lh.myVpnIp, to)
+	hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() {
+		// This callback runs whenever the DNS hostname resolver finds a different set of IP's
+		// in its resolution for hostnames.
+		am.Lock()
+		defer am.Unlock()
+		am.shouldRebuild = true
+	})
+	if err != nil {
+		return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
+	}
+	am.unlockedSetHostnamesResults(hr)
 
-	} else {
-		to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))
-		if !lh.unlockedShouldAddV6(vpnIp, to) {
-			return
+	for _, addrPort := range hr.GetIPs() {
+
+		switch {
+		case addrPort.Addr().Is4():
+			to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
+			if !lh.unlockedShouldAddV4(vpnIp, to) {
+				continue
+			}
+			am.unlockedPrependV4(lh.myVpnIp, to)
+		case addrPort.Addr().Is6():
+			to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
+			if !lh.unlockedShouldAddV6(vpnIp, to) {
+				continue
+			}
+			am.unlockedPrependV6(lh.myVpnIp, to)
 		}
-		am.unlockedPrependV6(lh.myVpnIp, to)
 	}
 
 	// Mark it as static in the caller provided map
 	staticList[vpnIp] = struct{}{}
+	return nil
 }
 
 // addCalculatedRemotes adds any calculated remotes based on the
@@ -545,12 +612,42 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
 func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
 	am, ok := lh.addrMap[vpnIp]
 	if !ok {
-		am = NewRemoteList()
+		am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) })
 		lh.addrMap[vpnIp] = am
 	}
 	return am
 }
 
+func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool {
+	switch {
+	case to.Is4():
+		ipBytes := to.As4()
+		ip := iputil.Ip2VpnIp(ipBytes[:])
+		allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip)
+		if lh.l.Level >= logrus.TraceLevel {
+			lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
+		}
+		if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) {
+			return false
+		}
+	case to.Is6():
+		ipBytes := to.As16()
+
+		hi := binary.BigEndian.Uint64(ipBytes[:8])
+		lo := binary.BigEndian.Uint64(ipBytes[8:])
+		allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo)
+		if lh.l.Level >= logrus.TraceLevel {
+			lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow")
+		}
+
+		// We don't check our vpn network here because nebula does not support ipv6 on the inside
+		if !allow {
+			return false
+		}
+	}
+	return true
+}
+
 // unlockedShouldAddV4 checks if to is allowed by our allow list
 func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
 	allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
@@ -609,6 +706,14 @@ func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
 	return &ipp
 }
 
+func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
+	v4Addr := ip.As4()
+	return &Ip4AndPort{
+		Ip:   binary.BigEndian.Uint32(v4Addr[:]),
+		Port: uint32(port),
+	}
+}
+
 func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
 	return &Ip6AndPort{
 		Hi:   binary.BigEndian.Uint64(ip[:8]),
@@ -617,6 +722,14 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
 	}
 }
 
+func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
+	ip6Addr := ip.As16()
+	return &Ip6AndPort{
+		Hi:   binary.BigEndian.Uint64(ip6Addr[:8]),
+		Lo:   binary.BigEndian.Uint64(ip6Addr[8:]),
+		Port: uint32(port),
+	}
+}
 func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
 	ip := ipp.Ip
 	return udp.NewAddr(

+ 8 - 7
lighthouse_test.go

@@ -1,6 +1,7 @@
 package nebula
 
 import (
+	"context"
 	"fmt"
 	"net"
 	"testing"
@@ -53,14 +54,14 @@ func Test_lhStaticMapping(t *testing.T) {
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
-	_, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
+	_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
 	assert.Nil(t, err)
 
 	lh2 := "10.128.0.3"
 	c = config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
-	_, err = NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
+	_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
 	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 
@@ -69,14 +70,14 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
 
 	c := config.NewC(l)
-	lh, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
 	if !assert.NoError(b, err) {
 		b.Fatal()
 	}
 
 	hAddr := udp.NewAddrFromString("4.5.6.7:12345")
 	hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
-	lh.addrMap[3] = NewRemoteList()
+	lh.addrMap[3] = NewRemoteList(nil)
 	lh.addrMap[3].unlockedSetV4(
 		3,
 		3,
@@ -89,7 +90,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 
 	rAddr := udp.NewAddrFromString("1.2.2.3:12345")
 	rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
-	lh.addrMap[2] = NewRemoteList()
+	lh.addrMap[2] = NewRemoteList(nil)
 	lh.addrMap[2].unlockedSetV4(
 		3,
 		3,
@@ -162,7 +163,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
 	assert.NoError(t, err)
 	lhh := lh.NewRequestHandler()
 
@@ -238,7 +239,7 @@ func TestLighthouse_reload(t *testing.T) {
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
 	assert.NoError(t, err)
 
 	c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}}

+ 1 - 1
main.go

@@ -226,7 +226,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	*/
 
 	punchy := NewPunchyFromConfig(l, c)
-	lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy)
+	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
 	switch {
 	case errors.As(err, &util.ContextualError{}):
 		return nil, err

+ 166 - 4
remote_list.go

@@ -2,10 +2,16 @@ package nebula
 
 import (
 	"bytes"
+	"context"
 	"net"
+	"net/netip"
 	"sort"
+	"strconv"
 	"sync"
+	"sync/atomic"
+	"time"
 
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 )
@@ -55,6 +61,132 @@ type cacheV6 struct {
 	reported []*Ip6AndPort
 }
 
+type hostnamePort struct {
+	name string
+	port uint16
+}
+
+type hostnamesResults struct {
+	hostnames     []hostnamePort
+	network       string
+	lookupTimeout time.Duration
+	stop          chan struct{}
+	l             *logrus.Logger
+	ips           atomic.Pointer[map[netip.AddrPort]struct{}]
+}
+
+func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
+	r := &hostnamesResults{
+		hostnames:     make([]hostnamePort, len(hostPorts)),
+		network:       network,
+		lookupTimeout: timeout,
+		stop:          make(chan (struct{})),
+		l:             l,
+	}
+
+	// Fastrack IP addresses to ensure they're immediately available for use.
+	// DNS lookups for hostnames that aren't hardcoded IP's will happen in a background goroutine.
+	performBackgroundLookup := false
+	ips := map[netip.AddrPort]struct{}{}
+	for idx, hostPort := range hostPorts {
+
+		rIp, sPort, err := net.SplitHostPort(hostPort)
+		if err != nil {
+			return nil, err
+		}
+
+		iPort, err := strconv.Atoi(sPort)
+		if err != nil {
+			return nil, err
+		}
+
+		r.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)}
+		addr, err := netip.ParseAddr(rIp)
+		if err != nil {
+			// This address is a hostname, not an IP address
+			performBackgroundLookup = true
+			continue
+		}
+
+		// Save the IP address immediately
+		ips[netip.AddrPortFrom(addr, uint16(iPort))] = struct{}{}
+	}
+	r.ips.Store(&ips)
+
+	// Time for the DNS lookup goroutine
+	if performBackgroundLookup {
+		ticker := time.NewTicker(d)
+		go func() {
+			defer ticker.Stop()
+			for {
+				netipAddrs := map[netip.AddrPort]struct{}{}
+				for _, hostPort := range r.hostnames {
+					timeoutCtx, timeoutCancel := context.WithTimeout(ctx, r.lookupTimeout)
+					addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
+					timeoutCancel()
+					if err != nil {
+						l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host")
+						continue
+					}
+					for _, a := range addrs {
+						netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{}
+					}
+				}
+				origSet := r.ips.Load()
+				different := false
+				for a := range *origSet {
+					if _, ok := netipAddrs[a]; !ok {
+						different = true
+						break
+					}
+				}
+				if !different {
+					for a := range netipAddrs {
+						if _, ok := (*origSet)[a]; !ok {
+							different = true
+							break
+						}
+					}
+				}
+				if different {
+					l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list")
+					r.ips.Store(&netipAddrs)
+					onUpdate()
+				}
+				select {
+				case <-ctx.Done():
+					return
+				case <-r.stop:
+					return
+				case <-ticker.C:
+					continue
+				}
+			}
+		}()
+	}
+
+	return r, nil
+}
+
+func (hr *hostnamesResults) Cancel() {
+	if hr != nil {
+		hr.stop <- struct{}{}
+	}
+}
+
+func (hr *hostnamesResults) GetIPs() []netip.AddrPort {
+	var retSlice []netip.AddrPort
+	if hr != nil {
+		p := hr.ips.Load()
+		if p != nil {
+			for k := range *p {
+				retSlice = append(retSlice, k)
+			}
+		}
+	}
+	return retSlice
+}
+
 // RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos.
 // It serves as a local cache of query replies, host update notifications, and locally learned addresses
 type RemoteList struct {
@@ -72,6 +204,9 @@ type RemoteList struct {
 	// For learned addresses, this is the vpnIp that sent the packet
 	cache map[iputil.VpnIp]*cache
 
+	hr        *hostnamesResults
+	shouldAdd func(netip.Addr) bool
+
 	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
 	// They should not be tried again during a handshake
 	badRemotes []*udp.Addr
@@ -81,14 +216,21 @@ type RemoteList struct {
 }
 
 // NewRemoteList creates a new empty RemoteList
-func NewRemoteList() *RemoteList {
+func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
 	return &RemoteList{
-		addrs:  make([]*udp.Addr, 0),
-		relays: make([]*iputil.VpnIp, 0),
-		cache:  make(map[iputil.VpnIp]*cache),
+		addrs:     make([]*udp.Addr, 0),
+		relays:    make([]*iputil.VpnIp, 0),
+		cache:     make(map[iputil.VpnIp]*cache),
+		shouldAdd: shouldAdd,
 	}
 }
 
+func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
+	// Cancel any existing hostnamesResults DNS goroutine to release resources
+	r.hr.Cancel()
+	r.hr = hr
+}
+
 // Len locks and reports the size of the deduplicated address list
 // The deduplication work may need to occur here, so you must pass preferredRanges
 func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
@@ -437,6 +579,26 @@ func (r *RemoteList) unlockedCollect() {
 		}
 	}
 
+	dnsAddrs := r.hr.GetIPs()
+	for _, addr := range dnsAddrs {
+		if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
+			switch {
+			case addr.Addr().Is4():
+				v4 := addr.Addr().As4()
+				addrs = append(addrs, &udp.Addr{
+					IP:   v4[:],
+					Port: addr.Port(),
+				})
+			case addr.Addr().Is6():
+				v6 := addr.Addr().As16()
+				addrs = append(addrs, &udp.Addr{
+					IP:   v6[:],
+					Port: addr.Port(),
+				})
+			}
+		}
+	}
+
 	r.addrs = addrs
 	r.relays = relays
 

+ 3 - 3
remote_list_test.go

@@ -9,7 +9,7 @@ import (
 )
 
 func TestRemoteList_Rebuild(t *testing.T) {
-	rl := NewRemoteList()
+	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
 		0,
 		0,
@@ -102,7 +102,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
 }
 
 func BenchmarkFullRebuild(b *testing.B) {
-	rl := NewRemoteList()
+	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
 		0,
 		0,
@@ -167,7 +167,7 @@ func BenchmarkFullRebuild(b *testing.B) {
 }
 
 func BenchmarkSortRebuild(b *testing.B) {
-	rl := NewRemoteList()
+	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
 		0,
 		0,