浏览代码

Merge remote-tracking branch 'origin/master' into mutex-debug

Wade Simmons 2 年之前
父节点
当前提交
a83f0ca470

+ 11 - 0
README.md

@@ -118,6 +118,17 @@ To build nebula for a specific platform (ex, Windows):
 
 
 See the [Makefile](Makefile) for more details on build targets
 See the [Makefile](Makefile) for more details on build targets
 
 
+## Curve P256 and BoringCrypto
+
+The default curve used for cryptographic handshakes and signatures is Curve25519. This is the recommended setting for most users. If your deployment has certain compliance requirements, you have the option of creating your CA using `nebula-cert ca -curve P256` to use NIST Curve P256. The CA will then sign certificates using ECDSA P256, and any hosts using these certificates will use P256 for ECDH handshakes.
+
+In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets:
+
+    make bin-boringcrypto
+    make release-boringcrypto
+
+This is not the recommended default deployment, but may be useful based on your compliance requirements.
+
 ## Credits
 ## Credits
 
 
 Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang.
 Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang.

+ 12 - 0
SECURITY.md

@@ -0,0 +1,12 @@
+Security Policy
+===============
+
+Reporting a Vulnerability
+-------------------------
+
+If you believe you have found a security vulnerability with Nebula, please let
+us know right away. We will investigate all reports and do our best to quickly
+fix valid issues.
+
+You can submit your report on [HackerOne](https://hackerone.com/slack) and our
+security team will respond as soon as possible.

+ 25 - 3
cidr/tree4.go

@@ -13,8 +13,14 @@ type Node struct {
 	value  interface{}
 	value  interface{}
 }
 }
 
 
+type entry struct {
+	CIDR  *net.IPNet
+	Value *interface{}
+}
+
 type Tree4 struct {
 type Tree4 struct {
 	root *Node
 	root *Node
+	list []entry
 }
 }
 
 
 const (
 const (
@@ -24,6 +30,7 @@ const (
 func NewTree4() *Tree4 {
 func NewTree4() *Tree4 {
 	tree := new(Tree4)
 	tree := new(Tree4)
 	tree.root = &Node{}
 	tree.root = &Node{}
+	tree.list = []entry{}
 	return tree
 	return tree
 }
 }
 
 
@@ -53,6 +60,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 
 	// We already have this range so update the value
 	// We already have this range so update the value
 	if next != nil {
 	if next != nil {
+		addCIDR := cidr.String()
+		for i, v := range tree.list {
+			if addCIDR == v.CIDR.String() {
+				tree.list = append(tree.list[:i], tree.list[i+1:]...)
+				break
+			}
+		}
+
+		tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
 		node.value = val
 		node.value = val
 		return
 		return
 	}
 	}
@@ -74,9 +90,10 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 
 	// Final node marks our cidr, set the value
 	// Final node marks our cidr, set the value
 	node.value = val
 	node.value = val
+	tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
 }
 }
 
 
-// Finds the first match, which may be the least specific
+// Contains finds the first match, which may be the least specific
 func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
 func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	bit := startbit
 	node := tree.root
 	node := tree.root
@@ -99,7 +116,7 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
 	return value
 	return value
 }
 }
 
 
-// Finds the most specific match
+// MostSpecificContains finds the most specific match
 func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
 func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	bit := startbit
 	node := tree.root
 	node := tree.root
@@ -121,7 +138,7 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
 	return value
 	return value
 }
 }
 
 
-// Finds the most specific match
+// Match finds the most specific match
 func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
 func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	bit := startbit
 	node := tree.root
 	node := tree.root
@@ -143,3 +160,8 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
 	}
 	}
 	return value
 	return value
 }
 }
+
+// List will return all CIDRs and their current values. Do not modify the contents!
+func (tree *Tree4) List() []entry {
+	return tree.list
+}

+ 14 - 0
cidr/tree4_test.go

@@ -8,6 +8,20 @@ import (
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
+func TestCIDRTree_List(t *testing.T) {
+	tree := NewTree4()
+	tree.AddCIDR(Parse("1.0.0.0/16"), "1")
+	tree.AddCIDR(Parse("1.0.0.0/8"), "2")
+	tree.AddCIDR(Parse("1.0.0.0/16"), "3")
+	tree.AddCIDR(Parse("1.0.0.0/16"), "4")
+	list := tree.List()
+	assert.Len(t, list, 2)
+	assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
+	assert.Equal(t, "2", *list[0].Value)
+	assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
+	assert.Equal(t, "4", *list[1].Value)
+}
+
 func TestCIDRTree_Contains(t *testing.T) {
 func TestCIDRTree_Contains(t *testing.T) {
 	tree := NewTree4()
 	tree := NewTree4()
 	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
 	tree.AddCIDR(Parse("1.0.0.0/8"), "1")

+ 1 - 1
control_test.go

@@ -47,7 +47,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		Signature: []byte{1, 2, 1, 2, 1, 3},
 		Signature: []byte{1, 2, 1, 2, 1, 3},
 	}
 	}
 
 
-	remotes := NewRemoteList()
+	remotes := NewRemoteList(nil)
 	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
 	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
 	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
 	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
 	hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{
 	hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{

+ 6 - 1
examples/config.yml

@@ -223,6 +223,10 @@ tun:
     #  metric: 100
     #  metric: 100
     #  install: true
     #  install: true
 
 
+  # On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
+  # in nebula configuration files. Default false, not reloadable.
+  #use_system_route_table: false
+
 # TODO
 # TODO
 # Configure logging level
 # Configure logging level
 logging:
 logging:
@@ -301,7 +305,8 @@ firewall:
   #   host: `any` or a literal hostname, ie `test-host`
   #   host: `any` or a literal hostname, ie `test-host`
   #   group: `any` or a literal group name, ie `default-group`
   #   group: `any` or a literal group name, ie `default-group`
   #   groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
   #   groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
-  #   cidr: a CIDR, `0.0.0.0/0` is any.
+  #   cidr: a remote CIDR, `0.0.0.0/0` is any.
+  #   local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
   #   ca_name: An issuing CA name
   #   ca_name: An issuing CA name
   #   ca_sha: An issuing CA shasum
   #   ca_sha: An issuing CA shasum
 
 

+ 65 - 36
firewall.go

@@ -25,7 +25,7 @@ const tcpACK = 0x10
 const tcpFIN = 0x01
 const tcpFIN = 0x01
 
 
 type FirewallInterface interface {
 type FirewallInterface interface {
-	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error
+	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
 }
 }
 
 
 type conn struct {
 type conn struct {
@@ -106,11 +106,12 @@ type FirewallCA struct {
 }
 }
 
 
 type FirewallRule struct {
 type FirewallRule struct {
-	// Any makes Hosts, Groups, and CIDR irrelevant
-	Any    bool
-	Hosts  map[string]struct{}
-	Groups [][]string
-	CIDR   *cidr.Tree4
+	// Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant
+	Any       bool
+	Hosts     map[string]struct{}
+	Groups    [][]string
+	CIDR      *cidr.Tree4
+	LocalCIDR *cidr.Tree4
 }
 }
 
 
 // Even though ports are uint16, int32 maps are faster for lookup
 // Even though ports are uint16, int32 maps are faster for lookup
@@ -218,18 +219,22 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
 }
 }
 
 
 // AddRule properly creates the in memory rule structure for a firewall table.
 // AddRule properly creates the in memory rule structure for a firewall table.
-func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
+func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
 	// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
 	// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
 	// https://github.com/golang/go/issues/14131
 	// https://github.com/golang/go/issues/14131
 	sIp := ""
 	sIp := ""
 	if ip != nil {
 	if ip != nil {
 		sIp = ip.String()
 		sIp = ip.String()
 	}
 	}
+	lIp := ""
+	if localIp != nil {
+		lIp = localIp.String()
+	}
 
 
 	// We need this rule string because we generate a hash. Removing this will break firewall reload.
 	// We need this rule string because we generate a hash. Removing this will break firewall reload.
 	ruleString := fmt.Sprintf(
 	ruleString := fmt.Sprintf(
-		"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
-		incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
+		"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
+		incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
 	)
 	)
 	f.rules += ruleString + "\n"
 	f.rules += ruleString + "\n"
 
 
@@ -237,7 +242,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 	if !incoming {
 	if !incoming {
 		direction = "outgoing"
 		direction = "outgoing"
 	}
 	}
-	f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
+	f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
 		Info("Firewall rule added")
 		Info("Firewall rule added")
 
 
 	var (
 	var (
@@ -264,7 +269,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 		return fmt.Errorf("unknown protocol %v", proto)
 		return fmt.Errorf("unknown protocol %v", proto)
 	}
 	}
 
 
-	return fp.addRule(startPort, endPort, groups, host, ip, caName, caSha)
+	return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha)
 }
 }
 
 
 // GetRuleHash returns a hash representation of all inbound and outbound rules
 // GetRuleHash returns a hash representation of all inbound and outbound rules
