Pārlūkot izejas kodu

Merge branch 'master' into multiport

Wade Simmons 2 gadi atpakaļ
vecāks
revīzija
0e593ad582
68 mainītis faili ar 2560 papildinājumiem un 668 dzēšanām
  1. 7 8
      .github/workflows/gofmt.yml
  2. 6 6
      .github/workflows/release.yml
  3. 12 4
      .github/workflows/smoke.yml
  4. 1 1
      .github/workflows/smoke/build.sh
  5. 38 8
      .github/workflows/test.yml
  6. 15 1
      Makefile
  7. 11 0
      README.md
  8. 12 0
      SECURITY.md
  9. 8 0
      boring.go
  10. 2 2
      cert.go
  11. 246 36
      cert/cert.go
  12. 111 48
      cert/cert.pb.go
  13. 7 0
      cert/cert.proto
  14. 292 43
      cert/cert_test.go
  15. 9 4
      cert/errors.go
  16. 25 3
      cidr/tree4.go
  17. 14 0
      cidr/tree4_test.go
  18. 28 6
      cmd/nebula-cert/ca.go
  19. 5 1
      cmd/nebula-cert/ca_test.go
  20. 17 3
      cmd/nebula-cert/keygen.go
  21. 2 0
      cmd/nebula-cert/keygen_test.go
  22. 2 2
      cmd/nebula-cert/print_test.go
  23. 35 9
      cmd/nebula-cert/sign.go
  24. 1 1
      cmd/nebula-cert/sign_test.go
  25. 3 3
      cmd/nebula-cert/verify_test.go
  26. 244 33
      connection_manager.go
  27. 6 6
      connection_manager_test.go
  28. 14 3
      connection_state.go
  29. 1 1
      control_test.go
  30. 14 0
      control_tester.go
  31. 385 22
      e2e/handshakes_test.go
  32. 6 6
      e2e/helpers_test.go
  33. 5 5
      e2e/router/router.go
  34. 6 1
      examples/config.yml
  35. 65 36
      firewall.go
  36. 102 45
      firewall_test.go
  37. 14 13
      go.mod
  38. 27 27
      go.sum
  39. 5 4
      handshake_manager.go
  40. 5 1
      handshake_manager_test.go
  41. 34 41
      hostmap.go
  42. 12 24
      inside.go
  43. 2 0
      interface.go
  44. 162 33
      lighthouse.go
  45. 25 7
      lighthouse_test.go
  46. 1 1
      main.go
  47. 1 0
      message_metrics.go
  48. 83 80
      nebula.pb.go
  49. 1 0
      nebula.proto
  50. 7 0
      noiseutil/boring_test.go
  51. 68 0
      noiseutil/nist.go
  52. 7 8
      noiseutil/notboring_test.go
  53. 6 0
      notboring.go
  54. 12 13
      outside.go
  55. 2 0
      overlay/tun.go
  56. 2 2
      overlay/tun_android.go
  57. 2 2
      overlay/tun_darwin.go
  58. 2 2
      overlay/tun_freebsd.go
  59. 2 2
      overlay/tun_ios.go
  60. 109 15
      overlay/tun_linux.go
  61. 7 7
      overlay/tun_linux_test.go
  62. 2 2
      overlay/tun_tester.go
  63. 2 2
      overlay/tun_windows.go
  64. 1 1
      punchy.go
  65. 27 25
      relay_manager.go
  66. 166 4
      remote_list.go
  67. 3 3
      remote_list_test.go
  68. 4 2
      stats.go

+ 7 - 8
.github/workflows/gofmt.yml

@@ -14,10 +14,10 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - name: Set up Go 1.19
+    - name: Set up Go 1.20
       uses: actions/setup-go@v2
       with:
-        go-version: 1.19
+        go-version: "1.20"
       id: go
 
     - name: Check out code into the Go module directory
@@ -26,19 +26,18 @@ jobs:
     - uses: actions/cache@v2
       with:
         path: ~/go/pkg/mod
-        key: ${{ runner.os }}-gofmt1.19-${{ hashFiles('**/go.sum') }}
+        key: ${{ runner.os }}-gofmt1.20-${{ hashFiles('**/go.sum') }}
         restore-keys: |
-          ${{ runner.os }}-gofmt1.19-
+          ${{ runner.os }}-gofmt1.20-
 
     - name: Install goimports
       run: |
-        go get golang.org/x/tools/cmd/goimports
-        go build golang.org/x/tools/cmd/goimports
+        go install golang.org/x/tools/cmd/goimports@latest
 
     - name: gofmt
       run: |
-        if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -l)" ]
+        if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ]
         then
-          find . -iname '*.go' | grep -v '\.pb\.go$' | xargs ./goimports -d
+          find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d
           exit 1
         fi

+ 6 - 6
.github/workflows/release.yml

@@ -10,10 +10,10 @@ jobs:
     name: Build Linux All
     runs-on: ubuntu-latest
     steps:
-      - name: Set up Go 1.19
+      - name: Set up Go 1.20
         uses: actions/setup-go@v2
         with:
-          go-version: 1.19
+          go-version: "1.20"
 
       - name: Checkout code
         uses: actions/checkout@v2
@@ -34,10 +34,10 @@ jobs:
     name: Build Windows
     runs-on: windows-latest
     steps:
-      - name: Set up Go 1.19
+      - name: Set up Go 1.20
         uses: actions/setup-go@v2
         with:
-          go-version: 1.19
+          go-version: "1.20"
 
       - name: Checkout code
         uses: actions/checkout@v2
@@ -68,10 +68,10 @@ jobs:
       HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
     runs-on: macos-11
     steps:
-      - name: Set up Go 1.19
+      - name: Set up Go 1.20
         uses: actions/setup-go@v2
         with:
-          go-version: 1.19
+          go-version: "1.20"
 
       - name: Checkout code
         uses: actions/checkout@v2

+ 12 - 4
.github/workflows/smoke.yml

@@ -18,10 +18,10 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - name: Set up Go 1.19
+    - name: Set up Go 1.20
       uses: actions/setup-go@v2
       with:
-        go-version: 1.19
+        go-version: "1.20"
       id: go
 
     - name: Check out code into the Go module directory
@@ -30,9 +30,9 @@ jobs:
     - uses: actions/cache@v2
       with:
         path: ~/go/pkg/mod
-        key: ${{ runner.os }}-go1.19-${{ hashFiles('**/go.sum') }}
+        key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }}
         restore-keys: |
-          ${{ runner.os }}-go1.19-
+          ${{ runner.os }}-go1.20-
 
     - name: build
       run: make bin-docker
@@ -53,6 +53,14 @@ jobs:
       working-directory: ./.github/workflows/smoke
       run: ./smoke-relay.sh
 
+    - name: setup docker image for P256
+      working-directory: ./.github/workflows/smoke
+      run: NAME="smoke-p256" CURVE=P256 ./build.sh
+
+    - name: run smoke-p256
+      working-directory: ./.github/workflows/smoke
+      run: NAME="smoke-p256" ./smoke.sh
+
     - name: setup docker image for multiport
       working-directory: ./.github/workflows/smoke
       run: NAME="smoke-multiport" MULTIPORT_TX=true MULTIPORT_RX=true MULTIPORT_HANDSHAKE=true ./build.sh

+ 1 - 1
.github/workflows/smoke/build.sh

@@ -29,7 +29,7 @@ mkdir ./build
         OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
         ../genconfig.sh >host4.yml
 
-    ../../../../nebula-cert ca -name "Smoke Test"
+    ../../../../nebula-cert ca -curve "${CURVE:-25519}" -name "Smoke Test"
     ../../../../nebula-cert sign -name "lighthouse1" -groups "lighthouse,lighthouse1" -ip "192.168.100.1/24"
     ../../../../nebula-cert sign -name "host2" -groups "host,host2" -ip "192.168.100.2/24"
     ../../../../nebula-cert sign -name "host3" -groups "host,host3" -ip "192.168.100.3/24"

+ 38 - 8
.github/workflows/test.yml

@@ -18,10 +18,10 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - name: Set up Go 1.19
+    - name: Set up Go 1.20
       uses: actions/setup-go@v2
       with:
-        go-version: 1.19
+        go-version: "1.20"
       id: go
 
     - name: Check out code into the Go module directory
@@ -30,9 +30,9 @@ jobs:
     - uses: actions/cache@v2
       with:
         path: ~/go/pkg/mod
-        key: ${{ runner.os }}-go1.19-${{ hashFiles('**/go.sum') }}
+        key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }}
         restore-keys: |
-          ${{ runner.os }}-go1.19-
+          ${{ runner.os }}-go1.20-
 
     - name: Build
       run: make all
@@ -52,6 +52,36 @@ jobs:
         path: e2e/mermaid/
         if-no-files-found: warn
 
+  test-linux-boringcrypto:
+    name: Build and test on linux with boringcrypto
+    runs-on: ubuntu-latest
+    steps:
+
+    - name: Set up Go 1.20
+      uses: actions/setup-go@v2
+      with:
+        go-version: "1.20"
+      id: go
+
+    - name: Check out code into the Go module directory
+      uses: actions/checkout@v2
+
+    - uses: actions/cache@v2
+      with:
+        path: ~/go/pkg/mod
+        key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }}
+        restore-keys: |
+          ${{ runner.os }}-go1.20-
+
+    - name: Build
+      run: make bin-boringcrypto
+
+    - name: Test
+      run: make test-boringcrypto
+
+    - name: End 2 end
+      run: make e2evv GOEXPERIMENT=boringcrypto CGO_ENABLED=1
+
   test:
     name: Build and test on ${{ matrix.os }}
     runs-on: ${{ matrix.os }}
@@ -60,10 +90,10 @@ jobs:
         os: [windows-latest, macos-11]
     steps:
 
-    - name: Set up Go 1.19
+    - name: Set up Go 1.20
       uses: actions/setup-go@v2
       with:
-        go-version: 1.19
+        go-version: "1.20"
       id: go
 
     - name: Check out code into the Go module directory
@@ -72,9 +102,9 @@ jobs:
     - uses: actions/cache@v2
       with:
         path: ~/go/pkg/mod
-        key: ${{ runner.os }}-go1.19-${{ hashFiles('**/go.sum') }}
+        key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }}
         restore-keys: |
-          ${{ runner.os }}-go1.19-
+          ${{ runner.os }}-go1.20-
 
     - name: Build nebula
       run: go build ./cmd/nebula

+ 15 - 1
Makefile

@@ -1,4 +1,4 @@
-GOMINVERSION = 1.19
+GOMINVERSION = 1.20
 NEBULA_CMD_PATH = "./cmd/nebula"
 GO111MODULE = on
 export GO111MODULE
@@ -77,6 +77,8 @@ release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz)
 
 release-freebsd: build/nebula-freebsd-amd64.tar.gz
 
+release-boringcrypto: build/nebula-linux-$(shell go env GOARCH)-boringcrypto.tar.gz
+
 BUILD_ARGS = -trimpath
 
 bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe
@@ -91,6 +93,9 @@ bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert
 bin-freebsd: build/freebsd-amd64/nebula build/freebsd-amd64/nebula-cert
 	mv $? .
 
+bin-boringcrypto: build/linux-$(shell go env GOARCH)-boringcrypto/nebula build/linux-$(shell go env GOARCH)-boringcrypto/nebula-cert
+	mv $? .
+
 bin:
 	go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula${NEBULA_CMD_SUFFIX} ${NEBULA_CMD_PATH}
 	go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula-cert${NEBULA_CMD_SUFFIX} ./cmd/nebula-cert
@@ -105,6 +110,10 @@ build/linux-mips-%: GOENV += GOMIPS=$(word 3, $(subst -, ,$*))
 # Build an extra small binary for mips-softfloat
 build/linux-mips-softfloat/%: LDFLAGS += -s -w
 
+# boringcrypto
+build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
+build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
+
 build/%/nebula: .FORCE
 	GOOS=$(firstword $(subst -, , $*)) \
 		GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \
@@ -133,6 +142,9 @@ vet:
 test:
 	go test -v ./...
 
+test-boringcrypto:
+	GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -v ./...
+
 test-cov-html:
 	go test -coverprofile=coverage.out
 	go tool cover -html=coverage.out
@@ -170,6 +182,8 @@ bin-docker: bin build/linux-amd64/nebula build/linux-amd64/nebula-cert
 smoke-docker: bin-docker
 	cd .github/workflows/smoke/ && ./build.sh
 	cd .github/workflows/smoke/ && ./smoke.sh
+	cd .github/workflows/smoke/ && NAME="smoke-p256" CURVE="P256" ./build.sh
+	cd .github/workflows/smoke/ && NAME="smoke-p256" ./smoke.sh
 
 smoke-relay-docker: bin-docker
 	cd .github/workflows/smoke/ && ./build-relay.sh

+ 11 - 0
README.md

@@ -118,6 +118,17 @@ To build nebula for a specific platform (ex, Windows):
 
 See the [Makefile](Makefile) for more details on build targets
 
+## Curve P256 and BoringCrypto
+
+The default curve used for cryptographic handshakes and signatures is Curve25519. This is the recommended setting for most users. If your deployment has certain compliance requirements, you have the option of creating your CA using `nebula-cert ca -curve P256` to use NIST Curve P256. The CA will then sign certificates using ECDSA P256, and any hosts using these certificates will use P256 for ECDH handshakes.
+
+In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets:
+
+    make bin-boringcrypto
+    make release-boringcrypto
+
+This is not the recommended default deployment, but may be useful based on your compliance requirements.
+
 ## Credits
 
 Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang.

+ 12 - 0
SECURITY.md

@@ -0,0 +1,12 @@
+Security Policy
+===============
+
+Reporting a Vulnerability
+-------------------------
+
+If you believe you have found a security vulnerability with Nebula, please let
+us know right away. We will investigate all reports and do our best to quickly
+fix valid issues.
+
+You can submit your report on [HackerOne](https://hackerone.com/slack) and our
+security team will respond as soon as possible.

+ 8 - 0
boring.go

@@ -0,0 +1,8 @@
+//go:build boringcrypto
+// +build boringcrypto
+
+package nebula
+
+import "crypto/boring"
+
+var boringEnabled = boring.Enabled

+ 2 - 2
cert.go

@@ -66,7 +66,7 @@ func NewCertStateFromConfig(c *config.C) (*CertState, error) {
 		}
 	}
 
-	rawKey, _, err := cert.UnmarshalX25519PrivateKey(pemPrivateKey)
+	rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey)
 	if err != nil {
 		return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
 	}
@@ -102,7 +102,7 @@ func NewCertStateFromConfig(c *config.C) (*CertState, error) {
 		return nil, fmt.Errorf("no IPs encoded in certificate")
 	}
 
-	if err = nebulaCert.VerifyPrivateKey(rawKey); err != nil {
+	if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
 		return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
 	}
 

+ 246 - 36
cert/cert.go

@@ -2,7 +2,10 @@ package cert
 
 import (
 	"bytes"
-	"crypto"
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
 	"crypto/rand"
 	"crypto/sha256"
 	"encoding/binary"
@@ -12,11 +15,11 @@ import (
 	"errors"
 	"fmt"
 	"math"
+	"math/big"
 	"net"
 	"time"
 
 	"golang.org/x/crypto/curve25519"
-	"golang.org/x/crypto/ed25519"
 	"google.golang.org/protobuf/proto"
 )
 
@@ -29,6 +32,11 @@ const (
 	EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
 	Ed25519PrivateKeyBanner          = "NEBULA ED25519 PRIVATE KEY"
 	Ed25519PublicKeyBanner           = "NEBULA ED25519 PUBLIC KEY"
+
+	P256PrivateKeyBanner               = "NEBULA P256 PRIVATE KEY"
+	P256PublicKeyBanner                = "NEBULA P256 PUBLIC KEY"
+	EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
+	ECDSAP256PrivateKeyBanner          = "NEBULA ECDSA P256 PRIVATE KEY"
 )
 
 type NebulaCertificate struct {
@@ -49,6 +57,8 @@ type NebulaCertificateDetails struct {
 
 	// Map of groups for faster lookup
 	InvertedGroups map[string]struct{}
+
+	Curve Curve
 }
 
 type NebulaEncryptedData struct {
@@ -100,6 +110,7 @@ func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) {
 			PublicKey:      make([]byte, len(rc.Details.PublicKey)),
 			IsCA:           rc.Details.IsCA,
 			InvertedGroups: make(map[string]struct{}),
+			Curve:          rc.Details.Curve,
 		},
 		Signature: make([]byte, len(rc.Signature)),
 	}
@@ -150,6 +161,28 @@ func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, er
 	return nc, r, err
 }
 
+func MarshalPrivateKey(curve Curve, b []byte) []byte {
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b})
+	default:
+		return nil
+	}
+}
+
+func MarshalSigningPrivateKey(curve Curve, b []byte) []byte {
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b})
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b})
+	default:
+		return nil
+	}
+}
+
 // MarshalX25519PrivateKey is a simple helper to PEM encode an X25519 private key
 func MarshalX25519PrivateKey(b []byte) []byte {
 	return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
@@ -160,8 +193,58 @@ func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte {
 	return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key})
 }
 
-// EncryptAndMarshalX25519PrivateKey is a simple helper to encrypt and PEM encode an X25519 private key
-func EncryptAndMarshalEd25519PrivateKey(b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) {
+func UnmarshalPrivateKey(b []byte) ([]byte, []byte, Curve, error) {
+	k, r := pem.Decode(b)
+	if k == nil {
+		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+	var expectedLen int
+	var curve Curve
+	switch k.Type {
+	case X25519PrivateKeyBanner:
+		expectedLen = 32
+		curve = Curve_CURVE25519
+	case P256PrivateKeyBanner:
+		expectedLen = 32
+		curve = Curve_P256
+	default:
+		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula private key banner")
+	}
+	if len(k.Bytes) != expectedLen {
+		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve)
+	}
+	return k.Bytes, r, curve, nil
+}
+
+func UnmarshalSigningPrivateKey(b []byte) ([]byte, []byte, Curve, error) {
+	k, r := pem.Decode(b)
+	if k == nil {
+		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+	var curve Curve
+	switch k.Type {
+	case EncryptedEd25519PrivateKeyBanner:
+		return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted
+	case EncryptedECDSAP256PrivateKeyBanner:
+		return nil, nil, Curve_P256, ErrPrivateKeyEncrypted
+	case Ed25519PrivateKeyBanner:
+		curve = Curve_CURVE25519
+		if len(k.Bytes) != ed25519.PrivateKeySize {
+			return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize)
+		}
+	case ECDSAP256PrivateKeyBanner:
+		curve = Curve_P256
+		if len(k.Bytes) != 32 {
+			return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
+		}
+	default:
+		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula Ed25519/ECDSA private key banner")
+	}
+	return k.Bytes, r, curve, nil
+}
+
+// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key
+func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) {
 	ciphertext, err := aes256Encrypt(passphrase, kdfParams, b)
 	if err != nil {
 		return nil, err
@@ -181,7 +264,14 @@ func EncryptAndMarshalEd25519PrivateKey(b []byte, passphrase []byte, kdfParams *
 		Ciphertext: ciphertext,
 	})
 
-	return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil
+	default:
+		return nil, fmt.Errorf("invalid curve: %v", curve)
+	}
 }
 
 // UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b
@@ -282,21 +372,28 @@ func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parame
 
 }
 
-// DecryptAndUnmarshalEd25519PrivateKey will try to pem decode and decrypt an Ed25519 private key with
+// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with
 // the given passphrase, returning any other bytes b or an error on failure
-func DecryptAndUnmarshalEd25519PrivateKey(passphrase, b []byte) (ed25519.PrivateKey, []byte, error) {
+func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) {
+	var curve Curve
+
 	k, r := pem.Decode(b)
 	if k == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
+		return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
 	}
 
-	if k.Type != EncryptedEd25519PrivateKeyBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519 private key banner")
+	switch k.Type {
+	case EncryptedEd25519PrivateKeyBanner:
+		curve = Curve_CURVE25519
+	case EncryptedECDSAP256PrivateKeyBanner:
+		curve = Curve_P256
+	default:
+		return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
 	}
 
 	ned, err := UnmarshalNebulaEncryptedData(k.Bytes)
 	if err != nil {
-		return nil, r, err
+		return curve, nil, r, err
 	}
 
 	var bytes []byte
@@ -304,17 +401,35 @@ func DecryptAndUnmarshalEd25519PrivateKey(passphrase, b []byte) (ed25519.Private
 	case "AES-256-GCM":
 		bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext)
 		if err != nil {
-			return nil, r, err
+			return curve, nil, r, err
 		}
 	default:
-		return nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm)
+		return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm)
 	}
 
-	if len(bytes) != ed25519.PrivateKeySize {
-		return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
+	switch curve {
+	case Curve_CURVE25519:
+		if len(bytes) != ed25519.PrivateKeySize {
+			return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize)
+		}
+	case Curve_P256:
+		if len(bytes) != 32 {
+			return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
+		}
 	}
 
-	return bytes, r, nil
+	return curve, bytes, r, nil
+}
+
+func MarshalPublicKey(curve Curve, b []byte) []byte {
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
+	default:
+		return nil
+	}
 }
 
 // MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key
@@ -327,6 +442,30 @@ func MarshalEd25519PublicKey(key ed25519.PublicKey) []byte {
 	return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: key})
 }
 
