Pārlūkot izejas kodu

Merge remote-tracking branch 'origin/master' into multiport

Wade Simmons 1 gadu atpakaļ
vecāks
revīzija
f2aef0d6eb
74 mainītis faili ar 2543 papildinājumiem un 1405 dzēšanām
  1. 22 0
      .github/dependabot.yml
  2. 4 13
      .github/workflows/gofmt.yml
  3. 35 216
      .github/workflows/release.yml
  4. 4 13
      .github/workflows/smoke.yml
  5. 1 1
      .github/workflows/smoke/build-relay.sh
  6. 1 1
      .github/workflows/smoke/build.sh
  7. 25 25
      .github/workflows/smoke/smoke-relay.sh
  8. 46 46
      .github/workflows/smoke/smoke.sh
  9. 12 39
      .github/workflows/test.yml
  10. 16 1
      CHANGELOG.md
  11. 23 3
      Makefile
  12. 1 1
      README.md
  13. 0 163
      cert.go
  14. 3 0
      cert/cert.go
  15. 3 0
      cert/crypto.go
  16. 2 7
      cmd/nebula-service/main.go
  17. 3 7
      cmd/nebula/main.go
  18. 42 0
      cmd/nebula/notify_linux.go
  19. 10 0
      cmd/nebula/notify_notlinux.go
  20. 11 1
      config/config.go
  21. 1 1
      config/config_test.go
  22. 8 21
      connection_manager.go
  23. 44 36
      connection_manager_test.go
  24. 11 11
      connection_state.go
  25. 44 39
      control.go
  26. 5 5
      control_test.go
  27. 11 21
      control_tester.go
  28. 2 0
      dist/arch/nebula.service
  29. 2 0
      dist/fedora/nebula.service
  30. 2 2
      dns_server.go
  31. 106 0
      e2e/handshakes_test.go
  32. 1 1
      e2e/helpers_test.go
  33. 16 2
      examples/config.yml
  34. 2 0
      examples/service_scripts/nebula.service
  35. 14 13
      go.mod
  36. 29 31
      go.sum
  37. 1 1
      handshake.go
  38. 40 47
      handshake_ix.go
  39. 190 73
      handshake_manager.go
  40. 10 21
      handshake_manager_test.go
  41. 67 147
      hostmap.go
  42. 14 14
      hostmap_test.go
  43. 23 86
      inside.go
  44. 54 53
      interface.go
  45. 36 25
      lighthouse.go
  46. 41 2
      lighthouse_test.go
  47. 36 47
      main.go
  48. 8 11
      outside.go
  49. 1 9
      overlay/tun_darwin.go
  50. 118 20
      overlay/tun_freebsd.go
  51. 0 8
      overlay/tun_linux.go
  52. 162 0
      overlay/tun_netbsd.go
  53. 14 0
      overlay/tun_notwin.go
  54. 174 0
      overlay/tun_openbsd.go
  55. 14 1
      overlay/tun_tester.go
  56. 9 2
      overlay/tun_wintun_windows.go
  57. 248 0
      pki.go
  58. 13 6
      relay_manager.go
  59. 11 21
      remote_list.go
  60. 33 33
      ssh.go
  61. 31 0
      udp/conn.go
  62. 6 1
      udp/udp_android.go
  63. 47 0
      udp/udp_bsd.go
  64. 13 3
      udp/udp_darwin.go
  65. 14 12
      udp/udp_generic.go
  66. 25 20
      udp/udp_linux.go
  67. 1 1
      udp/udp_linux_32.go
  68. 1 1
      udp/udp_linux_64.go
  69. 6 1
      udp/udp_netbsd.go
  70. 403 0
      udp/udp_rio_windows.go
  71. 31 12
      udp/udp_tester.go
  72. 20 3
      udp/udp_windows.go
  73. 24 4
      util/error.go
  74. 42 0
      util/error_test.go

+ 22 - 0
.github/dependabot.yml

@@ -0,0 +1,22 @@
+version: 2
+updates:
+  - package-ecosystem: "github-actions"
+    directory: "/"
+    schedule:
+      interval: "weekly"
+
+  - package-ecosystem: "gomod"
+    directory: "/"
+    schedule:
+      interval: "weekly"
+    groups:
+      golang-x-dependencies:
+        patterns:
+          - "golang.org/x/*"
+      zx2c4-dependencies:
+        patterns:
+          - "golang.zx2c4.com/*"
+      protobuf-dependencies:
+        patterns:
+          - "github.com/golang/protobuf"
+          - "google.golang.org/protobuf"

+ 4 - 13
.github/workflows/gofmt.yml

@@ -14,21 +14,12 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - name: Set up Go 1.20
-      uses: actions/setup-go@v2
-      with:
-        go-version: "1.20"
-      id: go
-
-    - name: Check out code into the Go module directory
-      uses: actions/checkout@v2
+    - uses: actions/checkout@v4
 
-    - uses: actions/cache@v2
+    - uses: actions/setup-go@v4
       with:
-        path: ~/go/pkg/mod
-        key: ${{ runner.os }}-gofmt1.20-${{ hashFiles('**/go.sum') }}
-        restore-keys: |
-          ${{ runner.os }}-gofmt1.20-
+        go-version-file: 'go.mod'
+        check-latest: true
 
     - name: Install goimports
       run: |

+ 35 - 216
.github/workflows/release.yml

@@ -7,25 +7,24 @@ name: Create release and upload binaries
 
 jobs:
   build-linux:
-    name: Build Linux All
+    name: Build Linux/BSD All
     runs-on: ubuntu-latest
     steps:
-      - name: Set up Go 1.20
-        uses: actions/setup-go@v2
-        with:
-          go-version: "1.20"
+      - uses: actions/checkout@v4
 
-      - name: Checkout code
-        uses: actions/checkout@v2
+      - uses: actions/setup-go@v4
+        with:
+          go-version-file: 'go.mod'
+          check-latest: true
 
       - name: Build
         run: |
-          make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd
+          make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd release-openbsd release-netbsd
           mkdir release
           mv build/*.tar.gz release
 
       - name: Upload artifacts
-        uses: actions/upload-artifact@v2
+        uses: actions/upload-artifact@v3
         with:
           name: linux-latest
           path: release
@@ -34,13 +33,12 @@ jobs:
     name: Build Windows
     runs-on: windows-latest
     steps:
-      - name: Set up Go 1.20
-        uses: actions/setup-go@v2
-        with:
-          go-version: "1.20"
+      - uses: actions/checkout@v4
 
-      - name: Checkout code
-        uses: actions/checkout@v2
+      - uses: actions/setup-go@v4
+        with:
+          go-version-file: 'go.mod'
+          check-latest: true
 
       - name: Build
         run: |
@@ -57,7 +55,7 @@ jobs:
           mv dist\windows\wintun build\dist\windows\
 
       - name: Upload artifacts
-        uses: actions/upload-artifact@v2
+        uses: actions/upload-artifact@v3
         with:
           name: windows-latest
           path: build
@@ -68,17 +66,16 @@ jobs:
       HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
     runs-on: macos-11
     steps:
-      - name: Set up Go 1.20
-        uses: actions/setup-go@v2
-        with:
-          go-version: "1.20"
+      - uses: actions/checkout@v4
 
-      - name: Checkout code
-        uses: actions/checkout@v2
+      - uses: actions/setup-go@v4
+        with:
+          go-version-file: 'go.mod'
+          check-latest: true
 
       - name: Import certificates
         if: env.HAS_SIGNING_CREDS == 'true'
-        uses: Apple-Actions/import-codesign-certs@v1
+        uses: Apple-Actions/import-codesign-certs@v2
         with:
           p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
           p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}
@@ -107,7 +104,7 @@ jobs:
           fi
 
       - name: Upload artifacts
-        uses: actions/upload-artifact@v2
+        uses: actions/upload-artifact@v3
         with:
           name: darwin-latest
           path: ./release/*
@@ -117,12 +114,16 @@ jobs:
     needs: [build-linux, build-darwin, build-windows]
     runs-on: ubuntu-latest
     steps:
+      - uses: actions/checkout@v4
+
       - name: Download artifacts
-        uses: actions/download-artifact@v2
+        uses: actions/download-artifact@v3
+        with:
+          path: artifacts
 
       - name: Zip Windows
         run: |
-          cd windows-latest
+          cd artifacts/windows-latest
           cp windows-amd64/* .
           zip -r nebula-windows-amd64.zip nebula.exe nebula-cert.exe dist
           cp windows-arm64/* .
@@ -130,6 +131,7 @@ jobs:
 
       - name: Create sha256sum
         run: |
+          cd artifacts
           for dir in linux-latest darwin-latest windows-latest
           do
             (
@@ -159,195 +161,12 @@ jobs:
 
       - name: Create Release
         id: create_release
-        uses: actions/create-release@v1
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          tag_name: ${{ github.ref }}
-          release_name: Release ${{ github.ref }}
-          draft: false
-          prerelease: false
-
-      ##
-      ## Upload assets (I wish we could just upload the whole folder at once...
-      ##
-
-      - name: Upload SHASUM256.txt
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./SHASUM256.txt
-          asset_name: SHASUM256.txt
-          asset_content_type: text/plain
-
-      - name: Upload darwin zip
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./darwin-latest/nebula-darwin.zip
-          asset_name: nebula-darwin.zip
-          asset_content_type: application/zip
-
-      - name: Upload windows-amd64
-        uses: actions/[email protected]
         env:
           GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./windows-latest/nebula-windows-amd64.zip
-          asset_name: nebula-windows-amd64.zip
-          asset_content_type: application/zip
-
-      - name: Upload windows-arm64
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./windows-latest/nebula-windows-arm64.zip
-          asset_name: nebula-windows-arm64.zip
-          asset_content_type: application/zip
-
-      - name: Upload linux-amd64
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-linux-amd64.tar.gz
-          asset_name: nebula-linux-amd64.tar.gz
-          asset_content_type: application/gzip
-
-      - name: Upload linux-386
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-linux-386.tar.gz
-          asset_name: nebula-linux-386.tar.gz
-          asset_content_type: application/gzip
-
-      - name: Upload linux-ppc64le
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-linux-ppc64le.tar.gz
-          asset_name: nebula-linux-ppc64le.tar.gz
-          asset_content_type: application/gzip
-
-      - name: Upload linux-arm-5
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-linux-arm-5.tar.gz
-          asset_name: nebula-linux-arm-5.tar.gz
-          asset_content_type: application/gzip
-
-      - name: Upload linux-arm-6
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-linux-arm-6.tar.gz
-          asset_name: nebula-linux-arm-6.tar.gz
-          asset_content_type: application/gzip
-
-      - name: Upload linux-arm-7
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-linux-arm-7.tar.gz
-          asset_name: nebula-linux-arm-7.tar.gz
-          asset_content_type: application/gzip
-
-      - name: Upload linux-arm64
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-linux-arm64.tar.gz
-          asset_name: nebula-linux-arm64.tar.gz
-          asset_content_type: application/gzip
-
-      - name: Upload linux-mips
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-linux-mips.tar.gz
-          asset_name: nebula-linux-mips.tar.gz
-          asset_content_type: application/gzip
-
-      - name: Upload linux-mipsle
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-linux-mipsle.tar.gz
-          asset_name: nebula-linux-mipsle.tar.gz
-          asset_content_type: application/gzip
-
-      - name: Upload linux-mips64
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-linux-mips64.tar.gz
-          asset_name: nebula-linux-mips64.tar.gz
-          asset_content_type: application/gzip
-
-      - name: Upload linux-mips64le
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-linux-mips64le.tar.gz
-          asset_name: nebula-linux-mips64le.tar.gz
-          asset_content_type: application/gzip
-
-      - name: Upload linux-mips-softfloat
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-linux-mips-softfloat.tar.gz
-          asset_name: nebula-linux-mips-softfloat.tar.gz
-          asset_content_type: application/gzip
-
-      - name: Upload linux-riscv64
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-linux-riscv64.tar.gz
-          asset_name: nebula-linux-riscv64.tar.gz
-          asset_content_type: application/gzip
-
-      - name: Upload freebsd-amd64
-        uses: actions/[email protected]
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-        with:
-          upload_url: ${{ steps.create_release.outputs.upload_url }}
-          asset_path: ./linux-latest/nebula-freebsd-amd64.tar.gz
-          asset_name: nebula-freebsd-amd64.tar.gz
-          asset_content_type: application/gzip
+        run: |
+          cd artifacts
+          gh release create \
+            --verify-tag \
+            --title "Release ${{ github.ref_name }}" \
+            "${{ github.ref_name }}" \
+            SHASUM256.txt *-latest/*.zip *-latest/*.tar.gz

+ 4 - 13
.github/workflows/smoke.yml

@@ -18,21 +18,12 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - name: Set up Go 1.20
-      uses: actions/setup-go@v2
-      with:
-        go-version: "1.20"
-      id: go
-
-    - name: Check out code into the Go module directory
-      uses: actions/checkout@v2
+    - uses: actions/checkout@v4
 
-    - uses: actions/cache@v2
+    - uses: actions/setup-go@v4
       with:
-        path: ~/go/pkg/mod
-        key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }}
-        restore-keys: |
-          ${{ runner.os }}-go1.20-
+        go-version-file: 'go.mod'
+        check-latest: true
 
     - name: build
       run: make bin-docker

+ 1 - 1
.github/workflows/smoke/build-relay.sh

@@ -41,4 +41,4 @@ EOF
     ../../../../nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24"
 )
 
-sudo docker build -t nebula:smoke-relay .
+docker build -t nebula:smoke-relay .

+ 1 - 1
.github/workflows/smoke/build.sh

@@ -36,4 +36,4 @@ mkdir ./build
     ../../../../nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24"
 )
 
-sudo docker build -t "nebula:${NAME:-smoke}" .
+docker build -t "nebula:${NAME:-smoke}" .

+ 25 - 25
.github/workflows/smoke/smoke-relay.sh

@@ -14,24 +14,24 @@ cleanup() {
     set +e
     if [ "$(jobs -r)" ]
     then
-        sudo docker kill lighthouse1 host2 host3 host4
+        docker kill lighthouse1 host2 host3 host4
     fi
 }
 
 trap cleanup EXIT
 
-sudo docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
-sudo docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test
-sudo docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test
-sudo docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test
+docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
+docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test
+docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test
+docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test
 
-sudo docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/  [lighthouse1]  /' &
+docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/  [lighthouse1]  /' &
 sleep 1
-sudo docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/  [host2]  /' &
+docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/  [host2]  /' &
 sleep 1
-sudo docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/  [host3]  /' &
+docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/  [host3]  /' &
 sleep 1
-sudo docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/  [host4]  /' &
+docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/  [host4]  /' &
 sleep 1
 
 set +x
@@ -39,43 +39,43 @@ echo
 echo " *** Testing ping from lighthouse1"
 echo
 set -x
-sudo docker exec lighthouse1 ping -c1 192.168.100.2
-sudo docker exec lighthouse1 ping -c1 192.168.100.3
-sudo docker exec lighthouse1 ping -c1 192.168.100.4
+docker exec lighthouse1 ping -c1 192.168.100.2
+docker exec lighthouse1 ping -c1 192.168.100.3
+docker exec lighthouse1 ping -c1 192.168.100.4
 
 set +x
 echo
 echo " *** Testing ping from host2"
 echo
 set -x
-sudo docker exec host2 ping -c1 192.168.100.1
+docker exec host2 ping -c1 192.168.100.1
 # Should fail because no relay configured in this direction
-! sudo docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
-! sudo docker exec host2 ping -c1 192.168.100.4 -w5 || exit 1
+! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
+! docker exec host2 ping -c1 192.168.100.4 -w5 || exit 1
 
 set +x
 echo
 echo " *** Testing ping from host3"
 echo
 set -x
-sudo docker exec host3 ping -c1 192.168.100.1
-sudo docker exec host3 ping -c1 192.168.100.2
-sudo docker exec host3 ping -c1 192.168.100.4
+docker exec host3 ping -c1 192.168.100.1
+docker exec host3 ping -c1 192.168.100.2
+docker exec host3 ping -c1 192.168.100.4
 
 set +x
 echo
 echo " *** Testing ping from host4"
 echo
 set -x
-sudo docker exec host4 ping -c1 192.168.100.1
+docker exec host4 ping -c1 192.168.100.1
 # Should fail because relays not allowed
-! sudo docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
-sudo docker exec host4 ping -c1 192.168.100.3
+! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
+docker exec host4 ping -c1 192.168.100.3
 
-sudo docker exec host4 sh -c 'kill 1'
-sudo docker exec host3 sh -c 'kill 1'
-sudo docker exec host2 sh -c 'kill 1'
-sudo docker exec lighthouse1 sh -c 'kill 1'
+docker exec host4 sh -c 'kill 1'
+docker exec host3 sh -c 'kill 1'
+docker exec host2 sh -c 'kill 1'
+docker exec lighthouse1 sh -c 'kill 1'
 sleep 1
 
 if [ "$(jobs -r)" ]

+ 46 - 46
.github/workflows/smoke/smoke.sh

@@ -14,7 +14,7 @@ cleanup() {
     set +e
     if [ "$(jobs -r)" ]
     then
-        sudo docker kill lighthouse1 host2 host3 host4
+        docker kill lighthouse1 host2 host3 host4
     fi
 }
 
@@ -22,51 +22,51 @@ trap cleanup EXIT
 
 CONTAINER="nebula:${NAME:-smoke}"
 
-sudo docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
-sudo docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
-sudo docker run --name host3 --rm "$CONTAINER" -config host3.yml -test
-sudo docker run --name host4 --rm "$CONTAINER" -config host4.yml -test
+docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
+docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
+docker run --name host3 --rm "$CONTAINER" -config host3.yml -test
+docker run --name host4 --rm "$CONTAINER" -config host4.yml -test
 
-sudo docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/  [lighthouse1]  /' &
+docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/  [lighthouse1]  /' &
 sleep 1
-sudo docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/  [host2]  /' &
+docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/  [host2]  /' &
 sleep 1
-sudo docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/  [host3]  /' &
+docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/  [host3]  /' &
 sleep 1
-sudo docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/  [host4]  /' &
+docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/  [host4]  /' &
 sleep 1
 
 # grab tcpdump pcaps for debugging
-sudo docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
-sudo docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
-sudo docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
-sudo docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
-sudo docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
-sudo docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap &
-sudo docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
-sudo docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap &
-
-sudo docker exec host2 ncat -nklv 0.0.0.0 2000 &
-sudo docker exec host3 ncat -nklv 0.0.0.0 2000 &
-sudo docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
-sudo docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 &
+docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
+docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
+docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
+docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
+docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
+docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap &
+docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
+docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap &
+
+docker exec host2 ncat -nklv 0.0.0.0 2000 &
+docker exec host3 ncat -nklv 0.0.0.0 2000 &
+docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
+docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 &
 
 set +x
 echo
 echo " *** Testing ping from lighthouse1"
 echo
 set -x
-sudo docker exec lighthouse1 ping -c1 192.168.100.2
-sudo docker exec lighthouse1 ping -c1 192.168.100.3
+docker exec lighthouse1 ping -c1 192.168.100.2
+docker exec lighthouse1 ping -c1 192.168.100.3
 
 set +x
 echo
 echo " *** Testing ping from host2"
 echo
 set -x
-sudo docker exec host2 ping -c1 192.168.100.1
+docker exec host2 ping -c1 192.168.100.1
 # Should fail because not allowed by host3 inbound firewall
-! sudo docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
+! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
 
 set +x
 echo
@@ -74,34 +74,34 @@ echo " *** Testing ncat from host2"
 echo
 set -x
 # Should fail because not allowed by host3 inbound firewall
-! sudo docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
-! sudo docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
+! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
+! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
 
 set +x
 echo
 echo " *** Testing ping from host3"
 echo
 set -x
-sudo docker exec host3 ping -c1 192.168.100.1
-sudo docker exec host3 ping -c1 192.168.100.2
+docker exec host3 ping -c1 192.168.100.1
+docker exec host3 ping -c1 192.168.100.2
 
 set +x
 echo
 echo " *** Testing ncat from host3"
 echo
 set -x
-sudo docker exec host3 ncat -nzv -w5 192.168.100.2 2000
-sudo docker exec host3 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2
+docker exec host3 ncat -nzv -w5 192.168.100.2 2000
+docker exec host3 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2
 
 set +x
 echo
 echo " *** Testing ping from host4"
 echo
 set -x
-sudo docker exec host4 ping -c1 192.168.100.1
+docker exec host4 ping -c1 192.168.100.1
 # Should fail because not allowed by host4 outbound firewall
-! sudo docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
-! sudo docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1
+! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
+! docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1
 
 set +x
 echo
@@ -109,10 +109,10 @@ echo " *** Testing ncat from host4"
 echo
 set -x
 # Should fail because not allowed by host4 outbound firewall
-! sudo docker exec host4 ncat -nzv -w5 192.168.100.2 2000 || exit 1
-! sudo docker exec host4 ncat -nzv -w5 192.168.100.3 2000 || exit 1
-! sudo docker exec host4 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 || exit 1
-! sudo docker exec host4 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
+! docker exec host4 ncat -nzv -w5 192.168.100.2 2000 || exit 1
+! docker exec host4 ncat -nzv -w5 192.168.100.3 2000 || exit 1
+! docker exec host4 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 || exit 1
+! docker exec host4 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
 
 set +x
 echo
@@ -120,15 +120,15 @@ echo " *** Testing conntrack"
 echo
 set -x
 # host2 can ping host3 now that host3 pinged it first
-sudo docker exec host2 ping -c1 192.168.100.3
+docker exec host2 ping -c1 192.168.100.3
 # host4 can ping host2 once conntrack established
-sudo docker exec host2 ping -c1 192.168.100.4
-sudo docker exec host4 ping -c1 192.168.100.2
+docker exec host2 ping -c1 192.168.100.4
+docker exec host4 ping -c1 192.168.100.2
 
-sudo docker exec host4 sh -c 'kill 1'
-sudo docker exec host3 sh -c 'kill 1'
-sudo docker exec host2 sh -c 'kill 1'
-sudo docker exec lighthouse1 sh -c 'kill 1'
+docker exec host4 sh -c 'kill 1'
+docker exec host3 sh -c 'kill 1'
+docker exec host2 sh -c 'kill 1'
+docker exec lighthouse1 sh -c 'kill 1'
 sleep 1
 
 if [ "$(jobs -r)" ]

+ 12 - 39
.github/workflows/test.yml

@@ -18,21 +18,12 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - name: Set up Go 1.20
-      uses: actions/setup-go@v2
-      with:
-        go-version: "1.20"
-      id: go
-
-    - name: Check out code into the Go module directory
-      uses: actions/checkout@v2
+    - uses: actions/checkout@v4
 
-    - uses: actions/cache@v2
+    - uses: actions/setup-go@v4
       with:
-        path: ~/go/pkg/mod
-        key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }}
-        restore-keys: |
-          ${{ runner.os }}-go1.20-
+        go-version-file: 'go.mod'
+        check-latest: true
 
     - name: Build
       run: make all
@@ -57,21 +48,12 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - name: Set up Go 1.20
-      uses: actions/setup-go@v2
-      with:
-        go-version: "1.20"
-      id: go
-
-    - name: Check out code into the Go module directory
-      uses: actions/checkout@v2
+    - uses: actions/checkout@v4
 
-    - uses: actions/cache@v2
+    - uses: actions/setup-go@v4
       with:
-        path: ~/go/pkg/mod
-        key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }}
-        restore-keys: |
-          ${{ runner.os }}-go1.20-
+        go-version-file: 'go.mod'
+        check-latest: true
 
     - name: Build
       run: make bin-boringcrypto
@@ -90,21 +72,12 @@ jobs:
         os: [windows-latest, macos-11]
     steps:
 
-    - name: Set up Go 1.20
-      uses: actions/setup-go@v2
-      with:
-        go-version: "1.20"
-      id: go
-
-    - name: Check out code into the Go module directory
-      uses: actions/checkout@v2
+    - uses: actions/checkout@v4
 
-    - uses: actions/cache@v2
+    - uses: actions/setup-go@v4
       with:
-        path: ~/go/pkg/mod
-        key: ${{ runner.os }}-go1.20-${{ hashFiles('**/go.sum') }}
-        restore-keys: |
-          ${{ runner.os }}-go1.20-
+        go-version-file: 'go.mod'
+        check-latest: true
 
     - name: Build nebula
       run: go build ./cmd/nebula