@@ -302,8 +307,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 			return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
 			return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
 		}
 		}
 
 
-		if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.CAName == "" && r.CASha == "" {
-			return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, ca_name, or ca_sha must be provided", table, i)
+		if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
+			return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
 		}
 		}
 
 
 		if len(r.Groups) > 0 {
 		if len(r.Groups) > 0 {
@@ -355,7 +360,15 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 			}
 			}
 		}
 		}
 
 
-		err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, r.CAName, r.CASha)
+		var localCidr *net.IPNet
+		if r.LocalCidr != "" {
+			_, localCidr, err = net.ParseCIDR(r.LocalCidr)
+			if err != nil {
+				return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
+			}
+		}
+
+		err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
 		if err != nil {
 		if err != nil {
 			return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
 			return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
 		}
 		}
@@ -595,7 +608,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
 	return false
 	return false
 }
 }
 
 
-func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
+func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
 	if startPort > endPort {
 	if startPort > endPort {
 		return fmt.Errorf("start port was lower than end port")
 		return fmt.Errorf("start port was lower than end port")
 	}
 	}
@@ -608,7 +621,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
 			}
 			}
 		}
 		}
 
 
-		if err := fp[i].addRule(groups, host, ip, caName, caSha); err != nil {
+		if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil {
 			return err
 			return err
 		}
 		}
 	}
 	}
@@ -639,12 +652,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
 	return fp[firewall.PortAny].match(p, c, caPool)
 	return fp[firewall.PortAny].match(p, c, caPool)
 }
 }
 
 
-func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
+func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
 	fr := func() *FirewallRule {
 	fr := func() *FirewallRule {
 		return &FirewallRule{
 		return &FirewallRule{
-			Hosts:  make(map[string]struct{}),
-			Groups: make([][]string, 0),
-			CIDR:   cidr.NewTree4(),
+			Hosts:     make(map[string]struct{}),
+			Groups:    make([][]string, 0),
+			CIDR:      cidr.NewTree4(),
+			LocalCIDR: cidr.NewTree4(),
 		}
 		}
 	}
 	}
 
 
@@ -653,14 +667,14 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
 			fc.Any = fr()
 			fc.Any = fr()
 		}
 		}
 
 
-		return fc.Any.addRule(groups, host, ip)
+		return fc.Any.addRule(groups, host, ip, localIp)
 	}
 	}
 
 
 	if caSha != "" {
 	if caSha != "" {
 		if _, ok := fc.CAShas[caSha]; !ok {
 		if _, ok := fc.CAShas[caSha]; !ok {
 			fc.CAShas[caSha] = fr()
 			fc.CAShas[caSha] = fr()
 		}
 		}
-		err := fc.CAShas[caSha].addRule(groups, host, ip)
+		err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -670,7 +684,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
 		if _, ok := fc.CANames[caName]; !ok {
 		if _, ok := fc.CANames[caName]; !ok {
 			fc.CANames[caName] = fr()
 			fc.CANames[caName] = fr()
 		}
 		}
-		err := fc.CANames[caName].addRule(groups, host, ip)
+		err := fc.CANames[caName].addRule(groups, host, ip, localIp)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -702,17 +716,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
 	return fc.CANames[s.Details.Name].match(p, c)
 	return fc.CANames[s.Details.Name].match(p, c)
 }
 }
 
 
-func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
+func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error {
 	if fr.Any {
 	if fr.Any {
 		return nil
 		return nil
 	}
 	}
 
 
-	if fr.isAny(groups, host, ip) {
+	if fr.isAny(groups, host, ip, localIp) {
 		fr.Any = true
 		fr.Any = true
 		// If it's any we need to wipe out any pre-existing rules to save on memory
 		// If it's any we need to wipe out any pre-existing rules to save on memory
 		fr.Groups = make([][]string, 0)
 		fr.Groups = make([][]string, 0)
 		fr.Hosts = make(map[string]struct{})
 		fr.Hosts = make(map[string]struct{})
 		fr.CIDR = cidr.NewTree4()
 		fr.CIDR = cidr.NewTree4()
+		fr.LocalCIDR = cidr.NewTree4()
 	} else {
 	} else {
 		if len(groups) > 0 {
 		if len(groups) > 0 {
 			fr.Groups = append(fr.Groups, groups)
 			fr.Groups = append(fr.Groups, groups)
@@ -725,13 +740,17 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
 		if ip != nil {
 		if ip != nil {
 			fr.CIDR.AddCIDR(ip, struct{}{})
 			fr.CIDR.AddCIDR(ip, struct{}{})
 		}
 		}
+
+		if localIp != nil {
+			fr.LocalCIDR.AddCIDR(localIp, struct{}{})
+		}
 	}
 	}
 
 
 	return nil
 	return nil
 }
 }
 
 
-func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
-	if len(groups) == 0 && host == "" && ip == nil {
+func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool {
+	if len(groups) == 0 && host == "" && ip == nil && localIp == nil {
 		return true
 		return true
 	}
 	}
 
 
@@ -749,6 +768,10 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
 		return true
 		return true
 	}
 	}
 
 