+func UnmarshalPublicKey(b []byte) ([]byte, []byte, Curve, error) {
+	k, r := pem.Decode(b)
+	if k == nil {
+		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+	var expectedLen int
+	var curve Curve
+	switch k.Type {
+	case X25519PublicKeyBanner:
+		expectedLen = 32
+		curve = Curve_CURVE25519
+	case P256PublicKeyBanner:
+		// Uncompressed
+		expectedLen = 65
+		curve = Curve_P256
+	default:
+		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula public key banner")
+	}
+	if len(k.Bytes) != expectedLen {
+		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve)
+	}
+	return k.Bytes, r, curve, nil
+}
+
 // UnmarshalX25519PublicKey will try to pem decode an X25519 public key, returning any other bytes b
 // or an error on failure
 func UnmarshalX25519PublicKey(b []byte) ([]byte, []byte, error) {
@@ -362,27 +501,65 @@ func UnmarshalEd25519PublicKey(b []byte) (ed25519.PublicKey, []byte, error) {
 }
 
 // Sign signs a nebula cert with the provided private key
-func (nc *NebulaCertificate) Sign(key ed25519.PrivateKey) error {
+func (nc *NebulaCertificate) Sign(curve Curve, key []byte) error {
+	if curve != nc.Details.Curve {
+		return fmt.Errorf("curve in cert and private key supplied don't match")
+	}
+
 	b, err := proto.Marshal(nc.getRawDetails())
 	if err != nil {
 		return err
 	}
 
-	sig, err := key.Sign(rand.Reader, b, crypto.Hash(0))
-	if err != nil {
-		return err
+	var sig []byte
+
+	switch curve {
+	case Curve_CURVE25519:
+		signer := ed25519.PrivateKey(key)
+		sig = ed25519.Sign(signer, b)
+	case Curve_P256:
+		signer := &ecdsa.PrivateKey{
+			PublicKey: ecdsa.PublicKey{
+				Curve: elliptic.P256(),
+			},
+			// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
+			D: new(big.Int).SetBytes(key),
+		}
+		// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
+		signer.X, signer.Y = signer.Curve.ScalarBaseMult(key)
+
+		// We need to hash first for ECDSA
+		// - https://pkg.go.dev/crypto/ecdsa#SignASN1
+		hashed := sha256.Sum256(b)
+		sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:])
+		if err != nil {
+			return err
+		}
+	default:
+		return fmt.Errorf("invalid curve: %s", nc.Details.Curve)
 	}
+
 	nc.Signature = sig
 	return nil
 }
 
 // CheckSignature verifies the signature against the provided public key
-func (nc *NebulaCertificate) CheckSignature(key ed25519.PublicKey) bool {
+func (nc *NebulaCertificate) CheckSignature(key []byte) bool {
 	b, err := proto.Marshal(nc.getRawDetails())
 	if err != nil {
 		return false
 	}
-	return ed25519.Verify(key, b, nc.Signature)
+	switch nc.Details.Curve {
+	case Curve_CURVE25519:
+		return ed25519.Verify(ed25519.PublicKey(key), b, nc.Signature)
+	case Curve_P256:
+		x, y := elliptic.Unmarshal(elliptic.P256(), key)
+		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
+		hashed := sha256.Sum256(b)
+		return ecdsa.VerifyASN1(pubKey, hashed[:], nc.Signature)
+	default:
+		return false
+	}
 }
 
 // Expired will return true if the nebula cert is too young or too old compared to the provided time, otherwise false
@@ -393,7 +570,7 @@ func (nc *NebulaCertificate) Expired(t time.Time) bool {
 // Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
 func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) {
 	if ncp.IsBlocklisted(nc) {
-		return false, fmt.Errorf("certificate has been blocked")
+		return false, ErrBlockListed
 	}
 
 	signer, err := ncp.GetCAForCert(nc)
@@ -402,15 +579,15 @@ func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error
 	}
 
 	if signer.Expired(t) {
-		return false, fmt.Errorf("root certificate is expired")
+		return false, ErrRootExpired
 	}
 
 	if nc.Expired(t) {
-		return false, fmt.Errorf("certificate is expired")
+		return false, ErrExpired
 	}
 
 	if !nc.CheckSignature(signer.Details.PublicKey) {
-		return false, fmt.Errorf("certificate signature did not match")
+		return false, ErrSignatureMismatch
 	}
 
 	if err := nc.CheckRootConstrains(signer); err != nil {
@@ -463,22 +640,52 @@ func (nc *NebulaCertificate) CheckRootConstrains(signer *NebulaCertificate) erro
 }
 
 // VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match
-func (nc *NebulaCertificate) VerifyPrivateKey(key []byte) error {
+func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error {
+	if curve != nc.Details.Curve {
+		return fmt.Errorf("curve in cert and private key supplied don't match")
+	}
 	if nc.Details.IsCA {
-		// the call to PublicKey below will panic slice bounds out of range otherwise
-		if len(key) != ed25519.PrivateKeySize {
-			return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
-		}
+		switch curve {
+		case Curve_CURVE25519:
+			// the call to PublicKey below will panic slice bounds out of range otherwise
+			if len(key) != ed25519.PrivateKeySize {
+				return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
+			}
 
-		if !ed25519.PublicKey(nc.Details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) {
-			return fmt.Errorf("public key in cert and private key supplied don't match")
+			if !ed25519.PublicKey(nc.Details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) {
+				return fmt.Errorf("public key in cert and private key supplied don't match")
+			}
+		case Curve_P256:
+			privkey, err := ecdh.P256().NewPrivateKey(key)
+			if err != nil {
+				return fmt.Errorf("cannot parse private key as P256")
+			}
+			pub := privkey.PublicKey().Bytes()
+			if !bytes.Equal(pub, nc.Details.PublicKey) {
+				return fmt.Errorf("public key in cert and private key supplied don't match")
+			}
+		default:
+			return fmt.Errorf("invalid curve: %s", curve)
 		}
 		return nil
 	}
 
-	pub, err := curve25519.X25519(key, curve25519.Basepoint)
-	if err != nil {
-		return err
+	var pub []byte
+	switch curve {
+	case Curve_CURVE25519:
+		var err error
+		pub, err = curve25519.X25519(key, curve25519.Basepoint)
+		if err != nil {
+			return err
+		}
+	case Curve_P256:
+		privkey, err := ecdh.P256().NewPrivateKey(key)
+		if err != nil {
+			return err
+		}
+		pub = privkey.PublicKey().Bytes()
+	default:
+		return fmt.Errorf("invalid curve: %s", curve)
 	}
 	if !bytes.Equal(pub, nc.Details.PublicKey) {
 		return fmt.Errorf("public key in cert and private key supplied don't match")
@@ -532,6 +739,7 @@ func (nc *NebulaCertificate) String() string {
 	s += fmt.Sprintf("\t\tIs CA: %v\n", nc.Details.IsCA)
 	s += fmt.Sprintf("\t\tIssuer: %s\n", nc.Details.Issuer)
 	s += fmt.Sprintf("\t\tPublic key: %x\n", nc.Details.PublicKey)
+	s += fmt.Sprintf("\t\tCurve: %s\n", nc.Details.Curve)
 	s += "\t}\n"
 	fp, err := nc.Sha256Sum()
 	if err == nil {
@@ -552,6 +760,7 @@ func (nc *NebulaCertificate) getRawDetails() *RawNebulaCertificateDetails {
 		NotAfter:  nc.Details.NotAfter.Unix(),
 		PublicKey: make([]byte, len(nc.Details.PublicKey)),
 		IsCA:      nc.Details.IsCA,
+		Curve:     nc.Details.Curve,
 	}
 
 	for _, ipNet := range nc.Details.Ips {
@@ -621,6 +830,7 @@ func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
 			"publicKey": fmt.Sprintf("%x", nc.Details.PublicKey),
 			"isCa":      nc.Details.IsCA,
 			"issuer":    nc.Details.Issuer,
+			"curve":     nc.Details.Curve.String(),
 		},
 		"fingerprint": fp,
 		"signature":   fmt.Sprintf("%x", nc.Signature),

+ 111 - 48
cert/cert.pb.go

@@ -1,7 +1,7 @@
 // Code generated by protoc-gen-go. DO NOT EDIT.
 // versions:
-// 	protoc-gen-go v1.28.0
-// 	protoc        v3.19.4
+// 	protoc-gen-go v1.30.0
+// 	protoc        v3.21.5
 // source: cert.proto
 
 package cert
@@ -20,6 +20,52 @@ const (
 	_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
 )
 
+type Curve int32
+
+const (
+	Curve_CURVE25519 Curve = 0
+	Curve_P256       Curve = 1
+)
+
+// Enum value maps for Curve.
+var (
+	Curve_name = map[int32]string{
+		0: "CURVE25519",
+		1: "P256",
+	}
+	Curve_value = map[string]int32{
+		"CURVE25519": 0,
+		"P256":       1,
+	}
+)
+
+func (x Curve) Enum() *Curve {
+	p := new(Curve)
+	*p = x
+	return p
+}
+
+func (x Curve) String() string {
+	return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
+}
+
+func (Curve) Descriptor() protoreflect.EnumDescriptor {
+	return file_cert_proto_enumTypes[0].Descriptor()
+}
+
+func (Curve) Type() protoreflect.EnumType {
+	return &file_cert_proto_enumTypes[0]
+}
+
+func (x Curve) Number() protoreflect.EnumNumber {
+	return protoreflect.EnumNumber(x)
+}
+
+// Deprecated: Use Curve.Descriptor instead.
+func (Curve) EnumDescriptor() ([]byte, []int) {
+	return file_cert_proto_rawDescGZIP(), []int{0}
+}
+
 type RawNebulaCertificate struct {
 	state         protoimpl.MessageState
 	sizeCache     protoimpl.SizeCache
@@ -91,6 +137,7 @@ type RawNebulaCertificateDetails struct {
 	IsCA      bool     `protobuf:"varint,8,opt,name=IsCA,proto3" json:"IsCA,omitempty"`
 	// sha-256 of the issuer certificate, if this field is blank the cert is self-signed
 	Issuer []byte `protobuf:"bytes,9,opt,name=Issuer,proto3" json:"Issuer,omitempty"`
+	Curve  Curve  `protobuf:"varint,100,opt,name=curve,proto3,enum=cert.Curve" json:"curve,omitempty"`
 }
 
 func (x *RawNebulaCertificateDetails) Reset() {
@@ -188,6 +235,13 @@ func (x *RawNebulaCertificateDetails) GetIssuer() []byte {
 	return nil
 }
 
+func (x *RawNebulaCertificateDetails) GetCurve() Curve {
+	if x != nil {
+		return x.Curve
+	}
+	return Curve_CURVE25519
+}
+
 type RawNebulaEncryptedData struct {
 	state         protoimpl.MessageState
 	sizeCache     protoimpl.SizeCache
@@ -388,7 +442,7 @@ var file_cert_proto_rawDesc = []byte{
 	0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x52, 0x07,
 	0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, 0x67, 0x6e, 0x61,
 	0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, 0x69, 0x67, 0x6e,
-	0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xf9, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62,
+	0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62,
 	0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65,
 	0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20,
 	0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x70, 0x73,
@@ -404,38 +458,43 @@ var file_cert_proto_rawDesc = []byte{
 	0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20,
 	0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73,
 	0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65,
-	0x72, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45,
-	0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12, 0x51, 0x0a, 0x12,
+	0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, 0x01, 0x28, 0x0e,
+	0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, 0x52, 0x05, 0x63,
+	0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
+	0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12,
+	0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
+	0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65,
+	0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72,
+	0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12,
+	0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
+	0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74,
+	0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65,
+	0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61,
 	0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
-	0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e,
-	0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
-	0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12, 0x45, 0x6e, 0x63,
-	0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12,
-	0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x18, 0x02, 0x20,
-	0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x22,
-	0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63,
-	0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12,
-	0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67,
-	0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x45, 0x6e,
-	0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68,
-	0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d,
-	0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x65,
-	0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f,
-	0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52, 0x10, 0x41, 0x72,
-	0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x22, 0xa3,
-	0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f,
-	0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07,
-	0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76,
-	0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79,
-	0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x12, 0x20,
-	0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x18, 0x04, 0x20,
-	0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d,
-	0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03,
-	0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73,
-	0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04,
-	0x73, 0x61, 0x6c, 0x74, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63,
-	0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c,
-	0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+	0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e,
+	0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
+	0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72,
+	0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61,
+	0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f,
+	0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41,
+	0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52,
+	0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72,
+	0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41,
+	0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12,
+	0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05,
+	0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d,
+	0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72,
+	0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d,
+	0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c,
+	0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e,
+	0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69,
+	0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28,
+	0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, 0x72, 0x76, 0x65,
+	0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, 0x39, 0x10, 0x00,
+	0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69,
+	0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71,
+	0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72,
+	0x6f, 0x74, 0x6f, 0x33,
 }
 
 var (
@@ -450,23 +509,26 @@ func file_cert_proto_rawDescGZIP() []byte {
 	return file_cert_proto_rawDescData
 }
 
+var file_cert_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
 var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
 var file_cert_proto_goTypes = []interface{}{
-	(*RawNebulaCertificate)(nil),        // 0: cert.RawNebulaCertificate
-	(*RawNebulaCertificateDetails)(nil), // 1: cert.RawNebulaCertificateDetails
-	(*RawNebulaEncryptedData)(nil),      // 2: cert.RawNebulaEncryptedData
-	(*RawNebulaEncryptionMetadata)(nil), // 3: cert.RawNebulaEncryptionMetadata
-	(*RawNebulaArgon2Parameters)(nil),   // 4: cert.RawNebulaArgon2Parameters
+	(Curve)(0),                          // 0: cert.Curve
+	(*RawNebulaCertificate)(nil),        // 1: cert.RawNebulaCertificate
+	(*RawNebulaCertificateDetails)(nil), // 2: cert.RawNebulaCertificateDetails
+	(*RawNebulaEncryptedData)(nil),      // 3: cert.RawNebulaEncryptedData
+	(*RawNebulaEncryptionMetadata)(nil), // 4: cert.RawNebulaEncryptionMetadata
+	(*RawNebulaArgon2Parameters)(nil),   // 5: cert.RawNebulaArgon2Parameters
 }
 var file_cert_proto_depIdxs = []int32{
-	1, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails
-	3, // 1: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata
-	4, // 2: cert.RawNebulaEncryptionMetadata.Argon2Parameters:type_name -> cert.RawNebulaArgon2Parameters
-	3, // [3:3] is the sub-list for method output_type
-	3, // [3:3] is the sub-list for method input_type
-	3, // [3:3] is the sub-list for extension type_name
-	3, // [3:3] is the sub-list for extension extendee
-	0, // [0:3] is the sub-list for field type_name
+	2, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails
+	0, // 1: cert.RawNebulaCertificateDetails.curve:type_name -> cert.Curve
+	4, // 2: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata
+	5, // 3: cert.RawNebulaEncryptionMetadata.Argon2Parameters:type_name -> cert.RawNebulaArgon2Parameters
+	4, // [4:4] is the sub-list for method output_type
+	4, // [4:4] is the sub-list for method input_type
+	4, // [4:4] is the sub-list for extension type_name
+	4, // [4:4] is the sub-list for extension extendee
+	0, // [0:4] is the sub-list for field type_name
 }
 
 func init() { file_cert_proto_init() }
@@ -541,13 +603,14 @@ func file_cert_proto_init() {
 		File: protoimpl.DescBuilder{
 			GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
 			RawDescriptor: file_cert_proto_rawDesc,
-			NumEnums:      0,
+			NumEnums:      1,
 			NumMessages:   5,
 			NumExtensions: 0,
 			NumServices:   0,
 		},
 		GoTypes:           file_cert_proto_goTypes,
 		DependencyIndexes: file_cert_proto_depIdxs,
+		EnumInfos:         file_cert_proto_enumTypes,
 		MessageInfos:      file_cert_proto_msgTypes,
 	}.Build()
 	File_cert_proto = out.File

+ 7 - 0
cert/cert.proto

@@ -5,6 +5,11 @@ option go_package = "github.com/slackhq/nebula/cert";
 
 //import "google/protobuf/timestamp.proto";
 
+enum Curve {
+  CURVE25519 = 0;
+  P256 = 1;
+}
+
 message RawNebulaCertificate {
     RawNebulaCertificateDetails Details = 1;
     bytes Signature = 2;
@@ -26,6 +31,8 @@ message RawNebulaCertificateDetails {
 
     // sha-256 of the issuer certificate, if this field is blank the cert is self-signed
     bytes Issuer = 9;
+
+    Curve curve = 100;
 }
 
 message RawNebulaEncryptedData {

+ 292 - 43
cert/cert_test.go

@@ -1,6 +1,9 @@
 package cert
 
 import (
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/elliptic"
 	"crypto/rand"
 	"fmt"
 	"io"
@@ -101,7 +104,49 @@ func TestNebulaCertificate_Sign(t *testing.T) {
 	pub, priv, err := ed25519.GenerateKey(rand.Reader)
 	assert.Nil(t, err)
 	assert.False(t, nc.CheckSignature(pub))
-	assert.Nil(t, nc.Sign(priv))
+	assert.Nil(t, nc.Sign(Curve_CURVE25519, priv))
+	assert.True(t, nc.CheckSignature(pub))
+
+	_, err = nc.Marshal()
+	assert.Nil(t, err)
+	//t.Log("Cert size:", len(b))
+}
+
+func TestNebulaCertificate_SignP256(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
+
+	nc := NebulaCertificate{
+		Details: NebulaCertificateDetails{
+			Name: "testing",
+			Ips: []*net.IPNet{
+				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
+				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
+				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
+			},
+			Subnets: []*net.IPNet{
+				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
+				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
+				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
+			},
+			Groups:    []string{"test-group1", "test-group2", "test-group3"},
+			NotBefore: before,
+			NotAfter:  after,
+			PublicKey: pubKey,
+			IsCA:      false,
+			Curve:     Curve_P256,
+			Issuer:    "1234567890abcedfghij1234567890ab",
+		},
+	}
+
+	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
+	rawPriv := priv.D.FillBytes(make([]byte, 32))
+
+	assert.Nil(t, err)
+	assert.False(t, nc.CheckSignature(pub))
+	assert.Nil(t, nc.Sign(Curve_P256, rawPriv))
 	assert.True(t, nc.CheckSignature(pub))
 
 	_, err = nc.Marshal()
@@ -153,7 +198,7 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Equal(
 		t,
-		"{\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\",\"10.1.1.3/ff00ff00\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.1/ff00ff00\",\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"26cb1c30ad7872c804c166b5150fa372f437aa3856b04edb4334b4470ec728e4\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}",
+		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\",\"10.1.1.3/ff00ff00\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.1/ff00ff00\",\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"26cb1c30ad7872c804c166b5150fa372f437aa3856b04edb4334b4470ec728e4\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}",
 		string(b),
 	)
 }
@@ -177,7 +222,7 @@ func TestNebulaCertificate_Verify(t *testing.T) {
 
 	v, err := c.Verify(time.Now(), caPool)
 	assert.False(t, v)
-	assert.EqualError(t, err, "certificate has been blocked")
+	assert.EqualError(t, err, "certificate is in the block list")
 
 	caPool.ResetCertBlocklist()
 	v, err = c.Verify(time.Now(), caPool)
@@ -217,6 +262,65 @@ func TestNebulaCertificate_Verify(t *testing.T) {
 	assert.Nil(t, err)
 }
 
+func TestNebulaCertificate_VerifyP256(t *testing.T) {
+	ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	assert.Nil(t, err)
+
+	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	assert.Nil(t, err)
+
+	h, err := ca.Sha256Sum()
+	assert.Nil(t, err)
+
+	caPool := NewCAPool()
+	caPool.CAs[h] = ca
+
+	f, err := c.Sha256Sum()
+	assert.Nil(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	v, err := c.Verify(time.Now(), caPool)
+	assert.False(t, v)
+	assert.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	v, err = c.Verify(time.Now(), caPool)
+	assert.True(t, v)
+	assert.Nil(t, err)
+
+	v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool)
+	assert.False(t, v)
+	assert.EqualError(t, err, "root certificate is expired")
+
+	c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
+	assert.Nil(t, err)
+	v, err = c.Verify(time.Now().Add(time.Minute*6), caPool)
+	assert.False(t, v)
+	assert.EqualError(t, err, "certificate is expired")
+
+	// Test group assertion
+	ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"})
+	assert.Nil(t, err)
+
+	caPem, err := ca.MarshalToPEM()
+	assert.Nil(t, err)
+
+	caPool = NewCAPool()
+	caPool.AddCACertificate(caPem)
+
+	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"})
+	assert.Nil(t, err)
+	v, err = c.Verify(time.Now(), caPool)
+	assert.False(t, v)
+	assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad")
+
+	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"})
+	assert.Nil(t, err)
+	v, err = c.Verify(time.Now(), caPool)
+	assert.True(t, v)
+	assert.Nil(t, err)
+}
+
 func TestNebulaCertificate_Verify_IPs(t *testing.T) {
 	_, caIp1, _ := net.ParseCIDR("10.0.0.0/16")
 	_, caIp2, _ := net.ParseCIDR("192.168.0.0/24")
@@ -378,20 +482,40 @@ func TestNebulaCertificate_Verify_Subnets(t *testing.T) {
 func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) {
 	ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
 	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(caKey)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
 	assert.Nil(t, err)
 
 	_, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
 	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(caKey2)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
 	assert.NotNil(t, err)
 
 	c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	err = c.VerifyPrivateKey(priv)
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv)
 	assert.Nil(t, err)
 
 	_, priv2 := x25519Keypair()
-	err = c.VerifyPrivateKey(priv2)
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
+	assert.NotNil(t, err)
+}
+
+func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) {
+	ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
+	assert.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_P256, caKey)
+	assert.Nil(t, err)
+
+	_, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
+	assert.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
+	assert.NotNil(t, err)
+
+	c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
+	err = c.VerifyPrivateKey(Curve_P256, priv)
+	assert.Nil(t, err)
+
+	_, priv2 := p256Keypair()
+	err = c.VerifyPrivateKey(Curve_P256, priv2)
 	assert.NotNil(t, err)
 }
 
@@ -438,6 +562,16 @@ CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4
 vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie
 WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs=
 -----END NEBULA CERTIFICATE-----
+`
+
+	p256 := `
+# p256 certificate
+-----BEGIN NEBULA CERTIFICATE-----
+CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2
+6tctO9sPT3jOx8ES6M1nIqOhpTmZeabF/4rELDqPV4aH5jfJut798DUXql0FlF8H
+76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC
+IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
+-----END NEBULA CERTIFICATE-----
 `
 
 	rootCA := NebulaCertificate{
@@ -452,6 +586,12 @@ WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs=
 		},
 	}
 
+	rootCAP256 := NebulaCertificate{
+		Details: NebulaCertificateDetails{
+			Name: "nebula P256 test",
+		},
+	}
+
 	p, err := NewCAPoolFromBytes([]byte(noNewLines))
 	assert.Nil(t, err)
 	assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
@@ -474,6 +614,11 @@ WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs=
 	assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
 	assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired")
 	assert.Equal(t, len(pppp.CAs), 3)
+
+	ppppp, err := NewCAPoolFromBytes([]byte(p256))
+	assert.Nil(t, err)
+	assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name)
+	assert.Equal(t, len(ppppp.CAs), 1)
 }
 
 func appendByteSlices(b ...[]byte) []byte {
@@ -529,11 +674,16 @@ bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
 	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
 }
 
-func TestUnmarshalEd25519PrivateKey(t *testing.T) {
+func TestUnmarshalSigningPrivateKey(t *testing.T) {
 	privKey := []byte(`# A good key
 -----BEGIN NEBULA ED25519 PRIVATE KEY-----
 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
 -----END NEBULA ED25519 PRIVATE KEY-----
+`)
+	privP256Key := []byte(`# A good key
+-----BEGIN NEBULA ECDSA P256 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA ECDSA P256 PRIVATE KEY-----
 `)
 	shortKey := []byte(`# A short key
 -----BEGIN NEBULA ED25519 PRIVATE KEY-----
@@ -550,35 +700,43 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
 -END NEBULA ED25519 PRIVATE KEY-----`)
 
-	keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
+	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
 
 	// Success test case
-	k, rest, err := UnmarshalEd25519PrivateKey(keyBundle)
+	k, rest, curve, err := UnmarshalSigningPrivateKey(keyBundle)
 	assert.Len(t, k, 64)
+	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Nil(t, err)
+
+	// Success test case
+	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
+	assert.Len(t, k, 32)
 	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_P256, curve)
 	assert.Nil(t, err)
 
 	// Fail due to short key
-	k, rest, err = UnmarshalEd25519PrivateKey(rest)
+	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
+	assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
 
 	// Fail due to invalid banner
-	k, rest, err = UnmarshalEd25519PrivateKey(rest)
+	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519 private key banner")
+	assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519/ECDSA private key banner")
 
 	// Fail due to ivalid PEM format, because
 	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, err = UnmarshalEd25519PrivateKey(rest)
+	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
 	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
 }
 
-func TestDecryptAndUnmarshalEd25519PrivateKey(t *testing.T) {
+func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
 	passphrase := []byte("DO NOT USE THIS KEY")
 	privKey := []byte(`# A good key
 -----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
@@ -614,60 +772,67 @@ qrlJ69wer3ZUHFXA
 	keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
 
 	// Success test case
-	k, rest, err := DecryptAndUnmarshalEd25519PrivateKey(passphrase, keyBundle)
+	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
 	assert.Nil(t, err)
+	assert.Equal(t, Curve_CURVE25519, curve)
 	assert.Len(t, k, 64)
 	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
 
 	// Fail due to short key
-	k, rest, err = DecryptAndUnmarshalEd25519PrivateKey(passphrase, rest)
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
 	assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
 	assert.Nil(t, k)
 	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
 
 	// Fail due to invalid banner
-	k, rest, err = DecryptAndUnmarshalEd25519PrivateKey(passphrase, rest)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519 private key banner")
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
+	assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
 
 	// Fail due to ivalid PEM format, because
 	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, err = DecryptAndUnmarshalEd25519PrivateKey(passphrase, rest)
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
 	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
 
 	// Fail due to invalid passphrase
-	k, rest, err = DecryptAndUnmarshalEd25519PrivateKey([]byte("invalid passphrase"), privKey)
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
 	assert.EqualError(t, err, "invalid passphrase or corrupt private key")
 	assert.Nil(t, k)
 	assert.Equal(t, rest, []byte{})
 }
 
-func TestEncryptAndMarshalEd25519PrivateKey(t *testing.T) {
+func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
 	// Having proved that decryption works correctly above, we can test the
 	// encryption function produces a value which can be decrypted
 	passphrase := []byte("passphrase")
 	bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
 	kdfParams := NewArgon2Parameters(64*1024, 4, 3)
-	key, err := EncryptAndMarshalEd25519PrivateKey(bytes, passphrase, kdfParams)
+	key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
 	assert.Nil(t, err)
 
 	// Verify the "key" can be decrypted successfully
-	k, rest, err := DecryptAndUnmarshalEd25519PrivateKey(passphrase, key)
+	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
 	assert.Len(t, k, 64)
+	assert.Equal(t, Curve_CURVE25519, curve)
 	assert.Equal(t, rest, []byte{})
 	assert.Nil(t, err)
 
 	// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
 }
 
-func TestUnmarshalX25519PrivateKey(t *testing.T) {
+func TestUnmarshalPrivateKey(t *testing.T) {
 	privKey := []byte(`# A good key
 -----BEGIN NEBULA X25519 PRIVATE KEY-----
 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 -----END NEBULA X25519 PRIVATE KEY-----
+`)
+	privP256Key := []byte(`# A good key
+-----BEGIN NEBULA P256 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA P256 PRIVATE KEY-----
 `)
 	shortKey := []byte(`# A short key
 -----BEGIN NEBULA X25519 PRIVATE KEY-----
@@ -684,29 +849,37 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 -END NEBULA X25519 PRIVATE KEY-----`)
 
-	keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
+	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, curve, err := UnmarshalPrivateKey(keyBundle)
+	assert.Len(t, k, 32)
+	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Nil(t, err)
 
 	// Success test case
-	k, rest, err := UnmarshalX25519PrivateKey(keyBundle)
+	k, rest, curve, err = UnmarshalPrivateKey(rest)
 	assert.Len(t, k, 32)
 	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_P256, curve)
 	assert.Nil(t, err)
 
 	// Fail due to short key
-	k, rest, err = UnmarshalX25519PrivateKey(rest)
+	k, rest, curve, err = UnmarshalPrivateKey(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid X25519 private key")
+	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
 
 	// Fail due to invalid banner
-	k, rest, err = UnmarshalX25519PrivateKey(rest)
+	k, rest, curve, err = UnmarshalPrivateKey(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula X25519 private key banner")
+	assert.EqualError(t, err, "bytes did not contain a proper nebula private key banner")
 
 	// Fail due to ivalid PEM format, because
 	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, err = UnmarshalX25519PrivateKey(rest)
+	k, rest, curve, err = UnmarshalPrivateKey(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
 	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
@@ -766,6 +939,12 @@ func TestUnmarshalX25519PublicKey(t *testing.T) {
 -----BEGIN NEBULA X25519 PUBLIC KEY-----
 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 -----END NEBULA X25519 PUBLIC KEY-----
+`)
+	pubP256Key := []byte(`# A good key
+-----BEGIN NEBULA P256 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
+AAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA P256 PUBLIC KEY-----
 `)
 	shortKey := []byte(`# A short key
 -----BEGIN NEBULA X25519 PUBLIC KEY-----
@@ -782,29 +961,37 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 -END NEBULA X25519 PUBLIC KEY-----`)
 
-	keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem)
+	keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
 
 	// Success test case
-	k, rest, err := UnmarshalX25519PublicKey(keyBundle)
+	k, rest, curve, err := UnmarshalPublicKey(keyBundle)
 	assert.Equal(t, len(k), 32)
 	assert.Nil(t, err)
+	assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_CURVE25519, curve)
+
+	// Success test case
+	k, rest, curve, err = UnmarshalPublicKey(rest)
+	assert.Equal(t, len(k), 65)
+	assert.Nil(t, err)
 	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_P256, curve)
 
 	// Fail due to short key
-	k, rest, err = UnmarshalX25519PublicKey(rest)
+	k, rest, curve, err = UnmarshalPublicKey(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid X25519 public key")
+	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
 
 	// Fail due to invalid banner
-	k, rest, err = UnmarshalX25519PublicKey(rest)
+	k, rest, curve, err = UnmarshalPublicKey(rest)
 	assert.Nil(t, k)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula X25519 public key banner")
+	assert.EqualError(t, err, "bytes did not contain a proper nebula public key banner")
 	assert.Equal(t, rest, invalidPem)
 
 	// Fail due to ivalid PEM format, because
 	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, err = UnmarshalX25519PublicKey(rest)
+	k, rest, curve, err = UnmarshalPublicKey(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
 	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
@@ -901,13 +1088,56 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
 		nc.Details.Groups = groups
 	}
 
-	err = nc.Sign(priv)
+	err = nc.Sign(Curve_CURVE25519, priv)
 	if err != nil {
 		return nil, nil, nil, err
 	}
 	return nc, pub, priv, nil
 }
 
+func newTestCaCertP256(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
+	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
+	rawPriv := priv.D.FillBytes(make([]byte, 32))
+
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	nc := &NebulaCertificate{
+		Details: NebulaCertificateDetails{
+			Name:           "test ca",
+			NotBefore:      time.Unix(before.Unix(), 0),
+			NotAfter:       time.Unix(after.Unix(), 0),
+			PublicKey:      pub,
+			IsCA:           true,
+			Curve:          Curve_P256,
+			InvertedGroups: make(map[string]struct{}),
+		},
+	}
+
+	if len(ips) > 0 {
+		nc.Details.Ips = ips
+	}
+
+	if len(subnets) > 0 {
+		nc.Details.Subnets = subnets
+	}
+
+	if len(groups) > 0 {
+		nc.Details.Groups = groups
+	}
+
+	err = nc.Sign(Curve_P256, rawPriv)
+	if err != nil {
+		return nil, nil, nil, err
+	}
+	return nc, pub, rawPriv, nil
+}
+
 func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
 	issuer, err := ca.Sha256Sum()
 	if err != nil {
@@ -941,7 +1171,16 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
 		}
 	}
 
-	pub, rawPriv := x25519Keypair()
+	var pub, rawPriv []byte
+
+	switch ca.Details.Curve {
+	case Curve_CURVE25519:
+		pub, rawPriv = x25519Keypair()
+	case Curve_P256:
+		pub, rawPriv = p256Keypair()
+	default:
+		return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Details.Curve)
+	}
 
 	nc := &NebulaCertificate{
 		Details: NebulaCertificateDetails{
@@ -953,12 +1192,13 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
 			NotAfter:       time.Unix(after.Unix(), 0),
 			PublicKey:      pub,
 			IsCA:           false,
+			Curve:          ca.Details.Curve,
 			Issuer:         issuer,
 			InvertedGroups: make(map[string]struct{}),
 		},
 	}
 
-	err = nc.Sign(key)
+	err = nc.Sign(ca.Details.Curve, key)
 	if err != nil {
 		return nil, nil, nil, err
 	}
@@ -979,3 +1219,12 @@ func x25519Keypair() ([]byte, []byte) {
 
 	return pubkey, privkey
 }
+
+func p256Keypair() ([]byte, []byte) {
+	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
+	if err != nil {
+		panic(err)
+	}
+	pubkey := privkey.PublicKey()
+	return pubkey.Bytes(), privkey.Bytes()
+}

+ 9 - 4
cert/errors.go

@@ -1,9 +1,14 @@
 package cert
 
-import "errors"
+import (
+	"errors"
+)
 
 var (
-	ErrExpired       = errors.New("certificate is expired")
-	ErrNotCA         = errors.New("certificate is not a CA")
-	ErrNotSelfSigned = errors.New("certificate is not self-signed")
+	ErrRootExpired       = errors.New("root certificate is expired")
+	ErrExpired           = errors.New("certificate is expired")
+	ErrNotCA             = errors.New("certificate is not a CA")
+	ErrNotSelfSigned     = errors.New("certificate is not self-signed")
+	ErrBlockListed       = errors.New("certificate is in the block list")
+	ErrSignatureMismatch = errors.New("certificate signature did not match")
 )

+ 25 - 3
cidr/tree4.go

@@ -13,8 +13,14 @@ type Node struct {
 	value  interface{}
 }
 
+type entry struct {
+	CIDR  *net.IPNet
+	Value *interface{}
+}
+
 type Tree4 struct {
 	root *Node
+	list []entry
 }
 
 const (
@@ -24,6 +30,7 @@ const (
 func NewTree4() *Tree4 {
 	tree := new(Tree4)
 	tree.root = &Node{}
+	tree.list = []entry{}
 	return tree
 }
 
@@ -53,6 +60,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 	// We already have this range so update the value
 	if next != nil {
+		addCIDR := cidr.String()
+		for i, v := range tree.list {
+			if addCIDR == v.CIDR.String() {
+				tree.list = append(tree.list[:i], tree.list[i+1:]...)
+				break
+			}
+		}
+
+		tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
 		node.value = val
 		return
 	}
@@ -74,9 +90,10 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 	// Final node marks our cidr, set the value
 	node.value = val
+	tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
 }
 
-// Finds the first match, which may be the least specific
+// Contains finds the first match, which may be the least specific
 func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	node := tree.root
@@ -99,7 +116,7 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
 	return value
 }
 
-// Finds the most specific match
+// MostSpecificContains finds the most specific match
 func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	node := tree.root
@@ -121,7 +138,7 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
 	return value
 }
 
-// Finds the most specific match
+// Match finds the most specific match
 func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	node := tree.root
@@ -143,3 +160,8 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
 	}
 	return value
 }
+
+// List will return all CIDRs and their current values. Do not modify the contents!
+func (tree *Tree4) List() []entry {
+	return tree.list
+}

+ 14 - 0
cidr/tree4_test.go

@@ -8,6 +8,20 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
+func TestCIDRTree_List(t *testing.T) {
+	tree := NewTree4()
+	tree.AddCIDR(Parse("1.0.0.0/16"), "1")
+	tree.AddCIDR(Parse("1.0.0.0/8"), "2")
+	tree.AddCIDR(Parse("1.0.0.0/16"), "3")
+	tree.AddCIDR(Parse("1.0.0.0/16"), "4")
+	list := tree.List()
+	assert.Len(t, list, 2)
+	assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
+	assert.Equal(t, "2", *list[0].Value)
+	assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
+	assert.Equal(t, "4", *list[1].Value)
+}
+
 func TestCIDRTree_Contains(t *testing.T) {
 	tree := NewTree4()
 	tree.AddCIDR(Parse("1.0.0.0/8"), "1")

+ 28 - 6
cmd/nebula-cert/ca.go

@@ -1,6 +1,8 @@
 package main
 
 import (
+	"crypto/ecdsa"
+	"crypto/elliptic"
 	"crypto/rand"
 	"flag"
 	"fmt"
@@ -31,6 +33,8 @@ type caFlags struct {
 	argonIterations  *uint
 	argonParallelism *uint
 	encryption       *bool
+
+	curve *string
 }
 
 func newCaFlags() *caFlags {
@@ -48,6 +52,7 @@ func newCaFlags() *caFlags {
 	cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
 	cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
 	cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
+	cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
 	return &cf
 }
 
@@ -160,9 +165,25 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 		}
 	}
 
-	pub, rawPriv, err := ed25519.GenerateKey(rand.Reader)
-	if err != nil {
-		return fmt.Errorf("error while generating ed25519 keys: %s", err)
+	var curve cert.Curve
+	var pub, rawPriv []byte
+	switch *cf.curve {
+	case "25519", "X25519", "Curve25519", "CURVE25519":
+		curve = cert.Curve_CURVE25519
+		pub, rawPriv, err = ed25519.GenerateKey(rand.Reader)
+		if err != nil {
+			return fmt.Errorf("error while generating ed25519 keys: %s", err)
+		}
+	case "P256":
+		var key *ecdsa.PrivateKey
+		curve = cert.Curve_P256
+		key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+		if err != nil {
+			return fmt.Errorf("error while generating ecdsa keys: %s", err)
+		}
+		// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L60
+		rawPriv = key.D.FillBytes(make([]byte, 32))
+		pub = elliptic.Marshal(elliptic.P256(), key.X, key.Y)
 	}
 
 	nc := cert.NebulaCertificate{
@@ -175,6 +196,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 			NotAfter:  time.Now().Add(*cf.duration),
 			PublicKey: pub,
 			IsCA:      true,
+			Curve:     curve,
 		},
 	}
 
@@ -186,20 +208,20 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 		return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
 	}
 
-	err = nc.Sign(rawPriv)
+	err = nc.Sign(curve, rawPriv)
 	if err != nil {
 		return fmt.Errorf("error while signing: %s", err)
 	}
 
 	if *cf.encryption {
-		b, err := cert.EncryptAndMarshalEd25519PrivateKey(rawPriv, passphrase, kdfParams)
+		b, err := cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams)
 		if err != nil {
 			return fmt.Errorf("error while encrypting out-key: %s", err)
 		}
 
 		err = ioutil.WriteFile(*cf.outKeyPath, b, 0600)
 	} else {
-		err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalEd25519PrivateKey(rawPriv), 0600)
+		err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalSigningPrivateKey(curve, rawPriv), 0600)
 	}
 
 	if err != nil {

+ 5 - 1
cmd/nebula-cert/ca_test.go

@@ -35,6 +35,8 @@ func Test_caHelp(t *testing.T) {
 			"    \tOptional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase (default 2097152)\n"+
 			"  -argon-parallelism uint\n"+
 			"    \tOptional: Argon2 parallelism parameter used for encrypted private key passphrase (default 4)\n"+
+			"  -curve string\n"+
+			"    \tEdDSA/ECDSA Curve (25519, P256) (default \"25519\")\n"+
 			"  -duration duration\n"+
 			"    \tOptional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\" (default 8760h0m0s)\n"+
 			"  -encrypt\n"+
@@ -174,7 +176,9 @@ func Test_ca(t *testing.T) {
 	assert.Equal(t, uint32(1), ned.EncryptionMetadata.Argon2Parameters.Iterations)
 
 	// verify the key is valid and decrypt-able
-	lKey, b, err = cert.DecryptAndUnmarshalEd25519PrivateKey(passphrase, rb)
+	var curve cert.Curve
+	curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb)
+	assert.Equal(t, cert.Curve_CURVE25519, curve)
 	assert.Nil(t, err)
 	assert.Len(t, b, 0)
 	assert.Len(t, lKey, 64)

+ 17 - 3
cmd/nebula-cert/keygen.go

@@ -14,6 +14,8 @@ type keygenFlags struct {
 	set        *flag.FlagSet
 	outKeyPath *string
 	outPubPath *string
+
+	curve *string
 }
 
 func newKeygenFlags() *keygenFlags {
@@ -21,6 +23,7 @@ func newKeygenFlags() *keygenFlags {
 	cf.set.Usage = func() {}
 	cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to")
 	cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to")
+	cf.curve = cf.set.String("curve", "25519", "ECDH Curve (25519, P256)")
 	return &cf
 }
 
@@ -38,14 +41,25 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
 		return err
 	}
 
-	pub, rawPriv := x25519Keypair()
+	var pub, rawPriv []byte
+	var curve cert.Curve
+	switch *cf.curve {
+	case "25519", "X25519", "Curve25519", "CURVE25519":
+		pub, rawPriv = x25519Keypair()
+		curve = cert.Curve_CURVE25519
+	case "P256":
+		pub, rawPriv = p256Keypair()
+		curve = cert.Curve_P256
+	default:
+		return fmt.Errorf("invalid curve: %s", *cf.curve)
+	}
 
-	err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalX25519PrivateKey(rawPriv), 0600)
+	err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-key: %s", err)
 	}
 
-	err = ioutil.WriteFile(*cf.outPubPath, cert.MarshalX25519PublicKey(pub), 0600)
+	err = ioutil.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-pub: %s", err)
 	}

+ 2 - 0
cmd/nebula-cert/keygen_test.go

@@ -22,6 +22,8 @@ func Test_keygenHelp(t *testing.T) {
 	assert.Equal(
 		t,
 		"Usage of "+os.Args[0]+" keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+
+			"  -curve string\n"+
+			"    \tECDH Curve (25519, P256) (default \"25519\")\n"+
 			"  -out-key string\n"+
 			"    \tRequired: path to write the private key to\n"+
 			"  -out-pub string\n"+

+ 2 - 2
cmd/nebula-cert/print_test.go

@@ -87,7 +87,7 @@ func Test_printCert(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Equal(
 		t,
-		"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n",
+		"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n",
 		ob.String(),
 	)
 	assert.Equal(t, "", eb.String())
@@ -115,7 +115,7 @@ func Test_printCert(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Equal(
 		t,
-		"{\"details\":{\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n",
+		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n",
 		ob.String(),
 	)
 	assert.Equal(t, "", eb.String())

+ 35 - 9
cmd/nebula-cert/sign.go

@@ -1,7 +1,7 @@
 package main
 
 import (
-	"crypto/ed25519"
+	"crypto/ecdh"
 	"crypto/rand"
 	"flag"
 	"fmt"
@@ -78,10 +78,11 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("error while reading ca-key: %s", err)
 	}
 
-	var caKey ed25519.PrivateKey
+	var curve cert.Curve
+	var caKey []byte
 
 	// naively attempt to decode the private key as though it is not encrypted
-	caKey, _, err = cert.UnmarshalEd25519PrivateKey(rawCAKey)
+	caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey)
 	if err == cert.ErrPrivateKeyEncrypted {
 		// ask for a passphrase until we get one
 		var passphrase []byte
@@ -103,7 +104,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 			return fmt.Errorf("cannot open encrypted ca-key without passphrase")
 		}
 
-		caKey, _, err = cert.DecryptAndUnmarshalEd25519PrivateKey(passphrase, rawCAKey)
+		curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
 		if err != nil {
 			return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
 		}
@@ -121,7 +122,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("error while parsing ca-crt: %s", err)
 	}
 
-	if err := caCert.VerifyPrivateKey(caKey); err != nil {
+	if err := caCert.VerifyPrivateKey(curve, caKey); err != nil {
 		return fmt.Errorf("refusing to sign, root certificate does not match private key")
 	}
 
@@ -181,12 +182,16 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		if err != nil {
 			return fmt.Errorf("error while reading in-pub: %s", err)
 		}
-		pub, _, err = cert.UnmarshalX25519PublicKey(rawPub)
+		var pubCurve cert.Curve
+		pub, _, pubCurve, err = cert.UnmarshalPublicKey(rawPub)
 		if err != nil {
 			return fmt.Errorf("error while parsing in-pub: %s", err)
 		}
+		if pubCurve != curve {
+			return fmt.Errorf("curve of in-pub does not match ca")
+		}
 	} else {
-		pub, rawPriv = x25519Keypair()
+		pub, rawPriv = newKeypair(curve)
 	}
 
 	nc := cert.NebulaCertificate{
@@ -200,6 +205,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 			PublicKey: pub,
 			IsCA:      false,
 			Issuer:    issuer,
+			Curve:     curve,
 		},
 	}
 
@@ -219,7 +225,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
 	}
 
-	err = nc.Sign(caKey)
+	err = nc.Sign(curve, caKey)
 	if err != nil {
 		return fmt.Errorf("error while signing: %s", err)
 	}
@@ -229,7 +235,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 			return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
 		}
 
-		err = ioutil.WriteFile(*sf.outKeyPath, cert.MarshalX25519PrivateKey(rawPriv), 0600)
+		err = ioutil.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
 		if err != nil {
 			return fmt.Errorf("error while writing out-key: %s", err)
 		}
@@ -260,6 +266,17 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 	return nil
 }
 
+func newKeypair(curve cert.Curve) ([]byte, []byte) {
+	switch curve {
+	case cert.Curve_CURVE25519:
+		return x25519Keypair()
+	case cert.Curve_P256:
+		return p256Keypair()
+	default:
+		return nil, nil
+	}
+}
+
 func x25519Keypair() ([]byte, []byte) {
 	privkey := make([]byte, 32)
 	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
@@ -274,6 +291,15 @@ func x25519Keypair() ([]byte, []byte) {
 	return pubkey, privkey
 }
 
+func p256Keypair() ([]byte, []byte) {
+	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
+	if err != nil {
+		panic(err)
+	}
+	pubkey := privkey.PublicKey()
+	return pubkey.Bytes(), privkey.Bytes()
+}
+
 func signSummary() string {
 	return "sign <flags>: create and sign a certificate"
 }

+ 1 - 1
cmd/nebula-cert/sign_test.go

@@ -359,7 +359,7 @@ func Test_signCert(t *testing.T) {
 	// generate the encrypted key
 	caPub, caPriv, _ = ed25519.GenerateKey(rand.Reader)
 	kdfParams := cert.NewArgon2Parameters(64*1024, 4, 3)
-	b, _ = cert.EncryptAndMarshalEd25519PrivateKey(caPriv, passphrase, kdfParams)
+	b, _ = cert.EncryptAndMarshalSigningPrivateKey(cert.Curve_CURVE25519, caPriv, passphrase, kdfParams)
 	caKeyF.Write(b)
 
 	ca = cert.NebulaCertificate{

+ 3 - 3
cmd/nebula-cert/verify_test.go

@@ -77,7 +77,7 @@ func Test_verify(t *testing.T) {
 			IsCA:      true,
 		},
 	}
-	ca.Sign(caPriv)
+	ca.Sign(cert.Curve_CURVE25519, caPriv)
 	b, _ := ca.MarshalToPEM()
 	caFile.Truncate(0)
 	caFile.Seek(0, 0)
@@ -117,7 +117,7 @@ func Test_verify(t *testing.T) {
 		},
 	}
 
-	crt.Sign(badPriv)
+	crt.Sign(cert.Curve_CURVE25519, badPriv)
 	b, _ = crt.MarshalToPEM()
 	certFile.Truncate(0)
 	certFile.Seek(0, 0)
@@ -129,7 +129,7 @@ func Test_verify(t *testing.T) {
 	assert.EqualError(t, err, "certificate signature did not match")
 
 	// verified cert at path
-	crt.Sign(caPriv)
+	crt.Sign(cert.Curve_CURVE25519, caPriv)
 	b, _ = crt.MarshalToPEM()
 	certFile.Truncate(0)
 	certFile.Seek(0, 0)

+ 244 - 33
connection_manager.go

@@ -1,16 +1,30 @@
 package nebula
 
 import (
+	"bytes"
 	"context"
 	"sync"
 	"time"
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 )
 
+type trafficDecision int
+
+const (
+	doNothing      trafficDecision = 0
+	deleteTunnel   trafficDecision = 1 // delete the hostinfo on our side, do not notify the remote
+	closeTunnel    trafficDecision = 2 // delete the hostinfo and notify the remote
+	swapPrimary    trafficDecision = 3
+	migrateRelays  trafficDecision = 4
+	tryRehandshake trafficDecision = 5
+)
+
 type connectionManager struct {
 	in     map[uint32]struct{}
 	inLock *sync.RWMutex
@@ -18,6 +32,10 @@ type connectionManager struct {
 	out     map[uint32]struct{}
 	outLock *sync.RWMutex
 
+	// relayUsed holds which relay localIndexs are in use
+	relayUsed     map[uint32]struct{}
+	relayUsedLock *sync.RWMutex
+
 	hostMap                 *HostMap
 	trafficTimer            *LockingTimerWheel[uint32]
 	intf                    *Interface
@@ -44,6 +62,8 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
 		inLock:                  &sync.RWMutex{},
 		out:                     make(map[uint32]struct{}),
 		outLock:                 &sync.RWMutex{},
+		relayUsed:               make(map[uint32]struct{}),
+		relayUsedLock:           &sync.RWMutex{},
 		trafficTimer:            NewLockingTimerWheel[uint32](time.Millisecond*500, max),
 		intf:                    intf,
 		pendingDeletion:         make(map[uint32]struct{}),
@@ -84,6 +104,19 @@ func (n *connectionManager) Out(localIndex uint32) {
 	n.outLock.Unlock()
 }
 
+func (n *connectionManager) RelayUsed(localIndex uint32) {
+	n.relayUsedLock.RLock()
+	// If this already exists, return
+	if _, ok := n.relayUsed[localIndex]; ok {
+		n.relayUsedLock.RUnlock()
+		return
+	}
+	n.relayUsedLock.RUnlock()
+	n.relayUsedLock.Lock()
+	n.relayUsed[localIndex] = struct{}{}
+	n.relayUsedLock.Unlock()
+}
+
 // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
 // resets the state for this local index
 func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
@@ -99,8 +132,15 @@ func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bo
 }
 
 func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
-	n.Out(localIndex)
+	// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
+	n.outLock.Lock()
+	if _, ok := n.out[localIndex]; ok {
+		n.outLock.Unlock()
+		return
+	}
+	n.out[localIndex] = struct{}{}
 	n.trafficTimer.Add(localIndex, n.checkInterval)
+	n.outLock.Unlock()
 }
 
 func (n *connectionManager) Start(ctx context.Context) {
@@ -136,18 +176,136 @@ func (n *connectionManager) Run(ctx context.Context) {
 }
 
 func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
-	hostinfo, err := n.hostMap.QueryIndex(localIndex)
-	if err != nil {
+	decision, hostinfo, primary := n.makeTrafficDecision(localIndex, p, nb, out, now)
+
+	switch decision {
+	case deleteTunnel:
+		if n.hostMap.DeleteHostInfo(hostinfo) {
+			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
+			n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
+		}
+
+	case closeTunnel:
+		n.intf.sendCloseTunnel(hostinfo)
+		n.intf.closeTunnel(hostinfo)
+
+	case swapPrimary:
+		n.swapPrimary(hostinfo, primary)
+
+	case migrateRelays:
+		n.migrateRelayUsed(hostinfo, primary)
+
+	case tryRehandshake:
+		n.tryRehandshake(hostinfo)
+	}
+
+	n.resetRelayTrafficCheck(hostinfo)
+}
+
+func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
+	if hostinfo != nil {
+		n.relayUsedLock.Lock()
+		defer n.relayUsedLock.Unlock()
+		// No need to migrate any relays, delete usage info now.
+		for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
+			delete(n.relayUsed, idx)
+		}
+	}
+}
+
+func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
+	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
+
+	for _, r := range relayFor {
+		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
+
+		var index uint32
+		var relayFrom iputil.VpnIp
+		var relayTo iputil.VpnIp
+		switch {
+		case ok && existing.State == Established:
+			// This relay already exists in newhostinfo, then do nothing.
+			continue
+		case ok && existing.State == Requested:
+			// The relay exists in a Requested state; re-send the request
+			index = existing.LocalIndex
+			switch r.Type {
+			case TerminalType:
+				relayFrom = newhostinfo.vpnIp
+				relayTo = existing.PeerIp
+			case ForwardingType:
+				relayFrom = existing.PeerIp
+				relayTo = newhostinfo.vpnIp
+			default:
+				// should never happen
+			}
+		case !ok:
+			n.relayUsedLock.RLock()
+			if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
+				// The relay hasn't been used; don't migrate it.
+				n.relayUsedLock.RUnlock()
+				continue
+			}
+			n.relayUsedLock.RUnlock()
+			// The relay doesn't exist at all; create some relay state and send the request.
+			var err error
+			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
+			if err != nil {
+				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
+				continue
+			}
+			switch r.Type {
+			case TerminalType:
+				relayFrom = newhostinfo.vpnIp
+				relayTo = r.PeerIp
+			case ForwardingType:
+				relayFrom = r.PeerIp
+				relayTo = newhostinfo.vpnIp
+			default:
+				// should never happen
+			}
+		}
+
+		// Send a CreateRelayRequest to the peer.
+		req := NebulaControl{
+			Type:                NebulaControl_CreateRelayRequest,
+			InitiatorRelayIndex: index,
+			RelayFromIp:         uint32(relayFrom),
+			RelayToIp:           uint32(relayTo),
+		}
+		msg, err := req.Marshal()
+		if err != nil {
+			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
+		} else {
+			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
+			n.l.WithFields(logrus.Fields{
+				"relayFrom":           iputil.VpnIp(req.RelayFromIp),
+				"relayTo":             iputil.VpnIp(req.RelayToIp),
+				"initiatorRelayIndex": req.InitiatorRelayIndex,
+				"responderRelayIndex": req.ResponderRelayIndex,
+				"vpnIp":               newhostinfo.vpnIp}).
+				Info("send CreateRelayRequest")
+		}
+	}
+}
+
+func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []byte, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
+	n.hostMap.RLock()
+	defer n.hostMap.RUnlock()
+
+	hostinfo := n.hostMap.Indexes[localIndex]
+	if hostinfo == nil {
 		n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
 		delete(n.pendingDeletion, localIndex)
-		return
+		return doNothing, nil, nil
 	}
 
-	if n.handleInvalidCertificate(now, hostinfo) {
-		return
+	if n.isInvalidCertificate(now, hostinfo) {
+		delete(n.pendingDeletion, hostinfo.localIndexId)
+		return closeTunnel, hostinfo, nil
 	}
 
-	primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
+	primary := n.hostMap.Hosts[hostinfo.vpnIp]
 	mainHostInfo := true
 	if primary != nil && primary != hostinfo {
 		mainHostInfo = false
@@ -158,6 +316,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 
 	// A hostinfo is determined alive if there is incoming traffic
 	if inTraffic {
+		decision := doNothing
 		if n.l.Level >= logrus.DebugLevel {
 			hostinfo.logger(n.l).
 				WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
@@ -165,11 +324,15 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 		}
 		delete(n.pendingDeletion, hostinfo.localIndexId)
 
-		if !mainHostInfo {
-			if hostinfo.vpnIp > n.intf.myVpnIp {
-				// We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
-				// This the primary and prime the old primary hostinfo for testing
-				n.hostMap.MakePrimary(hostinfo)
+		if mainHostInfo {
+			decision = tryRehandshake
+
+		} else {
+			if n.shouldSwapPrimary(hostinfo, primary) {
+				decision = swapPrimary
+			} else {
+				// migrate the relays to the primary, if in use.
+				decision = migrateRelays
 			}
 		}
 
@@ -180,7 +343,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 			n.sendPunch(hostinfo)
 		}
 
-		return
+		return decision, hostinfo, primary
 	}
 
 	if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
@@ -189,22 +352,17 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 			WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
 			Info("Tunnel status")
 
-		n.hostMap.DeleteHostInfo(hostinfo)
 		delete(n.pendingDeletion, hostinfo.localIndexId)
-		return
+		return deleteTunnel, hostinfo, nil
 	}
 
-	hostinfo.logger(n.l).
-		WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
-		Debug("Tunnel status")
-
 	if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
 		if !outTraffic {
 			// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
 			// Just maintain NAT state if configured to do so.
 			n.sendPunch(hostinfo)
 			n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
-			return
+			return doNothing, nil, nil
 
 		}
 
@@ -215,29 +373,55 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 			n.sendPunch(hostinfo)
 		}
 
-		if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
-			// We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel
-			n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
-			return
+		if n.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(n.l).
+				WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
+				Debug("Tunnel status")
 		}
 
 		// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-		n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
+		n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
 
 	} else {
-		hostinfo.logger(n.l).Debugf("Hostinfo sadness")
+		if n.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(n.l).Debugf("Hostinfo sadness")
+		}
 	}
 
 	n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
 	n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
+	return doNothing, nil, nil
 }
 
-// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
-func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
-	if !n.intf.disconnectInvalid {
+func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
+	// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
+	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
+	// Let's sort this out.
+
+	if current.vpnIp < n.intf.myVpnIp {
+		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
+		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
+		// The remotes vpn ip is lower than mine. I will not flip.
 		return false
 	}
 
+	certState := n.intf.certState.Load()
+	return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature)
+}
+
+func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
+	n.hostMap.Lock()
+	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
+	if n.hostMap.Hosts[current.vpnIp] == primary {
+		n.hostMap.unlockedMakePrimary(current)
+	}
+	n.hostMap.Unlock()
+}
+
+// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
+// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
+// check and return true.
+func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
 	remoteCert := hostinfo.GetCert()
 	if remoteCert == nil {
 		return false
@@ -248,15 +432,16 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *Ho
 		return false
 	}
 
+	if !n.intf.disconnectInvalid && err != cert.ErrBlockListed {
+		// Block listed certificates should always be disconnected
+		return false
+	}
+
 	fingerprint, _ := remoteCert.Sha256Sum()
 	hostinfo.logger(n.l).WithError(err).
 		WithField("fingerprint", fingerprint).
 		Info("Remote certificate is no longer valid, tearing down the tunnel")
 
-	// Inform the remote and close the tunnel locally
-	n.intf.sendCloseTunnel(hostinfo)
-	n.intf.closeTunnel(hostinfo)
-	delete(n.pendingDeletion, hostinfo.localIndexId)
 	return true
 }
 
@@ -277,3 +462,29 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 		n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
 	}
 }
+
+func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
+	certState := n.intf.certState.Load()
+	if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) {
+		return
+	}
+
+	n.l.WithField("vpnIp", hostinfo.vpnIp).
+		WithField("reason", "local certificate is not current").
+		Info("Re-handshaking with remote")
+
+	//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
+	newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)
+	if !newHostinfo.HandshakeReady {
+		ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
+	}
+
+	//If this is a static host, we don't need to wait for the HostQueryReply
+	//We can trigger the handshake right now
+	if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
+		select {
+		case n.intf.handshakeManager.trigger <- hostinfo.vpnIp:
+		default:
+		}
+	}
+}

+ 6 - 6
connection_manager_test.go

@@ -220,7 +220,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 			PublicKey: pubCA,
 		},
 	}
-	caCert.Sign(privCA)
+	caCert.Sign(cert.Curve_CURVE25519, privCA)
 	ncp := &cert.NebulaCAPool{
 		CAs: cert.NewCAPool().CAs,
 	}
@@ -239,7 +239,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 			Issuer:    "ca",
 		},
 	}
-	peerCert.Sign(privCA)
+	peerCert.Sign(cert.Curve_CURVE25519, privCA)
 
 	cs := &CertState{
 		rawCertificate:      []byte{},
@@ -279,13 +279,13 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	// Check if to disconnect with invalid certificate.
 	// Should be alive.
 	nextTick := now.Add(45 * time.Second)
-	destroyed := nc.handleInvalidCertificate(nextTick, hostinfo)
-	assert.False(t, destroyed)
+	invalid := nc.isInvalidCertificate(nextTick, hostinfo)
+	assert.False(t, invalid)
 
 	// Move ahead 61s.
 	// Check if to disconnect with invalid certificate.
 	// Should be disconnected.
 	nextTick = now.Add(61 * time.Second)
-	destroyed = nc.handleInvalidCertificate(nextTick, hostinfo)
-	assert.True(t, destroyed)
+	invalid = nc.isInvalidCertificate(nextTick, hostinfo)
+	assert.True(t, invalid)
 }

+ 14 - 3
connection_state.go

@@ -29,12 +29,23 @@ type ConnectionState struct {
 }
 
 func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
-	cs := noise.NewCipherSuite(noise.DH25519, noiseutil.CipherAESGCM, noise.HashSHA256)
+	var dhFunc noise.DHFunc
+	curCertState := f.certState.Load()
+
+	switch curCertState.certificate.Details.Curve {
+	case cert.Curve_CURVE25519:
+		dhFunc = noise.DH25519
+	case cert.Curve_P256:
+		dhFunc = noiseutil.DHP256
+	default:
+		l.Errorf("invalid curve: %s", curCertState.certificate.Details.Curve)
+		return nil
+	}
+	cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
 	if f.cipher == "chachapoly" {
-		cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
+		cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
 	}
 
-	curCertState := f.certState.Load()
 	static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}
 
 	b := NewBits(ReplayWindow)

+ 1 - 1
control_test.go

@@ -47,7 +47,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		Signature: []byte{1, 2, 1, 2, 1, 3},
 	}
 
-	remotes := NewRemoteList()
+	remotes := NewRemoteList(nil)
 	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
 	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
 	hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{

+ 14 - 0
control_tester.go

@@ -163,3 +163,17 @@ func (c *Control) GetHostmap() *HostMap {
 func (c *Control) GetCert() *cert.NebulaCertificate {
 	return c.f.certState.Load().certificate
 }
+
+func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
+	hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo)
+	ixHandshakeStage0(c.f, vpnIp, hostinfo)
+
+	// If this is a static host, we don't need to wait for the HostQueryReply
+	// We can trigger the handshake right now
+	if _, ok := c.f.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
+		select {
+		case c.f.handshakeManager.trigger <- hostinfo.vpnIp:
+		default:
+		}
+	}
+}

+ 385 - 22
e2e/handshakes_test.go

@@ -4,6 +4,7 @@
 package e2e
 
 import (
+	"fmt"
 	"net"
 	"testing"
 	"time"
@@ -15,12 +16,13 @@ import (
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
+	"gopkg.in/yaml.v2"
 )
 
 func BenchmarkHotPath(b *testing.B) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Put their info in our lighthouse
 	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
@@ -43,8 +45,8 @@ func BenchmarkHotPath(b *testing.B) {
 
 func TestGoodHandshake(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Put their info in our lighthouse
 	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
@@ -98,9 +100,9 @@ func TestWrongResponderHandshake(t *testing.T) {
 	// The IPs here are chosen on purpose:
 	// The current remote handling will sort by preference, public, and then lexically.
 	// So we need them to have a higher address than evil (we could apply a preference though)
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
-	evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
+	evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
 
 	// Add their real udp addr, which should be tried after evil.
 	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
@@ -163,8 +165,8 @@ func TestStage1Race(t *testing.T) {
 	// But will eventually collapse down to a single tunnel
 
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Put their info in our lighthouse and vice versa
 	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
@@ -240,8 +242,8 @@ func TestStage1Race(t *testing.T) {
 
 func TestUncleanShutdownRaceLoser(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
@@ -289,8 +291,8 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
 
 func TestUncleanShutdownRaceWinner(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
@@ -340,9 +342,9 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
 
 func TestRelays(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
@@ -371,9 +373,9 @@ func TestRelays(t *testing.T) {
 func TestStage1RaceRelays(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
@@ -418,9 +420,9 @@ func TestStage1RaceRelays(t *testing.T) {
 func TestStage1RaceRelays2(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
 	l := NewTestLogger()
 
 	// Teach my how to get to the relay and that their can be reached via the relay
@@ -503,5 +505,366 @@ func TestStage1RaceRelays2(t *testing.T) {
 	//
 	////TODO: assert hostmaps
 }
+func TestRehandshakingRelays(t *testing.T) {
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, relayControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	relayControl.Start()
+	theirControl.Start()
+
+	t.Log("Trigger a handshake from me to them via the relay")
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+
+	p := r.RouteForAllUntilTxTun(theirControl)
+	r.Log("Assert the tunnel works")
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
+
+	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
+	// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
+	r.Log("Renew relay certificate and spin until me and them sees it")
+	_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
+
+	caB, err := ca.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	relayConfig.Settings["pki"] = m{
+		"ca":   string(caB),
+		"cert": string(myNextPEM),
+		"key":  string(myNextPrivKey),
+	}
+	rc, err := yaml.Marshal(relayConfig.Settings)
+	assert.NoError(t, err)
+	relayConfig.ReloadConfigString(string(rc))
+
+	for {
+		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
+		assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+		c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		if len(c.Cert.Details.Groups) != 0 {
+			// We have a new certificate now
+			r.Log("Certificate between my and relay is updated!")
+			break
+		}
+
+		time.Sleep(time.Second)
+	}
+
+	for {
+		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
+		assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+		c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		if len(c.Cert.Details.Groups) != 0 {
+			// We have a new certificate now
+			r.Log("Certificate between their and relay is updated!")
+			break
+		}
+
+		time.Sleep(time.Second)
+	}
+
+	r.Log("Assert the relay tunnel still works")
+	assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
+	// We should have two hostinfos on all sides
+	for len(myControl.GetHostmap().Indexes) != 2 {
+		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
+		r.Log("Assert the relay tunnel still works")
+		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		r.Log("yupitdoes")
+		time.Sleep(time.Second)
+	}
+	t.Logf("myControl hostinfos got cleaned up!")
+	for len(theirControl.GetHostmap().Indexes) != 2 {
+		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
+		r.Log("Assert the relay tunnel still works")
+		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		r.Log("yupitdoes")
+		time.Sleep(time.Second)
+	}
+	t.Logf("theirControl hostinfos got cleaned up!")
+	for len(relayControl.GetHostmap().Indexes) != 2 {
+		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
+		r.Log("Assert the relay tunnel still works")
+		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		r.Log("yupitdoes")
+		time.Sleep(time.Second)
+	}
+	t.Logf("relayControl hostinfos got cleaned up!")
+}
+
+func TestRehandshaking(t *testing.T) {
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
+
+	// Put their info in our lighthouse and vice versa
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	t.Log("Stand up a tunnel between me and them")
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+
+	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
+
+	r.Log("Renew my certificate and spin until their sees it")
+	_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"})
+
+	caB, err := ca.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	myConfig.Settings["pki"] = m{
+		"ca":   string(caB),
+		"cert": string(myNextPEM),
+		"key":  string(myNextPrivKey),
+	}
+	rc, err := yaml.Marshal(myConfig.Settings)
+	assert.NoError(t, err)
+	myConfig.ReloadConfigString(string(rc))
+
+	for {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+		if len(c.Cert.Details.Groups) != 0 {
+			// We have a new certificate now
+			break
+		}
+
+		time.Sleep(time.Second)
+	}
+
+	// Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly
+	rc, err = yaml.Marshal(theirConfig.Settings)
+	assert.NoError(t, err)
+	var theirNewConfig m
+	assert.NoError(t, yaml.Unmarshal(rc, &theirNewConfig))
+	theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{})
+	theirFirewall["inbound"] = []m{{
+		"proto": "any",
+		"port":  "any",
+		"group": "new group",
+	}}
+	rc, err = yaml.Marshal(theirNewConfig)
+	assert.NoError(t, err)
+	theirConfig.ReloadConfigString(string(rc))
+
+	r.Log("Spin until there is only 1 tunnel")
+	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		t.Log("Connection manager hasn't ticked yet")
+		time.Sleep(time.Second)
+	}
+
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
+	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
+	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
+	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
+
+	// Make sure the correct tunnel won
+	c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+	assert.Contains(t, c.Cert.Details.Groups, "new group")
+
+	// We should only have a single tunnel now on both sides
+	assert.Len(t, myFinalHostmapHosts, 1)
+	assert.Len(t, theirFinalHostmapHosts, 1)
+	assert.Len(t, myFinalHostmapIndexes, 1)
+	assert.Len(t, theirFinalHostmapIndexes, 1)
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func TestRehandshakingLoser(t *testing.T) {
+	// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
+	// Should be the one with the new certificate
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
+
+	// Put their info in our lighthouse and vice versa
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	t.Log("Stand up a tunnel between me and them")
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+
+	tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+	tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+	fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
+
+	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
+
+	r.Log("Renew their certificate and spin until mine sees it")
+	_, _, theirNextPrivKey, theirNextPEM := newTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"})
+
+	caB, err := ca.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	theirConfig.Settings["pki"] = m{
+		"ca":   string(caB),
+		"cert": string(theirNextPEM),
+		"key":  string(theirNextPrivKey),
+	}
+	rc, err := yaml.Marshal(theirConfig.Settings)
+	assert.NoError(t, err)
+	theirConfig.ReloadConfigString(string(rc))
+
+	for {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+
+		_, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"]
+		if theirNewGroup {
+			break
+		}
+
+		time.Sleep(time.Second)
+	}
+
+	// Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly
+	rc, err = yaml.Marshal(myConfig.Settings)
+	assert.NoError(t, err)
+	var myNewConfig m
+	assert.NoError(t, yaml.Unmarshal(rc, &myNewConfig))
+	theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{})
+	theirFirewall["inbound"] = []m{{
+		"proto": "any",
+		"port":  "any",
+		"group": "their new group",
+	}}
+	rc, err = yaml.Marshal(myNewConfig)
+	assert.NoError(t, err)
+	myConfig.ReloadConfigString(string(rc))
+
+	r.Log("Spin until there is only 1 tunnel")
+	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		t.Log("Connection manager hasn't ticked yet")
+		time.Sleep(time.Second)
+	}
+
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
+	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
+	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
+	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
+
+	// Make sure the correct tunnel won
+	theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+	assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group")
+
+	// We should only have a single tunnel now on both sides
+	assert.Len(t, myFinalHostmapHosts, 1)
+	assert.Len(t, theirFinalHostmapHosts, 1)
+	assert.Len(t, myFinalHostmapIndexes, 1)
+	assert.Len(t, theirFinalHostmapIndexes, 1)
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func TestRaceRegression(t *testing.T) {
+	// This test forces stage 1, stage 2, stage 1 to be received by me from them
+	// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
+	// caused a cross-linked hostinfo
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+
+	// Put their info in our lighthouse
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	//them rx stage:1 initiatorIndex=642843150 responderIndex=0
+	//me rx   stage:1 initiatorIndex=120607833 responderIndex=0
+	//them rx stage:1 initiatorIndex=642843150 responderIndex=0
+	//me rx   stage:2 initiatorIndex=642843150 responderIndex=3701775874
+	//me rx   stage:1 initiatorIndex=120607833 responderIndex=0
+	//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
+
+	t.Log("Start both handshakes")
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+
+	t.Log("Get both stage 1")
+	myStage1ForThem := myControl.GetFromUDP(true)
+	theirStage1ForMe := theirControl.GetFromUDP(true)
+
+	t.Log("Inject them in a special way")
+	theirControl.InjectUDPPacket(myStage1ForThem)
+	myControl.InjectUDPPacket(theirStage1ForMe)
+	theirControl.InjectUDPPacket(myStage1ForThem)
+
+	//TODO: ensure stage 2
+	t.Log("Get both stage 2")
+	myStage2ForThem := myControl.GetFromUDP(true)
+	theirStage2ForMe := theirControl.GetFromUDP(true)
+
+	t.Log("Inject them in a special way again")
+	myControl.InjectUDPPacket(theirStage2ForMe)
+	myControl.InjectUDPPacket(theirStage1ForMe)
+	theirControl.InjectUDPPacket(myStage2ForThem)
+
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	t.Log("Flush the packets")
+	r.RouteForAllUntilTxTun(myControl)
+	r.RouteForAllUntilTxTun(theirControl)
+	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
+
+	t.Log("Make sure the tunnel still works")
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+
+	myControl.Stop()
+	theirControl.Stop()
+}
 
+//TODO: test
+// Race winner renews and handshakes
+// Race loser renews and handshakes
+// Does race winner repin the cert to old?
 //TODO: add a test with many lies

+ 6 - 6
e2e/helpers_test.go

@@ -30,7 +30,7 @@ import (
 type m map[string]interface{}
 
 // newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr) {
+func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) {
 	l := NewTestLogger()
 
 	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
@@ -78,8 +78,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 			"level":            l.Level.String(),
 		},
 		"timers": m{
-			"pending_deletion_interval": 4,
-			"connection_alive_interval": 4,
+			"pending_deletion_interval": 2,
+			"connection_alive_interval": 2,
 		},
 	}
 
@@ -105,7 +105,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		panic(err)
 	}
 
-	return control, vpnIpNet, &udpAddr
+	return control, vpnIpNet, &udpAddr, c
 }
 
 // newTestCaCert will generate a CA cert
@@ -141,7 +141,7 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
 		nc.Details.Groups = groups
 	}
 
-	err = nc.Sign(priv)
+	err = nc.Sign(cert.Curve_CURVE25519, priv)
 	if err != nil {
 		panic(err)
 	}
@@ -187,7 +187,7 @@ func newTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af
 		},
 	}
 
-	err = nc.Sign(key)
+	err = nc.Sign(ca.Details.Curve, key)
 	if err != nil {
 		panic(err)
 	}

+ 5 - 5
e2e/router/router.go

@@ -215,7 +215,7 @@ func (r *R) renderFlow() {
 			continue
 		}
 		participants[addr] = struct{}{}
-		sanAddr := strings.Replace(addr, ":", "#58;", 1)
+		sanAddr := strings.Replace(addr, ":", "-", 1)
 		participantsVals = append(participantsVals, sanAddr)
 		fmt.Fprintf(
 			f, "    participant %s as Nebula: %s<br/>UDP: %s\n",
@@ -252,9 +252,9 @@ func (r *R) renderFlow() {
 
 			fmt.Fprintf(f,
 				"    %s%s%s: %s(%s), index %v, counter: %v\n",
-				strings.Replace(p.from.GetUDPAddr(), ":", "#58;", 1),
+				strings.Replace(p.from.GetUDPAddr(), ":", "-", 1),
 				line,
-				strings.Replace(p.to.GetUDPAddr(), ":", "#58;", 1),
+				strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
 				h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
 			)
 		}
@@ -758,8 +758,8 @@ func (r *R) formatUdpPacket(p *packet) string {
 	data := packet.ApplicationLayer()
 	return fmt.Sprintf(
 		"    %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
-		strings.Replace(from, ":", "#58;", 1),
-		strings.Replace(p.to.GetUDPAddr(), ":", "#58;", 1),
+		strings.Replace(from, ":", "-", 1),
+		strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
 		udp.SrcPort,
 		udp.DstPort,
 		string(data.Payload()),

+ 6 - 1
examples/config.yml

@@ -223,6 +223,10 @@ tun:
     #  metric: 100
     #  install: true
 
+  # On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
+  # in nebula configuration files. Default false, not reloadable.
+  #use_system_route_table: false
+
   # EXPERIMENTAL: This option may change or disappear in the future.
   # Multiport spreads outgoing UDP packets across multiple UDP send ports,
   # which allows nebula to work around any issues on the underlay network.
@@ -342,7 +346,8 @@ firewall:
   #   host: `any` or a literal hostname, ie `test-host`
   #   group: `any` or a literal group name, ie `default-group`
   #   groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
-  #   cidr: a CIDR, `0.0.0.0/0` is any.
+  #   cidr: a remote CIDR, `0.0.0.0/0` is any.
+  #   local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
   #   ca_name: An issuing CA name
   #   ca_sha: An issuing CA shasum
 

+ 65 - 36
firewall.go

@@ -25,7 +25,7 @@ const tcpACK = 0x10
 const tcpFIN = 0x01
 
 type FirewallInterface interface {
-	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error
+	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
 }
 
 type conn struct {
@@ -106,11 +106,12 @@ type FirewallCA struct {
 }
 
 type FirewallRule struct {
-	// Any makes Hosts, Groups, and CIDR irrelevant
-	Any    bool
-	Hosts  map[string]struct{}
-	Groups [][]string
-	CIDR   *cidr.Tree4
+	// Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant
+	Any       bool
+	Hosts     map[string]struct{}
+	Groups    [][]string
+	CIDR      *cidr.Tree4
+	LocalCIDR *cidr.Tree4
 }
 
 // Even though ports are uint16, int32 maps are faster for lookup
@@ -218,18 +219,22 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
 }
 
 // AddRule properly creates the in memory rule structure for a firewall table.
-func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
+func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
 	// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
 	// https://github.com/golang/go/issues/14131
 	sIp := ""
 	if ip != nil {
 		sIp = ip.String()
 	}
+	lIp := ""
+	if localIp != nil {
+		lIp = localIp.String()
+	}
 
 	// We need this rule string because we generate a hash. Removing this will break firewall reload.
 	ruleString := fmt.Sprintf(
-		"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
-		incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
+		"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
+		incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
 	)
 	f.rules += ruleString + "\n"
 
@@ -237,7 +242,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 	if !incoming {
 		direction = "outgoing"
 	}
-	f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
+	f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
 		Info("Firewall rule added")
 
 	var (
@@ -264,7 +269,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 		return fmt.Errorf("unknown protocol %v", proto)
 	}
 
-	return fp.addRule(startPort, endPort, groups, host, ip, caName, caSha)
+	return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha)
 }
 
 // GetRuleHash returns a hash representation of all inbound and outbound rules
@@ -302,8 +307,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 			return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
 		}
 
-		if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.CAName == "" && r.CASha == "" {
-			return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, ca_name, or ca_sha must be provided", table, i)
+		if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
+			return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
 		}
 
 		if len(r.Groups) > 0 {
@@ -355,7 +360,15 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 			}
 		}
 
-		err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, r.CAName, r.CASha)
+		var localCidr *net.IPNet
+		if r.LocalCidr != "" {
+			_, localCidr, err = net.ParseCIDR(r.LocalCidr)
+			if err != nil {
+				return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
+			}
+		}
+
+		err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
 		if err != nil {
 			return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
 		}
@@ -595,7 +608,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
 	return false
 }
 
-func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
+func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
 	if startPort > endPort {
 		return fmt.Errorf("start port was lower than end port")
 	}
@@ -608,7 +621,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
 			}
 		}
 
-		if err := fp[i].addRule(groups, host, ip, caName, caSha); err != nil {
+		if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil {
 			return err
 		}
 	}
@@ -639,12 +652,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
 	return fp[firewall.PortAny].match(p, c, caPool)
 }
 
-func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
+func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
 	fr := func() *FirewallRule {
 		return &FirewallRule{
-			Hosts:  make(map[string]struct{}),
-			Groups: make([][]string, 0),
-			CIDR:   cidr.NewTree4(),
+			Hosts:     make(map[string]struct{}),
+			Groups:    make([][]string, 0),
+			CIDR:      cidr.NewTree4(),
+			LocalCIDR: cidr.NewTree4(),
 		}
 	}
 
@@ -653,14 +667,14 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
 			fc.Any = fr()
 		}
 
-		return fc.Any.addRule(groups, host, ip)
+		return fc.Any.addRule(groups, host, ip, localIp)
 	}
 
 	if caSha != "" {
 		if _, ok := fc.CAShas[caSha]; !ok {
 			fc.CAShas[caSha] = fr()
 		}
-		err := fc.CAShas[caSha].addRule(groups, host, ip)
+		err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
 		if err != nil {
 			return err
 		}
@@ -670,7 +684,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
 		if _, ok := fc.CANames[caName]; !ok {
 			fc.CANames[caName] = fr()
 		}
-		err := fc.CANames[caName].addRule(groups, host, ip)
+		err := fc.CANames[caName].addRule(groups, host, ip, localIp)
 		if err != nil {
 			return err
 		}
@@ -702,17 +716,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
 	return fc.CANames[s.Details.Name].match(p, c)
 }
 
-func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
+func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error {
 	if fr.Any {
 		return nil
 	}
 
-	if fr.isAny(groups, host, ip) {
+	if fr.isAny(groups, host, ip, localIp) {
 		fr.Any = true
 		// If it's any we need to wipe out any pre-existing rules to save on memory
 		fr.Groups = make([][]string, 0)
 		fr.Hosts = make(map[string]struct{})
 		fr.CIDR = cidr.NewTree4()
+		fr.LocalCIDR = cidr.NewTree4()
 	} else {
 		if len(groups) > 0 {
 			fr.Groups = append(fr.Groups, groups)
@@ -725,13 +740,17 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
 		if ip != nil {
 			fr.CIDR.AddCIDR(ip, struct{}{})
 		}
+
+		if localIp != nil {
+			fr.LocalCIDR.AddCIDR(localIp, struct{}{})
+		}
 	}
 
 	return nil
 }
 
-func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
-	if len(groups) == 0 && host == "" && ip == nil {
+func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool {
+	if len(groups) == 0 && host == "" && ip == nil && localIp == nil {
 		return true
 	}
 
@@ -749,6 +768,10 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
 		return true
 	}
 
+	if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) {
+		return true
+	}
+
 	return false
 }
 
@@ -790,20 +813,25 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 		return true
 	}
 
+	if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil {
+		return true
+	}
+
 	// No host, group, or cidr matched, bye bye
 	return false
 }
 
 type rule struct {
-	Port   string
-	Code   string
-	Proto  string
-	Host   string
-	Group  string
-	Groups []string
-	Cidr   string
-	CAName string
-	CASha  string
+	Port      string
+	Code      string
+	Proto     string
+	Host      string
+	Group     string
+	Groups    []string
+	Cidr      string
+	LocalCidr string
+	CAName    string
+	CASha     string
 }
 
 func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
@@ -827,6 +855,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
 	r.Proto = toString("proto", m)
 	r.Host = toString("host", m)
 	r.Cidr = toString("cidr", m)
+	r.LocalCidr = toString("local_cidr", m)
 	r.CAName = toString("ca_name", m)
 	r.CASha = toString("ca_sha", m)
 

+ 102 - 45
firewall_test.go

@@ -69,67 +69,75 @@ func TestFirewall_AddRule(t *testing.T) {
 
 	_, ti, _ := net.ParseCIDR("1.2.3.4/32")
 
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
 	// An empty rule is any
 	assert.True(t, fw.InRules.TCP[1].Any.Any)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
 	assert.False(t, fw.InRules.UDP[1].Any.Any)
 	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
 	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
 	assert.False(t, fw.InRules.ICMP[1].Any.Any)
 	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
 	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
 	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
+	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
+	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
+	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
+	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
+
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
 	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
 
 	// Set any and clear fields
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
 	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
 	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
 	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
+	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
 
 	// run twice just to make sure
 	//TODO: these ANY rules should clear the CA firewall portion
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
 	// Test error conditions
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
-	assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", ""))
+	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))
+	assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))
 }
 
 func TestFirewall_Drop(t *testing.T) {
@@ -169,7 +177,7 @@ func TestFirewall_Drop(t *testing.T) {
 	h.CreateRemoteCIDR(&c)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
 	cp := cert.NewCAPool()
 
 	// Drop outbound
@@ -188,28 +196,28 @@ func TestFirewall_Drop(t *testing.T) {
 
 	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caSha doesn't drop on match
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 
 	// ensure ca name doesn't get in the way of group checks
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caName doesn't drop on match
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 }
 
@@ -219,11 +227,11 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	}
 
 	_, n, _ := net.ParseCIDR("172.1.1.1/32")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "")
 	cp := cert.NewCAPool()
 
 	b.Run("fail on proto", func(b *testing.B) {
@@ -291,7 +299,20 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 		}
 	})
 
-	_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
+	b.Run("pass on local ip", func(b *testing.B) {
+		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
+		c := &cert.NebulaCertificate{
+			Details: cert.NebulaCertificateDetails{
+				InvertedGroups: map[string]struct{}{"nope": {}},
+				Name:           "good-host",
+			},
+		}
+		for n := 0; n < b.N; n++ {
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
+		}
+	})
+
+	_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
 
 	b.Run("pass on ip with any port", func(b *testing.B) {
 		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
@@ -305,6 +326,19 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
 		}
 	})
+
+	b.Run("pass on local ip with any port", func(b *testing.B) {
+		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
+		c := &cert.NebulaCertificate{
+			Details: cert.NebulaCertificateDetails{
+				InvertedGroups: map[string]struct{}{"nope": {}},
+				Name:           "good-host",
+			},
+		}
+		for n := 0; n < b.N; n++ {
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
+		}
+	})
 }
 
 func TestFirewall_Drop2(t *testing.T) {
@@ -356,7 +390,7 @@ func TestFirewall_Drop2(t *testing.T) {
 	h1.CreateRemoteCIDR(&c1)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))
 	cp := cert.NewCAPool()
 
 	// h1/c1 lacks the proper groups
@@ -438,8 +472,8 @@ func TestFirewall_Drop3(t *testing.T) {
 	h3.CreateRemoteCIDR(&c3)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))
 	cp := cert.NewCAPool()
 
 	// c1 should pass because host match
@@ -489,7 +523,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	h.CreateRemoteCIDR(&c)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
 	cp := cert.NewCAPool()
 
 	// Drop outbound
@@ -502,7 +536,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 	oldFw := fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
@@ -511,7 +545,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 	oldFw = fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
@@ -653,7 +687,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
-	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
+	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
 
 	// Test code/port error
 	conf = config.NewC(l)
@@ -677,6 +711,12 @@ func TestNewFirewallFromConfig(t *testing.T) {
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
 
+	// Test local_cidr parse error
+	conf = config.NewC(l)
+	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
+	_, err = NewFirewallFromConfig(l, c, conf)
+	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")
+
 	// Test both group and groups
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
@@ -691,63 +731,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	mf := &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
 
 	// Test adding udp rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
 
 	// Test adding icmp rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
 
 	// Test adding any rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+
+	// Test adding rule with cidr
+	cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)}
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
+	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)
+
+	// Test adding rule with local_cidr
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
+	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)
 
 	// Test adding rule with ca_sha
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)
 
 	// Test adding rule with ca_name
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)
 
 	// Test single group
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
 
 	// Test single groups
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
 
 	// Test multiple AND groups
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)
 
 	// Test Add error
 	conf = config.NewC(l)
@@ -892,6 +947,7 @@ type addRuleCall struct {
 	groups    []string
 	host      string
 	ip        *net.IPNet
+	localIp   *net.IPNet
 	caName    string
 	caSha     string
 }
@@ -901,7 +957,7 @@ type mockFirewall struct {
 	nextCallReturn error
 }
 
-func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
+func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
 	mf.lastCall = addRuleCall{
 		incoming:  incoming,
 		proto:     proto,
@@ -910,6 +966,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
 		groups:    groups,
 		host:      host,
 		ip:        ip,
+		localIp:   localIp,
 		caName:    caName,
 		caSha:     caSha,
 	}

+ 14 - 13
go.mod

@@ -1,6 +1,6 @@
 module github.com/slackhq/nebula
 
-go 1.19
+go 1.20
 
 require (
 	github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
@@ -9,25 +9,25 @@ require (
 	github.com/flynn/noise v1.0.0
 	github.com/gogo/protobuf v1.3.2
 	github.com/google/gopacket v1.1.19
-	github.com/imdario/mergo v0.3.13
+	github.com/imdario/mergo v0.3.15
 	github.com/kardianos/service v1.2.2
-	github.com/miekg/dns v1.1.52
+	github.com/miekg/dns v1.1.54
 	github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
-	github.com/prometheus/client_golang v1.14.0
+	github.com/prometheus/client_golang v1.15.1
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
 	github.com/sirupsen/logrus v1.9.0
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
 	github.com/stretchr/testify v1.8.2
 	github.com/vishvananda/netlink v1.1.0
-	golang.org/x/crypto v0.7.0
-	golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0
-	golang.org/x/net v0.8.0
-	golang.org/x/sys v0.6.0
-	golang.org/x/term v0.6.0
+	golang.org/x/crypto v0.8.0
+	golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53
+	golang.org/x/net v0.9.0
+	golang.org/x/sys v0.8.0
+	golang.org/x/term v0.8.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
 	golang.zx2c4.com/wireguard/windows v0.5.3
-	google.golang.org/protobuf v1.29.0
+	google.golang.org/protobuf v1.30.0
 	gopkg.in/yaml.v2 v2.4.0
 )
 
@@ -38,11 +38,12 @@ require (
 	github.com/golang/protobuf v1.5.3 // indirect
 	github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
-	github.com/prometheus/client_model v0.3.0 // indirect
+	github.com/prometheus/client_model v0.4.0 // indirect
 	github.com/prometheus/common v0.42.0 // indirect
 	github.com/prometheus/procfs v0.9.0 // indirect
+	github.com/rogpeppe/go-internal v1.10.0 // indirect
 	github.com/vishvananda/netns v0.0.4 // indirect
-	golang.org/x/mod v0.9.0 // indirect
-	golang.org/x/tools v0.7.0 // indirect
+	golang.org/x/mod v0.10.0 // indirect
+	golang.org/x/tools v0.8.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )

+ 27 - 27
go.sum

@@ -35,7 +35,6 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69
 github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
 github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
 github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk=
 github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
 github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
 github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
@@ -55,8 +54,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
 github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
 github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
-github.com/imdario/mergo v0.3.13 h1:lFzP57bqS/wsqKssCGmtLAb8A0wKjLGrve2q3PPVcBk=
-github.com/imdario/mergo v0.3.13/go.mod h1:4lJ1jqUDcsbIECGy0RUJAXNIhg+6ocWgb1ALK2O4oXg=
+github.com/imdario/mergo v0.3.15 h1:M8XP7IuFNsqUx6VPK2P9OSmsYsI/YFaGil0uD21V3dM=
+github.com/imdario/mergo v0.3.15/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
 github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
 github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
 github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
@@ -71,16 +70,16 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv
 github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
 github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
-github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
 github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
+github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
 github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
 github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
 github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
 github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
 github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
-github.com/miekg/dns v1.1.52 h1:Bmlc/qsNNULOe6bpXcUTsuOajd0DzRHwup6D9k1An0c=
-github.com/miekg/dns v1.1.52/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
+github.com/miekg/dns v1.1.54 h1:5jon9mWcb0sFJGpnI99tOMhCPyJ+RPVz5b63MQG0VWI=
+github.com/miekg/dns v1.1.54/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
@@ -98,13 +97,13 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
 github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
 github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
 github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
-github.com/prometheus/client_golang v1.14.0 h1:nJdhIvne2eSX/XRAFV9PcvFFRbrjbcTUj0VP62TMhnw=
-github.com/prometheus/client_golang v1.14.0/go.mod h1:8vpkKitgIVNcqrRBWh1C4TIUQgYNtG/XQE4E/Zae36Y=
+github.com/prometheus/client_golang v1.15.1 h1:8tXpTmJbyH5lydzFPoxSIJ0J46jdh3tylbvM1xCv0LI=
+github.com/prometheus/client_golang v1.15.1/go.mod h1:e9yaBhRPU2pPNsZwE+JdQl0KEt1N9XgF6zxWmaC0xOk=
 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
 github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
-github.com/prometheus/client_model v0.3.0 h1:UBgGFHqYdG/TPFD1B1ogZywDqEkwp3fBMvqdiQ7Xew4=
-github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w=
+github.com/prometheus/client_model v0.4.0 h1:5lQXD3cAg1OXBf4Wq03gTrXHeaV0TQvGfUooCfx1yqY=
+github.com/prometheus/client_model v0.4.0/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU=
 github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
 github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
 github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
@@ -118,6 +117,8 @@ github.com/prometheus/procfs v0.9.0 h1:wzCHvIvM5SxWqYvwgVL7yJY8Lz3PKn49KQtpgMYJf
 github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
+github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
+github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
 github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
 github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
 github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
@@ -151,16 +152,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
-golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
-golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
-golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 h1:LGJsf5LRplCck6jUCH3dBL2dmycNruWNF5xugkSlfXw=
-golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
+golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ=
+golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
+golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o=
+golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
 golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
 golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
 golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
 golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs=
-golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
+golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
+golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
 golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -171,8 +172,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
 golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
 golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
-golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
+golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
+golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
 golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -197,11 +198,11 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
 golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
-golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
+golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
-golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
+golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols=
+golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -210,8 +211,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
 golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
 golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
 golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4=
-golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s=
+golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y=
+golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4=
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -229,8 +230,8 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi
 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
-google.golang.org/protobuf v1.29.0 h1:44S3JjaKmLEE4YIkjzexaP+NzZsudE3Zin5Njn/pYX0=
-google.golang.org/protobuf v1.29.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
+google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
+google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -244,6 +245,5 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
 gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

+ 5 - 4
handshake_manager.go

@@ -257,7 +257,8 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 							WithError(err).
 							Error("Failed to marshal Control message to create relay")
 					} else {
-						f.SendMessageToVpnIp(header.Control, 0, *relay, msg, make([]byte, 12), make([]byte, mtu))
+						// This must send over the hostinfo, not over hm.Hosts[ip]
+						f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						c.l.WithFields(logrus.Fields{
 							"relayFrom":           c.lightHouse.myVpnIp,
 							"relayTo":             vpnIp,
@@ -292,7 +293,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 							WithError(err).
 							Error("Failed to marshal Control message to create relay")
 					} else {
-						f.SendMessageToVpnIp(header.Control, 0, *relay, msg, make([]byte, 12), make([]byte, mtu))
+						f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						c.l.WithFields(logrus.Fields{
 							"relayFrom":           c.lightHouse.myVpnIp,
 							"relayTo":             vpnIp,
@@ -354,8 +355,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 		testHostInfo := existingHostInfo
 		for testHostInfo != nil {
 			// Is it just a delayed handshake packet?
-			if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
-				return existingHostInfo, ErrAlreadySeen
+			if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], testHostInfo.HandshakePacket[handshakePacket]) {
+				return testHostInfo, ErrAlreadySeen
 			}
 
 			testHostInfo = testHostInfo.next

+ 5 - 1
handshake_manager_test.go

@@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	assert.False(t, initCalled)
 	assert.Same(t, i, i2)
 
-	i.remotes = NewRemoteList()
+	i.remotes = NewRemoteList(nil)
 	i.HandshakeReady = true
 
 	// Adding something to pending should not affect the main hostmap
@@ -88,4 +88,8 @@ func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte
 	return
 }
 
+func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
+	return
+}
+
 func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {}

+ 34 - 41
hostmap.go

@@ -32,6 +32,7 @@ const RoamingSuppressSeconds = 2
 
 const (
 	Requested = iota
+	PeerRequested
 	Established
 )
 
@@ -79,6 +80,16 @@ func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) {
 	delete(rs.relays, ip)
 }
 
+func (rs *RelayState) CopyAllRelayFor() []*Relay {
+	rs.RLock()
+	defer rs.RUnlock()
+	ret := make([]*Relay, 0, len(rs.relayForByIdx))
+	for _, r := range rs.relayForByIdx {
+		ret = append(ret, r)
+	}
+	return ret
+}
+
 func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) {
 	rs.RLock()
 	defer rs.RUnlock()
@@ -281,29 +292,13 @@ func (hm *HostMap) EmitStats(name string) {
 
 func (hm *HostMap) RemoveRelay(localIdx uint32) {
 	hm.Lock()
-	hiRelay, ok := hm.Relays[localIdx]
+	_, ok := hm.Relays[localIdx]
 	if !ok {
 		hm.Unlock()
 		return
 	}
 	delete(hm.Relays, localIdx)
 	hm.Unlock()
-	ip, ok := hiRelay.relayState.RemoveRelay(localIdx)
-	if !ok {
-		return
-	}
-	hiPeer, err := hm.QueryVpnIp(ip)
-	if err != nil {
-		return
-	}
-	var otherPeerIdx uint32
-	hiPeer.relayState.DeleteRelay(hiRelay.vpnIp)
-	relay, ok := hiPeer.relayState.GetRelayForByIp(hiRelay.vpnIp)
-	if ok {
-		otherPeerIdx = relay.LocalIndex
-	}
-	// I am a relaying host. I need to remove the other relay, too.
-	hm.RemoveRelay(otherPeerIdx)
 }
 
 func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) {
@@ -397,29 +392,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 	hm.unlockedDeleteHostInfo(hostinfo)
 	hm.Unlock()
 
-	// And tear down all the relays going through this host, if final
-	for _, localIdx := range hostinfo.relayState.CopyRelayForIdxs() {
-		hm.RemoveRelay(localIdx)
-	}
-
-	if final {
-		// And tear down the relays this deleted hostInfo was using to be reached
-		teardownRelayIdx := []uint32{}
-		for _, relayIp := range hostinfo.relayState.CopyRelayIps() {
-			relayHostInfo, err := hm.QueryVpnIp(relayIp)
-			if err != nil {
-				hm.l.WithError(err).WithField("relay", relayIp).Info("Missing relay host in hostmap")
-			} else {
-				if r, ok := relayHostInfo.relayState.QueryRelayForByIp(hostinfo.vpnIp); ok {
-					teardownRelayIdx = append(teardownRelayIdx, r.LocalIndex)
-				}
-			}
-		}
-		for _, localIdx := range teardownRelayIdx {
-			hm.RemoveRelay(localIdx)
-		}
-	}
-
 	return final
 }
 
@@ -510,6 +482,10 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Hostmap hostInfo deleted")
 	}
+
+	for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() {
+		delete(hm.Relays, localRelayIdx)
+	}
 }
 
 func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) {
@@ -564,6 +540,24 @@ func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) {
 	return hm.queryVpnIp(vpnIp, nil)
 }
 
+func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) {
+	hm.RLock()
+	defer hm.RUnlock()
+
+	h, ok := hm.Hosts[relayHostIp]
+	if !ok {
+		return nil, nil, errors.New("unable to find host")
+	}
+	for h != nil {
+		r, ok := h.relayState.QueryRelayForByIp(targetIp)
+		if ok && r.State == Established {
+			return h, r, nil
+		}
+		h = h.next
+	}
+	return nil, nil, errors.New("unable to find host with relay")
+}
+
 // PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
 // `PromoteEvery` calls to this function for a given host.
 func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) {
@@ -711,7 +705,6 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) {
 	i.packetStore = make([]*cachedPacket, 0)
 	i.ConnectionState.ready = true
 	i.ConnectionState.queueLock.Unlock()
-	i.ConnectionState.certState = nil
 }
 
 func (i *HostInfo) GetCert() *cert.NebulaCertificate {

+ 12 - 24
inside.go

@@ -57,7 +57,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 	}
 	ci := hostinfo.ConnectionState
 
-	if ci.ready == false {
+	if !ci.ready {
 		// Because we might be sending stored packets, lock here to stop new things going to
 		// the packet queue.
 		ci.queueLock.Lock()
@@ -177,7 +177,7 @@ func (f *Interface) initHostInfo(hostinfo *HostInfo) {
 	hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
 }
 
-func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
+func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
 	fp := &firewall.Packet{}
 	err := newPacket(p, false, fp)
 	if err != nil {
@@ -186,7 +186,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 	}
 
 	// check if packet is in outbound fw rules
-	dropReason := f.firewall.Drop(p, *fp, false, hostInfo, f.caPool, nil)
+	dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.caPool, nil)
 	if dropReason != nil {
 		if f.l.Level >= logrus.DebugLevel {
 			f.l.WithField("fwPacket", fp).
@@ -196,7 +196,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 		return
 	}
 
-	f.sendNoMetrics(header.Message, st, hostInfo.ConnectionState, hostInfo, nil, p, nb, out, 0, nil)
+	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0, nil)
 }
 
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
@@ -215,19 +215,18 @@ func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSu
 		// the packet queue.
 		hostInfo.ConnectionState.queueLock.Lock()
 		if !hostInfo.ConnectionState.ready {
-			hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToVpnIp, f.cachedPacketMetrics)
+			hostInfo.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 			hostInfo.ConnectionState.queueLock.Unlock()
 			return
 		}
 		hostInfo.ConnectionState.queueLock.Unlock()
 	}
 
-	f.sendMessageToVpnIp(t, st, hostInfo, p, nb, out)
-	return
+	f.SendMessageToHostInfo(t, st, hostInfo, p, nb, out)
 }
 
-func (f *Interface) sendMessageToVpnIp(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
-	f.send(t, st, hostInfo.ConnectionState, hostInfo, p, nb, out)
+func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hi *HostInfo, p, nb, out []byte) {
+	f.send(t, st, hi.ConnectionState, hi, p, nb, out)
 }
 
 func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
@@ -302,6 +301,7 @@ func (f *Interface) SendVia(via *HostInfo,
 	if err != nil {
 		via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia")
 	}
+	f.connectionManager.RelayUsed(relay.LocalIndex)
 }
 
 func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int, udpPortGetter udp.SendPortGetter) {
@@ -401,31 +401,19 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	} else {
 		// Try to send via a relay
 		for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
-			relayHostInfo, err := f.hostMap.QueryVpnIp(relayIP)
+			relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP)
 			if err != nil {
+				hostinfo.relayState.DeleteRelay(relayIP)
 				hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
 				continue
 			}
-			relay, ok := relayHostInfo.relayState.QueryRelayForByIp(hostinfo.vpnIp)
-			if !ok {
-				hostinfo.logger(f.l).
-					WithField("relay", relayHostInfo.vpnIp).
-					WithField("relayTo", hostinfo.vpnIp).
-					Info("sendNoMetrics relay missing object for target")
-				continue
-			}
 			f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
 			break
 		}
 	}
-	return
 }
 
 func isMulticast(ip iputil.VpnIp) bool {
 	// Class D multicast
-	if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 {
-		return true
-	}
-
-	return false
+	return (((ip >> 24) & 0xff) & 0xf0) == 0xe0
 }

+ 2 - 0
interface.go

@@ -111,6 +111,7 @@ type EncWriter interface {
 		nocopy bool,
 	)
 	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
+	SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
 	Handshake(vpnIp iputil.VpnIp)
 }
 
@@ -216,6 +217,7 @@ func (f *Interface) activate() {
 
 	f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()).
 		WithField("build", f.version).WithField("udpAddr", addr).
+		WithField("boringcrypto", boringEnabled()).
 		Info("Nebula interface is active")
 
 	metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))

+ 162 - 33
lighthouse.go

@@ -6,6 +6,7 @@ import (
 	"errors"
 	"fmt"
 	"net"
+	"net/netip"
 	"sync"
 	"sync/atomic"
 	"time"
@@ -33,6 +34,7 @@ type netIpAndPort struct {
 type LightHouse struct {
 	//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
 	sync.RWMutex //Because we concurrently read and write to our maps
+	ctx          context.Context
 	amLighthouse bool
 	myVpnIp      iputil.VpnIp
 	myVpnZeros   iputil.VpnIp
@@ -82,7 +84,7 @@ type LightHouse struct {
 
 // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
 // addrMap should be nil unless this is during a config reload
-func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
+func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
 	amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
 	nebulaPort := uint32(c.GetInt("listen.port", 0))
 	if amLighthouse && nebulaPort == 0 {
@@ -100,6 +102,7 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet,
 
 	ones, _ := myVpnNet.Mask.Size()
 	h := LightHouse{
+		ctx:          ctx,
 		amLighthouse: amLighthouse,
 		myVpnIp:      iputil.Ip2VpnIp(myVpnNet.IP),
 		myVpnZeros:   iputil.VpnIp(32 - ones),
@@ -258,7 +261,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	}
 
 	//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
-	if initial || c.HasChanged("static_host_map") {
+	if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
 		staticList := make(map[iputil.VpnIp]struct{})
 		err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
 		if err != nil {
@@ -268,9 +271,19 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 		lh.staticList.Store(&staticList)
 		if !initial {
 			//TODO: we should remove any remote list entries for static hosts that were removed/modified?
-			lh.l.Info("static_host_map has changed")
+			if c.HasChanged("static_host_map") {
+				lh.l.Info("static_host_map has changed")
+			}
+			if c.HasChanged("static_map.cadence") {
+				lh.l.Info("static_map.cadence has changed")
+			}
+			if c.HasChanged("static_map.network") {
+				lh.l.Info("static_map.network has changed")
+			}
+			if c.HasChanged("static_map.lookup_timeout") {
+				lh.l.Info("static_map.lookup_timeout has changed")
+			}
 		}
-
 	}
 
 	if initial || c.HasChanged("lighthouse.hosts") {
@@ -344,7 +357,48 @@ func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap ma
 	return nil
 }
 
+func getStaticMapCadence(c *config.C) (time.Duration, error) {
+	cadence := c.GetString("static_map.cadence", "30s")
+	d, err := time.ParseDuration(cadence)
+	if err != nil {
+		return 0, err
+	}
+	return d, nil
+}
+
+func getStaticMapLookupTimeout(c *config.C) (time.Duration, error) {
+	lookupTimeout := c.GetString("static_map.lookup_timeout", "250ms")
+	d, err := time.ParseDuration(lookupTimeout)
+	if err != nil {
+		return 0, err
+	}
+	return d, nil
+}
+
+func getStaticMapNetwork(c *config.C) (string, error) {
+	network := c.GetString("static_map.network", "ip4")
+	if network != "ip" && network != "ip4" && network != "ip6" {
+		return "", fmt.Errorf("static_map.network must be one of ip, ip4, or ip6")
+	}
+	return network, nil
+}
+
 func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error {
+	d, err := getStaticMapCadence(c)
+	if err != nil {
+		return err
+	}
+
+	network, err := getStaticMapNetwork(c)
+	if err != nil {
+		return err
+	}
+
+	lookup_timeout, err := getStaticMapLookupTimeout(c)
+	if err != nil {
+		return err
+	}
+
 	shm := c.GetMap("static_host_map", map[interface{}]interface{}{})
 	i := 0
 
@@ -360,21 +414,17 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 
 		vpnIp := iputil.Ip2VpnIp(rip)
 		vals, ok := v.([]interface{})
-		if ok {
-			for _, v := range vals {
-				ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
-				if err != nil {
-					return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
-				}
-				lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList)
-			}
+		if !ok {
+			vals = []interface{}{v}
+		}
+		remoteAddrs := []string{}
+		for _, v := range vals {
+			remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v))
+		}
 
-		} else {
-			ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
-			if err != nil {
-				return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
-			}
-			lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList)
+		err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList)
+		if err != nil {
+			return err
 		}
 		i++
 	}
@@ -482,30 +532,47 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
 // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
 // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
 // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
-func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) {
+func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error {
 	lh.Lock()
 	am := lh.unlockedGetRemoteList(vpnIp)
 	am.Lock()
 	defer am.Unlock()
+	ctx := lh.ctx
 	lh.Unlock()
 
-	if ipv4 := toAddr.IP.To4(); ipv4 != nil {
-		to := NewIp4AndPort(ipv4, uint32(toAddr.Port))
-		if !lh.unlockedShouldAddV4(vpnIp, to) {
-			return
-		}
-		am.unlockedPrependV4(lh.myVpnIp, to)
+	hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() {
+		// This callback runs whenever the DNS hostname resolver finds a different set of IP's
+		// in its resolution for hostnames.
+		am.Lock()
+		defer am.Unlock()
+		am.shouldRebuild = true
+	})
+	if err != nil {
+		return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
+	}
+	am.unlockedSetHostnamesResults(hr)
 
-	} else {
-		to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))
-		if !lh.unlockedShouldAddV6(vpnIp, to) {
-			return
+	for _, addrPort := range hr.GetIPs() {
+
+		switch {
+		case addrPort.Addr().Is4():
+			to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
+			if !lh.unlockedShouldAddV4(vpnIp, to) {
+				continue
+			}
+			am.unlockedPrependV4(lh.myVpnIp, to)
+		case addrPort.Addr().Is6():
+			to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
+			if !lh.unlockedShouldAddV6(vpnIp, to) {
+				continue
+			}
+			am.unlockedPrependV6(lh.myVpnIp, to)
 		}
-		am.unlockedPrependV6(lh.myVpnIp, to)
 	}
 
 	// Mark it as static in the caller provided map
 	staticList[vpnIp] = struct{}{}
+	return nil
 }
 
 // addCalculatedRemotes adds any calculated remotes based on the
@@ -545,12 +612,42 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
 func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
 	am, ok := lh.addrMap[vpnIp]
 	if !ok {
-		am = NewRemoteList()
+		am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) })
 		lh.addrMap[vpnIp] = am
 	}
 	return am
 }
 
+func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool {
+	switch {
+	case to.Is4():
+		ipBytes := to.As4()
+		ip := iputil.Ip2VpnIp(ipBytes[:])
+		allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip)
+		if lh.l.Level >= logrus.TraceLevel {
+			lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
+		}
+		if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) {
+			return false
+		}
+	case to.Is6():
+		ipBytes := to.As16()
+
+		hi := binary.BigEndian.Uint64(ipBytes[:8])
+		lo := binary.BigEndian.Uint64(ipBytes[8:])
+		allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo)
+		if lh.l.Level >= logrus.TraceLevel {
+			lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow")
+		}
+
+		// We don't check our vpn network here because nebula does not support ipv6 on the inside
+		if !allow {
+			return false
+		}
+	}
+	return true
+}
+
 // unlockedShouldAddV4 checks if to is allowed by our allow list
 func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
 	allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
@@ -609,6 +706,14 @@ func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
 	return &ipp
 }
 
+func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
+	v4Addr := ip.As4()
+	return &Ip4AndPort{
+		Ip:   binary.BigEndian.Uint32(v4Addr[:]),
+		Port: uint32(port),
+	}
+}
+
 func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
 	return &Ip6AndPort{
 		Hi:   binary.BigEndian.Uint64(ip[:8]),
@@ -617,6 +722,14 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
 	}
 }
 
+func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
+	ip6Addr := ip.As16()
+	return &Ip6AndPort{
+		Hi:   binary.BigEndian.Uint64(ip6Addr[:8]),
+		Lo:   binary.BigEndian.Uint64(ip6Addr[8:]),
+		Port: uint32(port),
+	}
+}
 func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
 	ip := ipp.Ip
 	return udp.NewAddr(
@@ -793,11 +906,14 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp,
 		lhh.handleHostQueryReply(n, vpnIp)
 
 	case NebulaMeta_HostUpdateNotification:
-		lhh.handleHostUpdateNotification(n, vpnIp)
+		lhh.handleHostUpdateNotification(n, vpnIp, w)
 
 	case NebulaMeta_HostMovedNotification:
 	case NebulaMeta_HostPunchNotification:
 		lhh.handleHostPunchNotification(n, vpnIp, w)
+
+	case NebulaMeta_HostUpdateNotificationAck:
+		// noop
 	}
 }
 
@@ -906,7 +1022,7 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.V
 	}
 }
 
-func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp) {
+func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
 	if !lhh.lh.amLighthouse {
 		if lhh.l.Level >= logrus.DebugLevel {
 			lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
@@ -932,6 +1048,19 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
 	am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp)
 	am.Unlock()
+
+	n = lhh.resetMeta()
+	n.Type = NebulaMeta_HostUpdateNotificationAck
+	n.Details.VpnIp = uint32(vpnIp)
+	ln, err := n.MarshalTo(lhh.pb)
+
+	if err != nil {
+		lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host update ack")
+		return
+	}
+
+	lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1)
+	w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 }
 
 func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {

+ 25 - 7
lighthouse_test.go

@@ -1,6 +1,7 @@
 package nebula
 
 import (
+	"context"
 	"fmt"
 	"net"
 	"testing"
@@ -53,14 +54,14 @@ func Test_lhStaticMapping(t *testing.T) {
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
-	_, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
+	_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
 	assert.Nil(t, err)
 
 	lh2 := "10.128.0.3"
 	c = config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
-	_, err = NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
+	_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
 	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 
@@ -69,14 +70,14 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
 
 	c := config.NewC(l)
-	lh, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
 	if !assert.NoError(b, err) {
 		b.Fatal()
 	}
 
 	hAddr := udp.NewAddrFromString("4.5.6.7:12345")
 	hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
-	lh.addrMap[3] = NewRemoteList()
+	lh.addrMap[3] = NewRemoteList(nil)
 	lh.addrMap[3].unlockedSetV4(
 		3,
 		3,
@@ -89,7 +90,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 
 	rAddr := udp.NewAddrFromString("1.2.2.3:12345")
 	rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
-	lh.addrMap[2] = NewRemoteList()
+	lh.addrMap[2] = NewRemoteList(nil)
 	lh.addrMap[2].unlockedSetV4(
 		3,
 		3,
@@ -162,7 +163,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
 	assert.NoError(t, err)
 	lhh := lh.NewRequestHandler()
 
@@ -238,7 +239,7 @@ func TestLighthouse_reload(t *testing.T) {
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
 	assert.NoError(t, err)
 
 	c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}}
@@ -377,6 +378,23 @@ func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte
 func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
 }
 
+func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) {
+	msg := &NebulaMeta{}
+	err := msg.Unmarshal(p)
+	if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
+		tw.lastReply = testLhReply{
+			nebType:    t,
+			nebSubType: st,
+			vpnIp:      hostinfo.vpnIp,
+			msg:        msg,
+		}
+	}
+
+	if err != nil {
+		panic(err)
+	}
+}
+
 func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
 	msg := &NebulaMeta{}
 	err := msg.Unmarshal(p)

+ 1 - 1
main.go

@@ -226,7 +226,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	*/
 
 	punchy := NewPunchyFromConfig(l, c)
-	lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy)
+	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
 	switch {
 	case errors.As(err, &util.ContextualError{}):
 		return nil, err

+ 1 - 0
message_metrics.go

@@ -84,6 +84,7 @@ func newLighthouseMetrics() *MessageMetrics {
 			NebulaMeta_HostQueryReply,
 			NebulaMeta_HostUpdateNotification,
 			NebulaMeta_HostPunchNotification,
+			NebulaMeta_HostUpdateNotificationAck,
 		}
 		for _, i := range used {
 			h[i] = []metrics.Counter{metrics.GetOrRegisterCounter(fmt.Sprintf("lighthouse.%s.%s", t, i.String()), nil)}

+ 83 - 80
nebula.pb.go

@@ -25,42 +25,45 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
 type NebulaMeta_MessageType int32
 
 const (
-	NebulaMeta_None                   NebulaMeta_MessageType = 0
-	NebulaMeta_HostQuery              NebulaMeta_MessageType = 1
-	NebulaMeta_HostQueryReply         NebulaMeta_MessageType = 2
-	NebulaMeta_HostUpdateNotification NebulaMeta_MessageType = 3
-	NebulaMeta_HostMovedNotification  NebulaMeta_MessageType = 4
-	NebulaMeta_HostPunchNotification  NebulaMeta_MessageType = 5
-	NebulaMeta_HostWhoami             NebulaMeta_MessageType = 6
-	NebulaMeta_HostWhoamiReply        NebulaMeta_MessageType = 7
-	NebulaMeta_PathCheck              NebulaMeta_MessageType = 8
-	NebulaMeta_PathCheckReply         NebulaMeta_MessageType = 9
+	NebulaMeta_None                      NebulaMeta_MessageType = 0
+	NebulaMeta_HostQuery                 NebulaMeta_MessageType = 1
+	NebulaMeta_HostQueryReply            NebulaMeta_MessageType = 2
+	NebulaMeta_HostUpdateNotification    NebulaMeta_MessageType = 3
+	NebulaMeta_HostMovedNotification     NebulaMeta_MessageType = 4
+	NebulaMeta_HostPunchNotification     NebulaMeta_MessageType = 5
+	NebulaMeta_HostWhoami                NebulaMeta_MessageType = 6
+	NebulaMeta_HostWhoamiReply           NebulaMeta_MessageType = 7
+	NebulaMeta_PathCheck                 NebulaMeta_MessageType = 8
+	NebulaMeta_PathCheckReply            NebulaMeta_MessageType = 9
+	NebulaMeta_HostUpdateNotificationAck NebulaMeta_MessageType = 10
 )
 
 var NebulaMeta_MessageType_name = map[int32]string{
-	0: "None",
-	1: "HostQuery",
-	2: "HostQueryReply",
-	3: "HostUpdateNotification",
-	4: "HostMovedNotification",
-	5: "HostPunchNotification",
-	6: "HostWhoami",
-	7: "HostWhoamiReply",
-	8: "PathCheck",
-	9: "PathCheckReply",
+	0:  "None",
+	1:  "HostQuery",
+	2:  "HostQueryReply",
+	3:  "HostUpdateNotification",
+	4:  "HostMovedNotification",
+	5:  "HostPunchNotification",
+	6:  "HostWhoami",
+	7:  "HostWhoamiReply",
+	8:  "PathCheck",
+	9:  "PathCheckReply",
+	10: "HostUpdateNotificationAck",
 }
 
 var NebulaMeta_MessageType_value = map[string]int32{
-	"None":                   0,
-	"HostQuery":              1,
-	"HostQueryReply":         2,
-	"HostUpdateNotification": 3,
-	"HostMovedNotification":  4,
-	"HostPunchNotification":  5,
-	"HostWhoami":             6,
-	"HostWhoamiReply":        7,
-	"PathCheck":              8,
-	"PathCheckReply":         9,
+	"None":                      0,
+	"HostQuery":                 1,
+	"HostQueryReply":            2,
+	"HostUpdateNotification":    3,
+	"HostMovedNotification":     4,
+	"HostPunchNotification":     5,
+	"HostWhoami":                6,
+	"HostWhoamiReply":           7,
+	"PathCheck":                 8,
+	"PathCheckReply":            9,
+	"HostUpdateNotificationAck": 10,
 }
 
 func (x NebulaMeta_MessageType) String() string {
@@ -722,56 +725,56 @@ func init() {
 func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) }
 
 var fileDescriptor_2d65afa7693df5ef = []byte{
-	// 775 bytes of a gzipped FileDescriptorProto
-	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x6e, 0xd3, 0x4a,
-	0x14, 0x8e, 0x1d, 0xe7, 0xef, 0xa4, 0x49, 0x7d, 0xa7, 0xf7, 0xe6, 0xa6, 0xd5, 0x95, 0x95, 0xeb,
-	0xc5, 0x55, 0x56, 0x69, 0x95, 0xf6, 0x56, 0x2c, 0xa1, 0x41, 0x28, 0x91, 0x9a, 0x2a, 0x0c, 0x01,
-	0x24, 0x36, 0x68, 0x9a, 0x0c, 0x8d, 0x15, 0xc7, 0xe3, 0xda, 0x63, 0xd4, 0xbc, 0x05, 0xe2, 0x59,
-	0x58, 0xf2, 0x00, 0x2c, 0x90, 0xe8, 0x82, 0x05, 0x4b, 0xd4, 0xbe, 0x08, 0x9a, 0xf1, 0x6f, 0x7e,
-	0x80, 0xdd, 0x9c, 0x73, 0xbe, 0xef, 0xcc, 0x37, 0xe7, 0x7c, 0x71, 0x60, 0xc7, 0xa1, 0x97, 0x81,
-	0x4d, 0x3a, 0xae, 0xc7, 0x38, 0x43, 0xc5, 0x30, 0x32, 0x3f, 0xab, 0x00, 0x17, 0xf2, 0x38, 0xa4,
-	0x9c, 0xa0, 0x2e, 0x68, 0xe3, 0xa5, 0x4b, 0x9b, 0x4a, 0x4b, 0x69, 0xd7, 0xbb, 0x46, 0x27, 0xe2,
-	0xa4, 0x88, 0xce, 0x90, 0xfa, 0x3e, 0xb9, 0xa2, 0x02, 0x85, 0x25, 0x16, 0x1d, 0x43, 0xe9, 0x31,
-	0xe5, 0xc4, 0xb2, 0xfd, 0xa6, 0xda, 0x52, 0xda, 0xd5, 0xee, 0xfe, 0x26, 0x2d, 0x02, 0xe0, 0x18,
-	0x69, 0x7e, 0x55, 0xa0, 0x9a, 0x69, 0x85, 0xca, 0xa0, 0x5d, 0x30, 0x87, 0xea, 0x39, 0x54, 0x83,
-	0x4a, 0x9f, 0xf9, 0xfc, 0x69, 0x40, 0xbd, 0xa5, 0xae, 0x20, 0x04, 0xf5, 0x24, 0xc4, 0xd4, 0xb5,
-	0x97, 0xba, 0x8a, 0x0e, 0xa0, 0x21, 0x72, 0xcf, 0xdd, 0x29, 0xe1, 0xf4, 0x82, 0x71, 0xeb, 0x8d,
-	0x35, 0x21, 0xdc, 0x62, 0x8e, 0x9e, 0x47, 0xfb, 0xf0, 0x97, 0xa8, 0x0d, 0xd9, 0x5b, 0x3a, 0x5d,
-	0x29, 0x69, 0x71, 0x69, 0x14, 0x38, 0x93, 0xd9, 0x4a, 0xa9, 0x80, 0xea, 0x00, 0xa2, 0xf4, 0x72,
-	0xc6, 0xc8, 0xc2, 0xd2, 0x8b, 0x68, 0x0f, 0x76, 0xd3, 0x38, 0xbc, 0xb6, 0x24, 0x94, 0x8d, 0x08,
-	0x9f, 0xf5, 0x66, 0x74, 0x32, 0xd7, 0xcb, 0x42, 0x59, 0x12, 0x86, 0x90, 0x8a, 0xf9, 0x45, 0x81,
-	0x3f, 0x36, 0x5e, 0x8d, 0xfe, 0x84, 0xc2, 0x0b, 0xd7, 0x19, 0xb8, 0x72, 0xac, 0x35, 0x1c, 0x06,
-	0xe8, 0x04, 0xaa, 0x03, 0xf7, 0xe4, 0x91, 0x33, 0x1d, 0x31, 0x8f, 0x8b, 0xd9, 0xe5, 0xdb, 0xd5,
-	0x2e, 0x8a, 0x67, 0x97, 0x96, 0x70, 0x16, 0x16, 0xb2, 0x4e, 0x13, 0x96, 0xb6, 0xce, 0x3a, 0xcd,
-	0xb0, 0x12, 0x18, 0x32, 0x00, 0x30, 0xb5, 0xc9, 0x32, 0x94, 0x51, 0x68, 0xe5, 0xdb, 0x35, 0x9c,
-	0xc9, 0xa0, 0x26, 0x94, 0x26, 0x2c, 0x70, 0x38, 0xf5, 0x9a, 0x79, 0xa9, 0x31, 0x0e, 0xcd, 0x23,
-	0x80, 0xf4, 0x7a, 0x54, 0x07, 0x35, 0x79, 0x86, 0x3a, 0x70, 0x11, 0x02, 0x4d, 0xe4, 0xe5, 0xe2,
-	0x6b, 0x58, 0x9e, 0xcd, 0x87, 0x82, 0x71, 0x9a, 0x61, 0xf4, 0x2d, 0xc9, 0xd0, 0xb0, 0xda, 0xb7,
-	0x44, 0x7c, 0xce, 0x24, 0x5e, 0xc3, 0xea, 0x39, 0x4b, 0x3a, 0xe4, 0x33, 0x1d, 0x6e, 0x62, 0x4f,
-	0x8e, 0x2c, 0xe7, 0xea, 0xd7, 0x9e, 0x14, 0x88, 0x2d, 0x9e, 0x44, 0xa0, 0x8d, 0xad, 0x05, 0x8d,
-	0xee, 0x91, 0x67, 0xd3, 0xdc, 0x70, 0x9c, 0x20, 0xeb, 0x39, 0x54, 0x81, 0x42, 0xb8, 0x3f, 0xc5,
-	0x7c, 0x0d, 0xbb, 0x61, 0xdf, 0x3e, 0x71, 0xa6, 0xfe, 0x8c, 0xcc, 0x29, 0x7a, 0x90, 0xda, 0x5b,
-	0x91, 0xf6, 0x5e, 0x53, 0x90, 0x20, 0xd7, 0x3d, 0x2e, 0x44, 0xf4, 0x17, 0x64, 0x22, 0x45, 0xec,
-	0x60, 0x79, 0x36, 0xdf, 0x2b, 0xa0, 0x0f, 0x03, 0x9b, 0x5b, 0xe2, 0xa1, 0x31, 0xb0, 0x05, 0x55,
-	0x7c, 0xf3, 0x2c, 0x70, 0x5d, 0xe6, 0x71, 0x3a, 0x95, 0xd7, 0x94, 0x71, 0x36, 0x25, 0x10, 0xe3,
-	0x0c, 0x42, 0x0d, 0x11, 0x99, 0x14, 0x3a, 0x80, 0xf2, 0x19, 0xf1, 0x69, 0x66, 0x96, 0x49, 0x2c,
-	0xb6, 0x3f, 0x66, 0x9c, 0xd8, 0xb1, 0x65, 0x44, 0x35, 0x93, 0x31, 0x3f, 0xaa, 0xd0, 0xd8, 0xfe,
-	0x18, 0xf1, 0x86, 0x1e, 0xf5, 0xb8, 0xd4, 0xb4, 0x83, 0xe5, 0x19, 0xfd, 0x07, 0xf5, 0x81, 0x63,
-	0x71, 0x8b, 0x70, 0xe6, 0x0d, 0x9c, 0x29, 0xbd, 0x89, 0xd6, 0xbf, 0x96, 0x15, 0x38, 0x4c, 0x7d,
-	0x97, 0x39, 0x53, 0x1a, 0xe1, 0x42, 0x61, 0x6b, 0x59, 0xd4, 0x80, 0x62, 0x8f, 0xb1, 0xb9, 0x45,
-	0xa5, 0x34, 0x0d, 0x47, 0x51, 0xb2, 0xc4, 0x42, 0xba, 0x44, 0xd4, 0x07, 0x94, 0xdc, 0x92, 0xcc,
-	0xb1, 0x59, 0x94, 0x8b, 0x69, 0xc6, 0x8b, 0x59, 0x1f, 0x30, 0xde, 0xc2, 0x11, 0x9d, 0x12, 0x1d,
-	0x69, 0xa7, 0xd2, 0xef, 0x3a, 0x6d, 0x72, 0xcc, 0x0f, 0x2a, 0xd4, 0xc2, 0xf1, 0xf5, 0x98, 0xc3,
-	0x3d, 0x66, 0xa3, 0xff, 0x57, 0x2c, 0xfb, 0xef, 0xaa, 0x61, 0x22, 0xd0, 0x16, 0xd7, 0x1e, 0xc1,
-	0x5e, 0x22, 0x54, 0xfe, 0x38, 0xb3, 0xd3, 0xdd, 0x56, 0x12, 0x8c, 0x44, 0x50, 0x86, 0x11, 0xce,
-	0x79, 0x5b, 0x09, 0xfd, 0x03, 0x15, 0x19, 0x8d, 0xd9, 0xc0, 0x8d, 0xac, 0x90, 0x26, 0xa4, 0x13,
-	0x45, 0xf0, 0xc4, 0x63, 0x0b, 0xf9, 0xa1, 0x10, 0xf5, 0x6c, 0xca, 0xec, 0xff, 0xec, 0xbb, 0xdd,
-	0x00, 0xd4, 0xf3, 0x28, 0xe1, 0x54, 0xa2, 0x31, 0xbd, 0x0e, 0xa8, 0xcf, 0x75, 0x05, 0xfd, 0x0d,
-	0x7b, 0x2b, 0x79, 0x21, 0xc9, 0xa7, 0xba, 0x7a, 0x76, 0xfc, 0xe9, 0xce, 0x50, 0x6e, 0xef, 0x0c,
-	0xe5, 0xfb, 0x9d, 0xa1, 0xbc, 0xbb, 0x37, 0x72, 0xb7, 0xf7, 0x46, 0xee, 0xdb, 0xbd, 0x91, 0x7b,
-	0xb5, 0x7f, 0x65, 0xf1, 0x59, 0x70, 0xd9, 0x99, 0xb0, 0xc5, 0xa1, 0x6f, 0x93, 0xc9, 0x7c, 0x76,
-	0x7d, 0x18, 0x8e, 0xf0, 0xb2, 0x28, 0xff, 0xbe, 0x8e, 0x7f, 0x04, 0x00, 0x00, 0xff, 0xff, 0xdc,
-	0x87, 0xe2, 0x33, 0xce, 0x06, 0x00, 0x00,
+	// 784 bytes of a gzipped FileDescriptorProto
+	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x8e, 0xe3, 0x44,
+	0x10, 0x8e, 0x1d, 0xe7, 0xaf, 0x32, 0xc9, 0x9a, 0x1a, 0x08, 0xc9, 0x0a, 0xac, 0xe0, 0x03, 0xca,
+	0x29, 0xbb, 0xca, 0x2c, 0x23, 0x8e, 0xec, 0x06, 0xa1, 0x44, 0xda, 0x8c, 0x42, 0x13, 0x40, 0xe2,
+	0x82, 0x7a, 0x9c, 0x66, 0x62, 0xc5, 0x71, 0x7b, 0xed, 0x36, 0x9a, 0xbc, 0x05, 0xe2, 0x59, 0x38,
+	0xf2, 0x00, 0xdc, 0xd8, 0x23, 0x47, 0x34, 0x73, 0xe4, 0xc8, 0x0b, 0xa0, 0x6e, 0xff, 0xe6, 0x07,
+	0xb8, 0x75, 0x55, 0x7d, 0x5f, 0xf5, 0xd7, 0x55, 0x5f, 0x1c, 0xb8, 0xf0, 0xd9, 0x6d, 0xec, 0xd1,
+	0x71, 0x10, 0x72, 0xc1, 0xb1, 0x9e, 0x44, 0xf6, 0x5f, 0x3a, 0xc0, 0x8d, 0x3a, 0x2e, 0x98, 0xa0,
+	0x38, 0x01, 0x63, 0xb5, 0x0f, 0x58, 0x5f, 0x1b, 0x6a, 0xa3, 0xee, 0xc4, 0x1a, 0xa7, 0x9c, 0x02,
+	0x31, 0x5e, 0xb0, 0x28, 0xa2, 0x77, 0x4c, 0xa2, 0x88, 0xc2, 0xe2, 0x15, 0x34, 0x3e, 0x67, 0x82,
+	0xba, 0x5e, 0xd4, 0xd7, 0x87, 0xda, 0xa8, 0x3d, 0x19, 0x9c, 0xd2, 0x52, 0x00, 0xc9, 0x90, 0xf6,
+	0xdf, 0x1a, 0xb4, 0x4b, 0xad, 0xb0, 0x09, 0xc6, 0x0d, 0xf7, 0x99, 0x59, 0xc1, 0x0e, 0xb4, 0x66,
+	0x3c, 0x12, 0x5f, 0xc6, 0x2c, 0xdc, 0x9b, 0x1a, 0x22, 0x74, 0xf3, 0x90, 0xb0, 0xc0, 0xdb, 0x9b,
+	0x3a, 0x3e, 0x85, 0x9e, 0xcc, 0x7d, 0x1d, 0xac, 0xa9, 0x60, 0x37, 0x5c, 0xb8, 0x3f, 0xb8, 0x0e,
+	0x15, 0x2e, 0xf7, 0xcd, 0x2a, 0x0e, 0xe0, 0x3d, 0x59, 0x5b, 0xf0, 0x1f, 0xd9, 0xfa, 0xa0, 0x64,
+	0x64, 0xa5, 0x65, 0xec, 0x3b, 0x9b, 0x83, 0x52, 0x0d, 0xbb, 0x00, 0xb2, 0xf4, 0xed, 0x86, 0xd3,
+	0x9d, 0x6b, 0xd6, 0xf1, 0x12, 0x9e, 0x14, 0x71, 0x72, 0x6d, 0x43, 0x2a, 0x5b, 0x52, 0xb1, 0x99,
+	0x6e, 0x98, 0xb3, 0x35, 0x9b, 0x52, 0x59, 0x1e, 0x26, 0x90, 0x16, 0x7e, 0x08, 0x83, 0xf3, 0xca,
+	0x5e, 0x3a, 0x5b, 0x13, 0xec, 0xdf, 0x35, 0x78, 0xe7, 0x64, 0x28, 0xf8, 0x2e, 0xd4, 0xbe, 0x09,
+	0xfc, 0x79, 0xa0, 0xa6, 0xde, 0x21, 0x49, 0x80, 0x2f, 0xa0, 0x3d, 0x0f, 0x5e, 0xbc, 0xf4, 0xd7,
+	0x4b, 0x1e, 0x0a, 0x39, 0xda, 0xea, 0xa8, 0x3d, 0xc1, 0x6c, 0xb4, 0x45, 0x89, 0x94, 0x61, 0x09,
+	0xeb, 0x3a, 0x67, 0x19, 0xc7, 0xac, 0xeb, 0x12, 0x2b, 0x87, 0xa1, 0x05, 0x40, 0x98, 0x47, 0xf7,
+	0x89, 0x8c, 0xda, 0xb0, 0x3a, 0xea, 0x90, 0x52, 0x06, 0xfb, 0xd0, 0x70, 0x78, 0xec, 0x0b, 0x16,
+	0xf6, 0xab, 0x4a, 0x63, 0x16, 0xda, 0xcf, 0x01, 0x8a, 0xeb, 0xb1, 0x0b, 0x7a, 0xfe, 0x0c, 0x7d,
+	0x1e, 0x20, 0x82, 0x21, 0xf3, 0xca, 0x17, 0x1d, 0xa2, 0xce, 0xf6, 0x67, 0x92, 0x71, 0x5d, 0x62,
+	0xcc, 0x5c, 0xc5, 0x30, 0x88, 0x3e, 0x73, 0x65, 0xfc, 0x9a, 0x2b, 0xbc, 0x41, 0xf4, 0xd7, 0x3c,
+	0xef, 0x50, 0x2d, 0x75, 0xb8, 0xcf, 0x2c, 0xbb, 0x74, 0xfd, 0xbb, 0xff, 0xb6, 0xac, 0x44, 0x9c,
+	0xb1, 0x2c, 0x82, 0xb1, 0x72, 0x77, 0x2c, 0xbd, 0x47, 0x9d, 0x6d, 0xfb, 0xc4, 0x90, 0x92, 0x6c,
+	0x56, 0xb0, 0x05, 0xb5, 0x64, 0xbd, 0x9a, 0xfd, 0x3d, 0x3c, 0x49, 0xfa, 0xce, 0xa8, 0xbf, 0x8e,
+	0x36, 0x74, 0xcb, 0xf0, 0xd3, 0xc2, 0xfd, 0x9a, 0x72, 0xff, 0x91, 0x82, 0x1c, 0x79, 0xfc, 0x13,
+	0x90, 0x22, 0x66, 0x3b, 0xea, 0x28, 0x11, 0x17, 0x44, 0x9d, 0xed, 0x9f, 0x35, 0x30, 0x17, 0xb1,
+	0x27, 0x5c, 0xf9, 0xd0, 0x0c, 0x38, 0x84, 0x36, 0xb9, 0xff, 0x2a, 0x0e, 0x02, 0x1e, 0x0a, 0xb6,
+	0x56, 0xd7, 0x34, 0x49, 0x39, 0x25, 0x11, 0xab, 0x12, 0x42, 0x4f, 0x10, 0xa5, 0x14, 0x3e, 0x85,
+	0xe6, 0x2b, 0x1a, 0xb1, 0xd2, 0x2c, 0xf3, 0x58, 0x6e, 0x7f, 0xc5, 0x05, 0xf5, 0x32, 0xcb, 0xc8,
+	0x6a, 0x29, 0x63, 0xff, 0xaa, 0x43, 0xef, 0xfc, 0x63, 0xe4, 0x1b, 0xa6, 0x2c, 0x14, 0x4a, 0xd3,
+	0x05, 0x51, 0x67, 0xfc, 0x18, 0xba, 0x73, 0xdf, 0x15, 0x2e, 0x15, 0x3c, 0x9c, 0xfb, 0x6b, 0x76,
+	0x9f, 0xae, 0xff, 0x28, 0x2b, 0x71, 0x84, 0x45, 0x01, 0xf7, 0xd7, 0x2c, 0xc5, 0x25, 0xc2, 0x8e,
+	0xb2, 0xd8, 0x83, 0xfa, 0x94, 0xf3, 0xad, 0xcb, 0x94, 0x34, 0x83, 0xa4, 0x51, 0xbe, 0xc4, 0x5a,
+	0xb1, 0x44, 0x9c, 0x01, 0xe6, 0xb7, 0xe4, 0x73, 0xec, 0xd7, 0xd5, 0x62, 0xfa, 0xd9, 0x62, 0x8e,
+	0x07, 0x4c, 0xce, 0x70, 0x64, 0xa7, 0x5c, 0x47, 0xd1, 0xa9, 0xf1, 0x7f, 0x9d, 0x4e, 0x39, 0xf6,
+	0x2f, 0x3a, 0x74, 0x92, 0xf1, 0x4d, 0xb9, 0x2f, 0x42, 0xee, 0xe1, 0x27, 0x07, 0x96, 0xfd, 0xe8,
+	0xd0, 0x30, 0x29, 0xe8, 0x8c, 0x6b, 0x9f, 0xc3, 0x65, 0x2e, 0x54, 0xfd, 0x38, 0xcb, 0xd3, 0x3d,
+	0x57, 0x92, 0x8c, 0x5c, 0x50, 0x89, 0x91, 0xcc, 0xf9, 0x5c, 0x09, 0x3f, 0x80, 0x96, 0x8a, 0x56,
+	0x7c, 0x1e, 0xa4, 0x56, 0x28, 0x12, 0xca, 0x89, 0x32, 0xf8, 0x22, 0xe4, 0x3b, 0xf5, 0xa1, 0x90,
+	0xf5, 0x72, 0xca, 0x9e, 0xfd, 0xdb, 0x67, 0xbd, 0x07, 0x38, 0x0d, 0x19, 0x15, 0x4c, 0xa1, 0x09,
+	0x7b, 0x13, 0xb3, 0x48, 0x98, 0x1a, 0xbe, 0x0f, 0x97, 0x07, 0x79, 0x29, 0x29, 0x62, 0xa6, 0xfe,
+	0xea, 0xea, 0xb7, 0x07, 0x4b, 0x7b, 0xfb, 0x60, 0x69, 0x7f, 0x3e, 0x58, 0xda, 0x4f, 0x8f, 0x56,
+	0xe5, 0xed, 0xa3, 0x55, 0xf9, 0xe3, 0xd1, 0xaa, 0x7c, 0x37, 0xb8, 0x73, 0xc5, 0x26, 0xbe, 0x1d,
+	0x3b, 0x7c, 0xf7, 0x2c, 0xf2, 0xa8, 0xb3, 0xdd, 0xbc, 0x79, 0x96, 0x8c, 0xf0, 0xb6, 0xae, 0xfe,
+	0xdd, 0xae, 0xfe, 0x09, 0x00, 0x00, 0xff, 0xff, 0xb2, 0xb5, 0xba, 0xcc, 0xed, 0x06, 0x00, 0x00,
 }
 
 func (m *NebulaMeta) Marshal() (dAtA []byte, err error) {

+ 1 - 0
nebula.proto

@@ -15,6 +15,7 @@ message NebulaMeta {
     HostWhoamiReply = 7;
     PathCheck = 8;
     PathCheckReply = 9;
+    HostUpdateNotificationAck = 10;
   }
 
   MessageType Type = 1;

+ 7 - 0
noiseutil/boring_test.go

@@ -4,14 +4,21 @@
 package noiseutil
 
 import (
+	"crypto/boring"
 	"encoding/hex"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
 )
 
+func TestEncryptLockNeeded(t *testing.T) {
+	assert.True(t, EncryptLockNeeded)
+}
+
 // Ensure NewGCMTLS validates the nonce is non-repeating
 func TestNewGCMTLS(t *testing.T) {
+	assert.True(t, boring.Enabled())
+
 	// Test Case 16 from GCM Spec:
 	//  - (now dead link): http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-spec.pdf
 	//  - as listed in boringssl tests: https://github.com/google/boringssl/blob/fips-20220613/crypto/cipher_extra/test/cipher_tests.txt#L412-L418

+ 68 - 0
noiseutil/nist.go

@@ -0,0 +1,68 @@
+package noiseutil
+
+import (
+	"crypto/ecdh"
+	"crypto/rand"
+	"fmt"
+	"io"
+
+	"github.com/flynn/noise"
+)
+
+// DHP256 is the NIST P-256 ECDH function
+var DHP256 noise.DHFunc = newNISTCurve("P256", ecdh.P256(), 32)
+
+type nistCurve struct {
+	name   string
+	curve  ecdh.Curve
+	dhLen  int
+	pubLen int
+}
+
+func newNISTCurve(name string, curve ecdh.Curve, byteLen int) nistCurve {
+	return nistCurve{
+		name:  name,
+		curve: curve,
+		dhLen: byteLen,
+		// Standard uncompressed format, type (1 byte) plus both coordinates
+		pubLen: 1 + 2*byteLen,
+	}
+}
+
+func (c nistCurve) GenerateKeypair(rng io.Reader) (noise.DHKey, error) {
+	if rng == nil {
+		rng = rand.Reader
+	}
+	privkey, err := c.curve.GenerateKey(rng)
+	if err != nil {
+		return noise.DHKey{}, err
+	}
+	pubkey := privkey.PublicKey()
+	return noise.DHKey{Private: privkey.Bytes(), Public: pubkey.Bytes()}, nil
+}
+
+func (c nistCurve) DH(privkey, pubkey []byte) ([]byte, error) {
+	ecdhPubKey, err := c.curve.NewPublicKey(pubkey)
+	if err != nil {
+		return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err)
+	}
+	ecdhPrivKey, err := c.curve.NewPrivateKey(privkey)
+	if err != nil {
+		return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err)
+	}
+
+	return ecdhPrivKey.ECDH(ecdhPubKey)
+}
+
+func (c nistCurve) DHLen() int {
+	// NOTE: Noise Protocol specifies "DHLen" to represent two things:
+	// - The size of the public key
+	// - The return size of the DH() function
+	// But for standard NIST ECDH, the sizes of these are different.
+	// Luckily, the flynn/noise library actually only uses this DHLen()
+	// value to represent the public key size, so that is what we are
+	// returning here. The length of the DH() return bytes are unaffected by
+	// this value here.
+	return c.pubLen
+}
+func (c nistCurve) DHName() string { return c.name }

+ 7 - 8
noiseutil/notboring_test.go

@@ -4,12 +4,11 @@
 package noiseutil
 
 import (
-	// NOTE: We have to force these imports here or boring_test.go fails to
-	// compile correctly. This seems to be a Go bug:
-	//
-	//     $ GOEXPERIMENT=boringcrypto go test ./noiseutil
-	//     # github.com/slackhq/nebula/noiseutil
-	//     boring_test.go:10:2: cannot find package
-
-	_ "github.com/stretchr/testify/assert"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
 )
+
+func TestEncryptLockNeeded(t *testing.T) {
+	assert.False(t, EncryptLockNeeded)
+}

+ 6 - 0
notboring.go

@@ -0,0 +1,6 @@
+//go:build !boringcrypto
+// +build !boringcrypto
+
+package nebula
+
+var boringEnabled = func() bool { return false }

+ 12 - 13
outside.go

@@ -83,7 +83,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 
 		switch h.Subtype {
 		case header.MessageNone:
-			f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache)
+			if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
+				return
+			}
 		case header.MessageRelay:
 			// The entire body is sent as AD, not encrypted.
 			// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
@@ -100,7 +102,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 			signedPayload = signedPayload[header.Len:]
 			// Pull the Roaming parts up here, and return in all call paths.
 			f.handleHostRoaming(hostinfo, addr)
+			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
 			f.connectionManager.In(hostinfo.localIndexId)
+			f.connectionManager.RelayUsed(h.RemoteIndex)
 
 			relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
 			if !ok {
@@ -118,17 +122,11 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 				return
 			case ForwardingType:
 				// Find the target HostInfo relay object
-				targetHI, err := f.hostMap.QueryVpnIp(relay.PeerIp)
+				targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp)
 				if err != nil {
 					hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip")
 					return
 				}
-				// find the target Relay info object
-				targetRelay, ok := targetHI.relayState.QueryRelayForByIp(hostinfo.vpnIp)
-				if !ok {
-					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp}).Info("Failed to find relay in hostinfo")
-					return
-				}
 
 				// If that relay is Established, forward the payload through it
 				if targetRelay.State == Established {
@@ -395,7 +393,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
 	return out, nil
 }
 
-func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
 	var err error
 
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
@@ -403,20 +401,20 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
 		//TODO: maybe after build 64 is out? 06/14/2018 - NB
 		//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
-		return
+		return false
 	}
 
 	err = newPacket(out, true, fwPacket)
 	if err != nil {
 		hostinfo.logger(f.l).WithError(err).WithField("packet", out).
 			Warnf("Error while validating inbound packet")
-		return
+		return false
 	}
 
 	if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
 		hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
 			Debugln("dropping out of window packet")
-		return
+		return false
 	}
 
 	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache)
@@ -427,7 +425,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 				WithField("reason", dropReason).
 				Debugln("dropping inbound packet")
 		}
-		return
+		return false
 	}
 
 	f.connectionManager.In(hostinfo.localIndexId)
@@ -435,6 +433,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 	if err != nil {
 		f.l.WithError(err).Error("Failed to write to tun")
 	}
+	return true
 }
 
 func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) {

+ 2 - 0
overlay/tun.go

@@ -35,6 +35,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
 			c.GetInt("tun.mtu", DefaultMTU),
 			routes,
 			c.GetInt("tun.tx_queue", 500),
+			c.GetBool("tun.use_system_route_table", false),
 		)
 
 	default:
@@ -46,6 +47,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
 			routes,
 			c.GetInt("tun.tx_queue", 500),
 			routines > 1,
+			c.GetBool("tun.use_system_route_table", false),
 		)
 	}
 }

+ 2 - 2
overlay/tun_android.go

@@ -22,7 +22,7 @@ type tun struct {
 	l         *logrus.Logger
 }
 
-func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) {
+func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
 	routeTree, err := makeRouteTree(l, routes, false)
 	if err != nil {
 		return nil, err
@@ -41,7 +41,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes
 	}, nil
 }
 
-func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
+func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in Android")
 }
 

+ 2 - 2
overlay/tun_darwin.go

@@ -77,7 +77,7 @@ type ifreqMTU struct {
 	pad  [8]byte
 }
 
-func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
+func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
 	routeTree, err := makeRouteTree(l, routes, false)
 	if err != nil {
 		return nil, err
@@ -170,7 +170,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 }
 
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) {
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 }
 

+ 2 - 2
overlay/tun_freebsd.go

@@ -38,11 +38,11 @@ func (t *tun) Close() error {
 	return nil
 }
 
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) {
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 }
 
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
 	routeTree, err := makeRouteTree(l, routes, false)
 	if err != nil {
 		return nil, err

+ 2 - 2
overlay/tun_ios.go

@@ -23,11 +23,11 @@ type tun struct {
 	routeTree *cidr.Tree4
 }
 
-func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
+func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in iOS")
 }
 
-func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) {
+func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
 	routeTree, err := makeRouteTree(l, routes, false)
 	if err != nil {
 		return nil, err

+ 109 - 15
overlay/tun_linux.go

@@ -4,11 +4,13 @@
 package overlay
 
 import (
+	"bytes"
 	"fmt"
 	"io"
 	"net"
 	"os"
 	"strings"
+	"sync/atomic"
 	"unsafe"
 
 	"github.com/sirupsen/logrus"
@@ -26,9 +28,13 @@ type tun struct {
 	MaxMTU     int
 	DefaultMTU int
 	TXQueueLen int
-	Routes     []Route
-	routeTree  *cidr.Tree4
-	l          *logrus.Logger
+
+	Routes          []Route
+	routeTree       atomic.Pointer[cidr.Tree4]
+	routeChan       chan struct{}
+	useSystemRoutes bool
+
+	l *logrus.Logger
 }
 
 type ifReq struct {
@@ -63,7 +69,7 @@ type ifreqQLEN struct {
 	pad   [8]byte
 }
 
-func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) {
+func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, useSystemRoutes bool) (*tun, error) {
 	routeTree, err := makeRouteTree(l, routes, true)
 	if err != nil {
 		return nil, err
@@ -71,7 +77,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
 
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 
-	return &tun{
+	t := &tun{
 		ReadWriteCloser: file,
 		fd:              int(file.Fd()),
 		Device:          "tun0",
@@ -79,12 +85,14 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
 		DefaultMTU:      defaultMTU,
 		TXQueueLen:      txQueueLen,
 		Routes:          routes,
-		routeTree:       routeTree,
+		useSystemRoutes: useSystemRoutes,
 		l:               l,
-	}, nil
+	}
+	t.routeTree.Store(routeTree)
+	return t, nil
 }
 
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool) (*tun, error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool, useSystemRoutes bool) (*tun, error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {
 		return nil, err
@@ -119,7 +127,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
 		return nil, err
 	}
 
-	return &tun{
+	t := &tun{
 		ReadWriteCloser: file,
 		fd:              int(file.Fd()),
 		Device:          name,
@@ -128,9 +136,11 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
 		DefaultMTU:      defaultMTU,
 		TXQueueLen:      txQueueLen,
 		Routes:          routes,
-		routeTree:       routeTree,
+		useSystemRoutes: useSystemRoutes,
 		l:               l,
-	}, nil
+	}
+	t.routeTree.Store(routeTree)
+	return t, nil
 }
 
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
@@ -152,7 +162,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
+	r := t.routeTree.Load().MostSpecificContains(ip)
 	if r != nil {
 		return r.(iputil.VpnIp)
 	}
@@ -183,16 +193,20 @@ func (t *tun) Write(b []byte) (int, error) {
 	}
 }
 
-func (t tun) deviceBytes() (o [16]byte) {
+func (t *tun) deviceBytes() (o [16]byte) {
 	for i, c := range t.Device {
 		o[i] = byte(c)
 	}
 	return
 }
 
-func (t tun) Activate() error {
+func (t *tun) Activate() error {
 	devName := t.deviceBytes()
 
+	if t.useSystemRoutes {
+		t.watchRoutes()
+	}
+
 	var addr, mask [4]byte
 
 	copy(addr[:], t.cidr.IP.To4())
@@ -318,7 +332,7 @@ func (t *tun) Name() string {
 	return t.Device
 }
 
-func (t tun) advMSS(r Route) int {
+func (t *tun) advMSS(r Route) int {
 	mtu := r.MTU
 	if r.MTU == 0 {
 		mtu = t.DefaultMTU
@@ -330,3 +344,83 @@ func (t tun) advMSS(r Route) int {
 	}
 	return 0
 }
+
+func (t *tun) watchRoutes() {
+	rch := make(chan netlink.RouteUpdate)
+	doneChan := make(chan struct{})
+
+	if err := netlink.RouteSubscribe(rch, doneChan); err != nil {
+		t.l.WithError(err).Errorf("failed to subscribe to system route changes")
+		return
+	}
+
+	t.routeChan = doneChan
+
+	go func() {
+		for {
+			select {
+			case r := <-rch:
+				t.updateRoutes(r)
+			case <-doneChan:
+				// netlink.RouteSubscriber will close the rch for us
+				return
+			}
+		}
+	}()
+}
+
+func (t *tun) updateRoutes(r netlink.RouteUpdate) {
+	if r.Gw == nil {
+		// Not a gateway route, ignore
+		t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
+		return
+	}
+
+	if !t.cidr.Contains(r.Gw) {
+		// Gateway isn't in our overlay network, ignore
+		t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
+		return
+	}
+
+	if x := r.Dst.IP.To4(); x == nil {
+		// Nebula only handles ipv4 on the overlay currently
+		t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4")
+		return
+	}
+
+	newTree := cidr.NewTree4()
+	if r.Type == unix.RTM_NEWROUTE {
+		for _, oldR := range t.routeTree.Load().List() {
+			newTree.AddCIDR(oldR.CIDR, oldR.Value)
+		}
+
+		t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
+		newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw))
+
+	} else {
+		gw := iputil.Ip2VpnIp(r.Gw)
+		for _, oldR := range t.routeTree.Load().List() {
+			if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw {
+				// This is the record to delete
+				t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
+				continue
+			}
+
+			newTree.AddCIDR(oldR.CIDR, oldR.Value)
+		}
+	}
+
+	t.routeTree.Store(newTree)
+}
+
+func (t *tun) Close() error {
+	if t.routeChan != nil {
+		close(t.routeChan)
+	}
+
+	if t.ReadWriteCloser != nil {
+		t.ReadWriteCloser.Close()
+	}
+
+	return nil
+}

+ 7 - 7
overlay/tun_linux_test.go

@@ -7,19 +7,19 @@ import "testing"
 
 var runAdvMSSTests = []struct {
 	name     string
-	tun      tun
+	tun      *tun
 	r        Route
 	expected int
 }{
 	// Standard case, default MTU is the device max MTU
-	{"default", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
-	{"default-min", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
-	{"default-low", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
+	{"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
+	{"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
+	{"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
 
 	// Case where we have a route MTU set higher than the default
-	{"route", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
-	{"route-min", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
-	{"route-high", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
+	{"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
+	{"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
+	{"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
 }
 
 func TestTunAdvMSS(t *testing.T) {

+ 2 - 2
overlay/tun_tester.go

@@ -25,7 +25,7 @@ type TestTun struct {
 	TxPackets chan []byte // Packets transmitted outside by nebula
 }
 
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool, _ bool) (*TestTun, error) {
 	routeTree, err := makeRouteTree(l, routes, false)
 	if err != nil {
 		return nil, err
@@ -42,7 +42,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes
 	}, nil
 }
 
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*TestTun, error) {
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*TestTun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported")
 }
 

+ 2 - 2
overlay/tun_windows.go

@@ -14,11 +14,11 @@ import (
 	"github.com/sirupsen/logrus"
 )
 
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (Device, error) {
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (Device, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Windows")
 }
 
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (Device, error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (Device, error) {
 	useWintun := true
 	if err := checkWinTunExists(); err != nil {
 		l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")

+ 1 - 1
punchy.go

@@ -75,7 +75,7 @@ func (p *Punchy) reload(c *config.C, initial bool) {
 	}
 
 	if initial || c.HasChanged("punchy.target_all_remotes") {
-		p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", true))
+		p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
 		if !initial {
 			p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed")
 		}

+ 27 - 25
relay_manager.go

@@ -141,27 +141,29 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
 		rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo")
 		return
 	}
-	peerRelay.State = Established
-	resp := NebulaControl{
-		Type:                NebulaControl_CreateRelayResponse,
-		ResponderRelayIndex: peerRelay.LocalIndex,
-		InitiatorRelayIndex: peerRelay.RemoteIndex,
-		RelayFromIp:         uint32(peerHostInfo.vpnIp),
-		RelayToIp:           uint32(target),
-	}
-	msg, err := resp.Marshal()
-	if err != nil {
-		rm.l.
-			WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
-	} else {
-		f.SendMessageToVpnIp(header.Control, 0, peerHostInfo.vpnIp, msg, make([]byte, 12), make([]byte, mtu))
-		rm.l.WithFields(logrus.Fields{
-			"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
-			"relayTo":             iputil.VpnIp(resp.RelayToIp),
-			"initiatorRelayIndex": resp.InitiatorRelayIndex,
-			"responderRelayIndex": resp.ResponderRelayIndex,
-			"vpnIp":               peerHostInfo.vpnIp}).
-			Info("send CreateRelayResponse")
+	if peerRelay.State == PeerRequested {
+		peerRelay.State = Established
+		resp := NebulaControl{
+			Type:                NebulaControl_CreateRelayResponse,
+			ResponderRelayIndex: peerRelay.LocalIndex,
+			InitiatorRelayIndex: peerRelay.RemoteIndex,
+			RelayFromIp:         uint32(peerHostInfo.vpnIp),
+			RelayToIp:           uint32(target),
+		}
+		msg, err := resp.Marshal()
+		if err != nil {
+			rm.l.
+				WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
+		} else {
+			f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+			rm.l.WithFields(logrus.Fields{
+				"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
+				"relayTo":             iputil.VpnIp(resp.RelayToIp),
+				"initiatorRelayIndex": resp.InitiatorRelayIndex,
+				"responderRelayIndex": resp.ResponderRelayIndex,
+				"vpnIp":               peerHostInfo.vpnIp}).
+				Info("send CreateRelayResponse")
+		}
 	}
 }
 
@@ -223,7 +225,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 			logMsg.
 				WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
 		} else {
-			f.SendMessageToVpnIp(header.Control, 0, h.vpnIp, msg, make([]byte, 12), make([]byte, mtu))
+			f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
 			rm.l.WithFields(logrus.Fields{
 				"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
 				"relayTo":             iputil.VpnIp(resp.RelayToIp),
@@ -278,7 +280,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 				logMsg.
 					WithError(err).Error("relayManager Failed to marshal Control message to create relay")
 			} else {
-				f.SendMessageToVpnIp(header.Control, 0, target, msg, make([]byte, 12), make([]byte, mtu))
+				f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
 				rm.l.WithFields(logrus.Fields{
 					"relayFrom":           iputil.VpnIp(req.RelayFromIp),
 					"relayTo":             iputil.VpnIp(req.RelayToIp),
@@ -292,7 +294,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		relay, ok := h.relayState.QueryRelayForByIp(target)
 		if !ok {
 			// Add the relay
-			state := Requested
+			state := PeerRequested
 			if targetRelay != nil && targetRelay.State == Established {
 				state = Established
 			}
@@ -324,7 +326,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 					rm.l.
 						WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
 				} else {
-					f.SendMessageToVpnIp(header.Control, 0, h.vpnIp, msg, make([]byte, 12), make([]byte, mtu))
+					f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
 					rm.l.WithFields(logrus.Fields{
 						"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
 						"relayTo":             iputil.VpnIp(resp.RelayToIp),

+ 166 - 4
remote_list.go

@@ -2,10 +2,16 @@ package nebula
 
 import (
 	"bytes"
+	"context"
 	"net"
+	"net/netip"
 	"sort"
+	"strconv"
 	"sync"
+	"sync/atomic"
+	"time"
 
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 )
@@ -55,6 +61,132 @@ type cacheV6 struct {
 	reported []*Ip6AndPort
 }
 
+type hostnamePort struct {
+	name string
+	port uint16
+}
+
+type hostnamesResults struct {
+	hostnames     []hostnamePort
+	network       string
+	lookupTimeout time.Duration
+	stop          chan struct{}
+	l             *logrus.Logger
+	ips           atomic.Pointer[map[netip.AddrPort]struct{}]
+}
+
+func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
+	r := &hostnamesResults{
+		hostnames:     make([]hostnamePort, len(hostPorts)),
+		network:       network,
+		lookupTimeout: timeout,
+		stop:          make(chan (struct{})),
+		l:             l,
+	}
+
+	// Fastrack IP addresses to ensure they're immediately available for use.
+	// DNS lookups for hostnames that aren't hardcoded IP's will happen in a background goroutine.
+	performBackgroundLookup := false
+	ips := map[netip.AddrPort]struct{}{}
+	for idx, hostPort := range hostPorts {
+
+		rIp, sPort, err := net.SplitHostPort(hostPort)
+		if err != nil {
+			return nil, err
+		}
+
+		iPort, err := strconv.Atoi(sPort)
+		if err != nil {
+			return nil, err
+		}
+
+		r.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)}
+		addr, err := netip.ParseAddr(rIp)
+		if err != nil {
+			// This address is a hostname, not an IP address
+			performBackgroundLookup = true
+			continue
+		}
+
+		// Save the IP address immediately
+		ips[netip.AddrPortFrom(addr, uint16(iPort))] = struct{}{}
+	}
+	r.ips.Store(&ips)
+
+	// Time for the DNS lookup goroutine
+	if performBackgroundLookup {
+		ticker := time.NewTicker(d)
+		go func() {
+			defer ticker.Stop()
+			for {
+				netipAddrs := map[netip.AddrPort]struct{}{}
+				for _, hostPort := range r.hostnames {
+					timeoutCtx, timeoutCancel := context.WithTimeout(ctx, r.lookupTimeout)
+					addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
+					timeoutCancel()
+					if err != nil {
+						l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host")
+						continue
+					}
+					for _, a := range addrs {
+						netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{}
+					}
+				}
+				origSet := r.ips.Load()
+				different := false
+				for a := range *origSet {
+					if _, ok := netipAddrs[a]; !ok {
+						different = true
+						break
+					}
+				}
+				if !different {
+					for a := range netipAddrs {
+						if _, ok := (*origSet)[a]; !ok {
+							different = true
+							break
+						}
+					}
+				}
+				if different {
+					l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list")
+					r.ips.Store(&netipAddrs)
+					onUpdate()
+				}
+				select {
+				case <-ctx.Done():
+					return
+				case <-r.stop:
+					return
+				case <-ticker.C:
+					continue
+				}
+			}
+		}()
+	}
+
+	return r, nil
+}
+
+func (hr *hostnamesResults) Cancel() {
+	if hr != nil {
+		hr.stop <- struct{}{}
+	}
+}
+
+func (hr *hostnamesResults) GetIPs() []netip.AddrPort {
+	var retSlice []netip.AddrPort
+	if hr != nil {
+		p := hr.ips.Load()
+		if p != nil {
+			for k := range *p {
+				retSlice = append(retSlice, k)
+			}
+		}
+	}
+	return retSlice
+}
+
 // RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos.
 // It serves as a local cache of query replies, host update notifications, and locally learned addresses
 type RemoteList struct {
@@ -72,6 +204,9 @@ type RemoteList struct {
 	// For learned addresses, this is the vpnIp that sent the packet
 	cache map[iputil.VpnIp]*cache
 
+	hr        *hostnamesResults
+	shouldAdd func(netip.Addr) bool
+
 	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
 	// They should not be tried again during a handshake
 	badRemotes []*udp.Addr
@@ -81,14 +216,21 @@ type RemoteList struct {
 }
 
 // NewRemoteList creates a new empty RemoteList
-func NewRemoteList() *RemoteList {
+func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
 	return &RemoteList{
-		addrs:  make([]*udp.Addr, 0),
-		relays: make([]*iputil.VpnIp, 0),
-		cache:  make(map[iputil.VpnIp]*cache),
+		addrs:     make([]*udp.Addr, 0),
+		relays:    make([]*iputil.VpnIp, 0),
+		cache:     make(map[iputil.VpnIp]*cache),
+		shouldAdd: shouldAdd,
 	}
 }
 
+func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
+	// Cancel any existing hostnamesResults DNS goroutine to release resources
+	r.hr.Cancel()
+	r.hr = hr
+}
+
 // Len locks and reports the size of the deduplicated address list
 // The deduplication work may need to occur here, so you must pass preferredRanges
 func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
@@ -437,6 +579,26 @@ func (r *RemoteList) unlockedCollect() {
 		}
 	}
 
+	dnsAddrs := r.hr.GetIPs()
+	for _, addr := range dnsAddrs {
+		if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
+			switch {
+			case addr.Addr().Is4():
+				v4 := addr.Addr().As4()
+				addrs = append(addrs, &udp.Addr{
+					IP:   v4[:],
+					Port: addr.Port(),
+				})
+			case addr.Addr().Is6():
+				v6 := addr.Addr().As16()
+				addrs = append(addrs, &udp.Addr{
+					IP:   v6[:],
+					Port: addr.Port(),
+				})
+			}
+		}
+	}
+
 	r.addrs = addrs
 	r.relays = relays
 

+ 3 - 3
remote_list_test.go

@@ -9,7 +9,7 @@ import (
 )
 
 func TestRemoteList_Rebuild(t *testing.T) {
-	rl := NewRemoteList()
+	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
 		0,
 		0,
@@ -102,7 +102,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
 }
 
 func BenchmarkFullRebuild(b *testing.B) {
-	rl := NewRemoteList()
+	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
 		0,
 		0,
@@ -167,7 +167,7 @@ func BenchmarkFullRebuild(b *testing.B) {
 }
 
 func BenchmarkSortRebuild(b *testing.B) {
-	rl := NewRemoteList()
+	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
 		0,
 		0,

+ 4 - 2
stats.go

@@ -7,6 +7,7 @@ import (
 	"net"
 	"net/http"
 	"runtime"
+	"strconv"
 	"time"
 
 	graphite "github.com/cyberdelia/go-metrics-graphite"
@@ -105,8 +106,9 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildV
 		Name:      "info",
 		Help:      "Version information for the Nebula binary",
 		ConstLabels: prometheus.Labels{
-			"version":   buildVersion,
-			"goversion": runtime.Version(),
+			"version":      buildVersion,
+			"goversion":    runtime.Version(),
+			"boringcrypto": strconv.FormatBool(boringEnabled()),
 		},
 	})
 	pr.MustRegister(g)