+ 16 - 1
CHANGELOG.md

@@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ## [Unreleased]
 
+## [1.7.2] - 2023-06-01
+
+### Fixed
+
+- Fix a freeze during config reload if the `static_host_map` config was changed. (#886)
+
+## [1.7.1] - 2023-05-18
+
+### Fixed
+
+- Fix IPv4 addresses returned by `static_host_map` DNS lookup queries being
+  treated as IPv6 addresses. (#877)
+
 ## [1.7.0] - 2023-05-17
 
 ### Added
@@ -475,7 +488,9 @@ created.)
 
 - Initial public release.
 
-[Unreleased]: https://github.com/slackhq/nebula/compare/v1.7.0...HEAD
+[Unreleased]: https://github.com/slackhq/nebula/compare/v1.7.2...HEAD
+[1.7.2]: https://github.com/slackhq/nebula/releases/tag/v1.7.2
+[1.7.1]: https://github.com/slackhq/nebula/releases/tag/v1.7.1
 [1.7.0]: https://github.com/slackhq/nebula/releases/tag/v1.7.0
 [1.6.1]: https://github.com/slackhq/nebula/releases/tag/v1.6.1
 [1.6.0]: https://github.com/slackhq/nebula/releases/tag/v1.6.0

+ 23 - 3
Makefile

@@ -12,6 +12,8 @@ ifeq ($(OS),Windows_NT)
 	GOISMIN := $(shell IF "$(GOVERSION)" GEQ "$(GOMINVERSION)" ECHO 1)
 	NEBULA_CMD_SUFFIX = .exe
 	NULL_FILE = nul
+	# RIO on windows does pointer stuff that makes go vet angry
+	VET_FLAGS = -unsafeptr=false
 else
 	GOVERSION := $(shell go version | awk '{print substr($$3, 3)}')
 	GOISMIN := $(shell expr "$(GOVERSION)" ">=" "$(GOMINVERSION)")
@@ -44,10 +46,21 @@ ALL_LINUX = linux-amd64 \
 	linux-mips-softfloat \
 	linux-riscv64
 
+ALL_FREEBSD = freebsd-amd64 \
+	freebsd-arm64
+
+ALL_OPENBSD = openbsd-amd64 \
+	openbsd-arm64
+
+ALL_NETBSD = netbsd-amd64 \
+ 	netbsd-arm64
+
 ALL = $(ALL_LINUX) \
+	$(ALL_FREEBSD) \
+	$(ALL_OPENBSD) \
+	$(ALL_NETBSD) \
 	darwin-amd64 \
 	darwin-arm64 \
-	freebsd-amd64 \
 	windows-amd64 \
 	windows-arm64
 
@@ -75,7 +88,11 @@ release: $(ALL:%=build/nebula-%.tar.gz)
 
 release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz)
 
-release-freebsd: build/nebula-freebsd-amd64.tar.gz
+release-freebsd: $(ALL_FREEBSD:%=build/nebula-%.tar.gz)
+
+release-openbsd: $(ALL_OPENBSD:%=build/nebula-%.tar.gz)
+
+release-netbsd: $(ALL_NETBSD:%=build/nebula-%.tar.gz)
 
 release-boringcrypto: build/nebula-linux-$(shell go env GOARCH)-boringcrypto.tar.gz
 
@@ -93,6 +110,9 @@ bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert
 bin-freebsd: build/freebsd-amd64/nebula build/freebsd-amd64/nebula-cert
 	mv $? .
 
+bin-freebsd-arm64: build/freebsd-arm64/nebula build/freebsd-arm64/nebula-cert
+	mv $? .
+
 bin-boringcrypto: build/linux-$(shell go env GOARCH)-boringcrypto/nebula build/linux-$(shell go env GOARCH)-boringcrypto/nebula-cert
 	mv $? .
 
@@ -137,7 +157,7 @@ build/nebula-%.zip: build/%/nebula.exe build/%/nebula-cert.exe
 	cd build/$* && zip ../nebula-$*.zip nebula.exe nebula-cert.exe
 
 vet:
-	go vet -v ./...
+	go vet $(VET_FLAGS) -v ./...
 
 test:
 	go test -v ./...

+ 1 - 1
README.md

@@ -108,7 +108,7 @@ For each host, copy the nebula binary to the host, along with `config.yml` from
 
 ## Building Nebula from source
 
-Download go and clone this repo. Change to the nebula directory.
+Make sure you have [go](https://go.dev/doc/install) installed and clone this repo. Change to the nebula directory.
 
 To build nebula for all platforms:
 `make all`

+ 0 - 163
cert.go

@@ -1,163 +0,0 @@
-package nebula
-
-import (
-	"errors"
-	"fmt"
-	"io/ioutil"
-	"strings"
-	"time"
-
-	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cert"
-	"github.com/slackhq/nebula/config"
-)
-
-type CertState struct {
-	certificate         *cert.NebulaCertificate
-	rawCertificate      []byte
-	rawCertificateNoKey []byte
-	publicKey           []byte
-	privateKey          []byte
-}
-
-func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) {
-	// Marshal the certificate to ensure it is valid
-	rawCertificate, err := certificate.Marshal()
-	if err != nil {
-		return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
-	}
-
-	publicKey := certificate.Details.PublicKey
-	cs := &CertState{
-		rawCertificate: rawCertificate,
-		certificate:    certificate, // PublicKey has been set to nil above
-		privateKey:     privateKey,
-		publicKey:      publicKey,
-	}
-
-	cs.certificate.Details.PublicKey = nil
-	rawCertNoKey, err := cs.certificate.Marshal()
-	if err != nil {
-		return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
-	}
-	cs.rawCertificateNoKey = rawCertNoKey
-	// put public key back
-	cs.certificate.Details.PublicKey = cs.publicKey
-	return cs, nil
-}
-
-func NewCertStateFromConfig(c *config.C) (*CertState, error) {
-	var pemPrivateKey []byte
-	var err error
-
-	privPathOrPEM := c.GetString("pki.key", "")
-
-	if privPathOrPEM == "" {
-		return nil, errors.New("no pki.key path or PEM data provided")
-	}
-
-	if strings.Contains(privPathOrPEM, "-----BEGIN") {
-		pemPrivateKey = []byte(privPathOrPEM)
-		privPathOrPEM = "<inline>"
-	} else {
-		pemPrivateKey, err = ioutil.ReadFile(privPathOrPEM)
-		if err != nil {
-			return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
-		}
-	}
-
-	rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey)
-	if err != nil {
-		return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
-	}
-
-	var rawCert []byte
-
-	pubPathOrPEM := c.GetString("pki.cert", "")
-
-	if pubPathOrPEM == "" {
-		return nil, errors.New("no pki.cert path or PEM data provided")
-	}
-
-	if strings.Contains(pubPathOrPEM, "-----BEGIN") {
-		rawCert = []byte(pubPathOrPEM)
-		pubPathOrPEM = "<inline>"
-	} else {
-		rawCert, err = ioutil.ReadFile(pubPathOrPEM)
-		if err != nil {
-			return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err)
-		}
-	}
-
-	nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
-	if err != nil {
-		return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
-	}
-
-	if nebulaCert.Expired(time.Now()) {
-		return nil, fmt.Errorf("nebula certificate for this host is expired")
-	}
-
-	if len(nebulaCert.Details.Ips) == 0 {
-		return nil, fmt.Errorf("no IPs encoded in certificate")
-	}
-
-	if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
-		return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
-	}
-
-	return NewCertState(nebulaCert, rawKey)
-}
-
-func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
-	var rawCA []byte
-	var err error
-
-	caPathOrPEM := c.GetString("pki.ca", "")
-	if caPathOrPEM == "" {
-		return nil, errors.New("no pki.ca path or PEM data provided")
-	}
-
-	if strings.Contains(caPathOrPEM, "-----BEGIN") {
-		rawCA = []byte(caPathOrPEM)
-
-	} else {
-		rawCA, err = ioutil.ReadFile(caPathOrPEM)
-		if err != nil {
-			return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
-		}
-	}
-
-	CAs, err := cert.NewCAPoolFromBytes(rawCA)
-	if errors.Is(err, cert.ErrExpired) {
-		var expired int
-		for _, cert := range CAs.CAs {
-			if cert.Expired(time.Now()) {
-				expired++
-				l.WithField("cert", cert).Warn("expired certificate present in CA pool")
-			}
-		}
-
-		if expired >= len(CAs.CAs) {
-			return nil, errors.New("no valid CA certificates present")
-		}
-
-	} else if err != nil {
-		return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
-	}
-
-	for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
-		l.WithField("fingerprint", fp).Info("Blocklisting cert")
-		CAs.BlocklistFingerprint(fp)
-	}
-
-	// Support deprecated config for at least one minor release to allow for migrations
-	//TODO: remove in 2022 or later
-	for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) {
-		l.WithField("fingerprint", fp).Info("Blocklisting cert")
-		l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist")
-		CAs.BlocklistFingerprint(fp)
-	}
-
-	return CAs, nil
-}

+ 3 - 0
cert/cert.go

@@ -272,6 +272,9 @@ func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte
 		},
 		Ciphertext: ciphertext,
 	})
+	if err != nil {
+		return nil, err
+	}
 
 	switch curve {
 	case Curve_CURVE25519:

+ 3 - 0
cert/crypto.go

@@ -77,6 +77,9 @@ func aes256Decrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte)
 	}
 
 	gcm, err := cipher.NewGCM(block)
+	if err != nil {
+		return nil, err
+	}
 
 	nonce, ciphertext, err := splitNonceCiphertext(data, gcm.NonceSize())
 	if err != nil {

+ 2 - 7
cmd/nebula-service/main.go

@@ -59,13 +59,8 @@ func main() {
 	}
 
 	ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
-
-	switch v := err.(type) {
-	case util.ContextualError:
-		v.Log(l)
-		os.Exit(1)
-	case error:
-		l.WithError(err).Error("Failed to start")
+	if err != nil {
+		util.LogWithContextIfNeeded("Failed to start", err, l)
 		os.Exit(1)
 	}
 

+ 3 - 7
cmd/nebula/main.go

@@ -53,18 +53,14 @@ func main() {
 	}
 
 	ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
-
-	switch v := err.(type) {
-	case util.ContextualError:
-		v.Log(l)
-		os.Exit(1)
-	case error:
-		l.WithError(err).Error("Failed to start")
+	if err != nil {
+		util.LogWithContextIfNeeded("Failed to start", err, l)
 		os.Exit(1)
 	}
 
 	if !*configTest {
 		ctrl.Start()
+		notifyReady(l)
 		ctrl.ShutdownBlock()
 	}
 

+ 42 - 0
cmd/nebula/notify_linux.go

@@ -0,0 +1,42 @@
+package main
+
+import (
+	"net"
+	"os"
+	"time"
+
+	"github.com/sirupsen/logrus"
+)
+
+// SdNotifyReady tells systemd the service is ready and dependent services can now be started
+// https://www.freedesktop.org/software/systemd/man/sd_notify.html
+// https://www.freedesktop.org/software/systemd/man/systemd.service.html
+const SdNotifyReady = "READY=1"
+
+func notifyReady(l *logrus.Logger) {
+	sockName := os.Getenv("NOTIFY_SOCKET")
+	if sockName == "" {
+		l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal")
+		return
+	}
+
+	conn, err := net.DialTimeout("unixgram", sockName, time.Second)
+	if err != nil {
+		l.WithError(err).Error("failed to connect to systemd notification socket")
+		return
+	}
+	defer conn.Close()
+
+	err = conn.SetWriteDeadline(time.Now().Add(time.Second))
+	if err != nil {
+		l.WithError(err).Error("failed to set the write deadline for the systemd notification socket")
+		return
+	}
+
+	if _, err = conn.Write([]byte(SdNotifyReady)); err != nil {
+		l.WithError(err).Error("failed to signal the systemd notification socket")
+		return
+	}
+
+	l.Debugln("notified systemd the service is ready")
+}

+ 10 - 0
cmd/nebula/notify_notlinux.go

@@ -0,0 +1,10 @@
+//go:build !linux
+// +build !linux
+
+package main
+
+import "github.com/sirupsen/logrus"
+
+func notifyReady(_ *logrus.Logger) {
+	// No init service to notify
+}

+ 11 - 1
config/config.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"io/ioutil"
+	"math"
 	"os"
 	"os/signal"
 	"path/filepath"
@@ -15,7 +16,7 @@ import (
 	"syscall"
 	"time"
 
-	"github.com/imdario/mergo"
+	"dario.cat/mergo"
 	"github.com/sirupsen/logrus"
 	"gopkg.in/yaml.v2"
 )
@@ -236,6 +237,15 @@ func (c *C) GetInt(k string, d int) int {
 	return v
 }
 
+// GetUint32 will get the uint32 for k or return the default d if not found or invalid
+func (c *C) GetUint32(k string, d uint32) uint32 {
+	r := c.GetInt(k, int(d))
+	if uint64(r) > uint64(math.MaxUint32) {
+		return d
+	}
+	return uint32(r)
+}
+
 // GetBool will get the bool for k or return the default d if not found or invalid
 func (c *C) GetBool(k string, d bool) bool {
 	r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))

+ 1 - 1
config/config_test.go

@@ -7,7 +7,7 @@ import (
 	"testing"
 	"time"
 
-	"github.com/imdario/mergo"
+	"dario.cat/mergo"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"

+ 8 - 21
connection_manager.go

@@ -231,7 +231,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			index = existing.LocalIndex
 			switch r.Type {
 			case TerminalType:
-				relayFrom = newhostinfo.vpnIp
+				relayFrom = n.intf.myVpnIp
 				relayTo = existing.PeerIp
 			case ForwardingType:
 				relayFrom = existing.PeerIp
@@ -256,7 +256,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			}
 			switch r.Type {
 			case TerminalType:
-				relayFrom = newhostinfo.vpnIp
+				relayFrom = n.intf.myVpnIp
 				relayTo = r.PeerIp
 			case ForwardingType:
 				relayFrom = r.PeerIp
@@ -405,8 +405,8 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 		return false
 	}
 
-	certState := n.intf.certState.Load()
-	return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature)
+	certState := n.intf.pki.GetCertState()
+	return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
 }
 
 func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
@@ -427,7 +427,7 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
 		return false
 	}
 
-	valid, err := remoteCert.VerifyWithCache(now, n.intf.caPool)
+	valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
 	if valid {
 		return false
 	}
@@ -464,8 +464,8 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 }
 
 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
-	certState := n.intf.certState.Load()
-	if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) {
+	certState := n.intf.pki.GetCertState()
+	if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
 		return
 	}
 
@@ -473,18 +473,5 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
 		WithField("reason", "local certificate is not current").
 		Info("Re-handshaking with remote")
 
-	//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
-	newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)
-	if !newHostinfo.HandshakeReady {
-		ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
-	}
-
-	//If this is a static host, we don't need to wait for the HostQueryReply
-	//We can trigger the handshake right now
-	if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
-		select {
-		case n.intf.handshakeManager.trigger <- hostinfo.vpnIp:
-		default:
-		}
-	}
+	n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
 }

+ 44 - 36
connection_manager_test.go

@@ -42,25 +42,26 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	preferredRanges := []*net.IPNet{localrange}
 
 	// Very incomplete mock objects
-	hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
+	hostMap := NewHostMap(l, vpncidr, preferredRanges)
 	cs := &CertState{
-		rawCertificate:      []byte{},
-		privateKey:          []byte{},
-		certificate:         &cert.NebulaCertificate{},
-		rawCertificateNoKey: []byte{},
+		RawCertificate:      []byte{},
+		PrivateKey:          []byte{},
+		Certificate:         &cert.NebulaCertificate{},
+		RawCertificateNoKey: []byte{},
 	}
 
 	lh := newTestLighthouse()
 	ifce := &Interface{
 		hostMap:          hostMap,
 		inside:           &test.NoopTun{},
-		outside:          &udp.Conn{},
+		outside:          &udp.NoopConn{},
 		firewall:         &Firewall{},
 		lightHouse:       lh,
-		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
+		pki:              &PKI{},
+		handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
 		l:                l,
 	}
-	ifce.certState.Store(cs)
+	ifce.pki.cs.Store(cs)
 
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())
@@ -78,8 +79,8 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		remoteIndexId: 9901,
 	}
 	hostinfo.ConnectionState = &ConnectionState{
-		certState: cs,
-		H:         &noise.HandshakeState{},
+		myCert: &cert.NebulaCertificate{},
+		H:      &noise.HandshakeState{},
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
@@ -121,25 +122,26 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	preferredRanges := []*net.IPNet{localrange}
 
 	// Very incomplete mock objects
-	hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
+	hostMap := NewHostMap(l, vpncidr, preferredRanges)
 	cs := &CertState{
-		rawCertificate:      []byte{},
-		privateKey:          []byte{},
-		certificate:         &cert.NebulaCertificate{},
-		rawCertificateNoKey: []byte{},
+		RawCertificate:      []byte{},
+		PrivateKey:          []byte{},
+		Certificate:         &cert.NebulaCertificate{},
+		RawCertificateNoKey: []byte{},
 	}
 
 	lh := newTestLighthouse()
 	ifce := &Interface{
 		hostMap:          hostMap,
 		inside:           &test.NoopTun{},
-		outside:          &udp.Conn{},
+		outside:          &udp.NoopConn{},
 		firewall:         &Firewall{},
 		lightHouse:       lh,
-		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
+		pki:              &PKI{},
+		handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
 		l:                l,
 	}
-	ifce.certState.Store(cs)
+	ifce.pki.cs.Store(cs)
 
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())
@@ -157,8 +159,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		remoteIndexId: 9901,
 	}
 	hostinfo.ConnectionState = &ConnectionState{
-		certState: cs,
-		H:         &noise.HandshakeState{},
+		myCert: &cert.NebulaCertificate{},
+		H:      &noise.HandshakeState{},
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
@@ -207,7 +209,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	preferredRanges := []*net.IPNet{localrange}
-	hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
+	hostMap := NewHostMap(l, vpncidr, preferredRanges)
 
 	// Generate keys for CA and peer's cert.
 	pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
@@ -220,7 +222,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 			PublicKey: pubCA,
 		},
 	}
-	caCert.Sign(cert.Curve_CURVE25519, privCA)
+
+	assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
 	ncp := &cert.NebulaCAPool{
 		CAs: cert.NewCAPool().CAs,
 	}
@@ -239,28 +242,29 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 			Issuer:    "ca",
 		},
 	}
-	peerCert.Sign(cert.Curve_CURVE25519, privCA)
+	assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
 
 	cs := &CertState{
-		rawCertificate:      []byte{},
-		privateKey:          []byte{},
-		certificate:         &cert.NebulaCertificate{},
-		rawCertificateNoKey: []byte{},
+		RawCertificate:      []byte{},
+		PrivateKey:          []byte{},
+		Certificate:         &cert.NebulaCertificate{},
+		RawCertificateNoKey: []byte{},
 	}
 
 	lh := newTestLighthouse()
 	ifce := &Interface{
 		hostMap:           hostMap,
 		inside:            &test.NoopTun{},
-		outside:           &udp.Conn{},
+		outside:           &udp.NoopConn{},
 		firewall:          &Firewall{},
 		lightHouse:        lh,
-		handshakeManager:  NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
+		handshakeManager:  NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
 		l:                 l,
 		disconnectInvalid: true,
-		caPool:            ncp,
+		pki:               &PKI{},
 	}
-	ifce.certState.Store(cs)
+	ifce.pki.cs.Store(cs)
+	ifce.pki.caPool.Store(ncp)
 
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())
@@ -268,12 +272,16 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	punchy := NewPunchyFromConfig(l, config.NewC(l))
 	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
 	ifce.connectionManager = nc