+	if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) {
+		return true
+	}
+
 	return false
 	return false
 }
 }
 
 
@@ -790,20 +813,25 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 		return true
 		return true
 	}
 	}
 
 
+	if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil {
+		return true
+	}
+
 	// No host, group, or cidr matched, bye bye
 	// No host, group, or cidr matched, bye bye
 	return false
 	return false
 }
 }
 
 
 type rule struct {
 type rule struct {
-	Port   string
-	Code   string
-	Proto  string
-	Host   string
-	Group  string
-	Groups []string
-	Cidr   string
-	CAName string
-	CASha  string
+	Port      string
+	Code      string
+	Proto     string
+	Host      string
+	Group     string
+	Groups    []string
+	Cidr      string
+	LocalCidr string
+	CAName    string
+	CASha     string
 }
 }
 
 
 func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
 func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
@@ -827,6 +855,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
 	r.Proto = toString("proto", m)
 	r.Proto = toString("proto", m)
 	r.Host = toString("host", m)
 	r.Host = toString("host", m)
 	r.Cidr = toString("cidr", m)
 	r.Cidr = toString("cidr", m)
+	r.LocalCidr = toString("local_cidr", m)
 	r.CAName = toString("ca_name", m)
 	r.CAName = toString("ca_name", m)
 	r.CASha = toString("ca_sha", m)
 	r.CASha = toString("ca_sha", m)
 
 

+ 102 - 45
firewall_test.go

@@ -69,67 +69,75 @@ func TestFirewall_AddRule(t *testing.T) {
 
 
 	_, ti, _ := net.ParseCIDR("1.2.3.4/32")
 	_, ti, _ := net.ParseCIDR("1.2.3.4/32")
 
 
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
 	// An empty rule is any
 	// An empty rule is any
 	assert.True(t, fw.InRules.TCP[1].Any.Any)
 	assert.True(t, fw.InRules.TCP[1].Any.Any)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
 	assert.False(t, fw.InRules.UDP[1].Any.Any)
 	assert.False(t, fw.InRules.UDP[1].Any.Any)
 	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
 	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
 	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
 	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
 	assert.False(t, fw.InRules.ICMP[1].Any.Any)
 	assert.False(t, fw.InRules.ICMP[1].Any.Any)
 	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
 	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
 	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
 	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
 	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
+	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
+	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
+	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
+	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
+
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
 	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
 	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
 
 
 	// Set any and clear fields
 	// Set any and clear fields
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
 	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
 	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
 	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
 	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
 	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
 	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
+	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
 
 
 	// run twice just to make sure
 	// run twice just to make sure
 	//TODO: these ANY rules should clear the CA firewall portion
 	//TODO: these ANY rules should clear the CA firewall portion
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
 	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
 
 	// Test error conditions
 	// Test error conditions
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
-	assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", ""))
+	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))
+	assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))
 }
 }
 
 
 func TestFirewall_Drop(t *testing.T) {
 func TestFirewall_Drop(t *testing.T) {
@@ -169,7 +177,7 @@ func TestFirewall_Drop(t *testing.T) {
 	h.CreateRemoteCIDR(&c)
 	h.CreateRemoteCIDR(&c)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	// Drop outbound
 	// Drop outbound
@@ -188,28 +196,28 @@ func TestFirewall_Drop(t *testing.T) {
 
 
 	// ensure signer doesn't get in the way of group checks
 	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 
 
 	// test caSha doesn't drop on match
 	// test caSha doesn't drop on match
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 
 
 	// ensure ca name doesn't get in the way of group checks
 	// ensure ca name doesn't get in the way of group checks
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 
 
 	// test caName doesn't drop on match
 	// test caName doesn't drop on match
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 }
 }
 
 
@@ -219,11 +227,11 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	}
 	}
 
 
 	_, n, _ := net.ParseCIDR("172.1.1.1/32")
 	_, n, _ := net.ParseCIDR("172.1.1.1/32")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "")
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	b.Run("fail on proto", func(b *testing.B) {
 	b.Run("fail on proto", func(b *testing.B) {
@@ -291,7 +299,20 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 		}
 		}
 	})
 	})
 
 
-	_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
+	b.Run("pass on local ip", func(b *testing.B) {
+		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
+		c := &cert.NebulaCertificate{
+			Details: cert.NebulaCertificateDetails{
+				InvertedGroups: map[string]struct{}{"nope": {}},
+				Name:           "good-host",
+			},
+		}
+		for n := 0; n < b.N; n++ {
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
+		}
+	})
+
+	_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
 
 
 	b.Run("pass on ip with any port", func(b *testing.B) {
 	b.Run("pass on ip with any port", func(b *testing.B) {
 		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
 		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
@@ -305,6 +326,19 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
 			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
 		}
 		}
 	})
 	})
+
+	b.Run("pass on local ip with any port", func(b *testing.B) {
+		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
+		c := &cert.NebulaCertificate{
+			Details: cert.NebulaCertificateDetails{
+				InvertedGroups: map[string]struct{}{"nope": {}},
+				Name:           "good-host",
+			},
+		}
+		for n := 0; n < b.N; n++ {
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
+		}
+	})
 }
 }
 
 
 func TestFirewall_Drop2(t *testing.T) {
 func TestFirewall_Drop2(t *testing.T) {
@@ -356,7 +390,7 @@ func TestFirewall_Drop2(t *testing.T) {
 	h1.CreateRemoteCIDR(&c1)
 	h1.CreateRemoteCIDR(&c1)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	// h1/c1 lacks the proper groups
 	// h1/c1 lacks the proper groups
@@ -438,8 +472,8 @@ func TestFirewall_Drop3(t *testing.T) {
 	h3.CreateRemoteCIDR(&c3)
 	h3.CreateRemoteCIDR(&c3)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	// c1 should pass because host match
 	// c1 should pass because host match
@@ -489,7 +523,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	h.CreateRemoteCIDR(&c)
 	h.CreateRemoteCIDR(&c)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	// Drop outbound
 	// Drop outbound
@@ -502,7 +536,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 
 	oldFw := fw
 	oldFw := fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
 
@@ -511,7 +545,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 
 	oldFw = fw
 	oldFw = fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
 
@@ -653,7 +687,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	_, err = NewFirewallFromConfig(l, c, conf)
-	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
+	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
 
 
 	// Test code/port error
 	// Test code/port error
 	conf = config.NewC(l)
 	conf = config.NewC(l)
@@ -677,6 +711,12 @@ func TestNewFirewallFromConfig(t *testing.T) {
 	_, err = NewFirewallFromConfig(l, c, conf)
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
 	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
 
 
+	// Test local_cidr parse error
+	conf = config.NewC(l)
+	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
+	_, err = NewFirewallFromConfig(l, c, conf)
+	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")
+
 	// Test both group and groups
 	// Test both group and groups
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
@@ -691,63 +731,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	mf := &mockFirewall{}
 	mf := &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
 
 
 	// Test adding udp rule
 	// Test adding udp rule
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
 
 
 	// Test adding icmp rule
 	// Test adding icmp rule
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
 
 
 	// Test adding any rule
 	// Test adding any rule
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+
+	// Test adding rule with cidr
+	cidr := &net.IPNet{net.ParseIP("10.0.0.0").To4(), net.IPv4Mask(255, 0, 0, 0)}
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
+	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)
+
+	// Test adding rule with local_cidr
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
+	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)
 
 
 	// Test adding rule with ca_sha
 	// Test adding rule with ca_sha
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)
 
 
 	// Test adding rule with ca_name
 	// Test adding rule with ca_name
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)
 
 
 	// Test single group
 	// Test single group
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
 
 
 	// Test single groups
 	// Test single groups
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
 
 
 	// Test multiple AND groups
 	// Test multiple AND groups
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)
 
 
 	// Test Add error
 	// Test Add error
 	conf = config.NewC(l)
 	conf = config.NewC(l)
