2
0
Эх сурвалжийг харах

Merge remote-tracking branch 'origin/master' into holepunch-remote-allow-list

Wade Simmons 1 өдөр өмнө
parent
commit
c338f52ec9

+ 15 - 7
.github/ISSUE_TEMPLATE/config.yml

@@ -1,13 +1,21 @@
 blank_issues_enabled: true
 contact_links:
+  - name: 💨 Performance Issues
+    url: https://github.com/slackhq/nebula/discussions/new/choose
+    about: 'We ask that you create a discussion instead of an issue for performance-related questions. This allows us to have a more open conversation about the issue and helps us to better understand the problem.'
+
+  - name: 📄 Documentation Issues
+    url: https://github.com/definednet/nebula-docs
+    about: "If you've found an issue with the website documentation, please file it in the nebula-docs repository."
+
+  - name: 📱 Mobile Nebula Issues
+    url: https://github.com/definednet/mobile_nebula
+    about: "If you're using the mobile Nebula app and have found an issue, please file it in the mobile_nebula repository."
+
   - name: 📘 Documentation
     url: https://nebula.defined.net/docs/
-    about: Review documentation.
+    about: 'The documentation is the best place to start if you are new to Nebula.'
 
   - name: 💁 Support/Chat
-    url: https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU
-    about: 'This issue tracker is not for support questions. Join us on Slack for assistance!'
-
-  - name: 📱 Mobile Nebula
-    url: https://github.com/definednet/mobile_nebula
-    about: 'This issue tracker is not for mobile support. Try the Mobile Nebula repo instead!'
+    url: https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA
+    about: 'For faster support, join us on Slack for assistance!'

+ 11 - 0
.github/pull_request_template.md

@@ -0,0 +1,11 @@
+<!--
+Thank you for taking the time to submit a pull request!
+
+Please be sure to provide a clear description of what you're trying to achieve with the change.
+
+- If you're submitting a new feature, please explain how to use it and document any new config options in the example config.
+- If you're submitting a bugfix, please link the related issue or describe the circumstances surrounding the issue.
+- If you're changing a default, explain why you believe the new default is appropriate for most users.
+
+P.S. If you're only updating the README or other docs, please file a pull request here instead: https://github.com/DefinedNet/nebula-docs
+-->

+ 4 - 4
.github/workflows/test.yml

@@ -32,9 +32,9 @@ jobs:
       run: make vet
 
     - name: golangci-lint
-      uses: golangci/golangci-lint-action@v7
+      uses: golangci/golangci-lint-action@v8
       with:
-        version: v2.0
+        version: v2.1
 
     - name: Test
       run: make test
@@ -115,9 +115,9 @@ jobs:
       run: make vet
 
     - name: golangci-lint
-      uses: golangci/golangci-lint-action@v7
+      uses: golangci/golangci-lint-action@v8
       with:
-        version: v2.0
+        version: v2.1
 
     - name: Test
       run: make test

+ 38 - 29
README.md

@@ -4,7 +4,7 @@ It lets you seamlessly connect computers anywhere in the world. Nebula is portab
 It can be used to connect a small number of computers, but is also able to connect tens of thousands of computers.
 
 Nebula incorporates a number of existing concepts like encryption, security groups, certificates,
-and tunneling, and each of those individual pieces existed before Nebula in various forms.
+and tunneling.
 What makes Nebula different to existing offerings is that it brings all of these ideas together,
 resulting in a sum that is greater than its individual parts.
 
@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
 
 You can read more about Nebula [here](https://medium.com/p/884110a5579).
 
-You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
+You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA).
 
 ## Supported Platforms
 
@@ -28,33 +28,33 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
 #### Distribution Packages
 
 - [Arch Linux](https://archlinux.org/packages/extra/x86_64/nebula/)
-    ```
-    $ sudo pacman -S nebula
+    ```sh
+    sudo pacman -S nebula
     ```
 
 - [Fedora Linux](https://src.fedoraproject.org/rpms/nebula)
-    ```
-    $ sudo dnf install nebula
+    ```sh
+    sudo dnf install nebula
     ```
 
 - [Debian Linux](https://packages.debian.org/source/stable/nebula)
-    ```
-    $ sudo apt install nebula
+    ```sh
+    sudo apt install nebula
     ```
 
 - [Alpine Linux](https://pkgs.alpinelinux.org/packages?name=nebula)
-    ```
-    $ sudo apk add nebula
+    ```sh
+    sudo apk add nebula
     ```
 
 - [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb)
-    ```
-    $ brew install nebula
+    ```sh
+    brew install nebula
     ```
 
 - [Docker](https://hub.docker.com/r/nebulaoss/nebula)
-    ```
-    $ docker pull nebulaoss/nebula
+    ```sh
+    docker pull nebulaoss/nebula
     ```
 
 #### Mobile
@@ -64,10 +64,10 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
 
 ## Technical Overview
 
-Nebula is a mutually authenticated peer-to-peer software defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/).
+Nebula is a mutually authenticated peer-to-peer software-defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/).
 Nebula uses certificates to assert a node's IP address, name, and membership within user-defined groups.
 Nebula's user-defined groups allow for provider agnostic traffic filtering between nodes.
-Discovery nodes allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs.
+Discovery nodes (aka lighthouses) allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs.
 Users can move data between nodes in any number of cloud service providers, datacenters, and endpoints, without needing to maintain a particular addressing scheme.
 
 Nebula uses Elliptic-curve Diffie-Hellman (`ECDH`) key exchange and `AES-256-GCM` in its default configuration.
@@ -82,28 +82,34 @@ To set up a Nebula network, you'll need:
 
 #### 2. (Optional, but you really should..) At least one discovery node with a routable IP address, which we call a lighthouse.
 
-Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $5/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses.
-
-  Once you have launched an instance, ensure that Nebula udp traffic (default port udp/4242) can reach it over the internet.
+Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $6/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses.
 
+Once you have launched an instance, ensure that Nebula udp traffic (default port udp/4242) can reach it over the internet.
 
 #### 3. A Nebula certificate authority, which will be the root of trust for a particular Nebula network.
 