-	hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
-	hostinfo.ConnectionState = &ConnectionState{
-		certState: cs,
-		peerCert:  &peerCert,
-		H:         &noise.HandshakeState{},
+
+	hostinfo := &HostInfo{
+		vpnIp: vpnIp,
+		ConnectionState: &ConnectionState{
+			myCert:   &cert.NebulaCertificate{},
+			peerCert: &peerCert,
+			H:        &noise.HandshakeState{},
+		},
 	}
+	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 	// Move ahead 45s.
 	// Check if to disconnect with invalid certificate.

+ 11 - 11
connection_state.go

@@ -18,35 +18,35 @@ type ConnectionState struct {
 	eKey           *NebulaCipherState
 	dKey           *NebulaCipherState
 	H              *noise.HandshakeState
-	certState      *CertState
+	myCert         *cert.NebulaCertificate
 	peerCert       *cert.NebulaCertificate
 	initiator      bool
 	messageCounter atomic.Uint64
 	window         *Bits
-	queueLock      sync.Mutex
 	writeLock      sync.Mutex
 	ready          bool
 }
 
-func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
+func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
 	var dhFunc noise.DHFunc
-	curCertState := f.certState.Load()
-
-	switch curCertState.certificate.Details.Curve {
+	switch certState.Certificate.Details.Curve {
 	case cert.Curve_CURVE25519:
 		dhFunc = noise.DH25519
 	case cert.Curve_P256:
 		dhFunc = noiseutil.DHP256
 	default:
-		l.Errorf("invalid curve: %s", curCertState.certificate.Details.Curve)
+		l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
 		return nil
 	}
-	cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
-	if f.cipher == "chachapoly" {
+
+	var cs noise.CipherSuite
+	if cipher == "chachapoly" {
 		cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
+	} else {
+		cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
 	}
 
-	static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}
+	static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
 
 	b := NewBits(ReplayWindow)
 	// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
@@ -72,7 +72,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
 		initiator: initiator,
 		window:    b,
 		ready:     false,
-		certState: curCertState,
+		myCert:    certState.Certificate,
 	}
 
 	return ci

+ 44 - 39
control.go

@@ -17,13 +17,23 @@ import (
 // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
 // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
 
+type controlEach func(h *HostInfo)
+
+type controlHostLister interface {
+	QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo
+	ForEachIndex(each controlEach)
+	ForEachVpnIp(each controlEach)
+	GetPreferredRanges() []*net.IPNet
+}
+
 type Control struct {
-	f          *Interface
-	l          *logrus.Logger
-	cancel     context.CancelFunc
-	sshStart   func()
-	statsStart func()
-	dnsStart   func()
+	f               *Interface
+	l               *logrus.Logger
+	cancel          context.CancelFunc
+	sshStart        func()
+	statsStart      func()
+	dnsStart        func()
+	lighthouseStart func()
 }
 
 type ControlHostInfo struct {
@@ -54,12 +64,15 @@ func (c *Control) Start() {
 	if c.dnsStart != nil {
 		go c.dnsStart()
 	}
+	if c.lighthouseStart != nil {
+		c.lighthouseStart()
+	}
 
 	// Start reading packets.
 	c.f.run()
 }
 
-// Stop signals nebula to shutdown, returns after the shutdown is complete
+// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
 func (c *Control) Stop() {
 	// Stop the handshakeManager (and other services), to prevent new tunnels from
 	// being created while we're shutting them all down.
@@ -89,7 +102,7 @@ func (c *Control) RebindUDPServer() {
 	_ = c.f.outside.Rebind()
 
 	// Trigger a lighthouse update, useful for mobile clients that should have an update interval of 0
-	c.f.lightHouse.SendUpdate(c.f)
+	c.f.lightHouse.SendUpdate()
 
 	// Let the main interface know that we rebound so that underlying tunnels know to trigger punches from their remotes
 	c.f.rebindCount++
@@ -98,7 +111,7 @@ func (c *Control) RebindUDPServer() {
 // ListHostmapHosts returns details about the actual or pending (handshaking) hostmap by vpn ip
 func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
 	if pendingMap {
-		return listHostMapHosts(c.f.handshakeManager.pendingHostMap)
+		return listHostMapHosts(c.f.handshakeManager)
 	} else {
 		return listHostMapHosts(c.f.hostMap)
 	}
@@ -107,7 +120,7 @@ func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
 // ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id
 func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 	if pendingMap {
-		return listHostMapIndexes(c.f.handshakeManager.pendingHostMap)
+		return listHostMapIndexes(c.f.handshakeManager)
 	} else {
 		return listHostMapIndexes(c.f.hostMap)
 	}
@@ -115,15 +128,15 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 
 // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
 func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
-	var hm *HostMap
+	var hl controlHostLister
 	if pending {
-		hm = c.f.handshakeManager.pendingHostMap
+		hl = c.f.handshakeManager
 	} else {
-		hm = c.f.hostMap
+		hl = c.f.hostMap
 	}
 
-	h, err := hm.QueryVpnIp(vpnIp)
-	if err != nil {
+	h := hl.QueryVpnIp(vpnIp)
+	if h == nil {
 		return nil
 	}
 
@@ -133,8 +146,8 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
 
 // SetRemoteForTunnel forces a tunnel to use a specific remote
 func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
-	hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
-	if err != nil {
+	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+	if hostInfo == nil {
 		return nil
 	}
 
@@ -145,8 +158,8 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control
 
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
 func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
-	hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
-	if err != nil {
+	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+	if hostInfo == nil {
 		return false
 	}
 
@@ -241,28 +254,20 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 	return chi
 }
 
-func listHostMapHosts(hm *HostMap) []ControlHostInfo {
-	hm.RLock()
-	hosts := make([]ControlHostInfo, len(hm.Hosts))
-	i := 0
-	for _, v := range hm.Hosts {
-		hosts[i] = copyHostInfo(v, hm.preferredRanges)
-		i++
-	}
-	hm.RUnlock()
-
+func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
+	hosts := make([]ControlHostInfo, 0)
+	pr := hl.GetPreferredRanges()
+	hl.ForEachVpnIp(func(hostinfo *HostInfo) {
+		hosts = append(hosts, copyHostInfo(hostinfo, pr))
+	})
 	return hosts
 }
 
-func listHostMapIndexes(hm *HostMap) []ControlHostInfo {
-	hm.RLock()
-	hosts := make([]ControlHostInfo, len(hm.Indexes))
-	i := 0
-	for _, v := range hm.Indexes {
-		hosts[i] = copyHostInfo(v, hm.preferredRanges)
-		i++
-	}
-	hm.RUnlock()
-
+func listHostMapIndexes(hl controlHostLister) []ControlHostInfo {
+	hosts := make([]ControlHostInfo, 0)
+	pr := hl.GetPreferredRanges()
+	hl.ForEachIndex(func(hostinfo *HostInfo) {
+		hosts = append(hosts, copyHostInfo(hostinfo, pr))
+	})
 	return hosts
 }

+ 5 - 5
control_test.go

@@ -18,7 +18,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	l := test.NewLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
-	hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
+	hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0))
 	remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
 	remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
 	ipNet := net.IPNet{
@@ -50,7 +50,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	remotes := NewRemoteList(nil)
 	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
 	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
-	hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{
+	hm.unlockedAddHostInfo(&HostInfo{
 		remote:  remote1,
 		remotes: remotes,
 		ConnectionState: &ConnectionState{
@@ -64,9 +64,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 			relayForByIp:  map[iputil.VpnIp]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},
-	})
+	}, &Interface{})
 
-	hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{
+	hm.unlockedAddHostInfo(&HostInfo{
 		remote:  remote1,
 		remotes: remotes,
 		ConnectionState: &ConnectionState{
@@ -80,7 +80,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 			relayForByIp:  map[iputil.VpnIp]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},
-	})
+	}, &Interface{})
 
 	c := Control{
 		f: &Interface{

+ 11 - 21
control_tester.go

@@ -21,7 +21,7 @@ import (
 func (c *Control) WaitForType(msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
 	h := &header.H{}
 	for {
-		p := c.f.outside.Get(true)
+		p := c.f.outside.(*udp.TesterConn).Get(true)
 		if err := h.Parse(p.Data); err != nil {
 			panic(err)
 		}
@@ -37,7 +37,7 @@ func (c *Control) WaitForType(msgType header.MessageType, subType header.Message
 func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
 	h := &header.H{}
 	for {
-		p := c.f.outside.Get(true)
+		p := c.f.outside.(*udp.TesterConn).Get(true)
 		if err := h.Parse(p.Data); err != nil {
 			panic(err)
 		}
@@ -90,11 +90,11 @@ func (c *Control) GetFromTun(block bool) []byte {
 
 // GetFromUDP will pull a udp packet off the udp side of nebula
 func (c *Control) GetFromUDP(block bool) *udp.Packet {
-	return c.f.outside.Get(block)
+	return c.f.outside.(*udp.TesterConn).Get(block)
 }
 
 func (c *Control) GetUDPTxChan() <-chan *udp.Packet {
-	return c.f.outside.TxPackets
+	return c.f.outside.(*udp.TesterConn).TxPackets
 }
 
 func (c *Control) GetTunTxChan() <-chan []byte {
@@ -103,7 +103,7 @@ func (c *Control) GetTunTxChan() <-chan []byte {
 
 // InjectUDPPacket will inject a packet into the udp side of nebula
 func (c *Control) InjectUDPPacket(p *udp.Packet) {
-	c.f.outside.Send(p)
+	c.f.outside.(*udp.TesterConn).Send(p)
 }
 
 // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
@@ -143,16 +143,16 @@ func (c *Control) GetVpnIp() iputil.VpnIp {
 }
 
 func (c *Control) GetUDPAddr() string {
-	return c.f.outside.Addr.String()
+	return c.f.outside.(*udp.TesterConn).Addr.String()
 }
 
 func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
-	hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)]
-	if !ok {
+	hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp))
+	if hostinfo == nil {
 		return false
 	}
 
-	c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
+	c.f.handshakeManager.DeleteHostInfo(hostinfo)
 	return true
 }
 
@@ -161,19 +161,9 @@ func (c *Control) GetHostmap() *HostMap {
 }
 
 func (c *Control) GetCert() *cert.NebulaCertificate {
-	return c.f.certState.Load().certificate
+	return c.f.pki.GetCertState().Certificate
 }
 
 func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
-	hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo)
-	ixHandshakeStage0(c.f, vpnIp, hostinfo)
-
-	// If this is a static host, we don't need to wait for the HostQueryReply
-	// We can trigger the handshake right now
-	if _, ok := c.f.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
-		select {
-		case c.f.handshakeManager.trigger <- hostinfo.vpnIp:
-		default:
-		}
-	}
+	c.f.handshakeManager.StartHandshake(vpnIp, nil)
 }

+ 2 - 0
dist/arch/nebula.service

@@ -4,6 +4,8 @@ Wants=basic.target network-online.target nss-lookup.target time-sync.target
 After=basic.target network.target network-online.target
 
 [Service]
+Type=notify
+NotifyAccess=main
 SyslogIdentifier=nebula
 ExecReload=/bin/kill -HUP $MAINPID
 ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml

+ 2 - 0
dist/fedora/nebula.service

@@ -5,6 +5,8 @@ After=basic.target network.target network-online.target
 Before=sshd.service
 
 [Service]
+Type=notify
+NotifyAccess=main
 SyslogIdentifier=nebula
 ExecReload=/bin/kill -HUP $MAINPID
 ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml

+ 2 - 2
dns_server.go

@@ -47,8 +47,8 @@ func (d *dnsRecords) QueryCert(data string) string {
 		return ""
 	}
 	iip := iputil.Ip2VpnIp(ip)
-	hostinfo, err := d.hostMap.QueryVpnIp(iip)
-	if err != nil {
+	hostinfo := d.hostMap.QueryVpnIp(iip)
+	if hostinfo == nil {
 		return ""
 	}
 	q := hostinfo.GetCert()

+ 106 - 0
e2e/handshakes_test.go

@@ -410,6 +410,8 @@ func TestStage1RaceRelays(t *testing.T) {
 	p := r.RouteForAllUntilTxTun(myControl)
 	_ = p
 
+	r.FlushAll()
+
 	myControl.Stop()
 	theirControl.Stop()
 	relayControl.Stop()
@@ -608,6 +610,110 @@ func TestRehandshakingRelays(t *testing.T) {
 	t.Logf("relayControl hostinfos got cleaned up!")
 }
 
+func TestRehandshakingRelaysPrimary(t *testing.T) {
+	// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, relayControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	relayControl.Start()
+	theirControl.Start()
+
+	t.Log("Trigger a handshake from me to them via the relay")
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+
+	p := r.RouteForAllUntilTxTun(theirControl)
+	r.Log("Assert the tunnel works")
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
+
+	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
+	// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
+	r.Log("Renew relay certificate and spin until me and them sees it")
+	_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
+
+	caB, err := ca.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	relayConfig.Settings["pki"] = m{
+		"ca":   string(caB),
+		"cert": string(myNextPEM),
+		"key":  string(myNextPrivKey),
+	}
+	rc, err := yaml.Marshal(relayConfig.Settings)
+	assert.NoError(t, err)
+	relayConfig.ReloadConfigString(string(rc))
+
+	for {
+		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
+		assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+		c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		if len(c.Cert.Details.Groups) != 0 {
+			// We have a new certificate now
+			r.Log("Certificate between my and relay is updated!")
+			break
+		}
+
+		time.Sleep(time.Second)
+	}
+
+	for {
+		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
+		assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+		c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		if len(c.Cert.Details.Groups) != 0 {
+			// We have a new certificate now
+			r.Log("Certificate between their and relay is updated!")
+			break
+		}
+
+		time.Sleep(time.Second)
+	}
+
+	r.Log("Assert the relay tunnel still works")
+	assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
+	// We should have two hostinfos on all sides
+	for len(myControl.GetHostmap().Indexes) != 2 {
+		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
+		r.Log("Assert the relay tunnel still works")
+		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		r.Log("yupitdoes")
+		time.Sleep(time.Second)
+	}
+	t.Logf("myControl hostinfos got cleaned up!")
+	for len(theirControl.GetHostmap().Indexes) != 2 {
+		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
+		r.Log("Assert the relay tunnel still works")
+		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		r.Log("yupitdoes")
+		time.Sleep(time.Second)
+	}
+	t.Logf("theirControl hostinfos got cleaned up!")
+	for len(relayControl.GetHostmap().Indexes) != 2 {
+		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
+		r.Log("Assert the relay tunnel still works")
+		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		r.Log("yupitdoes")
+		time.Sleep(time.Second)
+	}
+	t.Logf("relayControl hostinfos got cleaned up!")
+}
+
 func TestRehandshaking(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)

+ 1 - 1
e2e/helpers_test.go

@@ -12,9 +12,9 @@ import (
 	"testing"
 	"time"
 
+	"dario.cat/mergo"
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
-	"github.com/imdario/mergo"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/cert"

+ 16 - 2
examples/config.yml

@@ -21,6 +21,19 @@ pki:
 static_host_map:
   "192.168.100.1": ["100.64.22.11:4242"]
 
+# The static_map config stanza can be used to configure how the static_host_map behaves.
+#static_map:
+  # cadence determines how frequently DNS is re-queried for updated IP addresses when a static_host_map entry contains
+  # a DNS name.
+  #cadence: 30s
+
+  # network determines the type of IP addresses to ask the DNS server for. The default is "ip4" because nodes typically
+  # do not know their public IPv4 address. Connecting to the Lighthouse via IPv4 allows the Lighthouse to detect the
+  # public address. Other valid options are "ip6" and "ip" (returns both.)
+  #network: ip4
+
+  # lookup_timeout is the DNS query timeout.
+  #lookup_timeout: 250ms
 
 lighthouse:
   # am_lighthouse is used to enable lighthouse functionality for a node. This should ONLY be true on nodes
@@ -158,7 +171,8 @@ punchy:
 # and has been deprecated for "preferred_ranges"
 #preferred_ranges: ["172.16.0.0/24"]
 
-# sshd can expose informational and administrative functions via ssh this is a
+# sshd can expose informational and administrative functions via ssh. This can expose informational and administrative
+# functions, and allows manual tweaking of various network settings when debugging or testing.
 #sshd:
   # Toggles the feature
   #enabled: true
@@ -194,7 +208,7 @@ tun:
   disabled: false
   # Name of the device. If not set, a default will be chosen by the OS.
   # For macOS: if set, must be in the form `utun[0-9]+`.
-  # For FreeBSD: Required to be set, must be in the form `tun[0-9]+`.
+  # For NetBSD: Required to be set, must be in the form `tun[0-9]+`
   dev: nebula1
   # Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert
   drop_local_broadcast: false

+ 2 - 0
examples/service_scripts/nebula.service

@@ -5,6 +5,8 @@ After=basic.target network.target network-online.target
 Before=sshd.service
 
 [Service]
+Type=notify
+NotifyAccess=main
 SyslogIdentifier=nebula
 ExecReload=/bin/kill -HUP $MAINPID
 ExecStart=/usr/local/bin/nebula -config /etc/nebula/config.yml

+ 14 - 13
go.mod

@@ -3,31 +3,32 @@ module github.com/slackhq/nebula
 go 1.20
 
 require (
+	dario.cat/mergo v1.0.0
 	github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
 	github.com/armon/go-radix v1.0.0
 	github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
 	github.com/flynn/noise v1.0.0
 	github.com/gogo/protobuf v1.3.2
 	github.com/google/gopacket v1.1.19
-	github.com/imdario/mergo v0.3.15
 	github.com/kardianos/service v1.2.2
-	github.com/miekg/dns v1.1.54
+	github.com/miekg/dns v1.1.56
 	github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
-	github.com/prometheus/client_golang v1.15.1
+	github.com/prometheus/client_golang v1.16.0
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
-	github.com/sirupsen/logrus v1.9.0
+	github.com/sirupsen/logrus v1.9.3
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
-	github.com/stretchr/testify v1.8.2
+	github.com/stretchr/testify v1.8.4
 	github.com/vishvananda/netlink v1.1.0
-	golang.org/x/crypto v0.8.0
+	golang.org/x/crypto v0.14.0
 	golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53
-	golang.org/x/net v0.9.0
-	golang.org/x/sys v0.8.0
-	golang.org/x/term v0.8.0
+	golang.org/x/net v0.17.0
+	golang.org/x/sys v0.13.0
+	golang.org/x/term v0.13.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
+	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
 	golang.zx2c4.com/wireguard/windows v0.5.3
-	google.golang.org/protobuf v1.30.0
+	google.golang.org/protobuf v1.31.0
 	gopkg.in/yaml.v2 v2.4.0
 )
 
@@ -40,10 +41,10 @@ require (
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/prometheus/client_model v0.4.0 // indirect
 	github.com/prometheus/common v0.42.0 // indirect
-	github.com/prometheus/procfs v0.9.0 // indirect
+	github.com/prometheus/procfs v0.10.1 // indirect
 	github.com/rogpeppe/go-internal v1.10.0 // indirect
 	github.com/vishvananda/netns v0.0.4 // indirect
-	golang.org/x/mod v0.10.0 // indirect
-	golang.org/x/tools v0.8.0 // indirect
+	golang.org/x/mod v0.12.0 // indirect
+	golang.org/x/tools v0.13.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )

+ 29 - 31
go.sum

@@ -1,4 +1,6 @@
 cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
+dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
+dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
 github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
 github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
 github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@@ -54,8 +56,6 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
 github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
 github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
-github.com/imdario/mergo v0.3.15 h1:M8XP7IuFNsqUx6VPK2P9OSmsYsI/YFaGil0uD21V3dM=
-github.com/imdario/mergo v0.3.15/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
 github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
 github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
 github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
@@ -78,8 +78,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
 github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
 github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
-github.com/miekg/dns v1.1.54 h1:5jon9mWcb0sFJGpnI99tOMhCPyJ+RPVz5b63MQG0VWI=
-github.com/miekg/dns v1.1.54/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
+github.com/miekg/dns v1.1.56 h1:5imZaSeoRNvpM9SzWNhEcP9QliKiz20/dA2QabIGVnE=
+github.com/miekg/dns v1.1.56/go.mod h1:cRm6Oo2C8TY9ZS/TqsSrseAcncm74lfK5G+ikN2SWWY=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
@@ -97,8 +97,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
 github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
 github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
 github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
-github.com/prometheus/client_golang v1.15.1 h1:8tXpTmJbyH5lydzFPoxSIJ0J46jdh3tylbvM1xCv0LI=
-github.com/prometheus/client_golang v1.15.1/go.mod h1:e9yaBhRPU2pPNsZwE+JdQl0KEt1N9XgF6zxWmaC0xOk=
+github.com/prometheus/client_golang v1.16.0 h1:yk/hx9hDbrGHovbci4BY+pRMfSuuat626eFsHb7tmT8=
+github.com/prometheus/client_golang v1.16.0/go.mod h1:Zsulrv/L9oM40tJ7T815tM89lFEugiJ9HzIqaAx4LKc=
 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
 github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -113,8 +113,8 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R
 github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
 github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
 github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
-github.com/prometheus/procfs v0.9.0 h1:wzCHvIvM5SxWqYvwgVL7yJY8Lz3PKn49KQtpgMYJfhI=
-github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY=
+github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg=
+github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
 github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -122,24 +122,20 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj
 github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
 github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
 github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
-github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
-github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
+github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
+github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
-github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
 github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
 github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
-github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
-github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
+github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
+github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
 github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0=
 github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
 github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
@@ -152,16 +148,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
-golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ=
-golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
+golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
+golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o=
 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
 golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
 golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
 golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
 golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
-golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
+golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
+golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
 golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -172,8 +168,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
 golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
 golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
-golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
+golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
+golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
 golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -181,7 +177,7 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
+golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
 golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -198,11 +194,11 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
 golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
-golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
+golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols=
-golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
+golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
+golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -211,14 +207,16 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
 golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
 golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
 golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y=
-golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4=
+golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
+golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
+golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
+golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
 golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
 golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
 google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
@@ -230,8 +228,8 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi
 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
-google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
-google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
+google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
+google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

+ 1 - 1
handshake.go

@@ -20,7 +20,7 @@ func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packe
 		case 1:
 			ixHandshakeStage1(f, addr, via, packet, h)
 		case 2:
-			newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
+			newHostinfo := f.handshakeManager.QueryIndex(h.RemoteIndex)
 			tearDown := ixHandshakeStage2(f, addr, via, newHostinfo, packet, h)
 			if tearDown && newHostinfo != nil {
 				f.handshakeManager.DeleteHostInfo(newHostinfo)

+ 40 - 47
handshake_ix.go

@@ -13,27 +13,22 @@ import (
 
 // This function constructs a handshake packet, but does not actually send it
 // Sending is done by the handshake manager
-func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
-	// This queries the lighthouse if we don't know a remote for the host
-	// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
-	// more quickly, effect is a quicker handshake.
-	if hostinfo.remote == nil {
-		f.lightHouse.QueryServer(vpnIp, f)
-	}
-
-	err := f.handshakeManager.AddIndexHostInfo(hostinfo)
+func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
+	err := f.handshakeManager.allocateIndex(hostinfo)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", vpnIp).
+		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
-		return
+		return false
 	}
 
-	ci := hostinfo.ConnectionState
+	certState := f.pki.GetCertState()
+	ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
+	hostinfo.ConnectionState = ci
 
 	hsProto := &NebulaHandshakeDetails{
 		InitiatorIndex: hostinfo.localIndexId,
 		Time:           uint64(time.Now().UnixNano()),
-		Cert:           ci.certState.rawCertificateNoKey,
+		Cert:           certState.RawCertificateNoKey,
 	}
 
 	if f.multiPort.Tx || f.multiPort.Rx {
@@ -53,9 +48,9 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
 	hsBytes, err = hs.Marshal()
 
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", vpnIp).
+		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
-		return
+		return false
 	}
 
 	h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
@@ -63,9 +58,9 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
 
 	msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", vpnIp).
+		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
-		return
+		return false
 	}
 
 	// We are sending handshake packet 1, so we don't expect to receive
@@ -75,10 +70,12 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
 	hostinfo.HandshakePacket[0] = msg
 	hostinfo.HandshakeReady = true
 	hostinfo.handshakeStart = time.Now()
+	return true
 }
 
 func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
-	ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
+	certState := f.pki.GetCertState()
+	ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
 	// Mark packet 1 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 1)
 
@@ -100,7 +97,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		return
 	}
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
+	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
 	if err != nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
@@ -190,7 +187,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		Info("Handshake message received")
 
 	hs.Details.ResponderIndex = myIndex
-	hs.Details.Cert = ci.certState.rawCertificateNoKey
+	hs.Details.Cert = certState.RawCertificateNoKey
 	// Update the time in case their clock is way off from ours
 	hs.Details.Time = uint64(time.Now().UnixNano())
 
@@ -467,7 +464,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 		}
 	}
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
+	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
 	if err != nil {
 		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@@ -490,34 +487,30 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 			Info("Incorrect host responded to handshake")
 
 		// Release our old handshake from pending, it should not continue
-		f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
+		f.handshakeManager.DeleteHostInfo(hostinfo)
 
 		// Create a new hostinfo/handshake for the intended vpn ip
-		//TODO: this adds it to the timer wheel in a way that aggressively retries
-		newHostInfo := f.getOrHandshake(hostinfo.vpnIp)
-		newHostInfo.Lock()
-
-		// Block the current used address
-		newHostInfo.remotes = hostinfo.remotes
-		newHostInfo.remotes.BlockRemote(addr)
-
-		// Get the correct remote list for the host we did handshake with
-		hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
-
-		f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
-			WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
-			Info("Blocked addresses for handshakes")
-
-		// Swap the packet store to benefit the original intended recipient
-		hostinfo.ConnectionState.queueLock.Lock()
-		newHostInfo.packetStore = hostinfo.packetStore
-		hostinfo.packetStore = []*cachedPacket{}
-		hostinfo.ConnectionState.queueLock.Unlock()
-
-		// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
-		hostinfo.vpnIp = vpnIp
-		f.sendCloseTunnel(hostinfo)
-		newHostInfo.Unlock()
+		f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHostInfo *HostInfo) {
+			//TODO: this doesnt know if its being added or is being used for caching a packet
+			// Block the current used address
+			newHostInfo.remotes = hostinfo.remotes
+			newHostInfo.remotes.BlockRemote(addr)
+
+			// Get the correct remote list for the host we did handshake with
+			hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
+
+			f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
+				WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
+				Info("Blocked addresses for handshakes")
+
+			// Swap the packet store to benefit the original intended recipient
+			newHostInfo.packetStore = hostinfo.packetStore
+			hostinfo.packetStore = []*cachedPacket{}
+
+			// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
+			hostinfo.vpnIp = vpnIp
+			f.sendCloseTunnel(hostinfo)
+		})
 
 		return true
 	}

+ 190 - 73
handshake_manager.go

@@ -7,6 +7,7 @@ import (
 	"encoding/binary"
 	"errors"
 	"net"
+	"sync"
 	"time"
 
 	"github.com/rcrowley/go-metrics"
@@ -42,15 +43,21 @@ type HandshakeConfig struct {
 }
 
 type HandshakeManager struct {
-	pendingHostMap         *HostMap
+	// Mutex for interacting with the vpnIps and indexes maps
+	sync.RWMutex
+
+	vpnIps  map[iputil.VpnIp]*HostInfo
+	indexes map[uint32]*HostInfo
+
 	mainHostMap            *HostMap
 	lightHouse             *LightHouse
-	outside                *udp.Conn
+	outside                udp.Conn
 	config                 HandshakeConfig
 	OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp]
 	messageMetrics         *MessageMetrics
 	metricInitiated        metrics.Counter
 	metricTimedOut         metrics.Counter
+	f                      *Interface
 	l                      *logrus.Logger
 
 	multiPort MultiPortConfig
@@ -60,9 +67,10 @@ type HandshakeManager struct {
 	trigger chan iputil.VpnIp
 }
 
-func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udp.Conn, config HandshakeConfig) *HandshakeManager {
+func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
 	return &HandshakeManager{
-		pendingHostMap:         NewHostMap(l, "pending", tunCidr, preferredRanges),
+		vpnIps:                 map[iputil.VpnIp]*HostInfo{},
+		indexes:                map[uint32]*HostInfo{},
 		mainHostMap:            mainHostMap,
 		lightHouse:             lightHouse,
 		outside:                outside,
@@ -76,7 +84,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
 	}
 }
 
-func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
+func (c *HandshakeManager) Run(ctx context.Context) {
 	clockSource := time.NewTicker(c.config.tryInterval)
 	defer clockSource.Stop()
 
@@ -85,27 +93,27 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
 		case <-ctx.Done():
 			return
 		case vpnIP := <-c.trigger:
-			c.handleOutbound(vpnIP, f, true)
+			c.handleOutbound(vpnIP, true)
 		case now := <-clockSource.C:
-			c.NextOutboundHandshakeTimerTick(now, f)
+			c.NextOutboundHandshakeTimerTick(now)
 		}
 	}
 }
 
-func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
+func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
 	c.OutboundHandshakeTimer.Advance(now)
 	for {
 		vpnIp, has := c.OutboundHandshakeTimer.Purge()
 		if !has {
 			break
 		}
-		c.handleOutbound(vpnIp, f, false)
+		c.handleOutbound(vpnIp, false)
 	}
 }
 
-func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) {
-	hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
-	if err != nil {
+func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
+	hostinfo := c.QueryVpnIp(vpnIp)
+	if hostinfo == nil {
 		return
 	}
 	hostinfo.Lock()
@@ -114,31 +122,34 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 	// We may have raced to completion but now that we have a lock we should ensure we have not yet completed.
 	if hostinfo.HandshakeComplete {
 		// Ensure we don't exist in the pending hostmap anymore since we have completed
-		c.pendingHostMap.DeleteHostInfo(hostinfo)
-		return
-	}
-
-	// Check if we have a handshake packet to transmit yet
-	if !hostinfo.HandshakeReady {
-		// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
-		// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
-		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+		c.DeleteHostInfo(hostinfo)
 		return
 	}
 
 	// If we are out of time, clean up
 	if hostinfo.HandshakeCounter >= c.config.retries {
-		hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)).
+		hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)).
 			WithField("initiatorIndex", hostinfo.localIndexId).
 			WithField("remoteIndex", hostinfo.remoteIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()).
 			Info("Handshake timed out")
 		c.metricTimedOut.Inc(1)
-		c.pendingHostMap.DeleteHostInfo(hostinfo)
+		c.DeleteHostInfo(hostinfo)
 		return
 	}
 