@@ -892,6 +947,7 @@ type addRuleCall struct {
 	groups    []string
 	groups    []string
 	host      string
 	host      string
 	ip        *net.IPNet
 	ip        *net.IPNet
+	localIp   *net.IPNet
 	caName    string
 	caName    string
 	caSha     string
 	caSha     string
 }
 }
@@ -901,7 +957,7 @@ type mockFirewall struct {
 	nextCallReturn error
 	nextCallReturn error
 }
 }
 
 
-func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
+func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
 	mf.lastCall = addRuleCall{
 	mf.lastCall = addRuleCall{
 		incoming:  incoming,
 		incoming:  incoming,
 		proto:     proto,
 		proto:     proto,
@@ -910,6 +966,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
 		groups:    groups,
 		groups:    groups,
 		host:      host,
 		host:      host,
 		ip:        ip,
 		ip:        ip,
+		localIp:   localIp,
 		caName:    caName,
 		caName:    caName,
 		caSha:     caSha,
 		caSha:     caSha,
 	}
 	}

+ 1 - 1
handshake_manager_test.go

@@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	assert.False(t, initCalled)
 	assert.False(t, initCalled)
 	assert.Same(t, i, i2)
 	assert.Same(t, i, i2)
 
 
-	i.remotes = NewRemoteList()
+	i.remotes = NewRemoteList(nil)
 	i.HandshakeReady = true
 	i.HandshakeReady = true
 
 
 	// Adding something to pending should not affect the main hostmap
 	// Adding something to pending should not affect the main hostmap

+ 144 - 31
lighthouse.go

@@ -6,6 +6,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"net/netip"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
@@ -33,6 +34,7 @@ type netIpAndPort struct {
 type LightHouse struct {
 type LightHouse struct {
 	//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
 	//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
 	sync.RWMutex //Because we concurrently read and write to our maps
 	sync.RWMutex //Because we concurrently read and write to our maps
+	ctx          context.Context
 	amLighthouse bool
 	amLighthouse bool
 	myVpnIp      iputil.VpnIp
 	myVpnIp      iputil.VpnIp
 	myVpnZeros   iputil.VpnIp
 	myVpnZeros   iputil.VpnIp
@@ -82,7 +84,7 @@ type LightHouse struct {
 
 
 // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
 // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
 // addrMap should be nil unless this is during a config reload
 // addrMap should be nil unless this is during a config reload
-func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
+func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
 	amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
 	amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
 	nebulaPort := uint32(c.GetInt("listen.port", 0))
 	nebulaPort := uint32(c.GetInt("listen.port", 0))
 	if amLighthouse && nebulaPort == 0 {
 	if amLighthouse && nebulaPort == 0 {
@@ -100,6 +102,7 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet,
 
 
 	ones, _ := myVpnNet.Mask.Size()
 	ones, _ := myVpnNet.Mask.Size()
 	h := LightHouse{
 	h := LightHouse{
+		ctx:          ctx,
 		amLighthouse: amLighthouse,
 		amLighthouse: amLighthouse,
 		myVpnIp:      iputil.Ip2VpnIp(myVpnNet.IP),
 		myVpnIp:      iputil.Ip2VpnIp(myVpnNet.IP),
 		myVpnZeros:   iputil.VpnIp(32 - ones),
 		myVpnZeros:   iputil.VpnIp(32 - ones),
@@ -258,7 +261,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	}
 	}
 
 
 	//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
 	//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
-	if initial || c.HasChanged("static_host_map") {
+	if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
 		staticList := make(map[iputil.VpnIp]struct{})
 		staticList := make(map[iputil.VpnIp]struct{})
 		err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
 		err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
 		if err != nil {
 		if err != nil {
@@ -268,9 +271,19 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 		lh.staticList.Store(&staticList)
 		lh.staticList.Store(&staticList)
 		if !initial {
 		if !initial {
 			//TODO: we should remove any remote list entries for static hosts that were removed/modified?
 			//TODO: we should remove any remote list entries for static hosts that were removed/modified?
-			lh.l.Info("static_host_map has changed")
+			if c.HasChanged("static_host_map") {
+				lh.l.Info("static_host_map has changed")
+			}
+			if c.HasChanged("static_map.cadence") {
+				lh.l.Info("static_map.cadence has changed")
+			}
+			if c.HasChanged("static_map.network") {
+				lh.l.Info("static_map.network has changed")
+			}
+			if c.HasChanged("static_map.lookup_timeout") {
+				lh.l.Info("static_map.lookup_timeout has changed")
+			}
 		}
 		}
-
 	}
 	}
 
 
 	if initial || c.HasChanged("lighthouse.hosts") {
 	if initial || c.HasChanged("lighthouse.hosts") {
@@ -344,7 +357,48 @@ func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap ma
 	return nil
 	return nil
 }
 }
 
 
+func getStaticMapCadence(c *config.C) (time.Duration, error) {
+	cadence := c.GetString("static_map.cadence", "30s")
+	d, err := time.ParseDuration(cadence)
+	if err != nil {
+		return 0, err
+	}
+	return d, nil
+}
+
+func getStaticMapLookupTimeout(c *config.C) (time.Duration, error) {
+	lookupTimeout := c.GetString("static_map.lookup_timeout", "250ms")
+	d, err := time.ParseDuration(lookupTimeout)
+	if err != nil {
+		return 0, err
+	}
+	return d, nil
+}
+
+func getStaticMapNetwork(c *config.C) (string, error) {
+	network := c.GetString("static_map.network", "ip4")
+	if network != "ip" && network != "ip4" && network != "ip6" {
+		return "", fmt.Errorf("static_map.network must be one of ip, ip4, or ip6")
+	}
+	return network, nil
+}
+
 func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error {
 func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error {
+	d, err := getStaticMapCadence(c)
+	if err != nil {
+		return err
+	}
+
+	network, err := getStaticMapNetwork(c)
+	if err != nil {
+		return err
+	}
+
+	lookup_timeout, err := getStaticMapLookupTimeout(c)
+	if err != nil {
+		return err
+	}
+
 	shm := c.GetMap("static_host_map", map[interface{}]interface{}{})
 	shm := c.GetMap("static_host_map", map[interface{}]interface{}{})
 	i := 0
 	i := 0
 
 
@@ -360,21 +414,17 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 
 
 		vpnIp := iputil.Ip2VpnIp(rip)
 		vpnIp := iputil.Ip2VpnIp(rip)
 		vals, ok := v.([]interface{})
 		vals, ok := v.([]interface{})
-		if ok {
-			for _, v := range vals {
-				ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
-				if err != nil {
-					return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
-				}
-				lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList)
-			}
+		if !ok {
+			vals = []interface{}{v}
+		}
+		remoteAddrs := []string{}
+		for _, v := range vals {
+			remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v))
+		}
 
 
-		} else {
-			ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
-			if err != nil {
-				return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
-			}
-			lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList)
+		err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList)
+		if err != nil {
+			return err
 		}
 		}
 		i++
 		i++
 	}
 	}
@@ -482,30 +532,47 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
 // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
 // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
 // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
 // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
 // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
 // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
-func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) {
+func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error {
 	lh.Lock()
 	lh.Lock()
 	am := lh.unlockedGetRemoteList(vpnIp)
 	am := lh.unlockedGetRemoteList(vpnIp)
 	am.Lock()
 	am.Lock()
 	defer am.Unlock()
 	defer am.Unlock()
+	ctx := lh.ctx
 	lh.Unlock()
 	lh.Unlock()
 
 
-	if ipv4 := toAddr.IP.To4(); ipv4 != nil {
-		to := NewIp4AndPort(ipv4, uint32(toAddr.Port))
-		if !lh.unlockedShouldAddV4(vpnIp, to) {
-			return
-		}
-		am.unlockedPrependV4(lh.myVpnIp, to)
+	hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() {
+		// This callback runs whenever the DNS hostname resolver finds a different set of IP's
+		// in its resolution for hostnames.
+		am.Lock()
+		defer am.Unlock()
+		am.shouldRebuild = true
+	})
+	if err != nil {
+		return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
+	}
+	am.unlockedSetHostnamesResults(hr)
 
 
-	} else {
-		to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))
-		if !lh.unlockedShouldAddV6(vpnIp, to) {
-			return
+	for _, addrPort := range hr.GetIPs() {
+
+		switch {
+		case addrPort.Addr().Is4():
+			to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
+			if !lh.unlockedShouldAddV4(vpnIp, to) {
+				continue
+			}
+			am.unlockedPrependV4(lh.myVpnIp, to)
+		case addrPort.Addr().Is6():
+			to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
+			if !lh.unlockedShouldAddV6(vpnIp, to) {
+				continue
+			}
+			am.unlockedPrependV6(lh.myVpnIp, to)
 		}
 		}
-		am.unlockedPrependV6(lh.myVpnIp, to)
 	}
 	}
 
 
 	// Mark it as static in the caller provided map
 	// Mark it as static in the caller provided map
 	staticList[vpnIp] = struct{}{}
 	staticList[vpnIp] = struct{}{}