-  ```
-  ./nebula-cert ca -name "Myorganization, Inc"
-  ```
-  This will create files named `ca.key` and `ca.cert` in the current directory. The `ca.key` file is the most sensitive file you'll create, because it is the key used to sign the certificates for individual nebula nodes/hosts. Please store this file somewhere safe, preferably with strong encryption.
+```sh
+./nebula-cert ca -name "Myorganization, Inc"
+```
+
+This will create files named `ca.key` and `ca.cert` in the current directory. The `ca.key` file is the most sensitive file you'll create, because it is the key used to sign the certificates for individual nebula nodes/hosts. Please store this file somewhere safe, preferably with strong encryption.
+
+**Be aware!** By default, certificate authorities have a 1-year lifetime before expiration. See [this guide](https://nebula.defined.net/docs/guides/rotating-certificate-authority/) for details on rotating a CA.
 
 #### 4. Nebula host keys and certificates generated from that certificate authority
+
 This assumes you have four nodes, named lighthouse1, laptop, server1, host3. You can name the nodes any way you'd like, including FQDN. You'll also need to choose IP addresses and the associated subnet. In this example, we are creating a nebula network that will use 192.168.100.x/24 as its network range. This example also demonstrates nebula groups, which can later be used to define traffic rules in a nebula network.
-```
+```sh
 ./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24"
 ./nebula-cert sign -name "laptop" -ip "192.168.100.2/24" -groups "laptop,home,ssh"
 ./nebula-cert sign -name "server1" -ip "192.168.100.9/24" -groups "servers"
 ./nebula-cert sign -name "host3" -ip "192.168.100.10/24"
 ```
 
+By default, host certificates will expire 1 second before the CA expires. Use the `-duration` flag to specify a shorter lifetime.
+
 #### 5. Configuration files for each host
+
 Download a copy of the nebula [example configuration](https://github.com/slackhq/nebula/blob/master/examples/config.yml).
 
 * On the lighthouse node, you'll need to ensure `am_lighthouse: true` is set.
@@ -118,10 +124,13 @@ For each host, copy the nebula binary to the host, along with `config.yml` from
 **DO NOT COPY `ca.key` TO INDIVIDUAL NODES.**
 
 #### 7. Run nebula on each host
-```
+
+```sh
 ./nebula -config /path/to/config.yml
 ```
 
+For more detailed instructions, [find the full documentation here](https://nebula.defined.net/docs/).
+
 ## Building Nebula from source
 
 Make sure you have [go](https://go.dev/doc/install) installed and clone this repo. Change to the nebula directory.
@@ -140,8 +149,10 @@ The default curve used for cryptographic handshakes and signatures is Curve25519
 
 In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets:
 
-    make bin-boringcrypto
-    make release-boringcrypto
+```sh
+make bin-boringcrypto
+make release-boringcrypto
+```
 
 This is not the recommended default deployment, but may be useful based on your compliance requirements.
 
@@ -149,5 +160,3 @@ This is not the recommended default deployment, but may be useful based on your
 
 Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang.
 
-
-

+ 3 - 8
calculated_remote.go

@@ -84,16 +84,11 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
 
 	calculatedRemotes := new(bart.Table[[]*calculatedRemote])
 
-	rawMap, ok := value.(map[any]any)
+	rawMap, ok := value.(map[string]any)
 	if !ok {
 		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
 	}
-	for rawKey, rawValue := range rawMap {
-		rawCIDR, ok := rawKey.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
-		}
-
+	for rawCIDR, rawValue := range rawMap {
 		cidr, err := netip.ParsePrefix(rawCIDR)
 		if err != nil {
 			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
@@ -129,7 +124,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
 }
 
 func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
-	rawMap, ok := raw.(map[any]any)
+	rawMap, ok := raw.(map[string]any)
 	if !ok {
 		return nil, fmt.Errorf("invalid type: %T", raw)
 	}

+ 1 - 2
cert/cert.go

@@ -135,8 +135,7 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
 	case Version2:
 		c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
 	default:
-		//TODO: CERT-V2 make a static var
-		return nil, fmt.Errorf("unknown certificate version %d", v)
+		return nil, ErrUnknownVersion
 	}
 
 	if err != nil {

+ 9 - 9
cert/crypto_test.go

@@ -26,21 +26,21 @@ func TestNewArgon2Parameters(t *testing.T) {
 }
 
 func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
-	passphrase := []byte("DO NOT USE THIS KEY")
+	passphrase := []byte("DO NOT USE")
 	privKey := []byte(`# A good key
 -----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
-oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
-+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
-qrlJ69wer3ZUHFXA
+CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiCPoDfGQiosxNPTbPn5EsMlc2MI
+c0Bt4oz6gTrFQhX3aBJcimhHKeAuhyTGvllD0Z19fe+DFPcLH3h5VrdjVfIAajg0
+KrbV3n9UHif/Au5skWmquNJzoW1E4MTdRbvpti6o+WdQ49DxjBFhx0YH8LBqrbPU
+0BGkUHmIO7daP24=
 -----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
 `)
 	shortKey := []byte(`# A key which, once decrypted, is too short
 -----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
-k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
-GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
-rQr3bdH3Oy/WiYU=
+CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiAVJwdfl3r+eqi/vF6S7OMdpjfo
+hAzmTCRnr58Su4AqmBJbCv3zleYCEKYJP6UI3S8ekLMGISsgO4hm5leukCCyqT0Z
+cQ76yrberpzkJKoPLGisX8f+xdy4aXSZl7oEYWQte1+vqbtl/eY9PGZhxUQdcyq7
+hqzIyrRqfUgVuA==
 -----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
 `)
 	invalidBanner := []byte(`# Invalid banner (not encrypted)

+ 1 - 0
cert/errors.go

@@ -20,6 +20,7 @@ var (
 	ErrPublicPrivateKeyMismatch   = errors.New("public key and private key are not a pair")
 	ErrPrivateKeyEncrypted        = errors.New("private key must be decrypted")
 	ErrCaNotFound                 = errors.New("could not find ca for the certificate")
+	ErrUnknownVersion             = errors.New("certificate version unrecognized")
 
 	ErrInvalidPEMBlock                   = errors.New("input did not contain a valid PEM encoded block")
 	ErrInvalidPEMCertificateBanner       = errors.New("bytes did not contain a proper certificate banner")

+ 1 - 1
config/config.go

@@ -243,7 +243,7 @@ func (c *C) GetInt(k string, d int) int {
 // GetUint32 will get the uint32 for k or return the default d if not found or invalid
 func (c *C) GetUint32(k string, d uint32) uint32 {
 	r := c.GetInt(k, int(d))
-	if uint64(r) > uint64(math.MaxUint32) {
+	if r < 0 || uint64(r) > uint64(math.MaxUint32) {
 		return d
 	}
 	return uint32(r)

+ 226 - 193
connection_manager.go

@@ -4,13 +4,16 @@ import (
 	"bytes"
 	"context"
 	"encoding/binary"
+	"fmt"
 	"net/netip"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 )
 
@@ -27,130 +30,124 @@ const (
 )
 
 type connectionManager struct {
-	in     map[uint32]struct{}
-	inLock *sync.RWMutex
-
-	out     map[uint32]struct{}
-	outLock *sync.RWMutex
-
 	// relayUsed holds which relay localIndexs are in use
 	relayUsed     map[uint32]struct{}
 	relayUsedLock *sync.RWMutex
 
-	hostMap                 *HostMap
-	trafficTimer            *LockingTimerWheel[uint32]
-	intf                    *Interface
-	pendingDeletion         map[uint32]struct{}
-	punchy                  *Punchy
+	hostMap      *HostMap
+	trafficTimer *LockingTimerWheel[uint32]
+	intf         *Interface
+	punchy       *Punchy
+
+	// Configuration settings
 	checkInterval           time.Duration
 	pendingDeletionInterval time.Duration
-	metricsTxPunchy         metrics.Counter
+	inactivityTimeout       atomic.Int64
+	dropInactive            atomic.Bool
+
+	metricsTxPunchy metrics.Counter
 
 	l *logrus.Logger
 }
 
-func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
-	var max time.Duration
-	if checkInterval < pendingDeletionInterval {
-		max = pendingDeletionInterval
-	} else {
-		max = checkInterval
+func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
+	cm := &connectionManager{
+		hostMap:         hm,
+		l:               l,
+		punchy:          p,
+		relayUsed:       make(map[uint32]struct{}),
+		relayUsedLock:   &sync.RWMutex{},
+		metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
 	}
 
-	nc := &connectionManager{
-		hostMap:                 intf.hostMap,
-		in:                      make(map[uint32]struct{}),
-		inLock:                  &sync.RWMutex{},
-		out:                     make(map[uint32]struct{}),
-		outLock:                 &sync.RWMutex{},
-		relayUsed:               make(map[uint32]struct{}),
-		relayUsedLock:           &sync.RWMutex{},
-		trafficTimer:            NewLockingTimerWheel[uint32](time.Millisecond*500, max),
-		intf:                    intf,
-		pendingDeletion:         make(map[uint32]struct{}),
-		checkInterval:           checkInterval,
-		pendingDeletionInterval: pendingDeletionInterval,
-		punchy:                  punchy,
-		metricsTxPunchy:         metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
-		l:                       l,
-	}
+	cm.reload(c, true)
+	c.RegisterReloadCallback(func(c *config.C) {
+		cm.reload(c, false)
+	})
 
-	nc.Start(ctx)
-	return nc
+	return cm
 }
 
-func (n *connectionManager) In(localIndex uint32) {
-	n.inLock.RLock()
-	// If this already exists, return
-	if _, ok := n.in[localIndex]; ok {
-		n.inLock.RUnlock()
-		return
+func (cm *connectionManager) reload(c *config.C, initial bool) {
+	if initial {
+		cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second
+		cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second
+
+		// We want at least a minimum resolution of 500ms per tick so that we can hit these intervals
+		// pretty close to their configured duration.
+		// The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it.
+		minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval)
+		maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval)
+		cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration)
 	}
-	n.inLock.RUnlock()
-	n.inLock.Lock()
-	n.in[localIndex] = struct{}{}
-	n.inLock.Unlock()
-}
 
-func (n *connectionManager) Out(localIndex uint32) {
-	n.outLock.RLock()
-	// If this already exists, return
-	if _, ok := n.out[localIndex]; ok {
-		n.outLock.RUnlock()
-		return
+	if initial || c.HasChanged("tunnels.inactivity_timeout") {
+		old := cm.getInactivityTimeout()
+		cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
+		if !initial {
+			cm.l.WithField("oldDuration", old).
+				WithField("newDuration", cm.getInactivityTimeout()).
+				Info("Inactivity timeout has changed")
+		}
+	}
+
+	if initial || c.HasChanged("tunnels.drop_inactive") {
+		old := cm.dropInactive.Load()
+		cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
+		if !initial {
+			cm.l.WithField("oldBool", old).
+				WithField("newBool", cm.dropInactive.Load()).
+				Info("Drop inactive setting has changed")
+		}
 	}
-	n.outLock.RUnlock()
-	n.outLock.Lock()
-	n.out[localIndex] = struct{}{}
-	n.outLock.Unlock()
 }
 
-func (n *connectionManager) RelayUsed(localIndex uint32) {
-	n.relayUsedLock.RLock()
+func (cm *connectionManager) getInactivityTimeout() time.Duration {
+	return (time.Duration)(cm.inactivityTimeout.Load())
+}
+
+func (cm *connectionManager) In(h *HostInfo) {
+	h.in.Store(true)
+}
+
+func (cm *connectionManager) Out(h *HostInfo) {
+	h.out.Store(true)
+}
+
+func (cm *connectionManager) RelayUsed(localIndex uint32) {
+	cm.relayUsedLock.RLock()
 	// If this already exists, return
-	if _, ok := n.relayUsed[localIndex]; ok {
-		n.relayUsedLock.RUnlock()
+	if _, ok := cm.relayUsed[localIndex]; ok {
+		cm.relayUsedLock.RUnlock()
 		return
 	}
-	n.relayUsedLock.RUnlock()
-	n.relayUsedLock.Lock()
-	n.relayUsed[localIndex] = struct{}{}
-	n.relayUsedLock.Unlock()
+	cm.relayUsedLock.RUnlock()
+	cm.relayUsedLock.Lock()
+	cm.relayUsed[localIndex] = struct{}{}
+	cm.relayUsedLock.Unlock()
 }
 
 // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
 // resets the state for this local index
-func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
-	n.inLock.Lock()
-	n.outLock.Lock()
-	_, in := n.in[localIndex]
-	_, out := n.out[localIndex]
-	delete(n.in, localIndex)
-	delete(n.out, localIndex)
-	n.inLock.Unlock()
-	n.outLock.Unlock()
+func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) {
+	in := h.in.Swap(false)
+	out := h.out.Swap(false)
+	if in || out {
+		h.lastUsed = now
+	}
 	return in, out
 }
 
-func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
-	// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
-	n.outLock.Lock()
-	if _, ok := n.out[localIndex]; ok {
-		n.outLock.Unlock()
-		return
+// AddTrafficWatch must be called for every new HostInfo.
+// We will continue to monitor the HostInfo until the tunnel is dropped.
+func (cm *connectionManager) AddTrafficWatch(h *HostInfo) {
+	if h.out.Swap(true) == false {
+		cm.trafficTimer.Add(h.localIndexId, cm.checkInterval)
 	}
-	n.out[localIndex] = struct{}{}
-	n.trafficTimer.Add(localIndex, n.checkInterval)
-	n.outLock.Unlock()
 }
 
-func (n *connectionManager) Start(ctx context.Context) {
-	go n.Run(ctx)
-}
-
-func (n *connectionManager) Run(ctx context.Context) {
-	//TODO: this tick should be based on the min wheel tick? Check firewall
-	clockSource := time.NewTicker(500 * time.Millisecond)
+func (cm *connectionManager) Start(ctx context.Context) {
+	clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration)
 	defer clockSource.Stop()
 
 	p := []byte("")
@@ -163,61 +160,61 @@ func (n *connectionManager) Run(ctx context.Context) {
 			return
 
 		case now := <-clockSource.C:
-			n.trafficTimer.Advance(now)
+			cm.trafficTimer.Advance(now)
 			for {
-				localIndex, has := n.trafficTimer.Purge()
+				localIndex, has := cm.trafficTimer.Purge()
 				if !has {
 					break
 				}
 
-				n.doTrafficCheck(localIndex, p, nb, out, now)
+				cm.doTrafficCheck(localIndex, p, nb, out, now)
 			}
 		}
 	}
 }
 
-func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
-	decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
+func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
+	decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now)
 
 	switch decision {
 	case deleteTunnel:
-		if n.hostMap.DeleteHostInfo(hostinfo) {
+		if cm.hostMap.DeleteHostInfo(hostinfo) {
 			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
-			n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
+			cm.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
 		}
 
 	case closeTunnel:
-		n.intf.sendCloseTunnel(hostinfo)
-		n.intf.closeTunnel(hostinfo)
+		cm.intf.sendCloseTunnel(hostinfo)
+		cm.intf.closeTunnel(hostinfo)
 
 	case swapPrimary:
-		n.swapPrimary(hostinfo, primary)
+		cm.swapPrimary(hostinfo, primary)
 
 	case migrateRelays:
-		n.migrateRelayUsed(hostinfo, primary)
+		cm.migrateRelayUsed(hostinfo, primary)
 
 	case tryRehandshake:
-		n.tryRehandshake(hostinfo)
+		cm.tryRehandshake(hostinfo)
 
 	case sendTestPacket:
-		n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
+		cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
 	}
 
-	n.resetRelayTrafficCheck(hostinfo)
+	cm.resetRelayTrafficCheck(hostinfo)
 }
 
-func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
+func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
 	if hostinfo != nil {
-		n.relayUsedLock.Lock()
-		defer n.relayUsedLock.Unlock()
+		cm.relayUsedLock.Lock()
+		defer cm.relayUsedLock.Unlock()
 		// No need to migrate any relays, delete usage info now.
 		for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
-			delete(n.relayUsed, idx)
+			delete(cm.relayUsed, idx)
 		}
 	}
 }
 
-func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
+func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
 	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
 
 	for _, r := range relayFor {
@@ -227,46 +224,51 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 		var relayFrom netip.Addr
 		var relayTo netip.Addr
 		switch {
-		case ok && existing.State == Established:
-			// This relay already exists in newhostinfo, then do nothing.
-			continue
-		case ok && existing.State == Requested:
-			// The relay exists in a Requested state; re-send the request
-			index = existing.LocalIndex
-			switch r.Type {
-			case TerminalType:
-				relayFrom = n.intf.myVpnAddrs[0]
-				relayTo = existing.PeerAddr
-			case ForwardingType:
-				relayFrom = existing.PeerAddr
-				relayTo = newhostinfo.vpnAddrs[0]
-			default:
-				// should never happen
+		case ok:
+			switch existing.State {
+			case Established, PeerRequested, Disestablished:
+				// This relay already exists in newhostinfo, then do nothing.
+				continue
+			case Requested:
+				// The relay exists in a Requested state; re-send the request
+				index = existing.LocalIndex
+				switch r.Type {
+				case TerminalType:
+					relayFrom = cm.intf.myVpnAddrs[0]
+					relayTo = existing.PeerAddr
+				case ForwardingType:
+					relayFrom = existing.PeerAddr
+					relayTo = newhostinfo.vpnAddrs[0]
+				default:
+					// should never happen
+					panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
+				}
 			}
 		case !ok:
-			n.relayUsedLock.RLock()
-			if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
+			cm.relayUsedLock.RLock()
+			if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed {
 				// The relay hasn't been used; don't migrate it.
-				n.relayUsedLock.RUnlock()
+				cm.relayUsedLock.RUnlock()
 				continue
 			}
-			n.relayUsedLock.RUnlock()
+			cm.relayUsedLock.RUnlock()
 			// The relay doesn't exist at all; create some relay state and send the request.
 			var err error
-			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
+			index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
 			if err != nil {
-				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
+				cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
 				continue
 			}
 			switch r.Type {
 			case TerminalType:
-				relayFrom = n.intf.myVpnAddrs[0]
+				relayFrom = cm.intf.myVpnAddrs[0]
 				relayTo = r.PeerAddr
 			case ForwardingType:
 				relayFrom = r.PeerAddr
 				relayTo = newhostinfo.vpnAddrs[0]
 			default:
 				// should never happen
+				panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type))
 			}
 		}
 
@@ -279,12 +281,12 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 		switch newhostinfo.GetCert().Certificate.Version() {
 		case cert.Version1:
 			if !relayFrom.Is4() {
-				n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
+				cm.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
 				continue
 			}
 
 			if !relayTo.Is4() {
-				n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
+				cm.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
 				continue
 			}
 
@@ -296,16 +298,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 			req.RelayToAddr = netAddrToProtoAddr(relayTo)
 		default:
-			newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
+			newhostinfo.logger(cm.l).Error("Unknown certificate version found while attempting to migrate relay")
 			continue
 		}
 
 		msg, err := req.Marshal()
 		if err != nil {
-			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
+			cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
 		} else {
-			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
-			n.l.WithFields(logrus.Fields{
+			cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
+			cm.l.WithFields(logrus.Fields{
 				"relayFrom":           req.RelayFromAddr,
 				"relayTo":             req.RelayToAddr,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
@@ -316,46 +318,45 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 	}
 }
 
-func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
-	n.hostMap.RLock()
-	defer n.hostMap.RUnlock()
+func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
+	// Read lock the main hostmap to order decisions based on tunnels being the primary tunnel
+	cm.hostMap.RLock()
+	defer cm.hostMap.RUnlock()
 
-	hostinfo := n.hostMap.Indexes[localIndex]
+	hostinfo := cm.hostMap.Indexes[localIndex]
 	if hostinfo == nil {
-		n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
-		delete(n.pendingDeletion, localIndex)
+		cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
 		return doNothing, nil, nil
 	}
 
-	if n.isInvalidCertificate(now, hostinfo) {
-		delete(n.pendingDeletion, hostinfo.localIndexId)
+	if cm.isInvalidCertificate(now, hostinfo) {
 		return closeTunnel, hostinfo, nil
 	}
 
-	primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
+	primary := cm.hostMap.Hosts[hostinfo.vpnAddrs[0]]
 	mainHostInfo := true
 	if primary != nil && primary != hostinfo {
 		mainHostInfo = false
 	}
 
 	// Check for traffic on this hostinfo
-	inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
+	inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now)
 
 	// A hostinfo is determined alive if there is incoming traffic
 	if inTraffic {
 		decision := doNothing
-		if n.l.Level >= logrus.DebugLevel {
-			hostinfo.logger(n.l).
+		if cm.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(cm.l).
 				WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
 				Debug("Tunnel status")
 		}
-		delete(n.pendingDeletion, hostinfo.localIndexId)
+		hostinfo.pendingDeletion.Store(false)
 
 		if mainHostInfo {
 			decision = tryRehandshake
 
 		} else {
-			if n.shouldSwapPrimary(hostinfo, primary) {
+			if cm.shouldSwapPrimary(hostinfo) {
 				decision = swapPrimary
 			} else {
 				// migrate the relays to the primary, if in use.
@@ -363,46 +364,55 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 			}
 		}
 
-		n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
+		cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
 
 		if !outTraffic {
 			// Send a punch packet to keep the NAT state alive
-			n.sendPunch(hostinfo)
+			cm.sendPunch(hostinfo)
 		}
 
 		return decision, hostinfo, primary
 	}
 
-	if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
+	if hostinfo.pendingDeletion.Load() {
 		// We have already sent a test packet and nothing was returned, this hostinfo is dead
-		hostinfo.logger(n.l).
+		hostinfo.logger(cm.l).
 			WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
 			Info("Tunnel status")
 
-		delete(n.pendingDeletion, hostinfo.localIndexId)
 		return deleteTunnel, hostinfo, nil
 	}
 
 	decision := doNothing
 	if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
 		if !outTraffic {
+			inactiveFor, isInactive := cm.isInactive(hostinfo, now)
+			if isInactive {
+				// Tunnel is inactive, tear it down
+				hostinfo.logger(cm.l).
+					WithField("inactiveDuration", inactiveFor).
+					WithField("primary", mainHostInfo).
+					Info("Dropping tunnel due to inactivity")
+
+				return closeTunnel, hostinfo, primary
+			}
+
 			// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
 			// Just maintain NAT state if configured to do so.
-			n.sendPunch(hostinfo)
-			n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
+			cm.sendPunch(hostinfo)
+			cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
 			return doNothing, nil, nil
-
 		}
 
-		if n.punchy.GetTargetEverything() {
+		if cm.punchy.GetTargetEverything() {
 			// This is similar to the old punchy behavior with a slight optimization.
 			// We aren't receiving traffic but we are sending it, punch on all known
 			// ips in case we need to re-prime NAT state
-			n.sendPunch(hostinfo)
+			cm.sendPunch(hostinfo)
 		}
 
-		if n.l.Level >= logrus.DebugLevel {
-			hostinfo.logger(n.l).
+		if cm.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(cm.l).
 				WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
 				Debug("Tunnel status")
 		}
@@ -411,17 +421,33 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 		decision = sendTestPacket
 
 	} else {
-		if n.l.Level >= logrus.DebugLevel {
-			hostinfo.logger(n.l).Debugf("Hostinfo sadness")
+		if cm.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
 		}
 	}
 
-	n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
-	n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
+	hostinfo.pendingDeletion.Store(true)
+	cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval)
 	return decision, hostinfo, nil
 }
 
-func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
+func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) {
+	if cm.dropInactive.Load() == false {
+		// We aren't configured to drop inactive tunnels
+		return 0, false
+	}
+
+	inactiveDuration := now.Sub(hostinfo.lastUsed)
+	if inactiveDuration < cm.getInactivityTimeout() {
+		// It's not considered inactive
+		return inactiveDuration, false
+	}
+
+	// The tunnel is inactive
+	return inactiveDuration, true
+}
+
+func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
 	// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// Let's sort this out.
@@ -429,83 +455,90 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	// Only one side should swap because if both swap then we may never resolve to a single tunnel.
 	// vpn addr is static across all tunnels for this host pair so lets
 	// use that to determine if we should consider swapping.
-	if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
+	if current.vpnAddrs[0].Compare(cm.intf.myVpnAddrs[0]) < 0 {
 		// Their primary vpn addr is less than mine. Do not swap.
 		return false
 	}
 
-	crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
+	crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
 	// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
 	// settle down.
 	return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
 }
 
-func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
-	n.hostMap.Lock()
+func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
+	cm.hostMap.Lock()
 	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
-	if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
-		n.hostMap.unlockedMakePrimary(current)
+	if cm.hostMap.Hosts[current.vpnAddrs[0]] == primary {
+		cm.hostMap.unlockedMakePrimary(current)
 	}
-	n.hostMap.Unlock()
+	cm.hostMap.Unlock()
 }
 
 // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
 // the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
 // check and return true.
-func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
+func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
 	remoteCert := hostinfo.GetCert()
 	if remoteCert == nil {
 		return false
 	}
 
-	caPool := n.intf.pki.GetCAPool()
+	caPool := cm.intf.pki.GetCAPool()
 	err := caPool.VerifyCachedCertificate(now, remoteCert)
 	if err == nil {
 		return false
 	}
 
-	if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
+	if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
 		// Block listed certificates should always be disconnected
 		return false
 	}
 
-	hostinfo.logger(n.l).WithError(err).
+	hostinfo.logger(cm.l).WithError(err).
 		WithField("fingerprint", remoteCert.Fingerprint).
 		Info("Remote certificate is no longer valid, tearing down the tunnel")
 
 	return true
 }
 
-func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
-	if !n.punchy.GetPunch() {
+func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
+	if !cm.punchy.GetPunch() {
 		// Punching is disabled
 		return
 	}
 
-	if n.punchy.GetTargetEverything() {
-		hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
-			n.metricsTxPunchy.Inc(1)
-			n.intf.outside.WriteTo([]byte{1}, addr)
+	if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
+		// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
+		// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
+		// would lose the ability to notify us and punchy.respond would become unreliable.
+		return
+	}
+
+	if cm.punchy.GetTargetEverything() {
+		hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
+			cm.metricsTxPunchy.Inc(1)
+			cm.intf.outside.WriteTo([]byte{1}, addr)
 		})
 
 	} else if hostinfo.remote.IsValid() {
-		n.metricsTxPunchy.Inc(1)
-		n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
+		cm.metricsTxPunchy.Inc(1)
+		cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
 	}
 }
 
-func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
-	cs := n.intf.pki.getCertState()
+func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
+	cs := cm.intf.pki.getCertState()
 	curCrt := hostinfo.ConnectionState.myCert
 	myCrt := cs.getCertificate(curCrt.Version())
-	if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
+	if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
 		// The current tunnel is using the latest certificate and version, no need to rehandshake.
 		return
 	}
 
-	n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
+	cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
 		WithField("reason", "local certificate is not current").
 		Info("Re-handshaking with remote")
 
-	n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
+	cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
 }

+ 145 - 49
connection_manager_test.go

@@ -1,7 +1,6 @@
 package nebula
 
 import (
-	"context"
 	"crypto/ed25519"
 	"crypto/rand"
 	"net/netip"
@@ -23,7 +22,7 @@ func newTestLighthouse() *LightHouse {
 		addrMap:   map[netip.Addr]*RemoteList{},
 		queryChan: make(chan netip.Addr, 10),
 	}
-	lighthouses := map[netip.Addr]struct{}{}
+	lighthouses := []netip.Addr{}
 	staticList := map[netip.Addr]struct{}{}
 
 	lh.lighthouses.Store(&lighthouses)
@@ -44,10 +43,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 	cs := &CertState{
-		defaultVersion:   cert.Version1,
-		privateKey:       []byte{},
-		v1Cert:           &dummyCert{version: cert.Version1},
-		v1HandshakeBytes: []byte{},
+		initiatingVersion: cert.Version1,
+		privateKey:        []byte{},
+		v1Cert:            &dummyCert{version: cert.Version1},
+		v1HandshakeBytes:  []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -64,10 +63,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	ifce.pki.cs.Store(cs)
 
 	// Create manager
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-	punchy := NewPunchyFromConfig(l, config.NewC(l))
-	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
+	conf := config.NewC(l)
+	punchy := NewPunchyFromConfig(l, conf)
+	nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
+	nc.intf = ifce
 	p := []byte("")
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
@@ -85,32 +84,33 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 	// We saw traffic out to vpnIp
-	nc.Out(hostinfo.localIndexId)
-	nc.In(hostinfo.localIndexId)
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
+	nc.Out(hostinfo)
+	nc.In(hostinfo)
+	assert.False(t, hostinfo.pendingDeletion.Load())
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.out, hostinfo.localIndexId)
+	assert.True(t, hostinfo.out.Load())
+	assert.True(t, hostinfo.in.Load())
 
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 
 	// Do another traffic check tick, this host should be pending deletion now
-	nc.Out(hostinfo.localIndexId)
+	nc.Out(hostinfo)
+	assert.True(t, hostinfo.out.Load())
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.True(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 	// Do a final traffic check tick, the host should now be removed
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
+	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs)
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 }
 
@@ -126,10 +126,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 	cs := &CertState{
-		defaultVersion:   cert.Version1,
-		privateKey:       []byte{},
-		v1Cert:           &dummyCert{version: cert.Version1},
-		v1HandshakeBytes: []byte{},
+		initiatingVersion: cert.Version1,
+		privateKey:        []byte{},
+		v1Cert:            &dummyCert{version: cert.Version1},
+		v1HandshakeBytes:  []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -146,10 +146,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	ifce.pki.cs.Store(cs)
 
 	// Create manager
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-	punchy := NewPunchyFromConfig(l, config.NewC(l))
-	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
+	conf := config.NewC(l)
+	punchy := NewPunchyFromConfig(l, conf)
+	nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
+	nc.intf = ifce
 	p := []byte("")
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
@@ -167,33 +167,129 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 	// We saw traffic out to vpnIp
-	nc.Out(hostinfo.localIndexId)
-	nc.In(hostinfo.localIndexId)
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
+	nc.Out(hostinfo)
+	nc.In(hostinfo)
+	assert.True(t, hostinfo.in.Load())
+	assert.True(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.pendingDeletion.Load())
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 
 	// Do another traffic check tick, this host should be pending deletion now
-	nc.Out(hostinfo.localIndexId)
+	nc.Out(hostinfo)
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.True(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 	// We saw traffic, should no longer be pending deletion
-	nc.In(hostinfo.localIndexId)
+	nc.In(hostinfo)
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
+	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
+}
+
+func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
+	l := test.NewLogger()
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	vpnAddrs := []netip.Addr{netip.MustParseAddr("172.1.1.2")}
+	preferredRanges := []netip.Prefix{localrange}
+
+	// Very incomplete mock objects
+	hostMap := newHostMap(l)
+	hostMap.preferredRanges.Store(&preferredRanges)
+
+	cs := &CertState{
+		initiatingVersion: cert.Version1,
+		privateKey:        []byte{},
+		v1Cert:            &dummyCert{version: cert.Version1},
+		v1HandshakeBytes:  []byte{},
+	}
+
+	lh := newTestLighthouse()
+	ifce := &Interface{
+		hostMap:          hostMap,
+		inside:           &test.NoopTun{},
+		outside:          &udp.NoopConn{},
+		firewall:         &Firewall{},
+		lightHouse:       lh,
+		pki:              &PKI{},
+		handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
+		l:                l,
+	}
+	ifce.pki.cs.Store(cs)
+
+	// Create manager
+	conf := config.NewC(l)
+	conf.Settings["tunnels"] = map[string]any{
+		"drop_inactive": true,
+	}
+	punchy := NewPunchyFromConfig(l, conf)
+	nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
+	assert.True(t, nc.dropInactive.Load())
+	nc.intf = ifce
+
+	// Add an ip we have established a connection w/ to hostmap
+	hostinfo := &HostInfo{
+		vpnAddrs:      vpnAddrs,
+		localIndexId:  1099,
+		remoteIndexId: 9901,
+	}
+	hostinfo.ConnectionState = &ConnectionState{
+		myCert: &dummyCert{version: cert.Version1},
+		H:      &noise.HandshakeState{},
+	}
+	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
+
+	// Do a traffic check tick, in and out should be cleared but should not be pending deletion
+	nc.Out(hostinfo)
+	nc.In(hostinfo)
+	assert.True(t, hostinfo.out.Load())
+	assert.True(t, hostinfo.in.Load())
+
+	now := time.Now()
+	decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now)
+	assert.Equal(t, tryRehandshake, decision)
+	assert.Equal(t, now, hostinfo.lastUsed)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
+
+	decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5))
+	assert.Equal(t, doNothing, decision)
+	assert.Equal(t, now, hostinfo.lastUsed)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
+
+	// Do another traffic check tick, should still not be pending deletion
+	decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10))
+	assert.Equal(t, doNothing, decision)
+	assert.Equal(t, now, hostinfo.lastUsed)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
+	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
+
+	// Finally advance beyond the inactivity timeout
+	decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10))
+	assert.Equal(t, closeTunnel, decision)
+	assert.Equal(t, now, hostinfo.lastUsed)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 }
@@ -264,10 +360,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	ifce.disconnectInvalid.Store(true)
 
 	// Create manager
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-	punchy := NewPunchyFromConfig(l, config.NewC(l))
-	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
+	conf := config.NewC(l)
+	punchy := NewPunchyFromConfig(l, conf)
+	nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
+	nc.intf = ifce
 	ifce.connectionManager = nc
 
 	hostinfo := &HostInfo{

+ 13 - 10
control.go

@@ -26,14 +26,15 @@ type controlHostLister interface {
 }
 
 type Control struct {
-	f               *Interface
-	l               *logrus.Logger
-	ctx             context.Context
-	cancel          context.CancelFunc
-	sshStart        func()
-	statsStart      func()
-	dnsStart        func()
-	lighthouseStart func()
+	f                      *Interface
+	l                      *logrus.Logger
+	ctx                    context.Context
+	cancel                 context.CancelFunc
+	sshStart               func()
+	statsStart             func()
+	dnsStart               func()
+	lighthouseStart        func()
+	connectionManagerStart func(context.Context)
 }
 
 type ControlHostInfo struct {
@@ -63,6 +64,9 @@ func (c *Control) Start() {
 	if c.dnsStart != nil {
 		go c.dnsStart()
 	}
+	if c.connectionManagerStart != nil {
+		go c.connectionManagerStart(c.ctx)
+	}
 	if c.lighthouseStart != nil {
 		c.lighthouseStart()
 	}
@@ -131,8 +135,7 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 
 // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
 func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
-	_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
-	if found {
+	if c.f.myVpnAddrsTable.Contains(vpnIp) {
 		// Only returning the default certificate since its impossible
 		// for any other host but ourselves to have more than 1
 		return c.f.pki.getCertState().GetDefaultCertificate().Copy()

+ 2 - 2
control_test.go

@@ -53,7 +53,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		localIndexId:  201,
 		vpnAddrs:      []netip.Addr{vpnIp},
 		relayState: RelayState{
-			relays:         map[netip.Addr]struct{}{},
+			relays:         nil,
 			relayForByAddr: map[netip.Addr]*Relay{},
 			relayForByIdx:  map[uint32]*Relay{},
 		},
@@ -72,7 +72,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		localIndexId:  201,
 		vpnAddrs:      []netip.Addr{vpnIp2},
 		relayState: RelayState{
-			relays:         map[netip.Addr]struct{}{},
+			relays:         nil,
 			relayForByAddr: map[netip.Addr]*Relay{},
 			relayForByIdx:  map[uint32]*Relay{},
 		},

+ 3 - 3
dns_server.go

@@ -26,7 +26,7 @@ type dnsRecords struct {
 	dnsMap4         map[string]netip.Addr
 	dnsMap6         map[string]netip.Addr
 	hostMap         *HostMap
-	myVpnAddrsTable *bart.Table[struct{}]
+	myVpnAddrsTable *bart.Lite
 }
 
 func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
@@ -112,8 +112,8 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
 		return true
 	}
 
-	_, found := d.myVpnAddrsTable.Lookup(b)
-	return found //if we found it in this table, it's good
+	//if we found it in this table, it's good
+	return d.myVpnAddrsTable.Contains(b)
 }
 
 func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {

+ 4 - 1
e2e/handshakes_test.go

@@ -506,7 +506,7 @@ func TestReestablishRelays(t *testing.T) {
 	curIndexes := len(myControl.GetHostmap().Indexes)
 	for curIndexes >= start {
 		curIndexes = len(myControl.GetHostmap().Indexes)
-		r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes)
+		r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
 		myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
 
 		r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -1052,6 +1052,9 @@ func TestRehandshakingLoser(t *testing.T) {
 	t.Log("Stand up a tunnel between me and them")
 	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 
+	myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
+	theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
+
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 	r.Log("Renew their certificate and spin until mine sees it")

+ 1 - 0
e2e/router/router.go

@@ -700,6 +700,7 @@ func (r *R) FlushAll() {
 			r.Unlock()
 			panic("Can't FlushAll for host: " + p.To.String())
 		}
+		receiver.InjectUDPPacket(p)
 		r.Unlock()
 	}
 }

+ 57 - 0
e2e/tunnels_test.go

@@ -0,0 +1,57 @@
+//go:build e2e_testing
+// +build e2e_testing
+
+package e2e
+
+import (
+	"testing"
+	"time"
+
+	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/cert_test"
+	"github.com/slackhq/nebula/e2e/router"
+)
+
+func TestDropInactiveTunnels(t *testing.T) {
+	// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
+	// under ideal conditions
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
+
+	// Share our underlay information
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r := router.NewR(t, myControl, theirControl)
+
+	r.Log("Assert the tunnel between me and them works")
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+
+	r.Log("Go inactive and wait for the tunnels to get dropped")
+	waitStart := time.Now()
+	for {
+		myIndexes := len(myControl.GetHostmap().Indexes)
+		theirIndexes := len(theirControl.GetHostmap().Indexes)
+		if myIndexes == 0 && theirIndexes == 0 {
+			break
+		}
+
+		since := time.Since(waitStart)
+		r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
+		if since > time.Second*30 {
+			t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds")
+		}
+
+		time.Sleep(1 * time.Second)
+		r.FlushAll()
+	}
+
+	r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart))
+	myControl.Stop()
+	theirControl.Stop()
+}

+ 18 - 2
examples/config.yml

@@ -13,11 +13,11 @@ pki:
   # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
   #disconnect_invalid: true
 
-  # default_version controls which certificate version is used in handshakes.
+  # initiating_version controls which certificate version is used when initiating handshakes.
   # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
   # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
   # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
-  # default_version: 1
+  # initiating_version: 1
 
 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
 # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
@@ -275,6 +275,10 @@ tun:
   # On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
   # in nebula configuration files. Default false, not reloadable.
   #use_system_route_table: false
+  # Buffer size for reading routes updates. 0 means default system buffer size. (/proc/sys/net/core/rmem_default).
+  # If using massive routes updates, for example BGP, you may need to increase this value to avoid packet loss.
+  # SO_RCVBUFFORCE is used to avoid having to raise the system wide max
+  #use_system_route_table_buffer_size: 0
 
 # Configure logging level
 logging:
@@ -334,6 +338,18 @@ logging:
   # after receiving the response for lighthouse queries
   #trigger_buffer: 64
 
+# Tunnel manager settings
+#tunnels:
+  # drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has
+  # elapsed.
+  # In general, it is a good idea to enable this setting. It will be enabled by default in a future release.
+  # This setting is reloadable
+  #drop_inactive: false
+
+  # inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered
+  # inactive and eligible to be dropped.
+  # This setting is reloadable
+  #inactivity_timeout: 10m
 
 # Nebula security group configuration
 firewall:

+ 14 - 1
examples/go_service/main.go

@@ -5,8 +5,12 @@ import (
 	"fmt"
 	"log"
 	"net"
+	"os"
 
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/service"
 )
 
@@ -59,7 +63,16 @@ pki:
 	if err := cfg.LoadString(configStr); err != nil {
 		return err
 	}
-	svc, err := service.New(&cfg)
+
+	logger := logrus.New()
+	logger.Out = os.Stdout
+
+	ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
+	if err != nil {
+		return err
+	}
+
+	svc, err := service.New(ctrl)
 	if err != nil {
 		return err
 	}

+ 11 - 14
firewall.go

@@ -53,7 +53,7 @@ type Firewall struct {
 
 	// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
 	// The vpn addresses are a full bit match while the unsafe networks only match the prefix
-	routableNetworks *bart.Table[struct{}]
+	routableNetworks *bart.Lite
 
 	// assignedNetworks is a list of vpn networks assigned to us in the certificate.
 	assignedNetworks  []netip.Prefix
@@ -125,7 +125,7 @@ type firewallPort map[int32]*FirewallCA
 
 type firewallLocalCIDR struct {
 	Any       bool
-	LocalCIDR *bart.Table[struct{}]
+	LocalCIDR *bart.Lite
 }
 
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
@@ -148,17 +148,17 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 		tmax = defaultTimeout
 	}
 
-	routableNetworks := new(bart.Table[struct{}])
+	routableNetworks := new(bart.Lite)
 	var assignedNetworks []netip.Prefix
 	for _, network := range c.Networks() {
 		nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
-		routableNetworks.Insert(nprefix, struct{}{})
+		routableNetworks.Insert(nprefix)
 		assignedNetworks = append(assignedNetworks, network)
 	}
 
 	hasUnsafeNetworks := false
 	for _, n := range c.UnsafeNetworks() {
-		routableNetworks.Insert(n, struct{}{})
+		routableNetworks.Insert(n)
 		hasUnsafeNetworks = true
 	}
 
@@ -431,8 +431,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 
 	// Make sure remote address matches nebula certificate
 	if h.networks != nil {
-		_, ok := h.networks.Lookup(fp.RemoteAddr)
-		if !ok {
+		if !h.networks.Contains(fp.RemoteAddr) {
 			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
@@ -445,8 +444,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 	}
 
 	// Make sure we are supposed to be handling this local ip address
-	_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
-	if !ok {
+	if !f.routableNetworks.Contains(fp.LocalAddr) {
 		f.metrics(incoming).droppedLocalAddr.Inc(1)
 		return ErrInvalidLocalIP
 	}
@@ -752,7 +750,7 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
 func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
 	flc := func() *firewallLocalCIDR {
 		return &firewallLocalCIDR{
-			LocalCIDR: new(bart.Table[struct{}]),
+			LocalCIDR: new(bart.Lite),
 		}
 	}
 
@@ -879,7 +877,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
 		}
 
 		for _, network := range f.assignedNetworks {
-			flc.LocalCIDR.Insert(network, struct{}{})
+			flc.LocalCIDR.Insert(network)
 		}
 		return nil
 
@@ -888,7 +886,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
 		return nil
 	}
 
-	flc.LocalCIDR.Insert(localIp, struct{}{})
+	flc.LocalCIDR.Insert(localIp)
 	return nil
 }
 
@@ -901,8 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
 		return true
 	}
 
-	_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
-	return ok
+	return flc.LocalCIDR.Contains(p.LocalAddr)
 }
 
 type rule struct {

+ 197 - 0
firewall_test.go

@@ -68,6 +68,9 @@ func TestFirewall_AddRule(t *testing.T) {
 	ti, err := netip.ParsePrefix("1.2.3.4/32")
 	require.NoError(t, err)
 
+	ti6, err := netip.ParsePrefix("fd12::34/128")
+	require.NoError(t, err)
+
 	require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	// An empty rule is any
 	assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
@@ -92,12 +95,24 @@ func TestFirewall_AddRule(t *testing.T) {
 	_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
 	assert.True(t, ok)
 
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
+	assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
+	_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
+	assert.True(t, ok)
+
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
 	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
 	_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
 	assert.True(t, ok)
 
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
+	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
+	_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
+	assert.True(t, ok)
+
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
@@ -117,6 +132,13 @@ func TestFirewall_AddRule(t *testing.T) {
 	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	anyIp6, err := netip.ParsePrefix("::/0")
+	require.NoError(t, err)
+
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
+
 	// Test error conditions
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -199,6 +221,82 @@ func TestFirewall_Drop(t *testing.T) {
 	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
 }
 
+func TestFirewall_DropV6(t *testing.T) {
+	l := test.NewLogger()
+	ob := &bytes.Buffer{}
+	l.SetOutput(ob)
+
+	p := firewall.Packet{
+		LocalAddr:  netip.MustParseAddr("fd12::34"),
+		RemoteAddr: netip.MustParseAddr("fd12::34"),
+		LocalPort:  10,
+		RemotePort: 90,
+		Protocol:   firewall.ProtoUDP,
+		Fragment:   false,
+	}
+
+	c := dummyCert{
+		name:     "host1",
+		networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
+		groups:   []string{"default-group"},
+		issuer:   "signer-shasum",
+	}
+	h := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &cert.CachedCertificate{
+				Certificate:    &c,
+				InvertedGroups: map[string]struct{}{"default-group": {}},
+			},
+		},
+		vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
+	}
+	h.buildNetworks(c.networks, c.unsafeNetworks)
+
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	cp := cert.NewCAPool()
+
+	// Drop outbound
+	assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
+	// Allow inbound
+	resetConntrack(fw)
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
+	// Allow outbound because conntrack
+	require.NoError(t, fw.Drop(p, false, &h, cp, nil))
+
+	// test remote mismatch
+	oldRemote := p.RemoteAddr
+	p.RemoteAddr = netip.MustParseAddr("fd12::56")
+	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
+	p.RemoteAddr = oldRemote
+
+	// ensure signer doesn't get in the way of group checks
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
+	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
+
+	// test caSha doesn't drop on match
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
+
+	// ensure ca name doesn't get in the way of group checks
+	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
+	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
+
+	// test caName doesn't drop on match
+	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
+}
+
 func BenchmarkFirewallTable_match(b *testing.B) {
 	f := &Firewall{}
 	ft := FirewallTable{
@@ -208,6 +306,10 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	pfix := netip.MustParsePrefix("172.1.1.1/32")
 	_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
 	_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
+
+	pfix6 := netip.MustParsePrefix("fd11::11/128")
+	_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
+	_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
 	cp := cert.NewCAPool()
 
 	b.Run("fail on proto", func(b *testing.B) {
@@ -239,6 +341,15 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
 		}
 	})
+	b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{},
+		}
+		ip := netip.MustParsePrefix("fd99::99/128")
+		for n := 0; n < b.N; n++ {
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
+		}
+	})
 
 	b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
 		c := &cert.CachedCertificate{
@@ -252,6 +363,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
 		}
 	})
+	b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name:     "nope",
+				networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
+			},
+			InvertedGroups: map[string]struct{}{"nope": {}},
+		}
+		for n := 0; n < b.N; n++ {
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
+		}
+	})
 
 	b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
 		c := &cert.CachedCertificate{
@@ -265,6 +388,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 	})
+	b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name:     "nope",
+				networks: []netip.Prefix{netip.MustParsePrefix("fd99:99/128")},
+			},
+			InvertedGroups: map[string]struct{}{"nope": {}},
+		}
+		for n := 0; n < b.N; n++ {
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
+		}
+	})
 
 	b.Run("pass on group on any local cidr", func(b *testing.B) {
 		c := &cert.CachedCertificate{
@@ -289,6 +424,17 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 	})
+	b.Run("pass on group on specific local cidr6", func(b *testing.B) {
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name: "nope",
+			},
+			InvertedGroups: map[string]struct{}{"good-group": {}},
+		}
+		for n := 0; n < b.N; n++ {
+			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
+		}
+	})
 
 	b.Run("pass on name", func(b *testing.B) {
 		c := &cert.CachedCertificate{
@@ -447,6 +593,42 @@ func TestFirewall_Drop3(t *testing.T) {
 	require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
 }
 
+func TestFirewall_Drop3V6(t *testing.T) {
+	l := test.NewLogger()
+	ob := &bytes.Buffer{}
+	l.SetOutput(ob)
+
+	p := firewall.Packet{
+		LocalAddr:  netip.MustParseAddr("fd12::34"),
+		RemoteAddr: netip.MustParseAddr("fd12::34"),
+		LocalPort:  1,
+		RemotePort: 1,
+		Protocol:   firewall.ProtoUDP,
+		Fragment:   false,
+	}
+
+	network := netip.MustParsePrefix("fd12::34/120")
+	c := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host-owner",
+			networks: []netip.Prefix{network},
+		},
+	}
+	h := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c,
+		},
+		vpnAddrs: []netip.Addr{network.Addr()},
+	}
+	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
+
+	// Test a remote address match
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
+	cp := cert.NewCAPool()
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
+}
+
 func TestFirewall_DropConntrackReload(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
@@ -727,6 +909,21 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
 
+	// Test adding rule with cidr ipv6
+	cidr6 := netip.MustParsePrefix("fd00::/8")
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
+
+	// Test adding rule with local_cidr ipv6
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
+
 	// Test adding rule with ca_sha
 	conf = config.NewC(l)
 	mf = &mockFirewall{}

+ 12 - 13
go.mod

@@ -1,35 +1,35 @@
 module github.com/slackhq/nebula
 
-go 1.23.6
+go 1.23.0
 
 toolchain go1.24.1
 
 require (
-	dario.cat/mergo v1.0.1
+	dario.cat/mergo v1.0.2
 	github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
 	github.com/armon/go-radix v1.0.0
 	github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
 	github.com/flynn/noise v1.1.0
-	github.com/gaissmai/bart v0.20.1
+	github.com/gaissmai/bart v0.20.4
 	github.com/gogo/protobuf v1.3.2
 	github.com/google/gopacket v1.1.19
 	github.com/kardianos/service v1.2.2
-	github.com/miekg/dns v1.1.64
+	github.com/miekg/dns v1.1.65
 	github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
 	github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
-	github.com/prometheus/client_golang v1.21.1
+	github.com/prometheus/client_golang v1.22.0
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
 	github.com/sirupsen/logrus v1.9.3
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
 	github.com/stretchr/testify v1.10.0
-	github.com/vishvananda/netlink v1.3.0
-	golang.org/x/crypto v0.36.0
+	github.com/vishvananda/netlink v1.3.1
+	golang.org/x/crypto v0.37.0
 	golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
-	golang.org/x/net v0.38.0
-	golang.org/x/sync v0.12.0
-	golang.org/x/sys v0.31.0
-	golang.org/x/term v0.30.0
+	golang.org/x/net v0.39.0
+	golang.org/x/sync v0.13.0
+	golang.org/x/sys v0.32.0
+	golang.org/x/term v0.31.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
 	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
 	golang.zx2c4.com/wireguard/windows v0.5.3
@@ -43,13 +43,12 @@ require (
 	github.com/cespare/xxhash/v2 v2.3.0 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/google/btree v1.1.2 // indirect
-	github.com/klauspost/compress v1.17.11 // indirect
 	github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/prometheus/client_model v0.6.1 // indirect
 	github.com/prometheus/common v0.62.0 // indirect
 	github.com/prometheus/procfs v0.15.1 // indirect
-	github.com/vishvananda/netns v0.0.4 // indirect
+	github.com/vishvananda/netns v0.0.5 // indirect
 	golang.org/x/mod v0.23.0 // indirect
 	golang.org/x/time v0.5.0 // indirect
 	golang.org/x/tools v0.30.0 // indirect

+ 26 - 26
go.sum

@@ -1,6 +1,6 @@
 cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
-dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
-dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
+dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
+dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
 github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
 github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
 github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
 github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
-github.com/gaissmai/bart v0.20.1 h1:igNss0zDsSY8e+ophKgD9KJVPKBOo7uSVjyKCL7nIzo=
-github.com/gaissmai/bart v0.20.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY=
+github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U=
+github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk=
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -53,8 +53,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
 github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
-github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
+github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
 github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
 github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
@@ -68,8 +68,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX
 github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
 github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
 github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
-github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
-github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
+github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
+github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
 github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
 github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
-github.com/miekg/dns v1.1.64 h1:wuZgD9wwCE6XMT05UU/mlSko71eRSXEAm2EbjQXLKnQ=
-github.com/miekg/dns v1.1.64/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
+github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc=
+github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -106,8 +106,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
 github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
 github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
 github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
-github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
-github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
+github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
+github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
 github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -145,10 +145,10 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
 github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
-github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
-github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
-github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
-github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
+github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
+github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
+github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
+github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
 github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
@@ -156,8 +156,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
-golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
-golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
+golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
+golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
 golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -176,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
 golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
 golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
-golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
+golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
+golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
 golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -185,8 +185,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
-golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
+golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
+golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
 golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -204,11 +204,11 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc
 golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
-golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
+golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
-golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
+golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
+golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=

+ 10 - 11
handshake_ix.go

@@ -25,7 +25,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 
 	// If we're connecting to a v6 address we must use a v2 cert
 	cs := f.pki.getCertState()
-	v := cs.defaultVersion
+	v := cs.initiatingVersion
 	for _, a := range hh.hostinfo.vpnAddrs {
 		if a.Is6() {
 			v = cert.Version2
@@ -101,7 +101,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	if crt == nil {
 		f.l.WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
-			WithField("certVersion", cs.defaultVersion).
+			WithField("certVersion", cs.initiatingVersion).
 			Error("Unable to handshake with host because no certificate is available")
 	}
 
@@ -192,8 +192,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	for _, network := range remoteCert.Certificate.Networks() {
 		vpnAddr := network.Addr()
-		_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
-		if found {
+		if f.myVpnAddrsTable.Contains(vpnAddr) {
 			f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certVersion", certVersion).
@@ -204,7 +203,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		}
 
 		// vpnAddrs outside our vpn networks are of no use to us, filter them out
-		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+		if !f.myVpnNetworksTable.Contains(vpnAddr) {
 			continue
 		}
 
@@ -250,7 +249,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		lastHandshakeTime: hs.Details.Time,
 		relayState: RelayState{
-			relays:         map[netip.Addr]struct{}{},
+			relays:         nil,
 			relayForByAddr: map[netip.Addr]*Relay{},
 			relayForByIdx:  map[uint32]*Relay{},
 		},
@@ -458,9 +457,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			Info("Handshake message sent")
 	}
 
-	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
+	f.connectionManager.AddTrafficWatch(hostinfo)
 
-	hostinfo.remotes.ResetBlockedRemotes()
+	hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
 
 	return
 }
@@ -579,7 +578,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	for _, network := range vpnNetworks {
 		// vpnAddrs outside our vpn networks are of no use to us, filter them out
 		vpnAddr := network.Addr()
-		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+		if !f.myVpnNetworksTable.Contains(vpnAddr) {
 			continue
 		}
 
@@ -653,7 +652,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 
 	// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
 	f.handshakeManager.Complete(hostinfo, f)
-	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
+	f.connectionManager.AddTrafficWatch(hostinfo)
 
 	if f.l.Level >= logrus.DebugLevel {
 		hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
@@ -668,7 +667,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
 	}
 
-	hostinfo.remotes.ResetBlockedRemotes()
+	hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
 	f.metricHandshakes.Update(duration)
 
 	return false

+ 2 - 3
handshake_manager.go

@@ -274,8 +274,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 			}
 
 			// Don't relay through the host I'm trying to connect to
-			_, found := hm.f.myVpnAddrsTable.Lookup(relay)
-			if found {
+			if hm.f.myVpnAddrsTable.Contains(relay) {
 				continue
 			}
 
@@ -451,7 +450,7 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han
 		vpnAddrs:        []netip.Addr{vpnAddr},
 		HandshakePacket: make(map[uint8][]byte, 0),
 		relayState: RelayState{
-			relays:         map[netip.Addr]struct{}{},
+			relays:         nil,
 			relayForByAddr: map[netip.Addr]*Relay{},
 			relayForByIdx:  map[uint32]*Relay{},
 		},

+ 5 - 5
handshake_manager_test.go

@@ -24,10 +24,10 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	lh := newTestLighthouse()
 
 	cs := &CertState{
-		defaultVersion:   cert.Version1,
-		privateKey:       []byte{},
-		v1Cert:           &dummyCert{version: cert.Version1},
-		v1HandshakeBytes: []byte{},
+		initiatingVersion: cert.Version1,
+		privateKey:        []byte{},
+		v1Cert:            &dummyCert{version: cert.Version1},
+		v1HandshakeBytes:  []byte{},
 	}
 
 	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@@ -98,5 +98,5 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
 }
 
 func (mw *mockEncWriter) GetCertState() *CertState {
-	return &CertState{defaultVersion: cert.Version2}
+	return &CertState{initiatingVersion: cert.Version2}
 }

+ 26 - 22
hostmap.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"net"
 	"net/netip"
+	"slices"
 	"sync"
 	"sync/atomic"
 	"time"
@@ -16,12 +17,10 @@ import (
 	"github.com/slackhq/nebula/header"
 )
 
-// const ProbeLen = 100
 const defaultPromoteEvery = 1000       // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
 const defaultReQueryEvery = 5000       // Count of packets sent before re-querying a hostinfo to the lighthouse
 const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
 const MaxRemotes = 10
-const maxRecvError = 4
 
 // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
 // 5 allows for an initial handshake and each host pair re-handshaking twice
@@ -68,7 +67,7 @@ type HostMap struct {
 type RelayState struct {
 	sync.RWMutex
 
-	relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
+	relays []netip.Addr // Ordered set of VpnAddrs of Hosts to use as relays to access this peer
 	// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
 	// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
 	// the RelayState Lock held)
@@ -79,7 +78,12 @@ type RelayState struct {
 func (rs *RelayState) DeleteRelay(ip netip.Addr) {
 	rs.Lock()
 	defer rs.Unlock()
-	delete(rs.relays, ip)
+	for idx, val := range rs.relays {
+		if val == ip {
+			rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...)
+			return
+		}
+	}
 }
 
 func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
@@ -124,16 +128,16 @@ func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
 func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
 	rs.Lock()
 	defer rs.Unlock()
-	rs.relays[ip] = struct{}{}
+	if !slices.Contains(rs.relays, ip) {
+		rs.relays = append(rs.relays, ip)
+	}
 }
 
 func (rs *RelayState) CopyRelayIps() []netip.Addr {
+	ret := make([]netip.Addr, len(rs.relays))
 	rs.RLock()
 	defer rs.RUnlock()
-	ret := make([]netip.Addr, 0, len(rs.relays))
-	for ip := range rs.relays {
-		ret = append(ret, ip)
-	}
+	copy(ret, rs.relays)
 	return ret
 }
 
@@ -219,11 +223,10 @@ type HostInfo struct {
 	// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
 	// The host may have other vpn addresses that are outside our
 	// vpn networks but were removed because they are not usable
-	vpnAddrs  []netip.Addr
-	recvError atomic.Uint32
+	vpnAddrs []netip.Addr
 
 	// networks are both all vpn and unsafe networks assigned to this host
-	networks   *bart.Table[struct{}]
+	networks   *bart.Lite
 	relayState RelayState
 
 	// HandshakePacket records the packets used to create this hostinfo
@@ -250,6 +253,14 @@ type HostInfo struct {
 	// Used to track other hostinfos for this vpn ip since only 1 can be primary
 	// Synchronised via hostmap lock and not the hostinfo lock.
 	next, prev *HostInfo
+
+	//TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing
+	in, out, pendingDeletion atomic.Bool
+
+	// lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use.
+	// This value will be behind against actual tunnel utilization in the hot path.
+	// This should only be used by the ConnectionManagers ticker routine.
+	lastUsed time.Time
 }
 
 type ViaSender struct {
@@ -719,26 +730,19 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
 	return false
 }
 
-func (i *HostInfo) RecvErrorExceeded() bool {
-	if i.recvError.Add(1) >= maxRecvError {
-		return true
-	}
-	return true
-}
-
 func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
 	if len(networks) == 1 && len(unsafeNetworks) == 0 {
 		// Simple case, no CIDRTree needed
 		return
 	}
 
-	i.networks = new(bart.Table[struct{}])
+	i.networks = new(bart.Lite)
 	for _, network := range networks {
-		i.networks.Insert(network, struct{}{})
+		i.networks.Insert(network)
 	}
 
 	for _, network := range unsafeNetworks {
-		i.networks.Insert(network, struct{}{})
+		i.networks.Insert(network)
 	}
 }
 

+ 29 - 0
hostmap_test.go

@@ -7,6 +7,7 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestHostMap_MakePrimary(t *testing.T) {
@@ -215,3 +216,31 @@ func TestHostMap_reload(t *testing.T) {
 	c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
 	assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
 }
+
+func TestHostMap_RelayState(t *testing.T) {
+	h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
+	a1 := netip.MustParseAddr("::1")
+	a2 := netip.MustParseAddr("2001::1")
+
+	h1.relayState.InsertRelayTo(a1)
+	assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
+	h1.relayState.InsertRelayTo(a2)
+	assert.Equal(t, []netip.Addr{a1, a2}, h1.relayState.relays)
+	// Ensure that the first relay added is the first one returned in the copy
+	currentRelays := h1.relayState.CopyRelayIps()
+	require.Len(t, currentRelays, 2)
+	assert.Equal(t, a1, currentRelays[0])
+
+	// Deleting the last one in the list works ok
+	h1.relayState.DeleteRelay(a2)
+	assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
+
+	// Deleting an element not in the list works ok
+	h1.relayState.DeleteRelay(a2)
+	assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
+
+	// Deleting the only element in the list works ok
+	h1.relayState.DeleteRelay(a1)
+	assert.Equal(t, []netip.Addr{}, h1.relayState.relays)
+
+}

+ 5 - 8
inside.go

@@ -22,14 +22,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 
 	// Ignore local broadcast packets
 	if f.dropLocalBroadcast {
-		_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
-		if found {
+		if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
 			return
 		}
 	}
 
-	_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
-	if found {
+	if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
 		// Immediately forward packets from self to self.
 		// This should only happen on Darwin-based and FreeBSD hosts, which
 		// routes packets from the Nebula addr to the Nebula addr through the Nebula
@@ -130,8 +128,7 @@ func (f *Interface) Handshake(vpnAddr netip.Addr) {
 // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
 func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
-	_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
-	if found {
+	if f.myVpnNetworksTable.Contains(vpnAddr) {
 		return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
 	}
 
@@ -291,7 +288,7 @@ func (f *Interface) SendVia(via *HostInfo,
 	c := via.ConnectionState.messageCounter.Add(1)
 
 	out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
-	f.connectionManager.Out(via.localIndexId)
+	f.connectionManager.Out(via)
 
 	// Authenticate the header and payload, but do not encrypt for this message type.
 	// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
@@ -359,7 +356,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 
 	//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
 	out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
-	f.connectionManager.Out(hostinfo.localIndexId)
+	f.connectionManager.Out(hostinfo)
 
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
 	// all our addrs and enable a faster roaming.

+ 29 - 26
interface.go

@@ -24,23 +24,23 @@ import (
 const mtu = 9001
 
 type InterfaceConfig struct {
-	HostMap                 *HostMap
-	Outside                 udp.Conn
-	Inside                  overlay.Device
-	pki                     *PKI
-	Firewall                *Firewall
-	ServeDns                bool
-	HandshakeManager        *HandshakeManager
-	lightHouse              *LightHouse
-	checkInterval           time.Duration
-	pendingDeletionInterval time.Duration
-	DropLocalBroadcast      bool
-	DropMulticast           bool
-	routines                int
-	MessageMetrics          *MessageMetrics
-	version                 string
-	relayManager            *relayManager
-	punchy                  *Punchy
+	HostMap            *HostMap
+	Outside            udp.Conn
+	Inside             overlay.Device
+	pki                *PKI
+	Cipher             string
+	Firewall           *Firewall
+	ServeDns           bool
+	HandshakeManager   *HandshakeManager
+	lightHouse         *LightHouse
+	connectionManager  *connectionManager
+	DropLocalBroadcast bool
+	DropMulticast      bool
+	routines           int
+	MessageMetrics     *MessageMetrics
+	version            string
+	relayManager       *relayManager
+	punchy             *Punchy
 
 	tryPromoteEvery uint32
 	reQueryEvery    uint32
@@ -61,11 +61,11 @@ type Interface struct {
 	serveDns              bool
 	createTime            time.Time
 	lightHouse            *LightHouse
-	myBroadcastAddrsTable *bart.Table[struct{}]
-	myVpnAddrs            []netip.Addr          // A list of addresses assigned to us via our certificate
-	myVpnAddrsTable       *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
-	myVpnNetworks         []netip.Prefix        // A list of networks assigned to us via our certificate
-	myVpnNetworksTable    *bart.Table[struct{}] // A table of networks assigned to us via our certificate
+	myBroadcastAddrsTable *bart.Lite
+	myVpnAddrs            []netip.Addr // A list of addresses assigned to us via our certificate
+	myVpnAddrsTable       *bart.Lite
+	myVpnNetworks         []netip.Prefix // A list of networks assigned to us via our certificate
+	myVpnNetworksTable    *bart.Lite
 	dropLocalBroadcast    bool
 	dropMulticast         bool
 	routines              int
@@ -157,6 +157,9 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	if c.Firewall == nil {
 		return nil, errors.New("no firewall rules")
 	}
+	if c.connectionManager == nil {
+		return nil, errors.New("no connection manager")
+	}
 
 	cs := c.pki.getCertState()
 	ifce := &Interface{
@@ -181,7 +184,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		myVpnAddrsTable:       cs.myVpnAddrsTable,
 		myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
 		relayManager:          c.relayManager,
-
+		connectionManager:     c.connectionManager,
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 
 		metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
@@ -198,7 +201,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
 
-	ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
+	ifce.connectionManager.intf = ifce
 
 	return ifce, nil
 }
@@ -410,7 +413,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 	udpStats := udp.NewUDPStatsEmitter(f.writers)
 
 	certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
-	certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil)
+	certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil)
 	certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
 
 	for {
@@ -425,7 +428,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 			certState := f.pki.getCertState()
 			defaultCrt := certState.GetDefaultCertificate()
 			certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
-			certDefaultVersion.Update(int64(defaultCrt.Version()))
+			certInitiatingVersion.Update(int64(defaultCrt.Version()))
 
 			// Report the max certificate version we are capable of using
 			if certState.v2Cert != nil {

+ 114 - 125
lighthouse.go

@@ -24,6 +24,7 @@ import (
 )
 
 var ErrHostNotKnown = errors.New("host not known")
+var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
 
 type LightHouse struct {
 	//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
@@ -32,7 +33,7 @@ type LightHouse struct {
 	amLighthouse bool
 
 	myVpnNetworks      []netip.Prefix
-	myVpnNetworksTable *bart.Table[struct{}]
+	myVpnNetworksTable *bart.Lite
 	punchConn          udp.Conn
 	punchy             *Punchy
 
@@ -56,7 +57,7 @@ type LightHouse struct {
 	// staticList exists to avoid having a bool in each addrMap entry
 	// since static should be rare
 	staticList  atomic.Pointer[map[netip.Addr]struct{}]
-	lighthouses atomic.Pointer[map[netip.Addr]struct{}]
+	lighthouses atomic.Pointer[[]netip.Addr]
 
 	interval     atomic.Int64
 	updateCancel context.CancelFunc
@@ -107,7 +108,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
 		queryChan:          make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
 		l:                  l,
 	}
-	lighthouses := make(map[netip.Addr]struct{})
+	lighthouses := make([]netip.Addr, 0)
 	h.lighthouses.Store(&lighthouses)
 	staticList := make(map[netip.Addr]struct{})
 	h.staticList.Store(&staticList)
@@ -143,7 +144,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
 	return *lh.staticList.Load()
 }
 
-func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
+func (lh *LightHouse) GetLighthouses() []netip.Addr {
 	return *lh.lighthouses.Load()
 }
 
@@ -201,8 +202,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 
 			//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
 			addr := addrs[0].Unmap()
-			_, found := lh.myVpnNetworksTable.Lookup(addr)
-			if found {
+			if lh.myVpnNetworksTable.Contains(addr) {
 				lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
 					Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
 				continue
@@ -307,13 +307,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	}
 
 	if initial || c.HasChanged("lighthouse.hosts") {
-		lhMap := make(map[netip.Addr]struct{})
-		err := lh.parseLighthouses(c, lhMap)
+		lhList, err := lh.parseLighthouses(c)
 		if err != nil {
 			return err
 		}
 
-		lh.lighthouses.Store(&lhMap)
+		lh.lighthouses.Store(&lhList)
 		if !initial {
 			//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
 			lh.l.Info("lighthouse.hosts has changed")
@@ -347,37 +346,37 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
+func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
 	lhs := c.GetStringSlice("lighthouse.hosts", []string{})
 	if lh.amLighthouse && len(lhs) != 0 {
 		lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
 	}
+	out := make([]netip.Addr, len(lhs))
 
 	for i, host := range lhs {
 		addr, err := netip.ParseAddr(host)
 		if err != nil {
-			return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
+			return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
 		}
 
-		_, found := lh.myVpnNetworksTable.Lookup(addr)
-		if !found {
-			return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
+		if !lh.myVpnNetworksTable.Contains(addr) {
+			return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
 		}
-		lhMap[addr] = struct{}{}
+		out[i] = addr
 	}
 
-	if !lh.amLighthouse && len(lhMap) == 0 {
+	if !lh.amLighthouse && len(out) == 0 {
 		lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
 	}
 
 	staticList := lh.GetStaticHostList()
-	for lhAddr, _ := range lhMap {
-		if _, ok := staticList[lhAddr]; !ok {
-			return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
+	for i := range out {
+		if _, ok := staticList[out[i]]; !ok {
+			return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
 		}
 	}
 
-	return nil
+	return out, nil
 }
 
 func getStaticMapCadence(c *config.C) (time.Duration, error) {
@@ -431,8 +430,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
 			return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
 		}
 
-		_, found := lh.myVpnNetworksTable.Lookup(vpnAddr)
-		if !found {
+		if !lh.myVpnNetworksTable.Contains(vpnAddr) {
 			return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
 		}
 
@@ -489,7 +487,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
 	lh.Lock()
 	defer lh.Unlock()
 	// Add an entry if we don't already have one
-	return lh.unlockedGetRemoteList(vpnAddrs)
+	return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
 }
 
 // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
@@ -522,11 +520,15 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
 }
 
 func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
-	// First we check the static mapping
-	// and do nothing if it is there
-	if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok {
-		return
+	// First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
+	staticList := lh.GetStaticHostList()
+	for _, addr := range allVpnAddrs {
+		if _, ok := staticList[addr]; ok {
+			return
+		}
 	}
+
+	// None of the VpnAddrs were present. Now we can do the deletes.
 	lh.Lock()
 	rm, ok := lh.addrMap[allVpnAddrs[0]]
 	if ok {
@@ -568,7 +570,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
 	am.unlockedSetHostnamesResults(hr)
 
 	for _, addrPort := range hr.GetAddrs() {
-		if !lh.shouldAdd(vpnAddr, addrPort.Addr()) {
+		if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
 			continue
 		}
 		switch {
@@ -630,31 +632,37 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
 	return len(calculatedV4) > 0 || len(calculatedV6) > 0
 }
 
-// unlockedGetRemoteList
-// assumes you have the lh lock
+// unlockedGetRemoteList assumes you have the lh lock
 func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
-	am, ok := lh.addrMap[allAddrs[0]]
-	if !ok {
-		am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
-		for _, addr := range allAddrs {
-			lh.addrMap[addr] = am
+	// before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
+	for i, addr := range allAddrs {
+		am, ok := lh.addrMap[addr]
+		if ok {
+			if i != 0 {
+				lh.addrMap[allAddrs[0]] = am
+			}
+			return am
 		}
 	}
+
+	am := NewRemoteList(allAddrs, lh.shouldAdd)
+	for _, addr := range allAddrs {
+		lh.addrMap[addr] = am
+	}
 	return am
 }
 
-func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
-	allow := lh.GetRemoteAllowList().Allow(vpnAddr, to)
+func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
+	allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
 	if lh.l.Level >= logrus.TraceLevel {
-		lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow).
+		lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
 			Trace("remoteAllowList.Allow")
 	}
 	if !allow {
 		return false
 	}
 
-	_, found := lh.myVpnNetworksTable.Lookup(to)
-	if found {
+	if lh.myVpnNetworksTable.Contains(to) {
 		return false
 	}
 
@@ -674,8 +682,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
 		return false
 	}
 
-	_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
-	if found {
+	if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
 		return false
 	}
 
@@ -695,8 +702,7 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
 		return false
 	}
 
-	_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
-	if found {
+	if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
 		return false
 	}
 
@@ -704,19 +710,22 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
 }
 
 func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
-	if _, ok := lh.GetLighthouses()[vpnAddr]; ok {
-		return true
+	l := lh.GetLighthouses()
+	for i := range l {
+		if l[i] == vpnAddr {
+			return true
+		}
 	}
 	return false
 }
 
-// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake
-// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
-func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
+func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
 	l := lh.GetLighthouses()
-	for _, a := range vpnAddr {
-		if _, ok := l[a]; ok {
-			return true
+	for i := range vpnAddrs {
+		for j := range l {
+			if l[j] == vpnAddrs[i] {
+				return true
+			}
 		}
 	}
 	return false
@@ -758,12 +767,12 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
 	queried := 0
 	lighthouses := lh.GetLighthouses()
 
-	for lhVpnAddr := range lighthouses {
+	for _, lhVpnAddr := range lighthouses {
 		hi := lh.ifce.GetHostInfo(lhVpnAddr)
 		if hi != nil {
 			v = hi.ConnectionState.myCert.Version()
 		} else {
-			v = lh.ifce.GetCertState().defaultVersion
+			v = lh.ifce.GetCertState().initiatingVersion
 		}
 
 		if v == cert.Version1 {
@@ -856,8 +865,7 @@ func (lh *LightHouse) SendUpdate() {
 
 	lal := lh.GetLocalAllowList()
 	for _, e := range localAddrs(lh.l, lal) {
-		_, found := lh.myVpnNetworksTable.Lookup(e)
-		if found {
+		if lh.myVpnNetworksTable.Contains(e) {
 			continue
 		}
 
@@ -877,13 +885,13 @@ func (lh *LightHouse) SendUpdate() {
 	updated := 0
 	lighthouses := lh.GetLighthouses()
 
-	for lhVpnAddr := range lighthouses {
+	for _, lhVpnAddr := range lighthouses {
 		var v cert.Version
 		hi := lh.ifce.GetHostInfo(lhVpnAddr)
 		if hi != nil {
 			v = hi.ConnectionState.myCert.Version()
 		} else {
-			v = lh.ifce.GetCertState().defaultVersion
+			v = lh.ifce.GetCertState().initiatingVersion
 		}
 		if v == cert.Version1 {
 			if v1Update == nil {
@@ -1055,17 +1063,8 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
 		return
 	}
 
-	useVersion := cert.Version1
-	var queryVpnAddr netip.Addr
-	if n.Details.OldVpnAddr != 0 {
-		b := [4]byte{}
-		binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-		queryVpnAddr = netip.AddrFrom4(b)
-		useVersion = 1
-	} else if n.Details.VpnAddr != nil {
-		queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
-		useVersion = 2
-	} else {
+	queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
+	if err != nil {
 		if lhh.l.Level >= logrus.DebugLevel {
 			lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
 		}
@@ -1114,7 +1113,7 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
 		targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
 		var useVersion cert.Version
 		if targetHI == nil {
-			useVersion = lhh.lh.ifce.GetCertState().defaultVersion
+			useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
 		} else {
 			crt := targetHI.GetCert().Certificate
 			useVersion = crt.Version()
@@ -1123,8 +1122,9 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
 			if ok {
 				whereToPunch = newDest
 			} else {
-				//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee
-				//choosing to do nothing for now, but maybe we return an error?
+				if lhh.l.Level >= logrus.DebugLevel {
+					lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
+				}
 			}
 		}
 
@@ -1183,19 +1183,17 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
 				if !r.Is4() {
 					continue
 				}
-
 				b = r.As4()
 				n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
 			}
-
 		} else if v == cert.Version2 {
 			for _, r := range c.relay.relay {
 				n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
 			}
-
 		} else {
-			//TODO: CERT-V2 don't panic
-			panic("unsupported version")
+			if lhh.l.Level >= logrus.DebugLevel {
+				lhh.l.WithField("version", v).Debug("unsupported protocol version")
+			}
 		}
 	}
 }
@@ -1205,18 +1203,16 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
 		return
 	}
 
-	lhh.lh.Lock()
-
-	var certVpnAddr netip.Addr
-	if n.Details.OldVpnAddr != 0 {
-		b := [4]byte{}
-		binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-		certVpnAddr = netip.AddrFrom4(b)
-	} else if n.Details.VpnAddr != nil {
-		certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
+	certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
+	if err != nil {
+		if lhh.l.Level >= logrus.DebugLevel {
+			lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
+		}
+		return
 	}
 	relays := n.Details.GetRelays()
 
+	lhh.lh.Lock()
 	am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
 	am.Lock()
 	lhh.lh.Unlock()
@@ -1241,24 +1237,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
 		return
 	}
 
-	var detailsVpnAddr netip.Addr
-	useVersion := cert.Version1
-	if n.Details.OldVpnAddr != 0 {
-		b := [4]byte{}
-		binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-		detailsVpnAddr = netip.AddrFrom4(b)
-		useVersion = cert.Version1
-	} else if n.Details.VpnAddr != nil {
-		detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
-		useVersion = cert.Version2
-	} else {
+	detailsVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
+	if err != nil {
 		if lhh.l.Level >= logrus.DebugLevel {
-			lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification")
+			lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostUpdateNotification")
 		}
-		return
 	}
 
-	//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
 	//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
 	//Simple check that the host sent this not someone else
 	if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
@@ -1310,13 +1295,20 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
 func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
 	//It's possible the lighthouse is communicating with us using a non primary vpn addr,
 	//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
-	//maybe one day we'll have a better idea, if it matters.
 	if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
 		return
 	}
 
+	detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
+	if err != nil {
+		if lhh.l.Level >= logrus.DebugLevel {
+			lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
+		}
+		return
+	}
+
 	empty := []byte{0}
-	punch := func(vpnPeer netip.AddrPort) {
+	punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
 		if !vpnPeer.IsValid() {
 			return
 		}
@@ -1328,39 +1320,22 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
 		}()
 
 		if lhh.l.Level >= logrus.DebugLevel {
-			var logVpnAddr netip.Addr
-			if n.Details.OldVpnAddr != 0 {
-				b := [4]byte{}
-				binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-				logVpnAddr = netip.AddrFrom4(b)
-			} else if n.Details.VpnAddr != nil {
-				logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
-			}
 			lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
 		}
 	}
 
-	var queryVpnAddr netip.Addr
-	if n.Details.OldVpnAddr != 0 {
-		b := [4]byte{}
-		binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-		queryVpnAddr = netip.AddrFrom4(b)
-	} else if n.Details.VpnAddr != nil {
-		queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
-	}
-
 	remoteAllowList := lhh.lh.GetRemoteAllowList()
 	for _, a := range n.Details.V4AddrPorts {
 		b := protoV4AddrPortToNetAddrPort(a)
-		if remoteAllowList.Allow(queryVpnAddr, b.Addr()) {
-			punch(b)
+		if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
+			punch(b, detailsVpnAddr)
 		}
 	}
 
 	for _, a := range n.Details.V6AddrPorts {
 		b := protoV6AddrPortToNetAddrPort(a)
-		if remoteAllowList.Allow(queryVpnAddr, b.Addr()) {
-			punch(b)
+		if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
+			punch(b, detailsVpnAddr)
 		}
 	}
 
@@ -1371,12 +1346,12 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
 		go func() {
 			time.Sleep(lhh.lh.punchy.GetRespondDelay())
 			if lhh.l.Level >= logrus.DebugLevel {
-				lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr)
+				lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
 			}
 			//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
 			// for each punchBack packet. We should move this into a timerwheel or a single goroutine
 			// managed by a channel.
-			w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 		}()
 	}
 }
@@ -1455,3 +1430,17 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
 	}
 	return netip.Addr{}, false
 }
+
+func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
+	if d.OldVpnAddr != 0 {
+		b := [4]byte{}
+		binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
+		detailsVpnAddr := netip.AddrFrom4(b)
+		return detailsVpnAddr, cert.Version1, nil
+	} else if d.VpnAddr != nil {
+		detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
+		return detailsVpnAddr, cert.Version2, nil
+	} else {
+		return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
+	}
+}

+ 131 - 11
lighthouse_test.go

@@ -31,8 +31,8 @@ func TestOldIPv4Only(t *testing.T) {
 func Test_lhStaticMapping(t *testing.T) {
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
-	nt := new(bart.Table[struct{}])
-	nt.Insert(myVpnNet, struct{}{})
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
 	cs := &CertState{
 		myVpnNetworks:      []netip.Prefix{myVpnNet},
 		myVpnNetworksTable: nt,
@@ -56,8 +56,8 @@ func Test_lhStaticMapping(t *testing.T) {
 func TestReloadLighthouseInterval(t *testing.T) {
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
-	nt := new(bart.Table[struct{}])
-	nt.Insert(myVpnNet, struct{}{})
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
 	cs := &CertState{
 		myVpnNetworks:      []netip.Prefix{myVpnNet},
 		myVpnNetworksTable: nt,
@@ -91,8 +91,8 @@ func TestReloadLighthouseInterval(t *testing.T) {
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
-	nt := new(bart.Table[struct{}])
-	nt.Insert(myVpnNet, struct{}{})
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
 	cs := &CertState{
 		myVpnNetworks:      []netip.Prefix{myVpnNet},
 		myVpnNetworksTable: nt,
@@ -196,8 +196,8 @@ func TestLighthouse_Memory(t *testing.T) {
 	c.Settings["listen"] = map[string]any{"port": 4242}
 
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
-	nt := new(bart.Table[struct{}])
-	nt.Insert(myVpnNet, struct{}{})
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
 	cs := &CertState{
 		myVpnNetworks:      []netip.Prefix{myVpnNet},
 		myVpnNetworksTable: nt,
@@ -281,8 +281,8 @@ func TestLighthouse_reload(t *testing.T) {
 	c.Settings["listen"] = map[string]any{"port": 4242}
 
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
-	nt := new(bart.Table[struct{}])
-	nt.Insert(myVpnNet, struct{}{})
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
 	cs := &CertState{
 		myVpnNetworks:      []netip.Prefix{myVpnNet},
 		myVpnNetworksTable: nt,
@@ -417,7 +417,7 @@ func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
 }
 
 func (tw *testEncWriter) GetCertState() *CertState {
-	return &CertState{defaultVersion: tw.protocolVersion}
+	return &CertState{initiatingVersion: tw.protocolVersion}
 }
 
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
@@ -493,3 +493,123 @@ func Test_findNetworkUnion(t *testing.T) {
 	out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
 	assert.False(t, ok)
 }
+
+func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
+	l := test.NewLogger()
+
+	myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
+
+	testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
+	testStaticHost := netip.MustParseAddr("10.128.0.42")
+	//myVpnIp := netip.MustParseAddr("10.128.0.2")
+
+	c := config.NewC(l)
+	lh1 := "10.128.0.2"
+	c.Settings["lighthouse"] = map[string]any{
+		"hosts":    []any{lh1},
+		"interval": "1s",
+	}
+
+	c.Settings["listen"] = map[string]any{"port": 4242}
+	c.Settings["static_host_map"] = map[string]any{
+		lh1:           []any{"1.1.1.1:4242"},
+		"10.128.0.42": []any{"1.2.3.4:4242"},
+	}
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	require.NoError(t, err)
+	lh.ifce = &mockEncWriter{}
+
+	//test that we actually have the static entry:
+	out := lh.Query(testStaticHost)
+	assert.NotNil(t, out)
+	assert.Equal(t, out.vpnAddrs[0], testStaticHost)
+	out.Rebuild([]netip.Prefix{}) //why tho
+	assert.Equal(t, out.addrs[0], myUdpAddr2)
+
+	//bolt on a lower numbered primary IP
+	am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
+	am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
+	lh.addrMap[testSameHostNotStatic] = am
+	out.Rebuild([]netip.Prefix{}) //???
+
+	//test that we actually have the static entry:
+	out = lh.Query(testStaticHost)
+	assert.NotNil(t, out)
+	assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
+	assert.Equal(t, out.vpnAddrs[1], testStaticHost)
+	assert.Equal(t, out.addrs[0], myUdpAddr2)
+
+	//test that we actually have the static entry for BOTH:
+	out2 := lh.Query(testSameHostNotStatic)
+	assert.Same(t, out2, out)
+
+	//now do the delete
+	lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
+	//verify
+	out = lh.Query(testSameHostNotStatic)
+	assert.NotNil(t, out)
+	if out == nil {
+		t.Fatal("expected non-nil query for the static host")
+	}
+	assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
+	assert.Equal(t, out.vpnAddrs[1], testStaticHost)
+	assert.Equal(t, out.addrs[0], myUdpAddr2)
+}
+
+func TestLighthouse_DeletesWork(t *testing.T) {
+	l := test.NewLogger()
+
+	myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
+	testHost := netip.MustParseAddr("10.128.0.42")
+
+	c := config.NewC(l)
+	lh1 := "10.128.0.2"
+	c.Settings["lighthouse"] = map[string]any{
+		"hosts":    []any{lh1},
+		"interval": "1s",
+	}
+
+	c.Settings["listen"] = map[string]any{"port": 4242}
+	c.Settings["static_host_map"] = map[string]any{
+		lh1: []any{"1.1.1.1:4242"},
+	}
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	require.NoError(t, err)
+	lh.ifce = &mockEncWriter{}
+
+	//insert the host
+	am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
+	am.vpnAddrs = []netip.Addr{testHost}
+	am.addrs = []netip.AddrPort{myUdpAddr2}
+	lh.addrMap[testHost] = am
+	am.Rebuild([]netip.Prefix{}) //???
+
+	//test that we actually have the entry:
+	out := lh.Query(testHost)
+	assert.NotNil(t, out)
+	assert.Equal(t, out.vpnAddrs[0], testHost)
+	out.Rebuild([]netip.Prefix{}) //why tho
+	assert.Equal(t, out.addrs[0], myUdpAddr2)
+
+	//now do the delete
+	lh.DeleteVpnAddrs([]netip.Addr{testHost})
+	//verify
+	out = lh.Query(testHost)
+	assert.Nil(t, out)
+}

+ 21 - 24
main.go

@@ -185,6 +185,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	hostMap := NewHostMapFromConfig(l, c)
 	punchy := NewPunchyFromConfig(l, c)
+	connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
 	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
@@ -220,31 +221,26 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 	}
 
-	checkInterval := c.GetInt("timers.connection_alive_interval", 5)
-	pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
-
 	ifConfig := &InterfaceConfig{
-		HostMap:                 hostMap,
-		Inside:                  tun,
-		Outside:                 udpConns[0],
-		pki:                     pki,
-		Firewall:                fw,
-		ServeDns:                serveDns,
-		HandshakeManager:        handshakeManager,
-		lightHouse:              lightHouse,
-		checkInterval:           time.Second * time.Duration(checkInterval),
-		pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
-		tryPromoteEvery:         c.GetUint32("counters.try_promote", defaultPromoteEvery),
-		reQueryEvery:            c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
-		reQueryWait:             c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
-		DropLocalBroadcast:      c.GetBool("tun.drop_local_broadcast", false),
-		DropMulticast:           c.GetBool("tun.drop_multicast", false),
-		routines:                routines,
-		MessageMetrics:          messageMetrics,
-		version:                 buildVersion,
-		relayManager:            NewRelayManager(ctx, l, hostMap, c),
-		punchy:                  punchy,
-
+		HostMap:               hostMap,
+		Inside:                tun,
+		Outside:               udpConns[0],
+		pki:                   pki,
+		Firewall:              fw,
+		ServeDns:              serveDns,
+		HandshakeManager:      handshakeManager,
+		connectionManager:     connManager,
+		lightHouse:            lightHouse,
+		tryPromoteEvery:       c.GetUint32("counters.try_promote", defaultPromoteEvery),
+		reQueryEvery:          c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
+		reQueryWait:           c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
+		DropLocalBroadcast:    c.GetBool("tun.drop_local_broadcast", false),
+		DropMulticast:         c.GetBool("tun.drop_multicast", false),
+		routines:              routines,
+		MessageMetrics:        messageMetrics,
+		version:               buildVersion,
+		relayManager:          NewRelayManager(ctx, l, hostMap, c),
+		punchy:                punchy,
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		l:                     l,
 	}
@@ -296,5 +292,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		statsStart,
 		dnsStart,
 		lightHouse.StartUpdateWorker,
+		connManager.Start,
 	}, nil
 }

+ 16 - 20
outside.go

@@ -31,8 +31,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	//l.Error("in packet ", header, packet[HeaderLen:])
 	if ip.IsValid() {
-		_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
-		if found {
+		if f.myVpnNetworksTable.Contains(ip.Addr()) {
 			if f.l.Level >= logrus.DebugLevel {
 				f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
 			}
@@ -82,7 +81,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			// Pull the Roaming parts up here, and return in all call paths.
 			f.handleHostRoaming(hostinfo, ip)
 			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
-			f.connectionManager.In(hostinfo.localIndexId)
+			f.connectionManager.In(hostinfo)
 			f.connectionManager.RelayUsed(h.RemoteIndex)
 
 			relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
@@ -214,7 +213,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	f.handleHostRoaming(hostinfo, ip)
 
-	f.connectionManager.In(hostinfo.localIndexId)
+	f.connectionManager.In(hostinfo)
 }
 
 // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
@@ -255,16 +254,18 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
 
 }
 
+// handleEncrypted returns true if a packet should be processed, false otherwise
 func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
-	// If connectionstate exists and the replay protector allows, process packet
-	// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
-	if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
+	// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
+	if ci == nil {
 		if addr.IsValid() {
 			f.maybeSendRecvError(addr, h.RemoteIndex)
-			return false
-		} else {
-			return false
 		}
+		return false
+	}
+	// If the window check fails, refuse to process the packet, but don't send a recv error
+	if !ci.window.Check(f.l, h.MessageCounter) {
+		return false
 	}
 
 	return true
@@ -313,12 +314,11 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
 	offset := ipv6.HeaderLen // Start at the end of the ipv6 header
 	next := 0
 	for {
-		if dataLen < offset {
+		if protoAt >= dataLen {
 			break
 		}
-
 		proto := layers.IPProtocol(data[protoAt])
-		//fmt.Println(proto, protoAt)
+
 		switch proto {
 		case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
 			fp.Protocol = uint8(proto)
@@ -366,7 +366,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
 
 		case layers.IPProtocolAH:
 			// Auth headers, used by IPSec, have a different meaning for header length
-			if dataLen < offset+1 {
+			if dataLen <= offset+1 {
 				break
 			}
 
@@ -374,7 +374,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
 
 		default:
 			// Normal ipv6 header length processing
-			if dataLen < offset+1 {
+			if dataLen <= offset+1 {
 				break
 			}
 
@@ -500,7 +500,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 		return false
 	}
 
-	f.connectionManager.In(hostinfo.localIndexId)
+	f.connectionManager.In(hostinfo)
 	_, err = f.readers[q].Write(out)
 	if err != nil {
 		f.l.WithError(err).Error("Failed to write to tun")
@@ -539,10 +539,6 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
 		return
 	}
 
-	if !hostinfo.RecvErrorExceeded() {
-		return
-	}
-
 	if hostinfo.remote.IsValid() && hostinfo.remote != addr {
 		f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
 		return

+ 43 - 0
outside_test.go

@@ -117,6 +117,45 @@ func Test_newPacket_v6(t *testing.T) {
 	err = newPacket(buffer.Bytes(), true, p)
 	require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
 
+	// A v6 packet with a hop-by-hop extension
+	// ICMPv6 Payload (Echo Request)
+	icmpLayer := layers.ICMPv6{
+		TypeCode: layers.ICMPv6TypeEchoRequest,
+	}
+	// Hop-by-Hop Extension Header
+	hopOption := layers.IPv6HopByHopOption{}
+	hopOption.OptionData = []byte{0, 0, 0, 0}
+	hopByHop := layers.IPv6HopByHop{}
+	hopByHop.Options = append(hopByHop.Options, &hopOption)
+
+	ip = layers.IPv6{
+		Version:    6,
+		HopLimit:   128,
+		NextHeader: layers.IPProtocolIPv6Destination,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	buffer.Clear()
+	err = gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{
+		ComputeChecksums: false,
+		FixLengths:       true,
+	}, &ip, &hopByHop, &icmpLayer)
+	if err != nil {
+		panic(err)
+	}
+	// Ensure buffer length checks during parsing with the next 2 tests.
+
+	// A full IPv6 header and 1 byte in the first extension, but missing
+	// the length byte.
+	err = newPacket(buffer.Bytes()[:41], true, p)
+	require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+
+	// A full IPv6 header plus 1 full extension, but only 1 byte of the
+	// next layer, missing length byte
+	err = newPacket(buffer.Bytes()[:49], true, p)
+	require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+
 	// A good ICMP packet
 	ip = layers.IPv6{
 		Version:    6,
@@ -288,6 +327,10 @@ func Test_newPacket_v6(t *testing.T) {
 	assert.Equal(t, uint16(22), p.LocalPort)
 	assert.False(t, p.Fragment)
 
+	// Ensure buffer bounds checking during processing
+	err = newPacket(b[:41], true, p)
+	require.ErrorIs(t, err, ErrIPv6PacketTooShort)
+
 	// Invalid AH header
 	b = buffer.Bytes()
 	err = newPacket(b, true, p)

+ 37 - 14
overlay/tun_linux.go

@@ -34,10 +34,11 @@ type tun struct {
 	deviceIndex int
 	ioctlFd     uintptr
 
-	Routes          atomic.Pointer[[]Route]
-	routeTree       atomic.Pointer[bart.Table[routing.Gateways]]
-	routeChan       chan struct{}
-	useSystemRoutes bool
+	Routes                    atomic.Pointer[[]Route]
+	routeTree                 atomic.Pointer[bart.Table[routing.Gateways]]
+	routeChan                 chan struct{}
+	useSystemRoutes           bool
+	useSystemRoutesBufferSize int
 
 	l *logrus.Logger
 }
@@ -124,12 +125,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
 
 func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
 	t := &tun{
-		ReadWriteCloser: file,
-		fd:              int(file.Fd()),
-		vpnNetworks:     vpnNetworks,
-		TXQueueLen:      c.GetInt("tun.tx_queue", 500),
-		useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
-		l:               l,
+		ReadWriteCloser:           file,
+		fd:                        int(file.Fd()),
+		vpnNetworks:               vpnNetworks,
+		TXQueueLen:                c.GetInt("tun.tx_queue", 500),
+		useSystemRoutes:           c.GetBool("tun.use_system_route_table", false),
+		useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
+		l:                         l,
 	}
 
 	err := t.reload(c, true)
@@ -291,7 +293,6 @@ func (t *tun) addIPs(link netlink.Link) error {
 
 	//add all new addresses
 	for i := range newAddrs {
-		//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
 		//AddrReplace still adds new IPs, but if their properties change it will change them as well
 		if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
 			return err
@@ -359,6 +360,11 @@ func (t *tun) Activate() error {
 		t.l.WithError(err).Error("Failed to set tun tx queue length")
 	}
 
+	const modeNone = 1
+	if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
+		t.l.WithError(err).Warn("Failed to disable link local address generation")
+	}
+
 	if err = t.addIPs(link); err != nil {
 		return err
 	}
@@ -531,7 +537,13 @@ func (t *tun) watchRoutes() {
 	rch := make(chan netlink.RouteUpdate)
 	doneChan := make(chan struct{})
 
-	if err := netlink.RouteSubscribe(rch, doneChan); err != nil {
+	netlinkOptions := netlink.RouteSubscribeOptions{
+		ReceiveBufferSize:      t.useSystemRoutesBufferSize,
+		ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
+		ErrorCallback:          func(e error) { t.l.WithError(e).Errorf("netlink error") },
+	}
+
+	if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil {
 		t.l.WithError(err).Errorf("failed to subscribe to system route changes")
 		return
 	}
@@ -541,8 +553,14 @@ func (t *tun) watchRoutes() {
 	go func() {
 		for {
 			select {
-			case r := <-rch:
-				t.updateRoutes(r)
+			case r, ok := <-rch:
+				if ok {
+					t.updateRoutes(r)
+				} else {
+					// may be should do something here as
+					// netlink stops sending updates
+					return
+				}
 			case <-doneChan:
 				// netlink.RouteSubscriber will close the rch for us
 				return
@@ -624,6 +642,11 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 		return
 	}
 
+	if r.Dst == nil {
+		t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
+		return
+	}
+
 	dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
 	if !ok {
 		t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")

+ 32 - 31
pki.go

@@ -33,16 +33,16 @@ type CertState struct {
 	v2Cert           cert.Certificate
 	v2HandshakeBytes []byte
 
-	defaultVersion cert.Version
-	privateKey     []byte
-	pkcs11Backed   bool
-	cipher         string
+	initiatingVersion cert.Version
+	privateKey        []byte
+	pkcs11Backed      bool
+	cipher            string
 
 	myVpnNetworks            []netip.Prefix
-	myVpnNetworksTable       *bart.Table[struct{}]
+	myVpnNetworksTable       *bart.Lite
 	myVpnAddrs               []netip.Addr
-	myVpnAddrsTable          *bart.Table[struct{}]
-	myVpnBroadcastAddrsTable *bart.Table[struct{}]
+	myVpnAddrsTable          *bart.Lite
+	myVpnBroadcastAddrsTable *bart.Lite
 }
 
 func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
@@ -173,7 +173,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
 
 	p.cs.Store(newState)
 
-	//TODO: CERT-V2 newState needs a stringer that does json
 	if initial {
 		p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
 	} else {
@@ -194,7 +193,7 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
 }
 
 func (cs *CertState) GetDefaultCertificate() cert.Certificate {
-	c := cs.getCertificate(cs.defaultVersion)
+	c := cs.getCertificate(cs.initiatingVersion)
 	if c == nil {
 		panic("No default certificate found")
 	}
@@ -317,37 +316,37 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
 		return nil, errors.New("no certificates found in pki.cert")
 	}
 
-	useDefaultVersion := uint32(1)
+	useInitiatingVersion := uint32(1)
 	if v1 == nil {
 		// The only condition that requires v2 as the default is if only a v2 certificate is present
 		// We do this to avoid having to configure it specifically in the config file
-		useDefaultVersion = 2
+		useInitiatingVersion = 2
 	}
 
-	rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
-	var defaultVersion cert.Version
-	switch rawDefaultVersion {
+	rawInitiatingVersion := c.GetUint32("pki.initiating_version", useInitiatingVersion)
+	var initiatingVersion cert.Version
+	switch rawInitiatingVersion {
 	case 1:
 		if v1 == nil {
-			return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
+			return nil, fmt.Errorf("can not use pki.initiating_version 1 without a v1 certificate in pki.cert")
 		}
-		defaultVersion = cert.Version1
+		initiatingVersion = cert.Version1
 	case 2:
-		defaultVersion = cert.Version2
+		initiatingVersion = cert.Version2
 	default:
-		return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
+		return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion)
 	}
 
-	return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
+	return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey)
 }
 
 func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
 	cs := CertState{
 		privateKey:               privateKey,
 		pkcs11Backed:             pkcs11backed,
-		myVpnNetworksTable:       new(bart.Table[struct{}]),
-		myVpnAddrsTable:          new(bart.Table[struct{}]),
-		myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
+		myVpnNetworksTable:       new(bart.Lite),
+		myVpnAddrsTable:          new(bart.Lite),
+		myVpnBroadcastAddrsTable: new(bart.Lite),
 	}
 
 	if v1 != nil && v2 != nil {
@@ -359,9 +358,11 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
 			return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
 		}
 
-		//TODO: CERT-V2 make sure v2 has v1s address
+		if v1.Networks()[0] != v2.Networks()[0] {
+			return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
+		}
 
-		cs.defaultVersion = dv
+		cs.initiatingVersion = dv
 	}
 
 	if v1 != nil {
@@ -380,8 +381,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
 		cs.v1Cert = v1
 		cs.v1HandshakeBytes = v1hs
 
-		if cs.defaultVersion == 0 {
-			cs.defaultVersion = cert.Version1
+		if cs.initiatingVersion == 0 {
+			cs.initiatingVersion = cert.Version1
 		}
 	}
 
@@ -401,8 +402,8 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
 		cs.v2Cert = v2
 		cs.v2HandshakeBytes = v2hs
 
-		if cs.defaultVersion == 0 {
-			cs.defaultVersion = cert.Version2
+		if cs.initiatingVersion == 0 {
+			cs.initiatingVersion = cert.Version2
 		}
 	}
 
@@ -415,16 +416,16 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
 
 	for _, network := range crt.Networks() {
 		cs.myVpnNetworks = append(cs.myVpnNetworks, network)
-		cs.myVpnNetworksTable.Insert(network, struct{}{})
+		cs.myVpnNetworksTable.Insert(network)
 
 		cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
-		cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
+		cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()))
 
 		if network.Addr().Is4() {
 			addr := network.Masked().Addr().As4()
 			mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
 			binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
-			cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
+			cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()))
 		}
 	}
 

+ 2 - 4
relay_manager.go

@@ -241,15 +241,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
 	logMsg.Info("handleCreateRelayRequest")
 	// Is the source of the relay me? This should never happen, but did happen due to
 	// an issue migrating relays over to newly re-handshaked host info objects.
-	_, found := f.myVpnAddrsTable.Lookup(from)
-	if found {
+	if f.myVpnAddrsTable.Contains(from) {
 		logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
 		return
 	}
 
 	// Is the target of the relay me?
-	_, found = f.myVpnAddrsTable.Lookup(target)
-	if found {
+	if f.myVpnAddrsTable.Contains(target) {
 		existingRelay, ok := h.relayState.QueryRelayForByIp(from)
 		if ok {
 			switch existingRelay.State {

+ 16 - 5
remote_list.go

@@ -190,7 +190,7 @@ type RemoteList struct {
 	// The full list of vpn addresses assigned to this host
 	vpnAddrs []netip.Addr
 
-	// A deduplicated set of addresses. Any accessor should lock beforehand.
+	// A deduplicated set of underlay addresses. Any accessor should lock beforehand.
 	addrs []netip.AddrPort
 
 	// A set of relay addresses. VpnIp addresses that the remote identified as relays.
@@ -201,8 +201,10 @@ type RemoteList struct {
 	// For learned addresses, this is the vpnIp that sent the packet
 	cache map[netip.Addr]*cache
 
-	hr        *hostnamesResults
-	shouldAdd func(netip.Addr) bool
+	hr *hostnamesResults
+
+	// shouldAdd is a nillable function that decides if x should be added to addrs.
+	shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
 
 	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
 	// They should not be tried again during a handshake
@@ -213,7 +215,7 @@ type RemoteList struct {
 }
 
 // NewRemoteList creates a new empty RemoteList
-func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
+func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
 	r := &RemoteList{
 		vpnAddrs:  make([]netip.Addr, len(vpnAddrs)),
 		addrs:     make([]netip.AddrPort, 0),
@@ -368,6 +370,15 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
 	return c
 }
 
+// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
+func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
+	r.Lock()
+	r.badRemotes = nil
+	r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
+	copy(r.vpnAddrs, vpnAddrs)
+	r.Unlock()
+}
+
 // ResetBlockedRemotes locks and clears the blocked remotes list
 func (r *RemoteList) ResetBlockedRemotes() {
 	r.Lock()
@@ -577,7 +588,7 @@ func (r *RemoteList) unlockedCollect() {
 
 	dnsAddrs := r.hr.GetAddrs()
 	for _, addr := range dnsAddrs {
-		if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
+		if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
 			if !r.unlockedIsBad(addr) {
 				addrs = append(addrs, addr)
 			}

+ 1 - 11
service/service.go

@@ -9,13 +9,10 @@ import (
 	"math"
 	"net"
 	"net/netip"
-	"os"
 	"strings"
 	"sync"
 
-	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
-	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/overlay"
 	"golang.org/x/sync/errgroup"
 	"gvisor.dev/gvisor/pkg/buffer"
@@ -46,14 +43,7 @@ type Service struct {
 	}
 }
 
-func New(config *config.C) (*Service, error) {
-	logger := logrus.New()
-	logger.Out = os.Stdout
-
-	control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
-	if err != nil {
-		return nil, err
-	}
+func New(control *nebula.Control) (*Service, error) {
 	control.Start()
 
 	ctx := control.Context()

+ 13 - 1
service/service_test.go

@@ -5,13 +5,17 @@ import (
 	"context"
 	"errors"
 	"net/netip"
+	"os"
 	"testing"
 	"time"
 
 	"dario.cat/mergo"
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert_test"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/overlay"
 	"golang.org/x/sync/errgroup"
 	"gopkg.in/yaml.v3"
 )
@@ -71,7 +75,15 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n
 		panic(err)
 	}
 
-	s, err := New(&c)
+	logger := logrus.New()
+	logger.Out = os.Stdout
+
+	control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
+	if err != nil {
+		panic(err)
+	}
+
+	s, err := New(control)
 	if err != nil {
 		panic(err)
 	}

+ 5 - 0
udp/errors.go

@@ -0,0 +1,5 @@
+package udp
+
+import "errors"
+
+var ErrInvalidIPv6RemoteForSocket = errors.New("listener is IPv4, but writing to IPv6 remote")

+ 153 - 11
udp/udp_darwin.go

@@ -3,20 +3,62 @@
 
 package udp
 
-// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
-
 import (
+	"context"
+	"encoding/binary"
+	"errors"
 	"fmt"
 	"net"
 	"net/netip"
 	"syscall"
+	"unsafe"
 
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
 	"golang.org/x/sys/unix"
 )
 
+type StdConn struct {
+	*net.UDPConn
+	isV4  bool
+	sysFd uintptr
+	l     *logrus.Logger
+}
+
+var _ Conn = &StdConn{}
+
 func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
-	return NewGenericListener(l, ip, port, multi, batch)
+	lc := NewListenConfig(multi)
+	pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
+	if err != nil {
+		return nil, err
+	}
+
+	if uc, ok := pc.(*net.UDPConn); ok {
+		c := &StdConn{UDPConn: uc, l: l}
+
+		rc, err := uc.SyscallConn()
+		if err != nil {
+			return nil, fmt.Errorf("failed to open udp socket: %w", err)
+		}
+
+		err = rc.Control(func(fd uintptr) {
+			c.sysFd = fd
+		})
+		if err != nil {
+			return nil, fmt.Errorf("failed to get udp fd: %w", err)
+		}
+
+		la, err := c.LocalAddr()
+		if err != nil {
+			return nil, err
+		}
+		c.isV4 = la.Addr().Is4()
+
+		return c, nil
+	}
+
+	return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc)
 }
 
 func NewListenConfig(multi bool) net.ListenConfig {
@@ -43,16 +85,116 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 }
 
-func (u *GenericConn) Rebind() error {
-	rc, err := u.UDPConn.SyscallConn()
-	if err != nil {
-		return err
+//go:linkname sendto golang.org/x/sys/unix.sendto
+//go:noescape
+func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error)
+
+func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
+	var sa unsafe.Pointer
+	var addrLen int32
+
+	if u.isV4 {
+		if ap.Addr().Is6() {
+			return ErrInvalidIPv6RemoteForSocket
+		}
+
+		var rsa unix.RawSockaddrInet6
+		rsa.Family = unix.AF_INET6
+		rsa.Addr = ap.Addr().As16()
+		binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
+		sa = unsafe.Pointer(&rsa)
+		addrLen = syscall.SizeofSockaddrInet4
+	} else {
+		var rsa unix.RawSockaddrInet6
+		rsa.Family = unix.AF_INET6
+		rsa.Addr = ap.Addr().As16()
+		binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
+		sa = unsafe.Pointer(&rsa)
+		addrLen = syscall.SizeofSockaddrInet6
+	}
+
+	// Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves
+	// See https://github.com/golang/go/issues/73919
+	for {
+		//_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen)
+		err := sendto(int(u.sysFd), b, 0, sa, addrLen)
+		if err == nil {
+			// Written, get out before the error handling
+			return nil
+		}
+
+		if errors.Is(err, syscall.EINTR) {
+			// Write was interrupted, retry
+			continue
+		}
+
+		if errors.Is(err, syscall.EAGAIN) {
+			return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK}
+		}
+
+		if errors.Is(err, syscall.EBADF) {
+			return net.ErrClosed
+		}
+
+		return &net.OpError{Op: "sendto", Err: err}
 	}
+}
+
+func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
+	a := u.UDPConn.LocalAddr()
+
+	switch v := a.(type) {
+	case *net.UDPAddr:
+		addr, ok := netip.AddrFromSlice(v.IP)
+		if !ok {
+			return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
+		}
+		return netip.AddrPortFrom(addr, uint16(v.Port)), nil
+
+	default:
+		return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
+	}
+}
+
+func (u *StdConn) ReloadConfig(c *config.C) {
+	// TODO
+}
 
-	return rc.Control(func(fd uintptr) {
-		err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
+func NewUDPStatsEmitter(udpConns []Conn) func() {
+	// No UDP stats for non-linux
+	return func() {}
+}
+
+func (u *StdConn) ListenOut(r EncReader) {
+	buffer := make([]byte, MTU)
+
+	for {
+		// Just read one packet at a time
+		n, rua, err := u.ReadFromUDPAddrPort(buffer)
 		if err != nil {
-			u.l.WithError(err).Error("Failed to rebind udp socket")
+			if errors.Is(err, net.ErrClosed) {
+				u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
+				return
+			}
+
+			u.l.WithError(err).Error("unexpected udp socket receive error")
 		}
-	})
+
+		r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
+	}
+}
+
+func (u *StdConn) Rebind() error {
+	var err error
+	if u.isV4 {
+		err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0)
+	} else {
+		err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0)
+	}
+
+	if err != nil {
+		u.l.WithError(err).Error("Failed to rebind udp socket")
+	}
+
+	return nil
 }

+ 2 - 1
udp/udp_generic.go

@@ -1,6 +1,7 @@
-//go:build (!linux || android) && !e2e_testing
+//go:build (!linux || android) && !e2e_testing && !darwin
 // +build !linux android
 // +build !e2e_testing
+// +build !darwin
 
 // udp_generic implements the nebula UDP interface in pure Go stdlib. This
 // means it can be used on platforms like Darwin and Windows.

+ 1 - 1
udp/udp_linux.go

@@ -221,7 +221,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
 
 func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
 	if !ip.Addr().Is4() {
-		return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
+		return ErrInvalidIPv6RemoteForSocket
 	}
 
 	var rsa unix.RawSockaddrInet4

+ 25 - 2
udp/udp_rio_windows.go

@@ -92,6 +92,25 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
 	// Enable v4 for this socket
 	syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
 
+	// Disable reporting of PORT_UNREACHABLE and NET_UNREACHABLE errors from the UDP socket receive call.
+	// These errors are returned on Windows during UDP receives based on the receipt of ICMP packets. Disable
+	// the UDP receive error returns with these ioctl calls.
+	ret := uint32(0)
+	flag := uint32(0)
+	size := uint32(unsafe.Sizeof(flag))
+	err = syscall.WSAIoctl(syscall.Handle(u.sock), syscall.SIO_UDP_CONNRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0)
+	if err != nil {
+		return err
+	}
+	ret = 0
+	flag = 0
+	size = uint32(unsafe.Sizeof(flag))
+	SIO_UDP_NETRESET := uint32(syscall.IOC_IN | syscall.IOC_VENDOR | 15)
+	err = syscall.WSAIoctl(syscall.Handle(u.sock), SIO_UDP_NETRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0)
+	if err != nil {
+		return err
+	}
+
 	err = u.rx.Open()
 	if err != nil {
 		return err
@@ -122,8 +141,12 @@ func (u *RIOConn) ListenOut(r EncReader) {
 		// Just read one packet at a time
 		n, rua, err := u.receive(buffer)
 		if err != nil {
-			u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
-			return
+			if errors.Is(err, net.ErrClosed) {
+				u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
+				return
+			}
+			u.l.WithError(err).Error("unexpected udp socket receive error")
+			continue
 		}
 
 		r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])