+	// Increment the counter to increase our delay, linear backoff
+	hostinfo.HandshakeCounter++
+
+	// Check if we have a handshake packet to transmit yet
+	if !hostinfo.HandshakeReady {
+		if !ixHandshakeStage0(c.f, hostinfo) {
+			c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+			return
+		}
+	}
+
 	// Get a remotes object if we don't already have one.
 	// This is mainly to protect us as this should never be the case
 	// NB ^ This comment doesn't jive. It's how the thing gets initialized.
@@ -147,7 +158,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 		hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
 	}
 
-	remotes := hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)
+	remotes := hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)
 	remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes)
 
 	// We only care about a lighthouse trigger if we have new remotes to send to.
@@ -166,15 +177,15 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 		// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
 		// Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
 		// the learned public ip for them. Query again to short circuit the promotion counter
-		c.lightHouse.QueryServer(vpnIp, f)
+		c.lightHouse.QueryServer(vpnIp, c.f)
 	}
 
 	// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
 	var sentTo []*udp.Addr
 	var sentMultiport bool
-	hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
+	hostinfo.remotes.ForEach(c.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
 		c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
-		err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
+		err := c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
 		if err != nil {
 			hostinfo.logger(c.l).WithField("udpAddr", addr).
 				WithField("initiatorIndex", hostinfo.localIndexId).
@@ -230,10 +241,10 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 			if *relay == vpnIp || *relay == c.lightHouse.myVpnIp {
 				continue
 			}
-			relayHostInfo, err := c.mainHostMap.QueryVpnIp(*relay)
-			if err != nil || relayHostInfo.remote == nil {
-				hostinfo.logger(c.l).WithError(err).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
-				f.Handshake(*relay)
+			relayHostInfo := c.mainHostMap.QueryVpnIp(*relay)
+			if relayHostInfo == nil || relayHostInfo.remote == nil {
+				hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
+				c.f.Handshake(*relay)
 				continue
 			}
 			// Check the relay HostInfo to see if we already established a relay through it
@@ -241,7 +252,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 				switch existingRelay.State {
 				case Established:
 					hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Send handshake via relay")
-					f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
+					c.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
 				case Requested:
 					hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
 					// Re-send the CreateRelay request, in case the previous one was lost.
@@ -258,7 +269,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 							Error("Failed to marshal Control message to create relay")
 					} else {
 						// This must send over the hostinfo, not over hm.Hosts[ip]
-						f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+						c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						c.l.WithFields(logrus.Fields{
 							"relayFrom":           c.lightHouse.myVpnIp,
 							"relayTo":             vpnIp,
@@ -293,7 +304,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 							WithError(err).
 							Error("Failed to marshal Control message to create relay")
 					} else {
-						f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+						c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						c.l.WithFields(logrus.Fields{
 							"relayFrom":           c.lightHouse.myVpnIp,
 							"relayTo":             vpnIp,
@@ -306,23 +317,78 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 		}
 	}
 
-	// Increment the counter to increase our delay, linear backoff
-	hostinfo.HandshakeCounter++
-
 	// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
 	if !lighthouseTriggered {
 		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
 	}
 }
 
-func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
-	hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init)
+// GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
+// The 2nd argument will be true if the hostinfo is ready to transmit traffic
+func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) (*HostInfo, bool) {
+	// Check the main hostmap and maintain a read lock if our host is not there
+	hm.mainHostMap.RLock()
+	if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok {
+		hm.mainHostMap.RUnlock()
+		// Do not attempt promotion if you are a lighthouse
+		if !hm.lightHouse.amLighthouse {
+			h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f)
+		}
+		return h, true
+	}
+
+	defer hm.mainHostMap.RUnlock()
+	return hm.StartHandshake(vpnIp, cacheCb), false
+}
+
+// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
+func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) *HostInfo {
+	hm.Lock()
+
+	if hostinfo, ok := hm.vpnIps[vpnIp]; ok {
+		// We are already trying to handshake with this vpn ip
+		if cacheCb != nil {
+			cacheCb(hostinfo)
+		}
+		hm.Unlock()
+		return hostinfo
+	}
+
+	hostinfo := &HostInfo{
+		vpnIp:           vpnIp,
+		HandshakePacket: make(map[uint8][]byte, 0),
+		relayState: RelayState{
+			relays:        map[iputil.VpnIp]struct{}{},
+			relayForByIp:  map[iputil.VpnIp]*Relay{},
+			relayForByIdx: map[uint32]*Relay{},
+		},
+	}
+
+	hm.vpnIps[vpnIp] = hostinfo
+	hm.metricInitiated.Inc(1)
+	hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
 
-	if created {
-		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
-		c.metricInitiated.Inc(1)
+	if cacheCb != nil {
+		cacheCb(hostinfo)
 	}
 
+	// If this is a static host, we don't need to wait for the HostQueryReply
+	// We can trigger the handshake right now
+	_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp]
+	if !doTrigger {
+		// Add any calculated remotes, and trigger early handshake if one found
+		doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp)
+	}
+
+	if doTrigger {
+		select {
+		case hm.trigger <- vpnIp:
+		default:
+		}
+	}
+
+	hm.Unlock()
+	hm.lightHouse.QueryServer(vpnIp, hm.f)
 	return hostinfo
 }
 
@@ -344,10 +410,10 @@ var (
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
 func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
-	c.pendingHostMap.Lock()
-	defer c.pendingHostMap.Unlock()
 	c.mainHostMap.Lock()
 	defer c.mainHostMap.Unlock()
+	c.Lock()
+	defer c.Unlock()
 
 	// Check if we already have a tunnel with this vpn ip
 	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
@@ -376,7 +442,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 		return existingIndex, ErrLocalIndexCollision
 	}
 
-	existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
+	existingIndex, found = c.indexes[hostinfo.localIndexId]
 	if found && existingIndex != hostinfo {
 		// We have a collision, but for a different hostinfo
 		return existingIndex, ErrLocalIndexCollision
@@ -398,47 +464,47 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 // Complete is a simpler version of CheckAndComplete when we already know we
 // won't have a localIndexId collision because we already have an entry in the
 // pendingHostMap. An existing hostinfo is returned if there was one.
-func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
-	c.pendingHostMap.Lock()
-	defer c.pendingHostMap.Unlock()
-	c.mainHostMap.Lock()
-	defer c.mainHostMap.Unlock()
+func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
+	hm.mainHostMap.Lock()
+	defer hm.mainHostMap.Unlock()
+	hm.Lock()
+	defer hm.Unlock()
 
-	existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
+	existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
 	if found && existingRemoteIndex != nil {
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
-		hostinfo.logger(c.l).
+		hostinfo.logger(hm.l).
 			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
 			Info("New host shadows existing host remoteIndex")
 	}
 
 	// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
-	c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
-	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
+	hm.unlockedDeleteHostInfo(hostinfo)
+	hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
 }
 
-// AddIndexHostInfo generates a unique localIndexId for this HostInfo
+// allocateIndex generates a unique localIndexId for this HostInfo
 // and adds it to the pendingHostMap. Will error if we are unable to generate
 // a unique localIndexId
-func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
-	c.pendingHostMap.Lock()
-	defer c.pendingHostMap.Unlock()
-	c.mainHostMap.RLock()
-	defer c.mainHostMap.RUnlock()
+func (hm *HandshakeManager) allocateIndex(h *HostInfo) error {
+	hm.mainHostMap.RLock()
+	defer hm.mainHostMap.RUnlock()
+	hm.Lock()
+	defer hm.Unlock()
 
 	for i := 0; i < 32; i++ {
-		index, err := generateIndex(c.l)
+		index, err := generateIndex(hm.l)
 		if err != nil {
 			return err
 		}
 
-		_, inPending := c.pendingHostMap.Indexes[index]
-		_, inMain := c.mainHostMap.Indexes[index]
+		_, inPending := hm.indexes[index]
+		_, inMain := hm.mainHostMap.Indexes[index]
 
 		if !inMain && !inPending {
 			h.localIndexId = index
-			c.pendingHostMap.Indexes[index] = h
+			hm.indexes[index] = h
 			return nil
 		}
 	}
@@ -446,22 +512,73 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
 	return errors.New("failed to generate unique localIndexId")
 }
 
-func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
-	c.pendingHostMap.addRemoteIndexHostInfo(index, h)
+func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
+	c.Lock()
+	defer c.Unlock()
+	c.unlockedDeleteHostInfo(hostinfo)
 }
 
-func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
-	//l.Debugln("Deleting pending hostinfo :", hostinfo)
-	c.pendingHostMap.DeleteHostInfo(hostinfo)
+func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
+	delete(c.vpnIps, hostinfo.vpnIp)
+	if len(c.vpnIps) == 0 {
+		c.vpnIps = map[iputil.VpnIp]*HostInfo{}
+	}
+
+	delete(c.indexes, hostinfo.localIndexId)
+	if len(c.vpnIps) == 0 {
+		c.indexes = map[uint32]*HostInfo{}
+	}
+
+	if c.l.Level >= logrus.DebugLevel {
+		c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
+			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+			Debug("Pending hostmap hostInfo deleted")
+	}
+}
+
+func (c *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+	c.RLock()
+	defer c.RUnlock()
+	return c.vpnIps[vpnIp]
+}
+
+func (c *HandshakeManager) QueryIndex(index uint32) *HostInfo {
+	c.RLock()
+	defer c.RUnlock()
+	return c.indexes[index]
 }
 
-func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) {
-	return c.pendingHostMap.QueryIndex(index)
+func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
+	return c.mainHostMap.preferredRanges
+}
+
+func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
+	c.RLock()
+	defer c.RUnlock()
+
+	for _, v := range c.vpnIps {
+		f(v)
+	}
+}
+
+func (c *HandshakeManager) ForEachIndex(f controlEach) {
+	c.RLock()
+	defer c.RUnlock()
+
+	for _, v := range c.indexes {
+		f(v)
+	}
 }
 
 func (c *HandshakeManager) EmitStats() {
-	c.pendingHostMap.EmitStats("pending")
-	c.mainHostMap.EmitStats("main")
+	c.RLock()
+	hostLen := len(c.vpnIps)
+	indexLen := len(c.indexes)
+	c.RUnlock()
+
+	metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
+	metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
+	c.mainHostMap.EmitStats()
 }
 
 // Utility functions below

+ 10 - 21
handshake_manager_test.go

@@ -14,31 +14,20 @@ import (
 
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	l := test.NewLogger()
-	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
-	mw := &mockEncWriter{}
-	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
+	mainHM := NewHostMap(l, vpncidr, preferredRanges)
 	lh := newTestLighthouse()
 
-	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig)
+	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
 
 	now := time.Now()
-	blah.NextOutboundHandshakeTimerTick(now, mw)
+	blah.NextOutboundHandshakeTimerTick(now)
 
-	var initCalled bool
-	initFunc := func(*HostInfo) {
-		initCalled = true
-	}
-
-	i := blah.AddVpnIp(ip, initFunc)
-	assert.True(t, initCalled)
-
-	initCalled = false
-	i2 := blah.AddVpnIp(ip, initFunc)
-	assert.False(t, initCalled)
+	i := blah.StartHandshake(ip, nil)
+	i2 := blah.StartHandshake(ip, nil)
 	assert.Same(t, i, i2)
 
 	i.remotes = NewRemoteList(nil)
@@ -48,22 +37,22 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	assert.Len(t, mainHM.Hosts, 0)
 
 	// Confirm they are in the pending index list
-	assert.Contains(t, blah.pendingHostMap.Hosts, ip)
+	assert.Contains(t, blah.vpnIps, ip)
 
 	// Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right
 	for i := 1; i <= DefaultHandshakeRetries+1; i++ {
 		now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval)
-		blah.NextOutboundHandshakeTimerTick(now, mw)
+		blah.NextOutboundHandshakeTimerTick(now)
 	}
 
 	// Confirm they are still in the pending index list
-	assert.Contains(t, blah.pendingHostMap.Hosts, ip)
+	assert.Contains(t, blah.vpnIps, ip)
 
 	// Tick 1 more time, a minute will certainly flush it out
-	blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw)
+	blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute))
 
 	// Confirm they have been removed
-	assert.NotContains(t, blah.pendingHostMap.Hosts, ip)
+	assert.NotContains(t, blah.vpnIps, ip)
 }
 
 func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {

+ 67 - 147
hostmap.go

@@ -2,7 +2,6 @@ package nebula
 
 import (
 	"errors"
-	"fmt"
 	"net"
 	"sync"
 	"sync/atomic"
@@ -18,8 +17,9 @@ import (
 )
 
 // const ProbeLen = 100
-const PromoteEvery = 1000
-const ReQueryEvery = 5000
+const defaultPromoteEvery = 1000       // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
+const defaultReQueryEvery = 5000       // Count of packets sent before re-querying a hostinfo to the lighthouse
+const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
 const MaxRemotes = 10
 
 // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
@@ -52,7 +52,6 @@ type Relay struct {
 
 type HostMap struct {
 	sync.RWMutex    //Because we concurrently read and write to our maps
-	name            string
 	Indexes         map[uint32]*HostInfo
 	Relays          map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
 	RemoteIndexes   map[uint32]*HostInfo
@@ -205,13 +204,13 @@ type HostInfo struct {
 	multiportTx          bool
 	multiportRx          bool
 	ConnectionState      *ConnectionState
-	handshakeStart       time.Time        //todo: this an entry in the handshake manager
-	HandshakeReady       bool             //todo: being in the manager means you are ready
-	HandshakeCounter     int              //todo: another handshake manager entry
-	HandshakeLastRemotes []*udp.Addr      //todo: another handshake manager entry, which remotes we sent to last time
-	HandshakeComplete    bool             //todo: this should go away in favor of ConnectionState.ready
-	HandshakePacket      map[uint8][]byte //todo: this is other handshake manager entry
-	packetStore          []*cachedPacket  //todo: this is other handshake manager entry
+	handshakeStart       time.Time   //todo: this an entry in the handshake manager
+	HandshakeReady       bool        //todo: being in the manager means you are ready
+	HandshakeCounter     int         //todo: another handshake manager entry
+	HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time
+	HandshakeComplete    bool        //todo: this should go away in favor of ConnectionState.ready
+	HandshakePacket      map[uint8][]byte
+	packetStore          []*cachedPacket //todo: this is other handshake manager entry
 	remoteIndexId        uint32
 	localIndexId         uint32
 	vpnIp                iputil.VpnIp
@@ -219,6 +218,10 @@ type HostInfo struct {
 	remoteCidr           *cidr.Tree4
 	relayState           RelayState
 
+	// nextLHQuery is the earliest we can ask the lighthouse for new information.
+	// This is used to limit lighthouse re-queries in chatty clients
+	nextLHQuery atomic.Int64
+
 	// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
 	// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
 	// with a handshake
@@ -257,13 +260,12 @@ type cachedPacketMetrics struct {
 	dropped metrics.Counter
 }
 
-func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
+func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
 	h := map[iputil.VpnIp]*HostInfo{}
 	i := map[uint32]*HostInfo{}
 	r := map[uint32]*HostInfo{}
 	relays := map[uint32]*HostInfo{}
 	m := HostMap{
-		name:            name,
 		Indexes:         i,
 		Relays:          relays,
 		RemoteIndexes:   r,
@@ -275,8 +277,8 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
 	return &m
 }
 
-// UpdateStats takes a name and reports host and index counts to the stats collection system
-func (hm *HostMap) EmitStats(name string) {
+// EmitStats reports host, index, and relay counts to the stats collection system
+func (hm *HostMap) EmitStats() {
 	hm.RLock()
 	hostLen := len(hm.Hosts)
 	indexLen := len(hm.Indexes)
@@ -284,10 +286,10 @@ func (hm *HostMap) EmitStats(name string) {
 	relaysLen := len(hm.Relays)
 	hm.RUnlock()
 
-	metrics.GetOrRegisterGauge("hostmap."+name+".hosts", nil).Update(int64(hostLen))
-	metrics.GetOrRegisterGauge("hostmap."+name+".indexes", nil).Update(int64(indexLen))
-	metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
-	metrics.GetOrRegisterGauge("hostmap."+name+".relayIndexes", nil).Update(int64(relaysLen))
+	metrics.GetOrRegisterGauge("hostmap.main.hosts", nil).Update(int64(hostLen))
+	metrics.GetOrRegisterGauge("hostmap.main.indexes", nil).Update(int64(indexLen))
+	metrics.GetOrRegisterGauge("hostmap.main.remoteIndexes", nil).Update(int64(remoteIndexLen))
+	metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
 }
 
 func (hm *HostMap) RemoveRelay(localIdx uint32) {
@@ -301,88 +303,6 @@ func (hm *HostMap) RemoveRelay(localIdx uint32) {
 	hm.Unlock()
 }
 
-func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) {
-	hm.RLock()
-	if i, ok := hm.Hosts[vpnIp]; ok {
-		index := i.localIndexId
-		hm.RUnlock()
-		return index, nil
-	}
-	hm.RUnlock()
-	return 0, errors.New("vpn IP not found")
-}
-
-func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
-	hm.Lock()
-	hm.Hosts[ip] = hostinfo
-	hm.Unlock()
-}
-
-func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) (hostinfo *HostInfo, created bool) {
-	hm.RLock()
-	if h, ok := hm.Hosts[vpnIp]; !ok {
-		hm.RUnlock()
-		h = &HostInfo{
-			vpnIp:           vpnIp,
-			HandshakePacket: make(map[uint8][]byte, 0),
-			relayState: RelayState{
-				relays:        map[iputil.VpnIp]struct{}{},
-				relayForByIp:  map[iputil.VpnIp]*Relay{},
-				relayForByIdx: map[uint32]*Relay{},
-			},
-		}
-		if init != nil {
-			init(h)
-		}
-		hm.Lock()
-		hm.Hosts[vpnIp] = h
-		hm.Unlock()
-		return h, true
-	} else {
-		hm.RUnlock()
-		return h, false
-	}
-}
-
-// Only used by pendingHostMap when the remote index is not initially known
-func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
-	hm.Lock()
-	h.remoteIndexId = index
-	hm.RemoteIndexes[index] = h
-	hm.Unlock()
-
-	if hm.l.Level > logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
-			"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}).
-			Debug("Hostmap remoteIndex added")
-	}
-}
-
-// DeleteReverseIndex is used to clean up on recv_error
-// This function should only ever be called on the pending hostmap
-func (hm *HostMap) DeleteReverseIndex(index uint32) {
-	hm.Lock()
-	hostinfo, ok := hm.RemoteIndexes[index]
-	if ok {
-		delete(hm.Indexes, hostinfo.localIndexId)
-		delete(hm.RemoteIndexes, index)
-
-		// Check if we have an entry under hostId that matches the same hostinfo
-		// instance. Clean it up as well if we do (they might not match in pendingHostmap)
-		var hostinfo2 *HostInfo
-		hostinfo2, ok = hm.Hosts[hostinfo.vpnIp]
-		if ok && hostinfo2 == hostinfo {
-			delete(hm.Hosts, hostinfo.vpnIp)
-		}
-	}
-	hm.Unlock()
-
-	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
-			Debug("Hostmap remote index deleted")
-	}
-}
-
 // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
 func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 	// Delete the host itself, ensuring it's not modified anymore
@@ -395,12 +315,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 	return final
 }
 
-func (hm *HostMap) DeleteRelayIdx(localIdx uint32) {
-	hm.Lock()
-	defer hm.Unlock()
-	delete(hm.RemoteIndexes, localIdx)
-}
-
 func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
 	hm.Lock()
 	defer hm.Unlock()
@@ -478,7 +392,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	}
 
 	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
+		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
 			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Hostmap hostInfo deleted")
 	}