+	return nil
 }
 }
 
 
 // addCalculatedRemotes adds any calculated remotes based on the
 // addCalculatedRemotes adds any calculated remotes based on the
@@ -545,12 +612,42 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
 func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
 func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
 	am, ok := lh.addrMap[vpnIp]
 	am, ok := lh.addrMap[vpnIp]
 	if !ok {
 	if !ok {
-		am = NewRemoteList()
+		am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) })
 		lh.addrMap[vpnIp] = am
 		lh.addrMap[vpnIp] = am
 	}
 	}
 	return am
 	return am
 }
 }
 
 
+func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool {
+	switch {
+	case to.Is4():
+		ipBytes := to.As4()
+		ip := iputil.Ip2VpnIp(ipBytes[:])
+		allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip)
+		if lh.l.Level >= logrus.TraceLevel {
+			lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
+		}
+		if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) {
+			return false
+		}
+	case to.Is6():
+		ipBytes := to.As16()
+
+		hi := binary.BigEndian.Uint64(ipBytes[:8])
+		lo := binary.BigEndian.Uint64(ipBytes[8:])
+		allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo)
+		if lh.l.Level >= logrus.TraceLevel {
+			lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow")
+		}
+
+		// We don't check our vpn network here because nebula does not support ipv6 on the inside
+		if !allow {
+			return false
+		}
+	}
+	return true
+}
+
 // unlockedShouldAddV4 checks if to is allowed by our allow list
 // unlockedShouldAddV4 checks if to is allowed by our allow list
 func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
 func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
 	allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
 	allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
@@ -609,6 +706,14 @@ func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
 	return &ipp
 	return &ipp
 }
 }
 
 
+func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
+	v4Addr := ip.As4()
+	return &Ip4AndPort{
+		Ip:   binary.BigEndian.Uint32(v4Addr[:]),
+		Port: uint32(port),
+	}
+}
+
 func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
 func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
 	return &Ip6AndPort{
 	return &Ip6AndPort{
 		Hi:   binary.BigEndian.Uint64(ip[:8]),
 		Hi:   binary.BigEndian.Uint64(ip[:8]),
@@ -617,6 +722,14 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
 	}
 	}
 }
 }
 
 