@@ -488,55 +402,41 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	}
 }
 
-func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) {
-	//TODO: we probably just want to return bool instead of error, or at least a static error
+func (hm *HostMap) QueryIndex(index uint32) *HostInfo {
 	hm.RLock()
 	if h, ok := hm.Indexes[index]; ok {
 		hm.RUnlock()
-		return h, nil
+		return h
 	} else {
 		hm.RUnlock()
-		return nil, errors.New("unable to find index")
+		return nil
 	}
 }
 
-// Retrieves a HostInfo by Index. Returns whether the HostInfo is primary at time of query.
-// This helper exists so that the hostinfo.prev pointer can be read while the hostmap lock is held.
-func (hm *HostMap) QueryIndexIsPrimary(index uint32) (*HostInfo, bool, error) {
-	//TODO: we probably just want to return bool instead of error, or at least a static error
-	hm.RLock()
-	if h, ok := hm.Indexes[index]; ok {
-		hm.RUnlock()
-		return h, h.prev == nil, nil
-	} else {
-		hm.RUnlock()
-		return nil, false, errors.New("unable to find index")
-	}
-}
-func (hm *HostMap) QueryRelayIndex(index uint32) (*HostInfo, error) {
+func (hm *HostMap) QueryRelayIndex(index uint32) *HostInfo {
 	//TODO: we probably just want to return bool instead of error, or at least a static error
 	hm.RLock()
 	if h, ok := hm.Relays[index]; ok {
 		hm.RUnlock()
-		return h, nil
+		return h
 	} else {
 		hm.RUnlock()
-		return nil, errors.New("unable to find index")
+		return nil
 	}
 }
 
-func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
+func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
 	hm.RLock()
 	if h, ok := hm.RemoteIndexes[index]; ok {
 		hm.RUnlock()
-		return h, nil
+		return h
 	} else {
 		hm.RUnlock()
-		return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name)
+		return nil
 	}
 }
 
-func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) {
+func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
 	return hm.queryVpnIp(vpnIp, nil)
 }
 
@@ -558,13 +458,7 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host
 	return nil, nil, errors.New("unable to find host with relay")
 }
 
-// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
-// `PromoteEvery` calls to this function for a given host.
-func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) {
-	return hm.queryVpnIp(vpnIp, ifce)
-}
-
-func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) {
+func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo {
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
 		hm.RUnlock()
@@ -572,12 +466,12 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*Host
 		if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
 			h.TryPromoteBest(hm.preferredRanges, promoteIfce)
 		}
-		return h, nil
+		return h
 
 	}
 
 	hm.RUnlock()
-	return nil, errors.New("unable to find host")
+	return nil
 }
 
 // unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
@@ -600,7 +494,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 
 	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
+		hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
 			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
 			Debug("Hostmap vpnIp added")
 	}
@@ -616,11 +510,33 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	}
 }
 
+func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
+	return hm.preferredRanges
+}
+
+func (hm *HostMap) ForEachVpnIp(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
+
+	for _, v := range hm.Hosts {
+		f(v)
+	}
+}
+
+func (hm *HostMap) ForEachIndex(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
+
+	for _, v := range hm.Indexes {
+		f(v)
+	}
+}
+
 // TryPromoteBest handles re-querying lighthouses and probing for better paths
 // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
 func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
 	c := i.promoteCounter.Add(1)
-	if c%PromoteEvery == 0 {
+	if c%ifce.tryPromoteEvery.Load() == 0 {
 		// The lock here is currently protecting i.remote access
 		i.RLock()
 		remote := i.remote
@@ -648,12 +564,18 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 	}
 
 	// Re query our lighthouses for new remotes occasionally
-	if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
+	if c%ifce.reQueryEvery.Load() == 0 && ifce.lightHouse != nil {
+		now := time.Now().UnixNano()
+		if now < i.nextLHQuery.Load() {
+			return
+		}
+
+		i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
 		ifce.lightHouse.QueryServer(i.vpnIp, ifce)
 	}
 }
 
-func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
+func (i *HostInfo) unlockedCachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
 	//TODO: return the error so we can log with more context
 	if len(i.packetStore) < 100 {
 		tempPacket := make([]byte, len(packet))
@@ -682,7 +604,6 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) {
 	//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
 	//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
 
-	i.ConnectionState.queueLock.Lock()
 	i.HandshakeComplete = true
 	//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
 	// Clamping it to 2 gets us out of the woods for now
@@ -704,7 +625,6 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) {
 	i.remotes.ResetBlockedRemotes()
 	i.packetStore = make([]*cachedPacket, 0)
 	i.ConnectionState.ready = true
-	i.ConnectionState.queueLock.Unlock()
 }
 
 func (i *HostInfo) GetCert() *cert.NebulaCertificate {

+ 14 - 14
hostmap_test.go

@@ -11,7 +11,7 @@ import (
 func TestHostMap_MakePrimary(t *testing.T) {
 	l := test.NewLogger()
 	hm := NewHostMap(
-		l, "test",
+		l,
 		&net.IPNet{
 			IP:   net.IP{10, 0, 0, 1},
 			Mask: net.IPMask{255, 255, 255, 0},
@@ -32,7 +32,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.unlockedAddHostInfo(h1, f)
 
 	// Make sure we go h1 -> h2 -> h3 -> h4
-	prim, _ := hm.QueryVpnIp(1)
+	prim := hm.QueryVpnIp(1)
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -47,7 +47,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h3)
 
 	// Make sure we go h3 -> h1 -> h2 -> h4
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h3.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -62,7 +62,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -77,7 +77,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -92,7 +92,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 func TestHostMap_DeleteHostInfo(t *testing.T) {
 	l := test.NewLogger()
 	hm := NewHostMap(
-		l, "test",
+		l,
 		&net.IPNet{
 			IP:   net.IP{10, 0, 0, 1},
 			Mask: net.IPMask{255, 255, 255, 0},
@@ -119,11 +119,11 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	// h6 should be deleted
 	assert.Nil(t, h6.next)
 	assert.Nil(t, h6.prev)
-	_, err := hm.QueryIndex(h6.localIndexId)
-	assert.Error(t, err)
+	h := hm.QueryIndex(h6.localIndexId)
+	assert.Nil(t, h)
 
 	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
-	prim, _ := hm.QueryVpnIp(1)
+	prim := hm.QueryVpnIp(1)
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -142,7 +142,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h1.next)
 
 	// Make sure we go h2 -> h3 -> h4 -> h5
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -160,7 +160,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h3.next)
 
 	// Make sure we go h2 -> h4 -> h5
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -176,7 +176,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h5.next)
 
 	// Make sure we go h2 -> h4
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -190,7 +190,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h2.next)
 
 	// Make sure we only have h4
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.next)
@@ -202,6 +202,6 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h4.next)
 
 	// Make sure we have nil
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Nil(t, prim)
 }

+ 23 - 86
inside.go

@@ -1,7 +1,6 @@
 package nebula
 
 import (
-	"github.com/flynn/noise"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
@@ -45,7 +44,10 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		return
 	}
 
-	hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
+	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(h *HostInfo) {
+		h.unlockedCachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
+	})
+
 	if hostinfo == nil {
 		f.rejectInside(packet, out, q)
 		if f.l.Level >= logrus.DebugLevel {
@@ -55,23 +57,14 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		}
 		return
 	}
-	ci := hostinfo.ConnectionState
-
-	if !ci.ready {
-		// Because we might be sending stored packets, lock here to stop new things going to
-		// the packet queue.
-		ci.queueLock.Lock()
-		if !ci.ready {
-			hostinfo.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
-			ci.queueLock.Unlock()
-			return
-		}
-		ci.queueLock.Unlock()
+
+	if !ready {
+		return
 	}
 
-	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
+	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason == nil {
-		f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, packet, nb, out, q, fwPacket)
+		f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q, fwPacket)
 
 	} else {
 		f.rejectInside(packet, out, q)
@@ -110,71 +103,20 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 }
 
 func (f *Interface) Handshake(vpnIp iputil.VpnIp) {
-	f.getOrHandshake(vpnIp)
+	f.getOrHandshake(vpnIp, nil)
 }
 
-// getOrHandshake returns nil if the vpnIp is not routable
-func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
+// getOrHandshake returns nil if the vpnIp is not routable.
+// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
+func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(info *HostInfo)) (*HostInfo, bool) {
 	if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) {
 		vpnIp = f.inside.RouteFor(vpnIp)
 		if vpnIp == 0 {
-			return nil
+			return nil, false
 		}
 	}
-	hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
 
-	//if err != nil || hostinfo.ConnectionState == nil {
-	if err != nil {
-		hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
-		if err != nil {
-			hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
-		}
-	}
-	ci := hostinfo.ConnectionState
-
-	if ci != nil && ci.eKey != nil && ci.ready {
-		return hostinfo
-	}
-
-	// Handshake is not ready, we need to grab the lock now before we start the handshake process
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
-
-	// Double check, now that we have the lock
-	ci = hostinfo.ConnectionState
-	if ci != nil && ci.eKey != nil && ci.ready {
-		return hostinfo
-	}
-
-	// If we have already created the handshake packet, we don't want to call the function at all.
-	if !hostinfo.HandshakeReady {
-		ixHandshakeStage0(f, vpnIp, hostinfo)
-		// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
-		//xx_handshakeStage0(f, ip, hostinfo)
-
-		// If this is a static host, we don't need to wait for the HostQueryReply
-		// We can trigger the handshake right now
-		_, doTrigger := f.lightHouse.GetStaticHostList()[vpnIp]
-		if !doTrigger {
-			// Add any calculated remotes, and trigger early handshake if one found
-			doTrigger = f.lightHouse.addCalculatedRemotes(vpnIp)
-		}
-
-		if doTrigger {
-			select {
-			case f.handshakeManager.trigger <- vpnIp:
-			default:
-			}
-		}
-	}
-
-	return hostinfo
-}
-
-// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that
-// will create the initial Noise ConnectionState
-func (f *Interface) initHostInfo(hostinfo *HostInfo) {
-	hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
+	return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback)
 }
 
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@@ -186,7 +128,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 	}
 
 	// check if packet is in outbound fw rules
-	dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.caPool, nil)
+	dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.pki.GetCAPool(), nil)
 	if dropReason != nil {
 		if f.l.Level >= logrus.DebugLevel {
 			f.l.WithField("fwPacket", fp).
@@ -201,7 +143,10 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
 func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
-	hostInfo := f.getOrHandshake(vpnIp)
+	hostInfo, ready := f.getOrHandshake(vpnIp, func(h *HostInfo) {
+		h.unlockedCachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
+	})
+
 	if hostInfo == nil {
 		if f.l.Level >= logrus.DebugLevel {
 			f.l.WithField("vpnIp", vpnIp).
@@ -210,16 +155,8 @@ func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSu
 		return
 	}
 
-	if !hostInfo.ConnectionState.ready {
-		// Because we might be sending stored packets, lock here to stop new things going to
-		// the packet queue.
-		hostInfo.ConnectionState.queueLock.Lock()
-		if !hostInfo.ConnectionState.ready {
-			hostInfo.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
-			hostInfo.ConnectionState.queueLock.Unlock()
-			return
-		}
-		hostInfo.ConnectionState.queueLock.Unlock()
+	if !ready {
+		return
 	}
 
 	f.SendMessageToHostInfo(t, st, hostInfo, p, nb, out)
@@ -239,7 +176,7 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C
 	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0, nil)
 }
 
-// sendVia sends a payload through a Relay tunnel. No authentication or encryption is done
+// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
 // to the payload for the ultimate target host, making this a useful method for sending
 // handshake messages to peers through relay tunnels.
 // via is the HostInfo through which the message is relayed.

+ 54 - 53
interface.go

@@ -13,7 +13,6 @@ import (
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
@@ -26,9 +25,9 @@ const mtu = 9001
 
 type InterfaceConfig struct {
 	HostMap                 *HostMap
-	Outside                 *udp.Conn
+	Outside                 udp.Conn
 	Inside                  overlay.Device
-	certState               *CertState
+	pki                     *PKI
 	Cipher                  string
 	Firewall                *Firewall
 	ServeDns                bool
@@ -41,20 +40,23 @@ type InterfaceConfig struct {
 	routines                int
 	MessageMetrics          *MessageMetrics
 	version                 string
-	caPool                  *cert.NebulaCAPool
 	disconnectInvalid       bool
 	relayManager            *relayManager
 	punchy                  *Punchy
 
+	tryPromoteEvery uint32
+	reQueryEvery    uint32
+	reQueryWait     time.Duration
+
 	ConntrackCacheTimeout time.Duration
 	l                     *logrus.Logger
 }
 
 type Interface struct {
 	hostMap            *HostMap
-	outside            *udp.Conn
+	outside            udp.Conn
 	inside             overlay.Device
-	certState          atomic.Pointer[CertState]
+	pki                *PKI
 	cipher             string
 	firewall           *Firewall
 	connectionManager  *connectionManager
@@ -67,11 +69,14 @@ type Interface struct {
 	dropLocalBroadcast bool
 	dropMulticast      bool
 	routines           int
-	caPool             *cert.NebulaCAPool
 	disconnectInvalid  bool
 	closed             atomic.Bool
 	relayManager       *relayManager
 
+	tryPromoteEvery atomic.Uint32
+	reQueryEvery    atomic.Uint32
+	reQueryWait     atomic.Int64
+
 	sendRecvErrorConfig sendRecvErrorConfig
 
 	// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
@@ -80,7 +85,7 @@ type Interface struct {
 
 	conntrackCacheTimeout time.Duration
 
-	writers []*udp.Conn
+	writers []udp.Conn
 	readers []io.ReadWriteCloser
 	udpRaw  *udp.RawConn
 
@@ -156,15 +161,17 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	if c.Inside == nil {
 		return nil, errors.New("no inside interface (tun)")
 	}
-	if c.certState == nil {
+	if c.pki == nil {
 		return nil, errors.New("no certificate state")
 	}
 	if c.Firewall == nil {
 		return nil, errors.New("no firewall rules")
 	}
 
-	myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP)
+	certificate := c.pki.GetCertState().Certificate
+	myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP)
 	ifce := &Interface{
+		pki:                c.pki,
 		hostMap:            c.HostMap,
 		outside:            c.Outside,
 		inside:             c.Inside,
@@ -174,14 +181,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		handshakeManager:   c.HandshakeManager,
 		createTime:         time.Now(),
 		lightHouse:         c.lightHouse,
-		localBroadcast:     myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask),
+		localBroadcast:     myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask),
 		dropLocalBroadcast: c.DropLocalBroadcast,
 		dropMulticast:      c.DropMulticast,
 		routines:           c.routines,
 		version:            c.version,
-		writers:            make([]*udp.Conn, c.routines),
+		writers:            make([]udp.Conn, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
-		caPool:             c.caPool,
 		disconnectInvalid:  c.disconnectInvalid,
 		myVpnIp:            myVpnIp,
 		relayManager:       c.relayManager,
@@ -198,7 +204,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		l: c.l,
 	}
 
-	ifce.certState.Store(c.certState)
+	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
+	ifce.reQueryEvery.Store(c.reQueryEvery)
+	ifce.reQueryWait.Store(int64(c.reQueryWait))
+
 	ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
 
 	return ifce, nil
@@ -257,7 +266,7 @@ func (f *Interface) run() {
 func (f *Interface) listenOut(i int) {
 	runtime.LockOSThread()
 
-	var li *udp.Conn
+	var li udp.Conn
 	// TODO clean this up with a coherent interface for each outside connection
 	if i > 0 {
 		li = f.writers[i]
@@ -297,49 +306,14 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 }
 
 func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
-	c.RegisterReloadCallback(f.reloadCA)
-	c.RegisterReloadCallback(f.reloadCertKey)
 	c.RegisterReloadCallback(f.reloadFirewall)
 	c.RegisterReloadCallback(f.reloadSendRecvError)
+	c.RegisterReloadCallback(f.reloadMisc)
 	for _, udpConn := range f.writers {
 		c.RegisterReloadCallback(udpConn.ReloadConfig)
 	}
 }
 
-func (f *Interface) reloadCA(c *config.C) {
-	// reload and check regardless
-	// todo: need mutex?
-	newCAs, err := loadCAFromConfig(f.l, c)
-	if err != nil {
-		f.l.WithError(err).Error("Could not refresh trusted CA certificates")
-		return
-	}
-
-	f.caPool = newCAs
-	f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
-}
-
-func (f *Interface) reloadCertKey(c *config.C) {
-	// reload and check in all cases
-	cs, err := NewCertStateFromConfig(c)
-	if err != nil {
-		f.l.WithError(err).Error("Could not refresh client cert")
-		return
-	}
-
-	// did IP in cert change? if so, don't set
-	currentCert := f.certState.Load().certificate
-	oldIPs := currentCert.Details.Ips
-	newIPs := cs.certificate.Details.Ips
-	if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
-		f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
-		return
-	}
-
-	f.certState.Store(cs)
-	f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
-}
-
 func (f *Interface) reloadFirewall(c *config.C) {
 	//TODO: need to trigger/detect if the certificate changed too
 	if c.HasChanged("firewall") == false {
@@ -347,7 +321,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
 		return
 	}
 
-	fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c)
+	fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
 	if err != nil {
 		f.l.WithError(err).Error("Error while creating firewall during reload")
 		return
@@ -403,6 +377,26 @@ func (f *Interface) reloadSendRecvError(c *config.C) {
 	}
 }
 
+func (f *Interface) reloadMisc(c *config.C) {
+	if c.HasChanged("counters.try_promote") {
+		n := c.GetUint32("counters.try_promote", defaultPromoteEvery)
+		f.tryPromoteEvery.Store(n)
+		f.l.Info("counters.try_promote has changed")
+	}
+
+	if c.HasChanged("counters.requery_every_packets") {
+		n := c.GetUint32("counters.requery_every_packets", defaultReQueryEvery)
+		f.reQueryEvery.Store(n)
+		f.l.Info("counters.requery_every_packets has changed")
+	}
+
+	if c.HasChanged("timers.requery_wait_duration") {
+		n := c.GetDuration("timers.requery_wait_duration", defaultReQueryWait)
+		f.reQueryWait.Store(int64(n))
+		f.l.Info("timers.requery_wait_duration has changed")
+	}
+}
+
 func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 	ticker := time.NewTicker(i)
 	defer ticker.Stop()
@@ -427,7 +421,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 				}
 				rawStats()
 			}
-			certExpirationGauge.Update(int64(f.certState.Load().certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
+			certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
 		}
 	}
 }
@@ -435,6 +429,13 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 func (f *Interface) Close() error {
 	f.closed.Store(true)
 
+	for _, u := range f.writers {
+		err := u.Close()
+		if err != nil {
+			f.l.WithError(err).Error("Error while closing udp socket")
+		}
+	}
+
 	// Release the tun device
 	return f.inside.Close()
 }

+ 36 - 25
lighthouse.go

@@ -39,7 +39,7 @@ type LightHouse struct {
 	myVpnIp      iputil.VpnIp
 	myVpnZeros   iputil.VpnIp
 	myVpnNet     *net.IPNet
-	punchConn    *udp.Conn
+	punchConn    udp.Conn
 	punchy       *Punchy
 
 	// Local cache of answers from light houses
@@ -64,11 +64,10 @@ type LightHouse struct {
 	staticList  atomic.Pointer[map[iputil.VpnIp]struct{}]
 	lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}]
 
-	interval        atomic.Int64
-	updateCancel    context.CancelFunc
-	updateParentCtx context.Context
-	updateUdp       EncWriter
-	nebulaPort      uint32 // 32 bits because protobuf does not have a uint16
+	interval     atomic.Int64
+	updateCancel context.CancelFunc
+	ifce         EncWriter
+	nebulaPort   uint32 // 32 bits because protobuf does not have a uint16
 
 	advertiseAddrs atomic.Pointer[[]netIpAndPort]
 
@@ -84,7 +83,7 @@ type LightHouse struct {
 
 // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
 // addrMap should be nil unless this is during a config reload
-func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
+func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) {
 	amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
 	nebulaPort := uint32(c.GetInt("listen.port", 0))
 	if amLighthouse && nebulaPort == 0 {
@@ -133,7 +132,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
 	c.RegisterReloadCallback(func(c *config.C) {
 		err := h.reload(c, false)
 		switch v := err.(type) {
-		case util.ContextualError:
+		case *util.ContextualError:
 			v.Log(l)
 		case error:
 			l.WithError(err).Error("failed to reload lighthouse")
@@ -217,7 +216,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 				lh.updateCancel()
 			}
 
-			lh.LhUpdateWorker(lh.updateParentCtx, lh.updateUdp)
+			lh.StartUpdateWorker()
 		}
 	}
 
@@ -262,6 +261,18 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 
 	//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
 	if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
+		// Clean up. Entries still in the static_host_map will be re-built.
+		// Entries no longer present must have their (possible) background DNS goroutines stopped.
+		if existingStaticList := lh.staticList.Load(); existingStaticList != nil {
+			lh.RLock()
+			for staticVpnIp := range *existingStaticList {
+				if am, ok := lh.addrMap[staticVpnIp]; ok && am != nil {
+					am.hr.Cancel()
+				}
+			}
+			lh.RUnlock()
+		}
+		// Build a new list based on current config.
 		staticList := make(map[iputil.VpnIp]struct{})
 		err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
 		if err != nil {
@@ -742,33 +753,33 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
 	return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
 }
 
-func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
-	lh.updateParentCtx = ctx
-	lh.updateUdp = f
-
+func (lh *LightHouse) StartUpdateWorker() {
 	interval := lh.GetUpdateInterval()
 	if lh.amLighthouse || interval == 0 {
 		return
 	}
 
 	clockSource := time.NewTicker(time.Second * time.Duration(interval))
-	updateCtx, cancel := context.WithCancel(ctx)
+	updateCtx, cancel := context.WithCancel(lh.ctx)
 	lh.updateCancel = cancel
-	defer clockSource.Stop()
 
-	for {
-		lh.SendUpdate(f)
+	go func() {
+		defer clockSource.Stop()
 
-		select {
-		case <-updateCtx.Done():
-			return
-		case <-clockSource.C:
-			continue
+		for {
+			lh.SendUpdate()
+
+			select {
+			case <-updateCtx.Done():
+				return
+			case <-clockSource.C:
+				continue
+			}
 		}
-	}
+	}()
 }
 
-func (lh *LightHouse) SendUpdate(f EncWriter) {
+func (lh *LightHouse) SendUpdate() {
 	var v4 []*Ip4AndPort
 	var v6 []*Ip6AndPort
 
@@ -821,7 +832,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
 	}
 
 	for vpnIp := range lighthouses {
-		f.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out)
+		lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out)
 	}
 }
 

+ 41 - 2
lighthouse_test.go

@@ -12,6 +12,7 @@ import (
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
+	"gopkg.in/yaml.v2"
 )
 
 //TODO: Add a test to ensure udpAddr is copied and not reused
@@ -65,6 +66,35 @@ func Test_lhStaticMapping(t *testing.T) {
 	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 
+func TestReloadLighthouseInterval(t *testing.T) {
+	l := test.NewLogger()
+	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
+	lh1 := "10.128.0.2"
+
+	c := config.NewC(l)
+	c.Settings["lighthouse"] = map[interface{}]interface{}{
+		"hosts":    []interface{}{lh1},
+		"interval": "1s",
+	}
+
+	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	assert.NoError(t, err)
+	lh.ifce = &mockEncWriter{}
+
+	// The first one routine is kicked off by main.go currently, lets make sure that one dies
+	c.ReloadConfigString("lighthouse:\n  interval: 5")
+	assert.Equal(t, int64(5), lh.interval.Load())
+
+	// Subsequent calls are killed off by the LightHouse.Reload function
+	c.ReloadConfigString("lighthouse:\n  interval: 10")
+	assert.Equal(t, int64(10), lh.interval.Load())
+
+	// If this completes then nothing is stealing our reload routine
+	c.ReloadConfigString("lighthouse:\n  interval: 11")
+	assert.Equal(t, int64(11), lh.interval.Load())
+}
+
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	l := test.NewLogger()
 	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
@@ -242,8 +272,17 @@ func TestLighthouse_reload(t *testing.T) {
 	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
 	assert.NoError(t, err)
 
-	c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}}
-	lh.reload(c, false)
+	nc := map[interface{}]interface{}{
+		"static_host_map": map[interface{}]interface{}{
+			"10.128.0.2": []interface{}{"1.1.1.1:4242"},
+		},
+	}
+	rc, err := yaml.Marshal(nc)
+	assert.NoError(t, err)
+	c.ReloadConfigString(string(rc))
+
+	err = lh.reload(c, false)
+	assert.NoError(t, err)
 }
 
 func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {

+ 36 - 47
main.go

@@ -3,7 +3,6 @@ package nebula
 import (
 	"context"
 	"encoding/binary"
-	"errors"
 	"fmt"
 	"net"
 	"time"
@@ -46,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	err := configLogger(l, c)
 	if err != nil {
-		return nil, util.NewContextualError("Failed to configure the logger", nil, err)
+		return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
 	}
 
 	c.RegisterReloadCallback(func(c *config.C) {
@@ -56,28 +55,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 	})
 
-	caPool, err := loadCAFromConfig(l, c)
+	pki, err := NewPKIFromConfig(l, c)
 	if err != nil {
-		//The errors coming out of loadCA are already nicely formatted
-		return nil, util.NewContextualError("Failed to load ca from config", nil, err)
+		return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
 	}
-	l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
 
-	cs, err := NewCertStateFromConfig(c)
+	certificate := pki.GetCertState().Certificate
+	fw, err := NewFirewallFromConfig(l, certificate, c)
 	if err != nil {
-		//The errors coming out of NewCertStateFromConfig are already nicely formatted
-		return nil, util.NewContextualError("Failed to load certificate from config", nil, err)
-	}
-	l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
-
-	fw, err := NewFirewallFromConfig(l, cs.certificate, c)
-	if err != nil {
-		return nil, util.NewContextualError("Error while loading firewall rules", nil, err)
+		return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
 	}
 	l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
 
 	// TODO: make sure mask is 4 bytes
-	tunCidr := cs.certificate.Details.Ips[0]
+	tunCidr := certificate.Details.Ips[0]
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	wireSSHReload(l, ssh, c)
@@ -85,7 +76,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	if c.GetBool("sshd.enabled", false) {
 		sshStart, err = configSSH(l, ssh, c)
 		if err != nil {
-			return nil, util.NewContextualError("Error while configuring the sshd", nil, err)
+			return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
 		}
 	}
 
@@ -136,7 +127,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 		tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines)
 		if err != nil {
-			return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err)
+			return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
 		}
 
 		defer func() {
@@ -147,7 +138,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	}
 
 	// set up our UDP listener
-	udpConns := make([]*udp.Conn, routines)
+	udpConns := make([]udp.Conn, routines)
 	port := c.GetInt("listen.port", 0)
 
 	if !configTest {
@@ -160,7 +151,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		} else {
 			listenHost, err = net.ResolveIPAddr("ip", rawListenHost)
 			if err != nil {
-				return nil, util.NewContextualError("Failed to resolve listen.host", nil, err)
+				return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
 			}
 		}
 
@@ -182,7 +173,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		for _, rawPreferredRange := range rawPreferredRanges {
 			_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
 			if err != nil {
-				return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err)
+				return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err)
 			}
 			preferredRanges = append(preferredRanges, preferredRange)
 		}
@@ -195,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	if rawLocalRange != "" {
 		_, localRange, err := net.ParseCIDR(rawLocalRange)
 		if err != nil {
-			return nil, util.NewContextualError("Failed to parse local_range", nil, err)
+			return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err)
 		}
 
 		// Check if the entry for local_range was already specified in
@@ -212,7 +203,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 	}
 
-	hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
+	hostMap := NewHostMap(l, tunCidr, preferredRanges)
 	hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
 
 	l.
@@ -220,18 +211,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		WithField("preferredRanges", hostMap.preferredRanges).
 		Info("Main HostMap created")
 
-	/*
-		config.SetDefault("promoter.interval", 10)
-		go hostMap.Promoter(config.GetInt("promoter.interval"))
-	*/
-
 	punchy := NewPunchyFromConfig(l, c)
 	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
-	switch {
-	case errors.As(err, &util.ContextualError{}):
-		return nil, err
-	case err != nil:
-		return nil, util.NewContextualError("Failed to initialize lighthouse handler", nil, err)
+	if err != nil {
+		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
 	}
 
 	var messageMetrics *MessageMetrics
@@ -252,13 +235,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		messageMetrics: messageMetrics,
 	}
 
-	handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
+	handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
 	lightHouse.handshakeTrigger = handshakeManager.trigger
 
-	//TODO: These will be reused for psk
-	//handshakeMACKey := config.GetString("handshake_mac.key", "")
-	//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
-
 	serveDns := false
 	if c.GetBool("lighthouse.serve_dns", false) {
 		if c.GetBool("lighthouse.am_lighthouse", false) {
@@ -270,11 +249,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	checkInterval := c.GetInt("timers.connection_alive_interval", 5)
 	pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
+
 	ifConfig := &InterfaceConfig{
 		HostMap:                 hostMap,
 		Inside:                  tun,
 		Outside:                 udpConns[0],
-		certState:               cs,
+		pki:                     pki,
 		Cipher:                  c.GetString("cipher", "aes"),
 		Firewall:                fw,
 		ServeDns:                serveDns,
@@ -282,12 +262,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		lightHouse:              lightHouse,
 		checkInterval:           time.Second * time.Duration(checkInterval),
 		pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
+		tryPromoteEvery:         c.GetUint32("counters.try_promote", defaultPromoteEvery),
+		reQueryEvery:            c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
+		reQueryWait:             c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
 		DropLocalBroadcast:      c.GetBool("tun.drop_local_broadcast", false),
 		DropMulticast:           c.GetBool("tun.drop_multicast", false),
 		routines:                routines,
 		MessageMetrics:          messageMetrics,
 		version:                 buildVersion,
-		caPool:                  caPool,
 		disconnectInvalid:       c.GetBool("pki.disconnect_invalid", false),
 		relayManager:            NewRelayManager(ctx, l, hostMap, c),
 		punchy:                  punchy,
@@ -315,6 +297,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		// TODO: Better way to attach these, probably want a new interface in InterfaceConfig
 		// I don't want to make this initial commit too far-reaching though
 		ifce.writers = udpConns
+		lightHouse.ifce = ifce
 
 		loadMultiPortConfig := func(c *config.C) {
 			ifce.multiPort.Rx = c.GetBool("tun.multiport.rx_enabled", false)
@@ -350,19 +333,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		c.RegisterReloadCallback(loadMultiPortConfig)
 
 		ifce.RegisterConfigChangeCallbacks(c)
-
 		ifce.reloadSendRecvError(c)
 
-		go handshakeManager.Run(ctx, ifce)
-		go lightHouse.LhUpdateWorker(ctx, ifce)
+		handshakeManager.f = ifce
+		go handshakeManager.Run(ctx)
 	}
 
 	// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
 	// a context so that they can exit when the context is Done.
 	statsStart, err := startStats(l, c, buildVersion, configTest)
-
 	if err != nil {
-		return nil, util.NewContextualError("Failed to start stats emitter", nil, err)
+		return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
 	}
 
 	if configTest {
@@ -372,7 +353,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	//TODO: check if we _should_ be emitting stats
 	go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
 
-	attachCommands(l, c, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
+	attachCommands(l, c, ssh, ifce)
 
 	// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
 	var dnsStart func()
@@ -381,5 +362,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		dnsStart = dnsMain(l, hostMap, c)
 	}
 
-	return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil
+	return &Control{
+		ifce,
+		l,
+		cancel,
+		sshStart,
+		statsStart,
+		dnsStart,
+		lightHouse.StartUpdateWorker,
+	}, nil
 }

+ 8 - 11
outside.go

@@ -64,9 +64,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 	var hostinfo *HostInfo
 	// verify if we've seen this index before, otherwise respond to the handshake initiation
 	if h.Type == header.Message && h.Subtype == header.MessageRelay {
-		hostinfo, _ = f.hostMap.QueryRelayIndex(h.RemoteIndex)
+		hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
 	} else {
-		hostinfo, _ = f.hostMap.QueryIndex(h.RemoteIndex)
+		hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
 	}
 
 	var ci *ConnectionState
@@ -417,7 +417,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 		return false
 	}
 
-	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache)
+	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason != nil {
 		f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, out, q)
 		if f.l.Level >= logrus.DebugLevel {
@@ -462,12 +462,9 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
 			Debug("Recv error received")
 	}
 
-	// First, clean up in the pending hostmap
-	f.handshakeManager.pendingHostMap.DeleteReverseIndex(h.RemoteIndex)
-
-	hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex)
-	if err != nil {
-		f.l.Debugln(err, ": ", h.RemoteIndex)
+	hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
+	if hostinfo == nil {
+		f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap")
 		return
 	}
 
@@ -477,14 +474,14 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
 	if !hostinfo.RecvErrorExceeded() {
 		return
 	}
+
 	if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) {
 		f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
 		return
 	}
 
 	f.closeTunnel(hostinfo)
-	// We also delete it from pending hostmap to allow for
-	// fast reconnect.
+	// We also delete it from pending hostmap to allow for fast reconnect.
 	f.handshakeManager.DeleteHostInfo(hostinfo)
 }
 

+ 1 - 9
overlay/tun_darwin.go

@@ -47,14 +47,6 @@ type ifReq struct {
 	pad   [8]byte
 }
 