+func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
+	ip6Addr := ip.As16()
+	return &Ip6AndPort{
+		Hi:   binary.BigEndian.Uint64(ip6Addr[:8]),
+		Lo:   binary.BigEndian.Uint64(ip6Addr[8:]),
+		Port: uint32(port),
+	}
+}
 func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
 func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
 	ip := ipp.Ip
 	ip := ipp.Ip
 	return udp.NewAddr(
 	return udp.NewAddr(

+ 8 - 7
lighthouse_test.go

@@ -1,6 +1,7 @@
 package nebula
 package nebula
 
 
 import (
 import (
+	"context"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"testing"
 	"testing"
@@ -53,14 +54,14 @@ func Test_lhStaticMapping(t *testing.T) {
 	c := config.NewC(l)
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
-	_, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
+	_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
 	lh2 := "10.128.0.3"
 	lh2 := "10.128.0.3"
 	c = config.NewC(l)
 	c = config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
-	_, err = NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
+	_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
 	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 }
 
 
@@ -69,14 +70,14 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
 	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
 
 
 	c := config.NewC(l)
 	c := config.NewC(l)
-	lh, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
 	if !assert.NoError(b, err) {
 	if !assert.NoError(b, err) {
 		b.Fatal()
 		b.Fatal()
 	}
 	}
 
 
 	hAddr := udp.NewAddrFromString("4.5.6.7:12345")
 	hAddr := udp.NewAddrFromString("4.5.6.7:12345")
 	hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
 	hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
-	lh.addrMap[3] = NewRemoteList()
+	lh.addrMap[3] = NewRemoteList(nil)
 	lh.addrMap[3].unlockedSetV4(
 	lh.addrMap[3].unlockedSetV4(
 		3,
 		3,
 		3,
 		3,
@@ -89,7 +90,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 
 
 	rAddr := udp.NewAddrFromString("1.2.2.3:12345")
 	rAddr := udp.NewAddrFromString("1.2.2.3:12345")
 	rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
 	rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
-	lh.addrMap[2] = NewRemoteList()
+	lh.addrMap[2] = NewRemoteList(nil)
 	lh.addrMap[2].unlockedSetV4(
 	lh.addrMap[2].unlockedSetV4(
 		3,
 		3,
 		3,
 		3,
@@ -162,7 +163,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	c := config.NewC(l)
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	lhh := lh.NewRequestHandler()
 	lhh := lh.NewRequestHandler()
 
 
@@ -238,7 +239,7 @@ func TestLighthouse_reload(t *testing.T) {
 	c := config.NewC(l)
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
 	c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}}

+ 1 - 1
main.go

@@ -226,7 +226,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	*/
 	*/
 
 
 	punchy := NewPunchyFromConfig(l, c)
 	punchy := NewPunchyFromConfig(l, c)
-	lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy)
+	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
 	switch {
 	switch {
 	case errors.As(err, &util.ContextualError{}):
 	case errors.As(err, &util.ContextualError{}):
 		return nil, err
 		return nil, err

+ 2 - 0
overlay/tun.go

@@ -35,6 +35,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
 			c.GetInt("tun.mtu", DefaultMTU),
 			c.GetInt("tun.mtu", DefaultMTU),
 			routes,
 			routes,
 			c.GetInt("tun.tx_queue", 500),
 			c.GetInt("tun.tx_queue", 500),
+			c.GetBool("tun.use_system_route_table", false),
 		)
 		)
 
 
 	default:
 	default:
@@ -46,6 +47,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
 			routes,
 			routes,
 			c.GetInt("tun.tx_queue", 500),
 			c.GetInt("tun.tx_queue", 500),
 			routines > 1,
 			routines > 1,
+			c.GetBool("tun.use_system_route_table", false),
 		)
 		)
 	}
 	}
 }
 }

+ 2 - 2
overlay/tun_android.go

@@ -22,7 +22,7 @@ type tun struct {
 	l         *logrus.Logger
 	l         *logrus.Logger
 }
 }
 
 
-func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) {
+func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
 	routeTree, err := makeRouteTree(l, routes, false)
 	routeTree, err := makeRouteTree(l, routes, false)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -41,7 +41,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes
 	}, nil
 	}, nil
 }
 }
 
 
-func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
+func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in Android")
 	return nil, fmt.Errorf("newTun not supported in Android")
 }
 }
 
 

+ 2 - 2
overlay/tun_darwin.go

@@ -77,7 +77,7 @@ type ifreqMTU struct {
 	pad  [8]byte
 	pad  [8]byte
 }
 }
 
 
-func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
+func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
 	routeTree, err := makeRouteTree(l, routes, false)
 	routeTree, err := makeRouteTree(l, routes, false)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -170,7 +170,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 	return
 }
 }
 
 
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) {
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 }
 }
 
 

+ 2 - 2
overlay/tun_freebsd.go

@@ -38,11 +38,11 @@ func (t *tun) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) {
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 }
 }
 
 
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
 	routeTree, err := makeRouteTree(l, routes, false)
 	routeTree, err := makeRouteTree(l, routes, false)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err

+ 2 - 2
overlay/tun_ios.go

@@ -23,11 +23,11 @@ type tun struct {
 	routeTree *cidr.Tree4
 	routeTree *cidr.Tree4
 }
 }
 
 
-func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
+func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in iOS")
 	return nil, fmt.Errorf("newTun not supported in iOS")
 }
 }
 
 
-func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) {
+func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
 	routeTree, err := makeRouteTree(l, routes, false)
 	routeTree, err := makeRouteTree(l, routes, false)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err

+ 109 - 15
overlay/tun_linux.go

@@ -4,11 +4,13 @@
 package overlay
 package overlay
 
 
 import (
 import (
+	"bytes"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net"
 	"net"
 	"os"
 	"os"
 	"strings"
 	"strings"
+	"sync/atomic"
 	"unsafe"
 	"unsafe"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
@@ -26,9 +28,13 @@ type tun struct {
 	MaxMTU     int
 	MaxMTU     int
 	DefaultMTU int
 	DefaultMTU int
 	TXQueueLen int
 	TXQueueLen int
-	Routes     []Route
-	routeTree  *cidr.Tree4
-	l          *logrus.Logger
+
+	Routes          []Route
+	routeTree       atomic.Pointer[cidr.Tree4]
+	routeChan       chan struct{}
+	useSystemRoutes bool
+
+	l *logrus.Logger
 }
 }
 
 
 type ifReq struct {
 type ifReq struct {
@@ -63,7 +69,7 @@ type ifreqQLEN struct {
 	pad   [8]byte
 	pad   [8]byte
 }
 }
 
 
-func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) {
+func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, useSystemRoutes bool) (*tun, error) {
 	routeTree, err := makeRouteTree(l, routes, true)
 	routeTree, err := makeRouteTree(l, routes, true)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -71,7 +77,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
 
 
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 
 
-	return &tun{
+	t := &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		fd:              int(file.Fd()),
 		fd:              int(file.Fd()),
 		Device:          "tun0",
 		Device:          "tun0",
@@ -79,12 +85,14 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
 		DefaultMTU:      defaultMTU,
 		DefaultMTU:      defaultMTU,
 		TXQueueLen:      txQueueLen,
 		TXQueueLen:      txQueueLen,
 		Routes:          routes,
 		Routes:          routes,
-		routeTree:       routeTree,
+		useSystemRoutes: useSystemRoutes,
 		l:               l,
 		l:               l,
-	}, nil
+	}
+	t.routeTree.Store(routeTree)
+	return t, nil
 }
 }
 
 
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool) (*tun, error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool, useSystemRoutes bool) (*tun, error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -119,7 +127,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	return &tun{
+	t := &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		fd:              int(file.Fd()),
 		fd:              int(file.Fd()),
 		Device:          name,
 		Device:          name,
@@ -128,9 +136,11 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
 		DefaultMTU:      defaultMTU,
 		DefaultMTU:      defaultMTU,
 		TXQueueLen:      txQueueLen,
 		TXQueueLen:      txQueueLen,
 		Routes:          routes,
 		Routes:          routes,
-		routeTree:       routeTree,
+		useSystemRoutes: useSystemRoutes,
 		l:               l,
 		l:               l,
-	}, nil
+	}
+	t.routeTree.Store(routeTree)
+	return t, nil
 }
 }
 
 
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
@@ -152,7 +162,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 }
 }
 
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
+	r := t.routeTree.Load().MostSpecificContains(ip)
 	if r != nil {
 	if r != nil {
 		return r.(iputil.VpnIp)
 		return r.(iputil.VpnIp)
 	}
 	}
@@ -183,16 +193,20 @@ func (t *tun) Write(b []byte) (int, error) {
 	}
 	}
 }
 }
 
 
-func (t tun) deviceBytes() (o [16]byte) {
+func (t *tun) deviceBytes() (o [16]byte) {
 	for i, c := range t.Device {
 	for i, c := range t.Device {
 		o[i] = byte(c)
 		o[i] = byte(c)
 	}
 	}
 	return
 	return
 }
 }
 
 
-func (t tun) Activate() error {
+func (t *tun) Activate() error {
 	devName := t.deviceBytes()
 	devName := t.deviceBytes()
 
 
+	if t.useSystemRoutes {
+		t.watchRoutes()
+	}
+
 	var addr, mask [4]byte
 	var addr, mask [4]byte
 
 
 	copy(addr[:], t.cidr.IP.To4())
 	copy(addr[:], t.cidr.IP.To4())
@@ -318,7 +332,7 @@ func (t *tun) Name() string {
 	return t.Device
 	return t.Device
 }
 }
 
 
-func (t tun) advMSS(r Route) int {
+func (t *tun) advMSS(r Route) int {
 	mtu := r.MTU
 	mtu := r.MTU
 	if r.MTU == 0 {
 	if r.MTU == 0 {
 		mtu = t.DefaultMTU
 		mtu = t.DefaultMTU
@@ -330,3 +344,83 @@ func (t tun) advMSS(r Route) int {
 	}
 	}
 	return 0
 	return 0
 }
 }
+
+func (t *tun) watchRoutes() {
+	rch := make(chan netlink.RouteUpdate)
+	doneChan := make(chan struct{})
+
+	if err := netlink.RouteSubscribe(rch, doneChan); err != nil {
+		t.l.WithError(err).Errorf("failed to subscribe to system route changes")
+		return
+	}
+
+	t.routeChan = doneChan
+
+	go func() {
+		for {
+			select {
+			case r := <-rch:
+				t.updateRoutes(r)
+			case <-doneChan:
+				// netlink.RouteSubscriber will close the rch for us
+				return
+			}
+		}
+	}()
+}
+
+func (t *tun) updateRoutes(r netlink.RouteUpdate) {
+	if r.Gw == nil {
+		// Not a gateway route, ignore
+		t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
+		return
+	}
+
+	if !t.cidr.Contains(r.Gw) {
+		// Gateway isn't in our overlay network, ignore
+		t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
+		return
+	}
+
+	if x := r.Dst.IP.To4(); x == nil {
+		// Nebula only handles ipv4 on the overlay currently
+		t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4")
+		return
+	}
+
+	newTree := cidr.NewTree4()
+	if r.Type == unix.RTM_NEWROUTE {
+		for _, oldR := range t.routeTree.Load().List() {
+			newTree.AddCIDR(oldR.CIDR, oldR.Value)
+		}
+
+		t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
+		newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw))
+
+	} else {
+		gw := iputil.Ip2VpnIp(r.Gw)
+		for _, oldR := range t.routeTree.Load().List() {
+			if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw {
+				// This is the record to delete
+				t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
+				continue
+			}
+
+			newTree.AddCIDR(oldR.CIDR, oldR.Value)
+		}
+	}
+
+	t.routeTree.Store(newTree)
+}
+
+func (t *tun) Close() error {
+	if t.routeChan != nil {
+		close(t.routeChan)
+	}
+
+	if t.ReadWriteCloser != nil {
+		t.ReadWriteCloser.Close()
+	}
+
+	return nil
+}

+ 7 - 7
overlay/tun_linux_test.go