-func ioctl(a1, a2, a3 uintptr) error {
-	_, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
-	if errno != 0 {
-		return errno
-	}
-	return nil
-}
-
 var sockaddrCtlSize uintptr = 32
 
 const (
@@ -194,10 +186,10 @@ func (t *tun) Activate() error {
 		unix.SOCK_DGRAM,
 		unix.IPPROTO_IP,
 	)
-
 	if err != nil {
 		return err
 	}
+	defer unix.Close(s)
 
 	fd := uintptr(s)
 

+ 118 - 20
overlay/tun_freebsd.go

@@ -4,21 +4,44 @@
 package overlay
 
 import (
+	"bytes"
+	"errors"
 	"fmt"
 	"io"
+	"io/fs"
 	"net"
 	"os"
 	"os/exec"
-	"regexp"
 	"strconv"
-	"strings"
+	"syscall"
+	"unsafe"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/iputil"
 )
 
-var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
+const (
+	// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
+	// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
+	FIODGNAME = 0x80106678
+)
+
+type fiodgnameArg struct {
+	length int32
+	pad    [4]byte
+	buf    unsafe.Pointer
+}
+
+type ifreqRename struct {
+	Name [16]byte
+	Data uintptr
+}
+
+type ifreqDestroy struct {
+	Name [16]byte
+	pad  [16]byte
+}
 
 type tun struct {
 	Device    string
@@ -33,8 +56,23 @@ type tun struct {
 
 func (t *tun) Close() error {
 	if t.ReadWriteCloser != nil {
-		return t.ReadWriteCloser.Close()
+		if err := t.ReadWriteCloser.Close(); err != nil {
+			return err
+		}
+
+		s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
+		if err != nil {
+			return err
+		}
+		defer syscall.Close(s)
+
+		ifreq := ifreqDestroy{Name: t.deviceBytes()}
+
+		// Destroy the interface
+		err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
+		return err
 	}
+
 	return nil
 }
 
@@ -43,34 +81,87 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int
 }
 
 func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
-	routeTree, err := makeRouteTree(l, routes, false)
+	// Try to open existing tun device
+	var file *os.File
+	var err error
+	if deviceName != "" {
+		file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
+	}
+	if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
+		// If the device doesn't already exist, request a new one and rename it
+		file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
+	}
+	if err != nil {
+		return nil, err
+	}
+
+	rawConn, err := file.SyscallConn()
 	if err != nil {
+		return nil, fmt.Errorf("SyscallConn: %v", err)
+	}
+
+	var name [16]byte
+	var ctrlErr error
+	rawConn.Control(func(fd uintptr) {
+		// Read the name of the interface
+		arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
+		ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
+	})
+	if ctrlErr != nil {
 		return nil, err
 	}
 
-	if strings.HasPrefix(deviceName, "/dev/") {
-		deviceName = strings.TrimPrefix(deviceName, "/dev/")
+	ifName := string(bytes.TrimRight(name[:], "\x00"))
+	if deviceName == "" {
+		deviceName = ifName
 	}
-	if !deviceNameRE.MatchString(deviceName) {
-		return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`")
+
+	// If the name doesn't match the desired interface name, rename it now
+	if ifName != deviceName {
+		s, err := syscall.Socket(
+			syscall.AF_INET,
+			syscall.SOCK_DGRAM,
+			syscall.IPPROTO_IP,
+		)
+		if err != nil {
+			return nil, err
+		}
+		defer syscall.Close(s)
+
+		fd := uintptr(s)
+
+		var fromName [16]byte
+		var toName [16]byte
+		copy(fromName[:], ifName)
+		copy(toName[:], deviceName)
+
+		ifrr := ifreqRename{
+			Name: fromName,
+			Data: uintptr(unsafe.Pointer(&toName)),
+		}
+
+		// Set the device name
+		ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
 	}
+
+	routeTree, err := makeRouteTree(l, routes, false)
+	if err != nil {
+		return nil, err
+	}
+
 	return &tun{
-		Device:    deviceName,
-		cidr:      cidr,
-		MTU:       defaultMTU,
-		Routes:    routes,
-		routeTree: routeTree,
-		l:         l,
+		ReadWriteCloser: file,
+		Device:          deviceName,
+		cidr:            cidr,
+		MTU:             defaultMTU,
+		Routes:          routes,
+		routeTree:       routeTree,
+		l:               l,
 	}, nil
 }
 
 func (t *tun) Activate() error {
 	var err error
-	t.ReadWriteCloser, err = os.OpenFile("/dev/"+t.Device, os.O_RDWR, 0)
-	if err != nil {
-		return fmt.Errorf("activate failed: %v", err)
-	}
-
 	// TODO use syscalls instead of exec.Command
 	t.l.Debug("command: ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
 	if err = exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()).Run(); err != nil {
@@ -120,3 +211,10 @@ func (t *tun) Name() string {
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
 }
+
+func (t *tun) deviceBytes() (o [16]byte) {
+	for i, c := range t.Device {
+		o[i] = byte(c)
+	}
+	return
+}

+ 0 - 8
overlay/tun_linux.go

@@ -43,14 +43,6 @@ type ifReq struct {
 	pad   [8]byte
 }
 
-func ioctl(a1, a2, a3 uintptr) error {
-	_, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
-	if errno != 0 {
-		return errno
-	}
-	return nil
-}
-
 type ifreqAddr struct {
 	Name [16]byte
 	Addr unix.RawSockaddrInet4

+ 162 - 0
overlay/tun_netbsd.go

@@ -0,0 +1,162 @@
+//go:build !e2e_testing
+// +build !e2e_testing
+
+package overlay
+
+import (
+	"fmt"
+	"io"
+	"net"
+	"os"
+	"os/exec"
+	"regexp"
+	"strconv"
+	"syscall"
+	"unsafe"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/iputil"
+)
+
+type ifreqDestroy struct {
+	Name [16]byte
+	pad  [16]byte
+}
+
+type tun struct {
+	Device    string
+	cidr      *net.IPNet
+	MTU       int
+	Routes    []Route
+	routeTree *cidr.Tree4
+	l         *logrus.Logger
+
+	io.ReadWriteCloser
+}
+
+func (t *tun) Close() error {
+	if t.ReadWriteCloser != nil {
+		if err := t.ReadWriteCloser.Close(); err != nil {
+			return err
+		}
+
+		s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
+		if err != nil {
+			return err
+		}
+		defer syscall.Close(s)
+
+		ifreq := ifreqDestroy{Name: t.deviceBytes()}
+
+		err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
+
+		return err
+	}
+	return nil
+}
+
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
+	return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
+}
+
+var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
+
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
+	// Try to open tun device
+	var file *os.File
+	var err error
+	if deviceName == "" {
+		return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
+	}
+	if !deviceNameRE.MatchString(deviceName) {
+		return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
+	}
+	file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
+
+	if err != nil {
+		return nil, err
+	}
+
+	routeTree, err := makeRouteTree(l, routes, false)
+
+	if err != nil {
+		return nil, err
+	}
+
+	return &tun{
+		ReadWriteCloser: file,
+		Device:          deviceName,
+		cidr:            cidr,
+		MTU:             defaultMTU,
+		Routes:          routes,
+		routeTree:       routeTree,
+		l:               l,
+	}, nil
+}
+
+func (t *tun) Activate() error {
+	var err error
+
+	// TODO use syscalls instead of exec.Command
+	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
+	t.l.Debug("command: ", cmd.String())
+	if err = cmd.Run(); err != nil {
+		return fmt.Errorf("failed to run 'ifconfig': %s", err)
+	}
+
+	cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String())
+	t.l.Debug("command: ", cmd.String())
+	if err = cmd.Run(); err != nil {
+		return fmt.Errorf("failed to run 'route add': %s", err)
+	}
+
+	cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
+	t.l.Debug("command: ", cmd.String())
+	if err = cmd.Run(); err != nil {
+		return fmt.Errorf("failed to run 'ifconfig': %s", err)
+	}
+	// Unsafe path routes
+	for _, r := range t.Routes {
+		if r.Via == nil || !r.Install {
+			// We don't allow route MTUs so only install routes with a via
+			continue
+		}
+
+		cmd = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String())
+		t.l.Debug("command: ", cmd.String())
+		if err = cmd.Run(); err != nil {
+			return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
+		}
+	}
+
+	return nil
+}
+
+func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
+	r := t.routeTree.MostSpecificContains(ip)
+	if r != nil {
+		return r.(iputil.VpnIp)
+	}
+
+	return 0
+}
+
+func (t *tun) Cidr() *net.IPNet {
+	return t.cidr
+}
+
+func (t *tun) Name() string {
+	return t.Device
+}
+
+func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
+	return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
+}
+
+func (t *tun) deviceBytes() (o [16]byte) {
+	for i, c := range t.Device {
+		o[i] = byte(c)
+	}
+	return
+}

+ 14 - 0
overlay/tun_notwin.go

@@ -0,0 +1,14 @@
+//go:build !windows
+// +build !windows
+
+package overlay
+
+import "syscall"
+
+func ioctl(a1, a2, a3 uintptr) error {
+	_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3)
+	if errno != 0 {
+		return errno
+	}
+	return nil
+}

+ 174 - 0
overlay/tun_openbsd.go

@@ -0,0 +1,174 @@
+//go:build !e2e_testing
+// +build !e2e_testing
+
+package overlay
+
+import (
+	"fmt"
+	"io"
+	"net"
+	"os"
+	"os/exec"
+	"regexp"
+	"strconv"
+	"syscall"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/iputil"
+)
+
+type tun struct {
+	Device    string
+	cidr      *net.IPNet
+	MTU       int
+	Routes    []Route
+	routeTree *cidr.Tree4
+	l         *logrus.Logger
+
+	io.ReadWriteCloser
+
+	// cache out buffer since we need to prepend 4 bytes for tun metadata
+	out []byte
+}
+
+func (t *tun) Close() error {
+	if t.ReadWriteCloser != nil {
+		return t.ReadWriteCloser.Close()
+	}
+
+	return nil
+}
+
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
+	return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
+}
+
+var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
+
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
+	if deviceName == "" {
+		return nil, fmt.Errorf("a device name in the format of tunN must be specified")
+	}
+
+	if !deviceNameRE.MatchString(deviceName) {
+		return nil, fmt.Errorf("a device name in the format of tunN must be specified")
+	}
+
+	file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
+	if err != nil {
+		return nil, err
+	}
+
+	routeTree, err := makeRouteTree(l, routes, false)
+	if err != nil {
+		return nil, err
+	}
+
+	return &tun{
+		ReadWriteCloser: file,
+		Device:          deviceName,
+		cidr:            cidr,
+		MTU:             defaultMTU,
+		Routes:          routes,
+		routeTree:       routeTree,
+		l:               l,
+	}, nil
+}
+
+func (t *tun) Activate() error {
+	var err error
+	// TODO use syscalls instead of exec.Command
+	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
+	t.l.Debug("command: ", cmd.String())
+	if err = cmd.Run(); err != nil {
+		return fmt.Errorf("failed to run 'ifconfig': %s", err)
+	}
+
+	cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
+	t.l.Debug("command: ", cmd.String())
+	if err = cmd.Run(); err != nil {
+		return fmt.Errorf("failed to run 'ifconfig': %s", err)
+	}
+
+	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String())
+	t.l.Debug("command: ", cmd.String())
+	if err = cmd.Run(); err != nil {
+		return fmt.Errorf("failed to run 'route add': %s", err)
+	}
+
+	// Unsafe path routes
+	for _, r := range t.Routes {
+		if r.Via == nil || !r.Install {
+			// We don't allow route MTUs so only install routes with a via
+			continue
+		}
+
+		cmd = exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String())
+		t.l.Debug("command: ", cmd.String())
+		if err = cmd.Run(); err != nil {
+			return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
+		}
+	}
+
+	return nil
+}
+
+func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
+	r := t.routeTree.MostSpecificContains(ip)
+	if r != nil {
+		return r.(iputil.VpnIp)
+	}
+
+	return 0
+}
+
+func (t *tun) Cidr() *net.IPNet {
+	return t.cidr
+}
+
+func (t *tun) Name() string {
+	return t.Device
+}
+
+func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
+	return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
+}
+
+func (t *tun) Read(to []byte) (int, error) {
+	buf := make([]byte, len(to)+4)
+
+	n, err := t.ReadWriteCloser.Read(buf)
+
+	copy(to, buf[4:])
+	return n - 4, err
+}
+
+// Write is only valid for single threaded use
+func (t *tun) Write(from []byte) (int, error) {
+	buf := t.out
+	if cap(buf) < len(from)+4 {
+		buf = make([]byte, len(from)+4)
+		t.out = buf
+	}
+	buf = buf[:len(from)+4]
+
+	if len(from) == 0 {
+		return 0, syscall.EIO
+	}
+
+	// Determine the IP Family for the NULL L2 Header
+	ipVer := from[0] >> 4
+	if ipVer == 4 {
+		buf[3] = syscall.AF_INET
+	} else if ipVer == 6 {
+		buf[3] = syscall.AF_INET6
+	} else {
+		return 0, fmt.Errorf("unable to determine IP version from packet")
+	}
+
+	copy(buf[4:], from)
+
+	n, err := t.ReadWriteCloser.Write(buf)
+	return n - 4, err
+}

+ 14 - 1
overlay/tun_tester.go

@@ -8,6 +8,7 @@ import (
 	"io"
 	"net"
 	"os"
+	"sync/atomic"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cidr"
@@ -21,6 +22,7 @@ type TestTun struct {
 	routeTree *cidr.Tree4
 	l         *logrus.Logger
 
+	closed    atomic.Bool
 	rxPackets chan []byte // Packets to receive into nebula
 	TxPackets chan []byte // Packets transmitted outside by nebula
 }
@@ -50,6 +52,10 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int
 // These are unencrypted ip layer frames destined for another nebula node.
 // packets should exit the udp side, capture them with udpConn.Get
 func (t *TestTun) Send(packet []byte) {
+	if t.closed.Load() {
+		return
+	}
+
 	if t.l.Level >= logrus.DebugLevel {
 		t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet")
 	}
@@ -98,6 +104,10 @@ func (t *TestTun) Name() string {
 }
 
 func (t *TestTun) Write(b []byte) (n int, err error) {
+	if t.closed.Load() {
+		return 0, io.ErrClosedPipe
+	}
+
 	packet := make([]byte, len(b), len(b))
 	copy(packet, b)
 	t.TxPackets <- packet
@@ -105,7 +115,10 @@ func (t *TestTun) Write(b []byte) (n int, err error) {
 }
 
 func (t *TestTun) Close() error {
-	close(t.rxPackets)
+	if t.closed.CompareAndSwap(false, true) {
+		close(t.rxPackets)
+		close(t.TxPackets)
+	}
 	return nil
 }
 

+ 9 - 2
overlay/tun_wintun_windows.go

@@ -54,9 +54,16 @@ func newWinTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU
 		return nil, fmt.Errorf("generate GUID failed: %w", err)
 	}
 
-	tunDevice, err := wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
+	var tunDevice wintun.Device
+	tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
 	if err != nil {
-		return nil, fmt.Errorf("create TUN device failed: %w", err)
+		// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
+		// Trying a second time resolves the issue.
+		l.WithError(err).Debug("Failed to create wintun device, retrying")
+		tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
+		if err != nil {
+			return nil, fmt.Errorf("create TUN device failed: %w", err)
+		}
 	}
 
 	routeTree, err := makeRouteTree(l, routes, false)

+ 248 - 0
pki.go

@@ -0,0 +1,248 @@
+package nebula
+
+import (
+	"errors"
+	"fmt"
+	"os"
+	"strings"
+	"sync/atomic"
+	"time"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
+)
+
+type PKI struct {
+	cs     atomic.Pointer[CertState]
+	caPool atomic.Pointer[cert.NebulaCAPool]
+	l      *logrus.Logger
+}
+
+type CertState struct {
+	Certificate         *cert.NebulaCertificate
+	RawCertificate      []byte
+	RawCertificateNoKey []byte
+	PublicKey           []byte
+	PrivateKey          []byte
+}
+
+func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
+	pki := &PKI{l: l}
+	err := pki.reload(c, true)
+	if err != nil {
+		return nil, err
+	}
+
+	c.RegisterReloadCallback(func(c *config.C) {
+		rErr := pki.reload(c, false)
+		if rErr != nil {
+			util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l)
+		}
+	})
+
+	return pki, nil
+}
+
+func (p *PKI) GetCertState() *CertState {
+	return p.cs.Load()
+}
+
+func (p *PKI) GetCAPool() *cert.NebulaCAPool {
+	return p.caPool.Load()
+}
+
+func (p *PKI) reload(c *config.C, initial bool) error {
+	err := p.reloadCert(c, initial)
+	if err != nil {
+		if initial {
+			return err
+		}
+		err.Log(p.l)
+	}
+
+	err = p.reloadCAPool(c)
+	if err != nil {
+		if initial {
+			return err
+		}
+		err.Log(p.l)
+	}
+
+	return nil
+}
+
+func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
+	cs, err := newCertStateFromConfig(c)
+	if err != nil {
+		return util.NewContextualError("Could not load client cert", nil, err)
+	}
+
+	if !initial {
+		// did IP in cert change? if so, don't set
+		currentCert := p.cs.Load().Certificate
+		oldIPs := currentCert.Details.Ips
+		newIPs := cs.Certificate.Details.Ips
+		if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
+			return util.NewContextualError(
+				"IP in new cert was different from old",
+				m{"new_ip": newIPs[0], "old_ip": oldIPs[0]},
+				nil,
+			)
+		}
+	}
+
+	p.cs.Store(cs)
+	if initial {
+		p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate")
+	} else {
+		p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk")
+	}
+	return nil
+}
+
+func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
+	caPool, err := loadCAPoolFromConfig(p.l, c)
+	if err != nil {
+		return util.NewContextualError("Failed to load ca from config", nil, err)
+	}
+
+	p.caPool.Store(caPool)
+	p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
+	return nil
+}
+
+func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) {
+	// Marshal the certificate to ensure it is valid
+	rawCertificate, err := certificate.Marshal()
+	if err != nil {
+		return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
+	}
+
+	publicKey := certificate.Details.PublicKey
+	cs := &CertState{
+		RawCertificate: rawCertificate,
+		Certificate:    certificate,
+		PrivateKey:     privateKey,
+		PublicKey:      publicKey,
+	}
+
+	cs.Certificate.Details.PublicKey = nil
+	rawCertNoKey, err := cs.Certificate.Marshal()
+	if err != nil {
+		return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
+	}
+	cs.RawCertificateNoKey = rawCertNoKey
+	// put public key back
+	cs.Certificate.Details.PublicKey = cs.PublicKey
+	return cs, nil
+}
+
+func newCertStateFromConfig(c *config.C) (*CertState, error) {
+	var pemPrivateKey []byte
+	var err error
+
+	privPathOrPEM := c.GetString("pki.key", "")
+	if privPathOrPEM == "" {
+		return nil, errors.New("no pki.key path or PEM data provided")
+	}
+
+	if strings.Contains(privPathOrPEM, "-----BEGIN") {
+		pemPrivateKey = []byte(privPathOrPEM)
+		privPathOrPEM = "<inline>"
+
+	} else {
+		pemPrivateKey, err = os.ReadFile(privPathOrPEM)
+		if err != nil {
+			return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
+		}
+	}
+
+	rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey)
+	if err != nil {
+		return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
+	}
+
+	var rawCert []byte
+
+	pubPathOrPEM := c.GetString("pki.cert", "")
+	if pubPathOrPEM == "" {
+		return nil, errors.New("no pki.cert path or PEM data provided")
+	}
+
+	if strings.Contains(pubPathOrPEM, "-----BEGIN") {
+		rawCert = []byte(pubPathOrPEM)
+		pubPathOrPEM = "<inline>"
+
+	} else {
+		rawCert, err = os.ReadFile(pubPathOrPEM)
+		if err != nil {
+			return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err)
+		}
+	}
+
+	nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
+	if err != nil {
+		return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
+	}
+
+	if nebulaCert.Expired(time.Now()) {
+		return nil, fmt.Errorf("nebula certificate for this host is expired")
+	}
+
+	if len(nebulaCert.Details.Ips) == 0 {
+		return nil, fmt.Errorf("no IPs encoded in certificate")
+	}
+
+	if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
+		return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
+	}
+
+	return newCertState(nebulaCert, rawKey)
+}
+
+func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
+	var rawCA []byte
+	var err error
+
+	caPathOrPEM := c.GetString("pki.ca", "")
+	if caPathOrPEM == "" {
+		return nil, errors.New("no pki.ca path or PEM data provided")
+	}
+
+	if strings.Contains(caPathOrPEM, "-----BEGIN") {
+		rawCA = []byte(caPathOrPEM)
+
+	} else {
+		rawCA, err = os.ReadFile(caPathOrPEM)
+		if err != nil {
+			return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
+		}
+	}
+
+	caPool, err := cert.NewCAPoolFromBytes(rawCA)
+	if errors.Is(err, cert.ErrExpired) {
+		var expired int
+		for _, crt := range caPool.CAs {
+			if crt.Expired(time.Now()) {
+				expired++
+				l.WithField("cert", crt).Warn("expired certificate present in CA pool")
+			}
+		}
+
+		if expired >= len(caPool.CAs) {
+			return nil, errors.New("no valid CA certificates present")
+		}
+
+	} else if err != nil {
+		return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
+	}
+
+	for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
+		l.WithField("fingerprint", fp).Info("Blocklisting cert")
+		caPool.BlocklistFingerprint(fp)
+	}
+
+	return caPool, nil
+}

+ 13 - 6
relay_manager.go

@@ -131,9 +131,9 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
 		return
 	}
 	// I'm the middle man. Let the initiator know that the I've established the relay they requested.
-	peerHostInfo, err := rm.hostmap.QueryVpnIp(relay.PeerIp)
-	if err != nil {
-		rm.l.WithError(err).WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
+	peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp)
+	if peerHostInfo == nil {
+		rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
 		return
 	}
 	peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target)
@@ -179,6 +179,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		"vpnIp":               h.vpnIp})
 
 	logMsg.Info("handleCreateRelayRequest")
+	// Is the source of the relay me? This should never happen, but did happen due to
+	// an issue migrating relays over to newly re-handshaked host info objects.
+	if from == f.myVpnIp {
+		logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself")
+		return
+	}
 	// Is the target of the relay me?
 	if target == f.myVpnIp {
 		existingRelay, ok := h.relayState.QueryRelayForByIp(from)
@@ -240,11 +246,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		if !rm.GetAmRelay() {
 			return
 		}
-		peer, err := rm.hostmap.QueryVpnIp(target)
-		if err != nil {
+		peer := rm.hostmap.QueryVpnIp(target)
+		if peer == nil {
 			// Try to establish a connection to this host. If we get a future relay request,
 			// we'll be ready!
-			f.getOrHandshake(target)
+			f.Handshake(target)
 			return
 		}
 		if peer.remote == nil {
@@ -253,6 +259,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		}
 		sendCreateRequest := false
 		var index uint32
+		var err error
 		targetRelay, ok := peer.relayState.QueryRelayForByIp(from)
 		if ok {
 			index = targetRelay.LocalIndex

+ 11 - 21
remote_list.go

@@ -70,7 +70,7 @@ type hostnamesResults struct {
 	hostnames     []hostnamePort
 	network       string
 	lookupTimeout time.Duration
-	stop          chan struct{}
+	cancelFn      func()
 	l             *logrus.Logger
 	ips           atomic.Pointer[map[netip.AddrPort]struct{}]
 }
@@ -80,7 +80,6 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
 		hostnames:     make([]hostnamePort, len(hostPorts)),
 		network:       network,
 		lookupTimeout: timeout,
-		stop:          make(chan (struct{})),
 		l:             l,
 	}
 
@@ -115,6 +114,8 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
 
 	// Time for the DNS lookup goroutine
 	if performBackgroundLookup {
+		newCtx, cancel := context.WithCancel(ctx)
+		r.cancelFn = cancel
 		ticker := time.NewTicker(d)
 		go func() {
 			defer ticker.Stop()
@@ -154,9 +155,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
 					onUpdate()
 				}
 				select {
-				case <-ctx.Done():
-					return
-				case <-r.stop:
+				case <-newCtx.Done():
 					return
 				case <-ticker.C:
 					continue
@@ -169,8 +168,8 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
 }
 
 func (hr *hostnamesResults) Cancel() {
-	if hr != nil {
-		hr.stop <- struct{}{}
+	if hr != nil && hr.cancelFn != nil {
+		hr.cancelFn()
 	}
 }
 
@@ -582,20 +581,11 @@ func (r *RemoteList) unlockedCollect() {
 	dnsAddrs := r.hr.GetIPs()
 	for _, addr := range dnsAddrs {
 		if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
-			switch {
-			case addr.Addr().Is4():
-				v4 := addr.Addr().As4()
-				addrs = append(addrs, &udp.Addr{
-					IP:   v4[:],
-					Port: addr.Port(),
-				})
-			case addr.Addr().Is6():
-				v6 := addr.Addr().As16()
-				addrs = append(addrs, &udp.Addr{
-					IP:   v6[:],
-					Port: addr.Port(),
-				})
-			}
+			v6 := addr.Addr().As16()
+			addrs = append(addrs, &udp.Addr{
+				IP:   v6[:],
+				Port: addr.Port(),
+			})
 		}
 	}
 

+ 33 - 33
ssh.go

@@ -3,6 +3,7 @@ package nebula
 import (
 	"bytes"
 	"encoding/json"
+	"errors"
 	"flag"
 	"fmt"
 	"io/ioutil"
@@ -168,7 +169,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
 	return runner, nil
 }
 
-func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
+func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "list-hostmap",
 		ShortDescription: "List all known previously connected hosts",
@@ -181,7 +182,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshListHostMap(hostMap, fs, w)
+			return sshListHostMap(f.hostMap, fs, w)
 		},
 	})
 
@@ -197,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshListHostMap(pendingHostMap, fs, w)
+			return sshListHostMap(f.handshakeManager, fs, w)
 		},
 	})
 
@@ -212,7 +213,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshListLighthouseMap(lightHouse, fs, w)
+			return sshListLighthouseMap(f.lightHouse, fs, w)
 		},
 	})
 
@@ -277,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 		Name:             "version",
 		ShortDescription: "Prints the currently running version of nebula",
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshVersion(ifce, fs, a, w)
+			return sshVersion(f, fs, a, w)
 		},
 	})
 
@@ -293,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshPrintCert(ifce, fs, a, w)
+			return sshPrintCert(f, fs, a, w)
 		},
 	})
 
@@ -307,7 +308,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshPrintTunnel(ifce, fs, a, w)
+			return sshPrintTunnel(f, fs, a, w)
 		},
 	})
 
@@ -321,7 +322,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshPrintRelays(ifce, fs, a, w)
+			return sshPrintRelays(f, fs, a, w)
 		},
 	})
 
@@ -335,7 +336,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshChangeRemote(ifce, fs, a, w)
+			return sshChangeRemote(f, fs, a, w)
 		},
 	})
 
@@ -349,7 +350,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshCloseTunnel(ifce, fs, a, w)
+			return sshCloseTunnel(f, fs, a, w)
 		},
 	})
 
@@ -364,7 +365,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshCreateTunnel(ifce, fs, a, w)
+			return sshCreateTunnel(f, fs, a, w)
 		},
 	})
 
@@ -373,12 +374,12 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 		ShortDescription: "Query the lighthouses for the provided vpn ip",
 		Help:             "This command is asynchronous. Only currently known udp ips will be printed.",
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshQueryLighthouse(ifce, fs, a, w)
+			return sshQueryLighthouse(f, fs, a, w)
 		},
 	})
 }
 
-func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error {
+func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error {
 	fs, ok := a.(*sshListHostMapFlags)
 	if !ok {
 		//TODO: error
@@ -387,9 +388,9 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
 
 	var hm []ControlHostInfo
 	if fs.ByIndex {
-		hm = listHostMapIndexes(hostMap)
+		hm = listHostMapIndexes(hl)
 	} else {
-		hm = listHostMapHosts(hostMap)
+		hm = listHostMapHosts(hl)
 	}
 
 	sort.Slice(hm, func(i, j int) bool {
@@ -546,8 +547,8 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
-	if err != nil {
+	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	if hostInfo == nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
 
@@ -588,12 +589,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	hostInfo, _ := ifce.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
 	}
 
-	hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
+	hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp)
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
 	}
@@ -606,11 +607,10 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		}
 	}
 
-	hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo)
+	hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
 	if addr != nil {
 		hostInfo.SetRemote(addr)
 	}
-	ifce.getOrHandshake(vpnIp)
 
 	return w.WriteLine("Created")
 }
@@ -645,8 +645,8 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
-	if err != nil {
+	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	if hostInfo == nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
 
@@ -753,7 +753,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 		return nil
 	}
 
-	cert := ifce.certState.Load().certificate
+	cert := ifce.pki.GetCertState().Certificate
 	if len(a) > 0 {
 		parsedIp := net.ParseIP(a[0])
 		if parsedIp == nil {
@@ -765,8 +765,8 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		}
 
-		hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
-		if err != nil {
+		hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+		if hostInfo == nil {
 			return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 		}
 
@@ -851,9 +851,9 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 	for k, v := range relays {
 		ro := RelayOutput{NebulaIp: v.vpnIp}
 		co.Relays = append(co.Relays, &ro)
-		relayHI, err := ifce.hostMap.QueryVpnIp(v.vpnIp)
-		if err != nil {
-			ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: err})
+		relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp)
+		if relayHI == nil {
+			ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
 			continue
 		}
 		for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
@@ -889,8 +889,8 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 					rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
 				}
 			}
-			relayedHI, err := ifce.hostMap.QueryVpnIp(vpnIp)
-			if err == nil {
+			relayedHI := ifce.hostMap.QueryVpnIp(vpnIp)
+			if relayedHI != nil {
 				rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
 			}
 
@@ -925,8 +925,8 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
-	if err != nil {
+	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	if hostInfo == nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
 

+ 31 - 0
udp/conn.go

@@ -1,6 +1,7 @@
 package udp
 
 import (
+	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 )
@@ -18,3 +19,33 @@ type EncReader func(
 	q int,
 	localCache firewall.ConntrackCache,
 )
+
+type Conn interface {
+	Rebind() error
+	LocalAddr() (*Addr, error)
+	ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
+	WriteTo(b []byte, addr *Addr) error
+	ReloadConfig(c *config.C)
+	Close() error
+}
+
+type NoopConn struct{}
+
+func (NoopConn) Rebind() error {
+	return nil
+}
+func (NoopConn) LocalAddr() (*Addr, error) {
+	return nil, nil
+}
+func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
+	return
+}
+func (NoopConn) WriteTo(_ []byte, _ *Addr) error {
+	return nil
+}
+func (NoopConn) ReloadConfig(_ *config.C) {
+	return
+}
+func (NoopConn) Close() error {
+	return nil
+}

+ 6 - 1
udp/udp_android.go

@@ -8,9 +8,14 @@ import (
 	"net"
 	"syscall"
 
+	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 )
 
+func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+	return NewGenericListener(l, ip, port, multi, batch)
+}
+
 func NewListenConfig(multi bool) net.ListenConfig {
 	return net.ListenConfig{
 		Control: func(network, address string, c syscall.RawConn) error {
@@ -34,6 +39,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 }
 
-func (u *Conn) Rebind() error {
+func (u *GenericConn) Rebind() error {
 	return nil
 }

+ 47 - 0
udp/udp_bsd.go

@@ -0,0 +1,47 @@
+//go:build (openbsd || freebsd) && !e2e_testing
+// +build openbsd freebsd
+// +build !e2e_testing
+
+package udp
+
+// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
+
+import (
+	"fmt"
+	"net"
+	"syscall"
+
+	"github.com/sirupsen/logrus"
+	"golang.org/x/sys/unix"
+)
+
+func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+	return NewGenericListener(l, ip, port, multi, batch)
+}
+
+func NewListenConfig(multi bool) net.ListenConfig {
+	return net.ListenConfig{
+		Control: func(network, address string, c syscall.RawConn) error {
+			if multi {
+				var controlErr error
+				err := c.Control(func(fd uintptr) {
+					if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
+						controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
+						return
+					}
+				})
+				if err != nil {
+					return err
+				}
+				if controlErr != nil {
+					return controlErr
+				}
+			}
+			return nil
+		},
+	}
+}
+
+func (u *GenericConn) Rebind() error {
+	return nil
+}

+ 13 - 3
udp/udp_darwin.go

@@ -10,9 +10,14 @@ import (
 	"net"
 	"syscall"
 
+	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 )
 
+func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+	return NewGenericListener(l, ip, port, multi, batch)
+}
+
 func NewListenConfig(multi bool) net.ListenConfig {
 	return net.ListenConfig{
 		Control: func(network, address string, c syscall.RawConn) error {
@@ -37,11 +42,16 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 }
 
-func (u *Conn) Rebind() error {
-	file, err := u.File()
+func (u *GenericConn) Rebind() error {
+	rc, err := u.UDPConn.SyscallConn()
 	if err != nil {
 		return err
 	}
 
-	return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
+	return rc.Control(func(fd uintptr) {
+		err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
+		if err != nil {
+			u.l.WithError(err).Error("Failed to rebind udp socket")
+		}
+	})
 }

+ 14 - 12
udp/udp_generic.go

@@ -18,30 +18,32 @@ import (
 	"github.com/slackhq/nebula/header"
 )
 
-type Conn struct {
+type GenericConn struct {
 	*net.UDPConn
 	l *logrus.Logger
 }
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) {
+var _ Conn = &GenericConn{}
+
+func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
 	lc := NewListenConfig(multi)
 	pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
 	if err != nil {
 		return nil, err
 	}
 	if uc, ok := pc.(*net.UDPConn); ok {
-		return &Conn{UDPConn: uc, l: l}, nil
+		return &GenericConn{UDPConn: uc, l: l}, nil
 	}
 	return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
 }
 
-func (uc *Conn) WriteTo(b []byte, addr *Addr) error {
-	_, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
+func (u *GenericConn) WriteTo(b []byte, addr *Addr) error {
+	_, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
 	return err
 }
 
-func (uc *Conn) LocalAddr() (*Addr, error) {
-	a := uc.UDPConn.LocalAddr()
+func (u *GenericConn) LocalAddr() (*Addr, error) {
+	a := u.UDPConn.LocalAddr()
 
 	switch v := a.(type) {
 	case *net.UDPAddr:
@@ -55,11 +57,11 @@ func (uc *Conn) LocalAddr() (*Addr, error) {
 	}
 }
 
-func (u *Conn) ReloadConfig(c *config.C) {
+func (u *GenericConn) ReloadConfig(c *config.C) {
 	// TODO
 }
 
-func NewUDPStatsEmitter(udpConns []*Conn) func() {
+func NewUDPStatsEmitter(udpConns []Conn) func() {
 	// No UDP stats for non-linux
 	return func() {}
 }
@@ -68,7 +70,7 @@ type rawMessage struct {
 	Len uint32
 }
 
-func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
+func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
 	plaintext := make([]byte, MTU)
 	buffer := make([]byte, MTU)
 	h := &header.H{}
@@ -80,8 +82,8 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
 		// Just read one packet at a time
 		n, rua, err := u.ReadFromUDP(buffer)
 		if err != nil {
-			u.l.WithError(err).Error("Failed to read packets")
-			continue
+			u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
+			return
 		}
 
 		udpAddr.IP = rua.IP

+ 25 - 20
udp/udp_linux.go

@@ -20,7 +20,7 @@ import (
 
 //TODO: make it support reload as best you can!
 
-type Conn struct {
+type StdConn struct {
 	sysFd int
 	l     *logrus.Logger
 	batch int
@@ -45,7 +45,7 @@ const (
 
 type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) {
+func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
 	syscall.ForkLock.RLock()
 	fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
 	if err == nil {
@@ -77,30 +77,30 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
 	//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
 	//l.Println(v, err)
 
-	return &Conn{sysFd: fd, l: l, batch: batch}, err
+	return &StdConn{sysFd: fd, l: l, batch: batch}, err
 }
 
-func (u *Conn) Rebind() error {
+func (u *StdConn) Rebind() error {
 	return nil
 }
 
-func (u *Conn) SetRecvBuffer(n int) error {
+func (u *StdConn) SetRecvBuffer(n int) error {
 	return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
 }
 
-func (u *Conn) SetSendBuffer(n int) error {
+func (u *StdConn) SetSendBuffer(n int) error {
 	return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
 }
 
-func (u *Conn) GetRecvBuffer() (int, error) {
+func (u *StdConn) GetRecvBuffer() (int, error) {
 	return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
 }
 
-func (u *Conn) GetSendBuffer() (int, error) {
+func (u *StdConn) GetSendBuffer() (int, error) {
 	return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
 }
 
-func (u *Conn) LocalAddr() (*Addr, error) {
+func (u *StdConn) LocalAddr() (*Addr, error) {
 	sa, err := unix.Getsockname(u.sysFd)
 	if err != nil {
 		return nil, err
@@ -119,7 +119,7 @@ func (u *Conn) LocalAddr() (*Addr, error) {
 	return addr, nil
 }
 
-func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
+func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
 	plaintext := make([]byte, MTU)
 	h := &header.H{}
 	fwPacket := &firewall.Packet{}
@@ -137,8 +137,8 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
 	for {
 		n, err := read(msgs)
 		if err != nil {
-			u.l.WithError(err).Error("Failed to read packets")
-			continue
+			u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
+			return
 		}
 
 		//metric.Update(int64(n))
@@ -150,7 +150,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
 	}
 }
 
-func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) {
+func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
 	for {
 		n, _, err := unix.Syscall6(
 			unix.SYS_RECVMSG,
@@ -171,7 +171,7 @@ func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) {
 	}
 }
 
-func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) {
+func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
 	for {
 		n, _, err := unix.Syscall6(
 			unix.SYS_RECVMMSG,
@@ -191,7 +191,7 @@ func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) {
 	}
 }
 
-func (u *Conn) WriteTo(b []byte, addr *Addr) error {
+func (u *StdConn) WriteTo(b []byte, addr *Addr) error {
 
 	var rsa unix.RawSockaddrInet6
 	rsa.Family = unix.AF_INET6
@@ -221,7 +221,7 @@ func (u *Conn) WriteTo(b []byte, addr *Addr) error {
 	}
 }
 
-func (u *Conn) ReloadConfig(c *config.C) {
+func (u *StdConn) ReloadConfig(c *config.C) {
 	b := c.GetInt("listen.read_buffer", 0)
 	if b > 0 {
 		err := u.SetRecvBuffer(b)
@@ -253,7 +253,7 @@ func (u *Conn) ReloadConfig(c *config.C) {
 	}
 }
 
-func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error {
+func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error {
 	var vallen uint32 = 4 * _SK_MEMINFO_VARS
 	_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
 	if err != 0 {
@@ -262,11 +262,16 @@ func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error {
 	return nil
 }
 
-func NewUDPStatsEmitter(udpConns []*Conn) func() {
+func (u *StdConn) Close() error {
+	//TODO: this will not interrupt the read loop
+	return syscall.Close(u.sysFd)
+}
+
+func NewUDPStatsEmitter(udpConns []Conn) func() {
 	// Check if our kernel supports SO_MEMINFO before registering the gauges
 	var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
 	var meminfo _SK_MEMINFO
-	if err := udpConns[0].getMemInfo(&meminfo); err == nil {
+	if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil {
 		udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
 		for i := range udpConns {
 			udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{
@@ -285,7 +290,7 @@ func NewUDPStatsEmitter(udpConns []*Conn) func() {
 
 	return func() {
 		for i, gauges := range udpGauges {
-			if err := udpConns[i].getMemInfo(&meminfo); err == nil {
+			if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil {
 				for j := 0; j < _SK_MEMINFO_VARS; j++ {
 					gauges[j].Update(int64(meminfo[j]))
 				}

+ 1 - 1
udp/udp_linux_32.go

@@ -30,7 +30,7 @@ type rawMessage struct {
 	Len uint32
 }
 
-func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
+func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
 	msgs := make([]rawMessage, n)
 	buffers := make([][]byte, n)
 	names := make([][]byte, n)

+ 1 - 1
udp/udp_linux_64.go

@@ -33,7 +33,7 @@ type rawMessage struct {
 	Pad0 [4]byte
 }
 
-func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
+func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
 	msgs := make([]rawMessage, n)
 	buffers := make([][]byte, n)
 	names := make([][]byte, n)

+ 6 - 1
udp/udp_freebsd.go → udp/udp_netbsd.go

@@ -10,9 +10,14 @@ import (
 	"net"
 	"syscall"
 
+	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 )
 
+func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+	return NewGenericListener(l, ip, port, multi, batch)
+}
+
 func NewListenConfig(multi bool) net.ListenConfig {
 	return net.ListenConfig{
 		Control: func(network, address string, c syscall.RawConn) error {
@@ -36,6 +41,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 }
 
-func (u *Conn) Rebind() error {
+func (u *GenericConn) Rebind() error {
 	return nil
 }

+ 403 - 0
udp/udp_rio_windows.go

@@ -0,0 +1,403 @@
+//go:build !e2e_testing
+// +build !e2e_testing
+
+// Inspired by https://git.zx2c4.com/wireguard-go/tree/conn/bind_windows.go
+
+package udp
+
+import (
+	"errors"
+	"fmt"
+	"io"
+	"net"
+	"sync"
+	"sync/atomic"
+	"syscall"
+	"unsafe"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
+
+	"golang.org/x/sys/windows"
+	"golang.zx2c4.com/wireguard/conn/winrio"
+)
+
+// Assert we meet the standard conn interface
+var _ Conn = &RIOConn{}
+
+//go:linkname procyield runtime.procyield
+func procyield(cycles uint32)
+
+const (
+	packetsPerRing = 1024
+	bytesPerPacket = 2048 - 32
+	receiveSpins   = 15
+)
+
+type ringPacket struct {
+	addr windows.RawSockaddrInet6
+	data [bytesPerPacket]byte
+}
+
+type ringBuffer struct {
+	packets    uintptr
+	head, tail uint32
+	id         winrio.BufferId
+	iocp       windows.Handle
+	isFull     bool
+	cq         winrio.Cq
+	mu         sync.Mutex
+	overlapped windows.Overlapped
+}
+
+type RIOConn struct {
+	isOpen  atomic.Bool
+	l       *logrus.Logger
+	sock    windows.Handle
+	rx, tx  ringBuffer
+	rq      winrio.Rq
+	results [packetsPerRing]winrio.Result
+}
+
+func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) {
+	if !winrio.Initialize() {
+		return nil, errors.New("could not initialize winrio")
+	}
+
+	u := &RIOConn{l: l}
+
+	addr := [16]byte{}
+	copy(addr[:], ip.To16())
+	err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port})
+	if err != nil {
+		return nil, fmt.Errorf("bind: %w", err)
+	}
+
+	for i := 0; i < packetsPerRing; i++ {
+		err = u.insertReceiveRequest()
+		if err != nil {
+			return nil, fmt.Errorf("init rx ring: %w", err)
+		}
+	}
+
+	u.isOpen.Store(true)
+	return u, nil
+}
+
+func (u *RIOConn) bind(sa windows.Sockaddr) error {
+	var err error
+	u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
+	if err != nil {
+		return err
+	}
+
+	// Enable v4 for this socket
+	syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
+
+	err = u.rx.Open()
+	if err != nil {
+		return err
+	}
+
+	err = u.tx.Open()
+	if err != nil {
+		return err
+	}
+
+	u.rq, err = winrio.CreateRequestQueue(u.sock, packetsPerRing, 1, packetsPerRing, 1, u.rx.cq, u.tx.cq, 0)
+	if err != nil {
+		return err
+	}
+
+	err = windows.Bind(u.sock, sa)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
+	plaintext := make([]byte, MTU)
+	buffer := make([]byte, MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	udpAddr := &Addr{IP: make([]byte, 16)}
+	nb := make([]byte, 12, 12)
+
+	for {
+		// Just read one packet at a time
+		n, rua, err := u.receive(buffer)
+		if err != nil {
+			u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
+			return
+		}
+
+		udpAddr.IP = rua.Addr[:]
+		p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port))
+		p[0] = byte(rua.Port >> 8)
+		p[1] = byte(rua.Port)
+		r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+	}
+}
+
+func (u *RIOConn) insertReceiveRequest() error {
+	packet := u.rx.Push()
+	dataBuffer := &winrio.Buffer{
+		Id:     u.rx.id,
+		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.rx.packets),
+		Length: uint32(len(packet.data)),
+	}
+	addressBuffer := &winrio.Buffer{
+		Id:     u.rx.id,
+		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.rx.packets),
+		Length: uint32(unsafe.Sizeof(packet.addr)),
+	}
+
+	return winrio.ReceiveEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
+}
+
+func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) {
+	if !u.isOpen.Load() {
+		return 0, windows.RawSockaddrInet6{}, net.ErrClosed
+	}
+
+	u.rx.mu.Lock()
+	defer u.rx.mu.Unlock()
+
+	var err error
+	var count uint32
+	var results [1]winrio.Result
+
+retry:
+	count = 0
+	for tries := 0; count == 0 && tries < receiveSpins; tries++ {
+		if tries > 0 {
+			if !u.isOpen.Load() {
+				return 0, windows.RawSockaddrInet6{}, net.ErrClosed
+			}
+			procyield(1)
+		}
+
+		count = winrio.DequeueCompletion(u.rx.cq, results[:])
+	}
+
+	if count == 0 {
+		err = winrio.Notify(u.rx.cq)
+		if err != nil {
+			return 0, windows.RawSockaddrInet6{}, err
+		}
+		var bytes uint32
+		var key uintptr
+		var overlapped *windows.Overlapped
+		err = windows.GetQueuedCompletionStatus(u.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
+		if err != nil {
+			return 0, windows.RawSockaddrInet6{}, err
+		}
+
+		if !u.isOpen.Load() {
+			return 0, windows.RawSockaddrInet6{}, net.ErrClosed
+		}
+
+		count = winrio.DequeueCompletion(u.rx.cq, results[:])
+		if count == 0 {
+			return 0, windows.RawSockaddrInet6{}, io.ErrNoProgress
+
+		}
+	}
+
+	u.rx.Return(1)
+	err = u.insertReceiveRequest()
+	if err != nil {
+		return 0, windows.RawSockaddrInet6{}, err
+	}
+
+	// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
+	// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
+	// attacker bandwidth, just like the rest of the receive path.
+	if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
+		goto retry
+	}
+
+	if results[0].Status != 0 {
+		return 0, windows.RawSockaddrInet6{}, windows.Errno(results[0].Status)
+	}
+
+	packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
+	ep := packet.addr
+	n := copy(buf, packet.data[:results[0].BytesTransferred])
+	return n, ep, nil
+}
+
+func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
+	if !u.isOpen.Load() {
+		return net.ErrClosed
+	}
+
+	if len(buf) > bytesPerPacket {
+		return io.ErrShortBuffer
+	}
+
+	u.tx.mu.Lock()
+	defer u.tx.mu.Unlock()
+
+	count := winrio.DequeueCompletion(u.tx.cq, u.results[:])
+	if count == 0 && u.tx.isFull {
+		err := winrio.Notify(u.tx.cq)
+		if err != nil {
+			return err
+		}
+
+		var bytes uint32
+		var key uintptr
+		var overlapped *windows.Overlapped
+		err = windows.GetQueuedCompletionStatus(u.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
+		if err != nil {
+			return err
+		}
+
+		if !u.isOpen.Load() {
+			return net.ErrClosed
+		}
+
+		count = winrio.DequeueCompletion(u.tx.cq, u.results[:])
+		if count == 0 {
+			return io.ErrNoProgress
+		}
+	}
+
+	if count > 0 {
+		u.tx.Return(count)
+	}
+
+	packet := u.tx.Push()
+	packet.addr.Family = windows.AF_INET6
+	p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port))
+	p[0] = byte(addr.Port >> 8)
+	p[1] = byte(addr.Port)
+	copy(packet.addr.Addr[:], addr.IP.To16())
+	copy(packet.data[:], buf)
+
+	dataBuffer := &winrio.Buffer{
+		Id:     u.tx.id,
+		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.tx.packets),
+		Length: uint32(len(buf)),
+	}
+
+	addressBuffer := &winrio.Buffer{
+		Id:     u.tx.id,
+		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.tx.packets),
+		Length: uint32(unsafe.Sizeof(packet.addr)),
+	}
+
+	return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
+}
+
+func (u *RIOConn) LocalAddr() (*Addr, error) {
+	sa, err := windows.Getsockname(u.sock)
+	if err != nil {
+		return nil, err
+	}
+
+	v6 := sa.(*windows.SockaddrInet6)
+	return &Addr{
+		IP:   v6.Addr[:],
+		Port: uint16(v6.Port),
+	}, nil
+}
+
+func (u *RIOConn) Rebind() error {
+	return nil
+}
+
+func (u *RIOConn) ReloadConfig(*config.C) {}
+
+func (u *RIOConn) Close() error {
+	if !u.isOpen.CompareAndSwap(true, false) {
+		return nil
+	}
+
+	windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil)
+	windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil)
+
+	u.rx.CloseAndZero()
+	u.tx.CloseAndZero()
+	if u.sock != 0 {
+		windows.CloseHandle(u.sock)
+	}
+	return nil
+}
+
+func (ring *ringBuffer) Push() *ringPacket {
+	for ring.isFull {
+		panic("ring is full")
+	}
+	ret := (*ringPacket)(unsafe.Pointer(ring.packets + (uintptr(ring.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
+	ring.tail += 1
+	if ring.tail%packetsPerRing == ring.head%packetsPerRing {
+		ring.isFull = true
+	}
+	return ret
+}
+
+func (ring *ringBuffer) Return(count uint32) {
+	if ring.head%packetsPerRing == ring.tail%packetsPerRing && !ring.isFull {
+		return
+	}
+	ring.head += count
+	ring.isFull = false
+}
+
+func (ring *ringBuffer) CloseAndZero() {
+	if ring.cq != 0 {
+		winrio.CloseCompletionQueue(ring.cq)
+		ring.cq = 0
+	}
+
+	if ring.iocp != 0 {
+		windows.CloseHandle(ring.iocp)
+		ring.iocp = 0
+	}
+
+	if ring.id != 0 {
+		winrio.DeregisterBuffer(ring.id)
+		ring.id = 0
+	}
+
+	if ring.packets != 0 {
+		windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
+		ring.packets = 0
+	}
+
+	ring.head = 0
+	ring.tail = 0
+	ring.isFull = false
+}
+
+func (ring *ringBuffer) Open() error {
+	var err error
+	packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
+	ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
+	if err != nil {
+		return err
+	}
+
+	ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
+	if err != nil {
+		return err
+	}
+
+	ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
+	if err != nil {
+		return err
+	}
+
+	ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}

+ 31 - 12
udp/udp_tester.go

@@ -5,7 +5,9 @@ package udp
 
 import (
 	"fmt"
+	"io"
 	"net"
+	"sync/atomic"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
@@ -36,17 +38,18 @@ func (u *Packet) Copy() *Packet {
 	return n
 }
 
-type Conn struct {
+type TesterConn struct {
 	Addr *Addr
 
 	RxPackets chan *Packet // Packets to receive into nebula
 	TxPackets chan *Packet // Packets transmitted outside by nebula
 
-	l *logrus.Logger
+	closed atomic.Bool
+	l      *logrus.Logger
 }
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (*Conn, error) {
-	return &Conn{
+func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) {
+	return &TesterConn{
 		Addr:      &Addr{ip, uint16(port)},
 		RxPackets: make(chan *Packet, 10),
 		TxPackets: make(chan *Packet, 10),
@@ -57,7 +60,11 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (*Conn, e
 // Send will place a UdpPacket onto the receive queue for nebula to consume
 // this is an encrypted packet or a handshake message in most cases
 // packets were transmitted from another nebula node, you can send them with Tun.Send
-func (u *Conn) Send(packet *Packet) {
+func (u *TesterConn) Send(packet *Packet) {
+	if u.closed.Load() {
+		return
+	}
+
 	h := &header.H{}
 	if err := h.Parse(packet.Data); err != nil {
 		panic(err)
@@ -74,7 +81,7 @@ func (u *Conn) Send(packet *Packet) {
 // Get will pull a UdpPacket from the transmit queue
 // nebula meant to send this message on the network, it will be encrypted
 // packets were ingested from the tun side (in most cases), you can send them with Tun.Send
-func (u *Conn) Get(block bool) *Packet {
+func (u *TesterConn) Get(block bool) *Packet {
 	if block {
 		return <-u.TxPackets
 	}
@@ -91,7 +98,11 @@ func (u *Conn) Get(block bool) *Packet {
 // Below this is boilerplate implementation to make nebula actually work
 //********************************************************************************************************************//
 
-func (u *Conn) WriteTo(b []byte, addr *Addr) error {
+func (u *TesterConn) WriteTo(b []byte, addr *Addr) error {
+	if u.closed.Load() {
+		return io.ErrClosedPipe
+	}
+
 	p := &Packet{
 		Data:     make([]byte, len(b), len(b)),
 		FromIp:   make([]byte, 16),
@@ -108,7 +119,7 @@ func (u *Conn) WriteTo(b []byte, addr *Addr) error {
 	return nil
 }
 
-func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
+func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
 	plaintext := make([]byte, MTU)
 	h := &header.H{}
 	fwPacket := &firewall.Packet{}
@@ -126,17 +137,25 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
 	}
 }
 
-func (u *Conn) ReloadConfig(*config.C) {}
+func (u *TesterConn) ReloadConfig(*config.C) {}
 
-func NewUDPStatsEmitter(_ []*Conn) func() {
+func NewUDPStatsEmitter(_ []Conn) func() {
 	// No UDP stats for non-linux
 	return func() {}
 }
 
-func (u *Conn) LocalAddr() (*Addr, error) {
+func (u *TesterConn) LocalAddr() (*Addr, error) {
 	return u.Addr, nil
 }
 
-func (u *Conn) Rebind() error {
+func (u *TesterConn) Rebind() error {
+	return nil
+}
+
+func (u *TesterConn) Close() error {
+	if u.closed.CompareAndSwap(false, true) {
+		close(u.RxPackets)
+		close(u.TxPackets)
+	}
 	return nil
 }

+ 20 - 3
udp/udp_windows.go

@@ -3,14 +3,31 @@
 
 package udp
 
-// Windows support is primarily implemented in udp_generic, besides NewListenConfig
-
 import (
 	"fmt"
 	"net"
 	"syscall"
+
+	"github.com/sirupsen/logrus"
 )
 
+func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+	if multi {
+		//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
+		// The udp stack would need to be reworked to hide away the implementation differences between
+		// Windows and Linux
+		return nil, fmt.Errorf("multiple udp listeners not supported on windows")
+	}
+
+	rc, err := NewRIOListener(l, ip, port)
+	if err == nil {
+		return rc, nil
+	}
+
+	l.WithError(err).Error("Falling back to standard udp sockets")
+	return NewGenericListener(l, ip, port, multi, batch)
+}
+
 func NewListenConfig(multi bool) net.ListenConfig {
 	return net.ListenConfig{
 		Control: func(network, address string, c syscall.RawConn) error {
@@ -24,6 +41,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 }
 
-func (u *Conn) Rebind() error {
+func (u *GenericConn) Rebind() error {
 	return nil
 }

+ 24 - 4
util/error.go

@@ -12,18 +12,38 @@ type ContextualError struct {
 	Context   string
 }
 
-func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
-	return ContextualError{Context: msg, Fields: fields, RealError: realError}
+func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError {
+	return &ContextualError{Context: msg, Fields: fields, RealError: realError}
 }
 
-func (ce ContextualError) Error() string {
+// ContextualizeIfNeeded is a helper function to turn an error into a ContextualError if it is not already one
+func ContextualizeIfNeeded(msg string, err error) error {
+	switch err.(type) {
+	case *ContextualError:
+		return err
+	default:
+		return NewContextualError(msg, nil, err)
+	}
+}
+
+// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError
+func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) {
+	switch v := err.(type) {
+	case *ContextualError:
+		v.Log(l)
+	default:
+		l.WithError(err).Error(msg)
+	}
+}
+
+func (ce *ContextualError) Error() string {
 	if ce.RealError == nil {
 		return ce.Context
 	}
 	return ce.RealError.Error()
 }
 
-func (ce ContextualError) Unwrap() error {
+func (ce *ContextualError) Unwrap() error {
 	if ce.RealError == nil {
 		return errors.New(ce.Context)
 	}

+ 42 - 0
util/error_test.go

@@ -2,6 +2,7 @@ package util
 
 import (
 	"errors"
+	"fmt"
 	"testing"
 
 	"github.com/sirupsen/logrus"
@@ -67,3 +68,44 @@ func TestContextualError_Log(t *testing.T) {
 	e.Log(l)
 	assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
 }
+
+func TestLogWithContextIfNeeded(t *testing.T) {
+	l := logrus.New()
+	l.Formatter = &logrus.TextFormatter{
+		DisableTimestamp: true,
+		DisableColors:    true,
+	}
+
+	tl := NewTestLogWriter()
+	l.Out = tl
+
+	// Test ignoring fallback context
+	tl.Reset()
+	e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
+	LogWithContextIfNeeded("This should get thrown away", e, l)
+	assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
+
+	// Test using fallback context
+	tl.Reset()
+	err := fmt.Errorf("this is a normal error")
+	LogWithContextIfNeeded("Fallback context woo", err, l)
+	assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs)
+}
+
+func TestContextualizeIfNeeded(t *testing.T) {
+	// Test ignoring fallback context
+	e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
+	assert.Same(t, e, ContextualizeIfNeeded("should be ignored", e))
+
+	// Test using fallback context
+	err := fmt.Errorf("this is a normal error")
+	cErr := ContextualizeIfNeeded("Fallback context woo", err)
+
+	switch v := cErr.(type) {
+	case *ContextualError:
+		assert.Equal(t, err, v.RealError)
+	default:
+		t.Error("Error was not wrapped")
+		t.Fail()
+	}
+}