@@ -7,19 +7,19 @@ import "testing"
 
 
 var runAdvMSSTests = []struct {
 var runAdvMSSTests = []struct {
 	name     string
 	name     string
-	tun      tun
+	tun      *tun
 	r        Route
 	r        Route
 	expected int
 	expected int
 }{
 }{
 	// Standard case, default MTU is the device max MTU
 	// Standard case, default MTU is the device max MTU
-	{"default", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
-	{"default-min", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
-	{"default-low", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
+	{"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
+	{"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
+	{"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
 
 
 	// Case where we have a route MTU set higher than the default
 	// Case where we have a route MTU set higher than the default
-	{"route", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
-	{"route-min", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
-	{"route-high", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
+	{"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
+	{"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
+	{"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
 }
 }
 
 
 func TestTunAdvMSS(t *testing.T) {
 func TestTunAdvMSS(t *testing.T) {

+ 2 - 2
overlay/tun_tester.go

@@ -25,7 +25,7 @@ type TestTun struct {
 	TxPackets chan []byte // Packets transmitted outside by nebula
 	TxPackets chan []byte // Packets transmitted outside by nebula
 }
 }
 
 
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool, _ bool) (*TestTun, error) {
 	routeTree, err := makeRouteTree(l, routes, false)
 	routeTree, err := makeRouteTree(l, routes, false)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -42,7 +42,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes
 	}, nil
 	}, nil
 }
 }
 
 
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*TestTun, error) {
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*TestTun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported")
 	return nil, fmt.Errorf("newTunFromFd not supported")
 }
 }
 
 

+ 2 - 2
overlay/tun_windows.go

@@ -14,11 +14,11 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 )
 )
 
 
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (Device, error) {
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (Device, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Windows")
 	return nil, fmt.Errorf("newTunFromFd not supported in Windows")
 }
 }
 
 
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (Device, error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (Device, error) {
 	useWintun := true
 	useWintun := true
 	if err := checkWinTunExists(); err != nil {
 	if err := checkWinTunExists(); err != nil {
 		l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
 		l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")

+ 166 - 4
remote_list.go

@@ -2,10 +2,16 @@ package nebula
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"context"
 	"net"
 	"net"
+	"net/netip"
 	"sort"
 	"sort"
+	"strconv"
 	"sync"
 	"sync"
+	"sync/atomic"
+	"time"
 
 
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 )
 )
@@ -55,6 +61,132 @@ type cacheV6 struct {
 	reported []*Ip6AndPort
 	reported []*Ip6AndPort
 }
 }
 
 
+type hostnamePort struct {
+	name string
+	port uint16
+}
+
+type hostnamesResults struct {
+	hostnames     []hostnamePort
+	network       string
+	lookupTimeout time.Duration
+	stop          chan struct{}
+	l             *logrus.Logger
+	ips           atomic.Pointer[map[netip.AddrPort]struct{}]
+}
+
+func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
+	r := &hostnamesResults{
+		hostnames:     make([]hostnamePort, len(hostPorts)),
+		network:       network,
+		lookupTimeout: timeout,
+		stop:          make(chan (struct{})),
+		l:             l,
+	}
+
+	// Fastrack IP addresses to ensure they're immediately available for use.
+	// DNS lookups for hostnames that aren't hardcoded IP's will happen in a background goroutine.
+	performBackgroundLookup := false
+	ips := map[netip.AddrPort]struct{}{}
+	for idx, hostPort := range hostPorts {
+
+		rIp, sPort, err := net.SplitHostPort(hostPort)
+		if err != nil {
+			return nil, err
+		}
+
+		iPort, err := strconv.Atoi(sPort)
+		if err != nil {
+			return nil, err
+		}
+
+		r.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)}
+		addr, err := netip.ParseAddr(rIp)
+		if err != nil {
+			// This address is a hostname, not an IP address
+			performBackgroundLookup = true
+			continue
+		}
+
+		// Save the IP address immediately
+		ips[netip.AddrPortFrom(addr, uint16(iPort))] = struct{}{}
+	}
+	r.ips.Store(&ips)
+
+	// Time for the DNS lookup goroutine
+	if performBackgroundLookup {
+		ticker := time.NewTicker(d)
+		go func() {
+			defer ticker.Stop()
+			for {
+				netipAddrs := map[netip.AddrPort]struct{}{}
+				for _, hostPort := range r.hostnames {
+					timeoutCtx, timeoutCancel := context.WithTimeout(ctx, r.lookupTimeout)
+					addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
+					timeoutCancel()
+					if err != nil {
+						l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host")
+						continue
+					}
+					for _, a := range addrs {
+						netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{}
+					}
+				}
+				origSet := r.ips.Load()
+				different := false
+				for a := range *origSet {
+					if _, ok := netipAddrs[a]; !ok {
+						different = true
+						break
+					}
+				}
+				if !different {
+					for a := range netipAddrs {
+						if _, ok := (*origSet)[a]; !ok {
+							different = true
+							break
+						}
+					}
+				}
+				if different {
+					l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list")
+					r.ips.Store(&netipAddrs)
+					onUpdate()
+				}
+				select {
+				case <-ctx.Done():
+					return
+				case <-r.stop:
+					return
+				case <-ticker.C:
+					continue
+				}
+			}
+		}()
+	}
+
+	return r, nil
+}
+
+func (hr *hostnamesResults) Cancel() {
+	if hr != nil {
+		hr.stop <- struct{}{}
+	}
+}
+
+func (hr *hostnamesResults) GetIPs() []netip.AddrPort {
+	var retSlice []netip.AddrPort
+	if hr != nil {
+		p := hr.ips.Load()
+		if p != nil {
+			for k := range *p {
+				retSlice = append(retSlice, k)
+			}
+		}
+	}
+	return retSlice
+}
+
 // RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos.
 // RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos.
 // It serves as a local cache of query replies, host update notifications, and locally learned addresses
 // It serves as a local cache of query replies, host update notifications, and locally learned addresses
 type RemoteList struct {
 type RemoteList struct {
@@ -72,6 +204,9 @@ type RemoteList struct {
 	// For learned addresses, this is the vpnIp that sent the packet
 	// For learned addresses, this is the vpnIp that sent the packet
 	cache map[iputil.VpnIp]*cache
 	cache map[iputil.VpnIp]*cache
 
 
+	hr        *hostnamesResults
+	shouldAdd func(netip.Addr) bool
+
 	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
 	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
 	// They should not be tried again during a handshake
 	// They should not be tried again during a handshake
 	badRemotes []*udp.Addr
 	badRemotes []*udp.Addr
@@ -81,14 +216,21 @@ type RemoteList struct {
 }
 }
 
 
 // NewRemoteList creates a new empty RemoteList
 // NewRemoteList creates a new empty RemoteList
-func NewRemoteList() *RemoteList {
+func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
 	return &RemoteList{
 	return &RemoteList{
-		addrs:  make([]*udp.Addr, 0),
-		relays: make([]*iputil.VpnIp, 0),
-		cache:  make(map[iputil.VpnIp]*cache),
+		addrs:     make([]*udp.Addr, 0),
+		relays:    make([]*iputil.VpnIp, 0),
+		cache:     make(map[iputil.VpnIp]*cache),
+		shouldAdd: shouldAdd,
 	}
 	}
 }
 }
 
 
+func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
+	// Cancel any existing hostnamesResults DNS goroutine to release resources
+	r.hr.Cancel()
+	r.hr = hr
+}
+
 // Len locks and reports the size of the deduplicated address list
 // Len locks and reports the size of the deduplicated address list
 // The deduplication work may need to occur here, so you must pass preferredRanges
 // The deduplication work may need to occur here, so you must pass preferredRanges
 func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
 func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
@@ -437,6 +579,26 @@ func (r *RemoteList) unlockedCollect() {
 		}
 		}
 	}
 	}
 
 
+	dnsAddrs := r.hr.GetIPs()
+	for _, addr := range dnsAddrs {
+		if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
+			switch {
+			case addr.Addr().Is4():
+				v4 := addr.Addr().As4()
+				addrs = append(addrs, &udp.Addr{
+					IP:   v4[:],
+					Port: addr.Port(),
+				})
+			case addr.Addr().Is6():
+				v6 := addr.Addr().As16()
+				addrs = append(addrs, &udp.Addr{
+					IP:   v6[:],
+					Port: addr.Port(),
+				})
+			}
+		}
+	}
+
 	r.addrs = addrs
 	r.addrs = addrs
 	r.relays = relays
 	r.relays = relays
 
 

+ 3 - 3
remote_list_test.go

@@ -9,7 +9,7 @@ import (
 )
 )
 
 
 func TestRemoteList_Rebuild(t *testing.T) {
 func TestRemoteList_Rebuild(t *testing.T) {
-	rl := NewRemoteList()
+	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
 	rl.unlockedSetV4(
 		0,
 		0,
 		0,
 		0,
@@ -102,7 +102,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
 }
 }
 
 
 func BenchmarkFullRebuild(b *testing.B) {
 func BenchmarkFullRebuild(b *testing.B) {
-	rl := NewRemoteList()
+	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
 	rl.unlockedSetV4(
 		0,
 		0,
 		0,
 		0,
@@ -167,7 +167,7 @@ func BenchmarkFullRebuild(b *testing.B) {
 }
 }
 
 
 func BenchmarkSortRebuild(b *testing.B) {
 func BenchmarkSortRebuild(b *testing.B) {
-	rl := NewRemoteList()
+	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
 	rl.unlockedSetV4(
 		0,
 		0,
 		0,
 		0,