瀏覽代碼

Merge remote-tracking branch 'origin/master' into mutex-debug

Wade Simmons 1 年之前
父節點
當前提交
fdb78044ba
共有 80 個文件被更改,包括 2034 次插入1127 次删除
  1. 2 2
      .github/workflows/gofmt.yml
  2. 7 7
      .github/workflows/release.yml
  3. 3 3
      .github/workflows/smoke.yml
  4. 1 1
      .github/workflows/smoke/build-relay.sh
  5. 1 1
      .github/workflows/smoke/build.sh
  6. 25 25
      .github/workflows/smoke/smoke-relay.sh
  7. 46 46
      .github/workflows/smoke/smoke.sh
  8. 6 6
      .github/workflows/test.yml
  9. 49 1
      CHANGELOG.md
  10. 1 0
      Makefile
  11. 14 29
      allow_list.go
  12. 1 1
      allow_list_test.go
  13. 2 2
      calculated_remote.go
  14. 3 0
      cert/cert.go
  15. 3 0
      cert/crypto.go
  16. 34 28
      cidr/tree4.go
  17. 71 47
      cidr/tree4_test.go
  18. 24 20
      cidr/tree6.go
  19. 45 28
      cidr/tree6_test.go
  20. 7 8
      cmd/nebula-cert/ca.go
  21. 5 6
      cmd/nebula-cert/ca_test.go
  22. 2 3
      cmd/nebula-cert/keygen.go
  23. 4 5
      cmd/nebula-cert/keygen_test.go
  24. 2 3
      cmd/nebula-cert/print.go
  25. 1 2
      cmd/nebula-cert/print_test.go
  26. 6 7
      cmd/nebula-cert/sign.go
  27. 11 12
      cmd/nebula-cert/sign_test.go
  28. 2 3
      cmd/nebula-cert/verify.go
  29. 2 3
      cmd/nebula-cert/verify_test.go
  30. 5 2
      config/config.go
  31. 7 8
      config/config_test.go
  32. 6 19
      connection_manager.go
  33. 21 20
      connection_manager_test.go
  34. 11 14
      connection_state.go
  35. 10 2
      control.go
  36. 1 2
      control_test.go
  37. 1 11
      control_tester.go
  38. 120 16
      e2e/handshakes_test.go
  39. 118 0
      e2e/helpers.go
  40. 1 110
      e2e/helpers_test.go
  41. 3 2
      examples/config.yml
  42. 100 0
      examples/go_service/main.go
  43. 36 14
      firewall.go
  44. 8 4
      firewall_test.go
  45. 16 13
      go.mod
  46. 35 28
      go.sum
  47. 0 31
      handshake.go
  48. 63 80
      handshake_ix.go
  49. 241 125
      handshake_manager.go
  50. 16 18
      handshake_manager_test.go
  51. 18 86
      hostmap.go
  52. 41 87
      inside.go
  53. 17 7
      interface.go
  54. 34 7
      iputil/packet.go
  55. 73 0
      iputil/packet_test.go
  56. 4 5
      lighthouse.go
  57. 16 6
      main.go
  58. 22 6
      mutex_debug.go
  59. 4 5
      outside.go
  60. 2 2
      overlay/route.go
  61. 8 10
      overlay/route_test.go
  62. 24 8
      overlay/tun.go
  63. 3 7
      overlay/tun_android.go
  64. 4 4
      overlay/tun_darwin.go
  65. 3 7
      overlay/tun_freebsd.go
  66. 3 7
      overlay/tun_ios.go
  67. 5 9
      overlay/tun_linux.go
  68. 3 7
      overlay/tun_netbsd.go
  69. 3 7
      overlay/tun_openbsd.go
  70. 3 7
      overlay/tun_tester.go
  71. 3 7
      overlay/tun_water_windows.go
  72. 12 9
      overlay/tun_wintun_windows.go
  73. 63 0
      overlay/user.go
  74. 7 1
      relay_manager.go
  75. 36 0
      service/listener.go
  76. 248 0
      service/service.go
  77. 165 0
      service/service_test.go
  78. 2 4
      ssh.go
  79. 2 2
      test/logger.go
  80. 7 2
      udp/udp_darwin.go

+ 2 - 2
.github/workflows/gofmt.yml

@@ -14,9 +14,9 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - uses: actions/checkout@v3
+    - uses: actions/checkout@v4
 
-    - uses: actions/setup-go@v4
+    - uses: actions/setup-go@v5
       with:
         go-version-file: 'go.mod'
         check-latest: true

+ 7 - 7
.github/workflows/release.yml

@@ -10,9 +10,9 @@ jobs:
     name: Build Linux/BSD All
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
-      - uses: actions/setup-go@v4
+      - uses: actions/setup-go@v5
         with:
           go-version-file: 'go.mod'
           check-latest: true
@@ -33,9 +33,9 @@ jobs:
     name: Build Windows
     runs-on: windows-latest
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
-      - uses: actions/setup-go@v4
+      - uses: actions/setup-go@v5
         with:
           go-version-file: 'go.mod'
           check-latest: true
@@ -66,9 +66,9 @@ jobs:
       HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
     runs-on: macos-11
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
-      - uses: actions/setup-go@v4
+      - uses: actions/setup-go@v5
         with:
           go-version-file: 'go.mod'
           check-latest: true
@@ -114,7 +114,7 @@ jobs:
     needs: [build-linux, build-darwin, build-windows]
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
       - name: Download artifacts
         uses: actions/download-artifact@v3

+ 3 - 3
.github/workflows/smoke.yml

@@ -18,15 +18,15 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - uses: actions/checkout@v3
+    - uses: actions/checkout@v4
 
-    - uses: actions/setup-go@v4
+    - uses: actions/setup-go@v5
       with:
         go-version-file: 'go.mod'
         check-latest: true
 
     - name: build
-      run: make bin-docker
+      run: make bin-docker CGO_ENABLED=1 BUILD_ARGS=-race
 
     - name: setup docker image
       working-directory: ./.github/workflows/smoke

+ 1 - 1
.github/workflows/smoke/build-relay.sh

@@ -41,4 +41,4 @@ EOF
     ../../../../nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24"
 )
 
-sudo docker build -t nebula:smoke-relay .
+docker build -t nebula:smoke-relay .

+ 1 - 1
.github/workflows/smoke/build.sh

@@ -36,4 +36,4 @@ mkdir ./build
     ../../../../nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24"
 )
 
-sudo docker build -t "nebula:${NAME:-smoke}" .
+docker build -t "nebula:${NAME:-smoke}" .

+ 25 - 25
.github/workflows/smoke/smoke-relay.sh

@@ -14,24 +14,24 @@ cleanup() {
     set +e
     if [ "$(jobs -r)" ]
     then
-        sudo docker kill lighthouse1 host2 host3 host4
+        docker kill lighthouse1 host2 host3 host4
     fi
 }
 
 trap cleanup EXIT
 
-sudo docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
-sudo docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test
-sudo docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test
-sudo docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test
+docker run --name lighthouse1 --rm nebula:smoke-relay -config lighthouse1.yml -test
+docker run --name host2 --rm nebula:smoke-relay -config host2.yml -test
+docker run --name host3 --rm nebula:smoke-relay -config host3.yml -test
+docker run --name host4 --rm nebula:smoke-relay -config host4.yml -test
 
-sudo docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/  [lighthouse1]  /' &
+docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/  [lighthouse1]  /' &
 sleep 1
-sudo docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/  [host2]  /' &
+docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/  [host2]  /' &
 sleep 1
-sudo docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/  [host3]  /' &
+docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/  [host3]  /' &
 sleep 1
-sudo docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/  [host4]  /' &
+docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/  [host4]  /' &
 sleep 1
 
 set +x
@@ -39,43 +39,43 @@ echo
 echo " *** Testing ping from lighthouse1"
 echo
 set -x
-sudo docker exec lighthouse1 ping -c1 192.168.100.2
-sudo docker exec lighthouse1 ping -c1 192.168.100.3
-sudo docker exec lighthouse1 ping -c1 192.168.100.4
+docker exec lighthouse1 ping -c1 192.168.100.2
+docker exec lighthouse1 ping -c1 192.168.100.3
+docker exec lighthouse1 ping -c1 192.168.100.4
 
 set +x
 echo
 echo " *** Testing ping from host2"
 echo
 set -x
-sudo docker exec host2 ping -c1 192.168.100.1
+docker exec host2 ping -c1 192.168.100.1
 # Should fail because no relay configured in this direction
-! sudo docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
-! sudo docker exec host2 ping -c1 192.168.100.4 -w5 || exit 1
+! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
+! docker exec host2 ping -c1 192.168.100.4 -w5 || exit 1
 
 set +x
 echo
 echo " *** Testing ping from host3"
 echo
 set -x
-sudo docker exec host3 ping -c1 192.168.100.1
-sudo docker exec host3 ping -c1 192.168.100.2
-sudo docker exec host3 ping -c1 192.168.100.4
+docker exec host3 ping -c1 192.168.100.1
+docker exec host3 ping -c1 192.168.100.2
+docker exec host3 ping -c1 192.168.100.4
 
 set +x
 echo
 echo " *** Testing ping from host4"
 echo
 set -x
-sudo docker exec host4 ping -c1 192.168.100.1
+docker exec host4 ping -c1 192.168.100.1
 # Should fail because relays not allowed
-! sudo docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
-sudo docker exec host4 ping -c1 192.168.100.3
+! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
+docker exec host4 ping -c1 192.168.100.3
 
-sudo docker exec host4 sh -c 'kill 1'
-sudo docker exec host3 sh -c 'kill 1'
-sudo docker exec host2 sh -c 'kill 1'
-sudo docker exec lighthouse1 sh -c 'kill 1'
+docker exec host4 sh -c 'kill 1'
+docker exec host3 sh -c 'kill 1'
+docker exec host2 sh -c 'kill 1'
+docker exec lighthouse1 sh -c 'kill 1'
 sleep 1
 
 if [ "$(jobs -r)" ]

+ 46 - 46
.github/workflows/smoke/smoke.sh

@@ -14,7 +14,7 @@ cleanup() {
     set +e
     if [ "$(jobs -r)" ]
     then
-        sudo docker kill lighthouse1 host2 host3 host4
+        docker kill lighthouse1 host2 host3 host4
     fi
 }
 
@@ -22,51 +22,51 @@ trap cleanup EXIT
 
 CONTAINER="nebula:${NAME:-smoke}"
 
-sudo docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
-sudo docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
-sudo docker run --name host3 --rm "$CONTAINER" -config host3.yml -test
-sudo docker run --name host4 --rm "$CONTAINER" -config host4.yml -test
+docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
+docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
+docker run --name host3 --rm "$CONTAINER" -config host3.yml -test
+docker run --name host4 --rm "$CONTAINER" -config host4.yml -test
 
-sudo docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/  [lighthouse1]  /' &
+docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/  [lighthouse1]  /' &
 sleep 1
-sudo docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/  [host2]  /' &
+docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/  [host2]  /' &
 sleep 1
-sudo docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/  [host3]  /' &
+docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/  [host3]  /' &
 sleep 1
-sudo docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/  [host4]  /' &
+docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/  [host4]  /' &
 sleep 1
 
 # grab tcpdump pcaps for debugging
-sudo docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
-sudo docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
-sudo docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
-sudo docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
-sudo docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
-sudo docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap &
-sudo docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
-sudo docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap &
-
-sudo docker exec host2 ncat -nklv 0.0.0.0 2000 &
-sudo docker exec host3 ncat -nklv 0.0.0.0 2000 &
-sudo docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
-sudo docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 &
+docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
+docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
+docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
+docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
+docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
+docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap &
+docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
+docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap &
+
+docker exec host2 ncat -nklv 0.0.0.0 2000 &
+docker exec host3 ncat -nklv 0.0.0.0 2000 &
+docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
+docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 &
 
 set +x
 echo
 echo " *** Testing ping from lighthouse1"
 echo
 set -x
-sudo docker exec lighthouse1 ping -c1 192.168.100.2
-sudo docker exec lighthouse1 ping -c1 192.168.100.3
+docker exec lighthouse1 ping -c1 192.168.100.2
+docker exec lighthouse1 ping -c1 192.168.100.3
 
 set +x
 echo
 echo " *** Testing ping from host2"
 echo
 set -x
-sudo docker exec host2 ping -c1 192.168.100.1
+docker exec host2 ping -c1 192.168.100.1
 # Should fail because not allowed by host3 inbound firewall
-! sudo docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
+! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
 
 set +x
 echo
@@ -74,34 +74,34 @@ echo " *** Testing ncat from host2"
 echo
 set -x
 # Should fail because not allowed by host3 inbound firewall
-! sudo docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
-! sudo docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
+! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
+! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
 
 set +x
 echo
 echo " *** Testing ping from host3"
 echo
 set -x
-sudo docker exec host3 ping -c1 192.168.100.1
-sudo docker exec host3 ping -c1 192.168.100.2
+docker exec host3 ping -c1 192.168.100.1
+docker exec host3 ping -c1 192.168.100.2
 
 set +x
 echo
 echo " *** Testing ncat from host3"
 echo
 set -x
-sudo docker exec host3 ncat -nzv -w5 192.168.100.2 2000
-sudo docker exec host3 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2
+docker exec host3 ncat -nzv -w5 192.168.100.2 2000
+docker exec host3 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2
 
 set +x
 echo
 echo " *** Testing ping from host4"
 echo
 set -x
-sudo docker exec host4 ping -c1 192.168.100.1
+docker exec host4 ping -c1 192.168.100.1
 # Should fail because not allowed by host4 outbound firewall
-! sudo docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
-! sudo docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1
+! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1
+! docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1
 
 set +x
 echo
@@ -109,10 +109,10 @@ echo " *** Testing ncat from host4"
 echo
 set -x
 # Should fail because not allowed by host4 outbound firewall
-! sudo docker exec host4 ncat -nzv -w5 192.168.100.2 2000 || exit 1
-! sudo docker exec host4 ncat -nzv -w5 192.168.100.3 2000 || exit 1
-! sudo docker exec host4 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 || exit 1
-! sudo docker exec host4 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
+! docker exec host4 ncat -nzv -w5 192.168.100.2 2000 || exit 1
+! docker exec host4 ncat -nzv -w5 192.168.100.3 2000 || exit 1
+! docker exec host4 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 || exit 1
+! docker exec host4 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
 
 set +x
 echo
@@ -120,15 +120,15 @@ echo " *** Testing conntrack"
 echo
 set -x
 # host2 can ping host3 now that host3 pinged it first
-sudo docker exec host2 ping -c1 192.168.100.3
+docker exec host2 ping -c1 192.168.100.3
 # host4 can ping host2 once conntrack established
-sudo docker exec host2 ping -c1 192.168.100.4
-sudo docker exec host4 ping -c1 192.168.100.2
+docker exec host2 ping -c1 192.168.100.4
+docker exec host4 ping -c1 192.168.100.2
 
-sudo docker exec host4 sh -c 'kill 1'
-sudo docker exec host3 sh -c 'kill 1'
-sudo docker exec host2 sh -c 'kill 1'
-sudo docker exec lighthouse1 sh -c 'kill 1'
+docker exec host4 sh -c 'kill 1'
+docker exec host3 sh -c 'kill 1'
+docker exec host2 sh -c 'kill 1'
+docker exec lighthouse1 sh -c 'kill 1'
 sleep 1
 
 if [ "$(jobs -r)" ]

+ 6 - 6
.github/workflows/test.yml

@@ -18,9 +18,9 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - uses: actions/checkout@v3
+    - uses: actions/checkout@v4
 
-    - uses: actions/setup-go@v4
+    - uses: actions/setup-go@v5
       with:
         go-version-file: 'go.mod'
         check-latest: true
@@ -48,9 +48,9 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - uses: actions/checkout@v3
+    - uses: actions/checkout@v4
 
-    - uses: actions/setup-go@v4
+    - uses: actions/setup-go@v5
       with:
         go-version-file: 'go.mod'
         check-latest: true
@@ -72,9 +72,9 @@ jobs:
         os: [windows-latest, macos-11]
     steps:
 
-    - uses: actions/checkout@v3
+    - uses: actions/checkout@v4
 
-    - uses: actions/setup-go@v4
+    - uses: actions/setup-go@v5
       with:
         go-version-file: 'go.mod'
         check-latest: true

+ 49 - 1
CHANGELOG.md

@@ -7,6 +7,53 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ## [Unreleased]
 
+## [1.8.0] - 2023-12-06
+
+### Deprecated
+
+- The next minor release of Nebula, 1.9.0, will require at least Windows 10 or
+  Windows Server 2016. This is because support for earlier versions was removed
+  in Go 1.21. See https://go.dev/doc/go1.21#windows
+
+### Added
+
+- Linux: Notify systemd of service readiness. This should resolve timing issues
+  with services that depend on Nebula being active. For an example of how to
+  enable this, see: `examples/service_scripts/nebula.service`. (#929)
+
+- Windows: Use Registered IO (RIO) when possible. Testing on a Windows 11
+  machine shows ~50x improvement in throughput. (#905)
+
+- NetBSD, OpenBSD: Added rudimentary support. (#916, #812)
+
+- FreeBSD: Add support for naming tun devices. (#903)
+
+### Changed
+
+- `pki.disconnect_invalid` will now default to true. This means that once a
+  certificate expires, the tunnel will be disconnected. If you use SIGHUP to
+  reload certificates without restarting Nebula, you should ensure all of your
+  clients are on 1.7.0 or newer before you enable this feature. (#859)
+
+- Limit how often a busy tunnel can requery the lighthouse. The new config
+  option `timers.requery_wait_duration` defaults to `60s`. (#940)
+
+- The internal structures for hostmaps were refactored to reduce memory usage
+  and the potential for subtle bugs. (#843, #938, #953, #954, #955)
+
+- Lots of dependency updates.
+
+### Fixed
+
+- Windows: Retry wintun device creation if it fails the first time. (#985)
+
+- Fix issues with firewall reject packets that could cause panics. (#957)
+
+- Fix relay migration during re-handshakes. (#964)
+
+- Various other refactors and fixes. (#935, #952, #972, #961, #996, #1002,
+  #987, #1004, #1030, #1032, ...)
+
 ## [1.7.2] - 2023-06-01
 
 ### Fixed
@@ -488,7 +535,8 @@ created.)
 
 - Initial public release.
 
-[Unreleased]: https://github.com/slackhq/nebula/compare/v1.7.2...HEAD
+[Unreleased]: https://github.com/slackhq/nebula/compare/v1.8.0...HEAD
+[1.8.0]: https://github.com/slackhq/nebula/releases/tag/v1.8.0
 [1.7.2]: https://github.com/slackhq/nebula/releases/tag/v1.7.2
 [1.7.1]: https://github.com/slackhq/nebula/releases/tag/v1.7.1
 [1.7.0]: https://github.com/slackhq/nebula/releases/tag/v1.7.0

+ 1 - 0
Makefile

@@ -213,6 +213,7 @@ smoke-relay-docker: bin-docker
 	cd .github/workflows/smoke/ && ./smoke-relay.sh
 
 smoke-docker-race: BUILD_ARGS = -race
+smoke-docker-race: CGO_ENABLED = 1
 smoke-docker-race: smoke-docker
 
 .FORCE:

+ 14 - 29
allow_list.go

@@ -12,7 +12,7 @@ import (
 
 type AllowList struct {
 	// The values of this cidrTree are `bool`, signifying allow/deny
-	cidrTree *cidr.Tree6
+	cidrTree *cidr.Tree6[bool]
 }
 
 type RemoteAllowList struct {
@@ -20,7 +20,7 @@ type RemoteAllowList struct {
 
 	// Inside Range Specific, keys of this tree are inside CIDRs and values
 	// are *AllowList
-	insideAllowLists *cidr.Tree6
+	insideAllowLists *cidr.Tree6[*AllowList]
 }
 
 type LocalAllowList struct {
@@ -88,7 +88,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
 	}
 
-	tree := cidr.NewTree6()
+	tree := cidr.NewTree6[bool]()
 
 	// Keep track of the rules we have added for both ipv4 and ipv6
 	type allowListRules struct {
@@ -218,13 +218,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
 	return nameRules, nil
 }
 
-func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) {
+func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) {
 	value := c.Get(k)
 	if value == nil {
 		return nil, nil
 	}
 
-	remoteAllowRanges := cidr.NewTree6()
+	remoteAllowRanges := cidr.NewTree6[*AllowList]()
 
 	rawMap, ok := value.(map[interface{}]interface{})
 	if !ok {
@@ -257,13 +257,8 @@ func (al *AllowList) Allow(ip net.IP) bool {
 		return true
 	}
 
-	result := al.cidrTree.MostSpecificContains(ip)
-	switch v := result.(type) {
-	case bool:
-		return v
-	default:
-		panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
-	}
+	_, result := al.cidrTree.MostSpecificContains(ip)
+	return result
 }
 
 func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
@@ -271,13 +266,8 @@ func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
 		return true
 	}
 
-	result := al.cidrTree.MostSpecificContainsIpV4(ip)
-	switch v := result.(type) {
-	case bool:
-		return v
-	default:
-		panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
-	}
+	_, result := al.cidrTree.MostSpecificContainsIpV4(ip)
+	return result
 }
 
 func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
@@ -285,13 +275,8 @@ func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
 		return true
 	}
 
-	result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
-	switch v := result.(type) {
-	case bool:
-		return v
-	default:
-		panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
-	}
+	_, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
+	return result
 }
 
 func (al *LocalAllowList) Allow(ip net.IP) bool {
@@ -352,9 +337,9 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
 
 func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
 	if al.insideAllowLists != nil {
-		inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
-		if inside != nil {
-			return inside.(*AllowList)
+		ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
+		if ok {
+			return inside
 		}
 	}
 	return nil

+ 1 - 1
allow_list_test.go

@@ -100,7 +100,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
 func TestAllowList_Allow(t *testing.T) {
 	assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
 
-	tree := cidr.NewTree6()
+	tree := cidr.NewTree6[bool]()
 	tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
 	tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
 	tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)

+ 2 - 2
calculated_remote.go

@@ -51,13 +51,13 @@ func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort {
 	return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
 }
 
-func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4, error) {
+func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) {
 	value := c.Get(k)
 	if value == nil {
 		return nil, nil
 	}
 
-	calculatedRemotes := cidr.NewTree4()
+	calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]()
 
 	rawMap, ok := value.(map[any]any)
 	if !ok {

+ 3 - 0
cert/cert.go

@@ -272,6 +272,9 @@ func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte
 		},
 		Ciphertext: ciphertext,
 	})
+	if err != nil {
+		return nil, err
+	}
 
 	switch curve {
 	case Curve_CURVE25519:

+ 3 - 0
cert/crypto.go

@@ -77,6 +77,9 @@ func aes256Decrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte)
 	}
 
 	gcm, err := cipher.NewGCM(block)
+	if err != nil {
+		return nil, err
+	}
 
 	nonce, ciphertext, err := splitNonceCiphertext(data, gcm.NonceSize())
 	if err != nil {

+ 34 - 28
cidr/tree4.go

@@ -6,35 +6,36 @@ import (
 	"github.com/slackhq/nebula/iputil"
 )
 
-type Node struct {
-	left   *Node
-	right  *Node
-	parent *Node
-	value  interface{}
+type Node[T any] struct {
+	left     *Node[T]
+	right    *Node[T]
+	parent   *Node[T]
+	hasValue bool
+	value    T
 }
 
-type entry struct {
+type entry[T any] struct {
 	CIDR  *net.IPNet
-	Value *interface{}
+	Value T
 }
 
-type Tree4 struct {
-	root *Node
-	list []entry
+type Tree4[T any] struct {
+	root *Node[T]
+	list []entry[T]
 }
 
 const (
 	startbit = iputil.VpnIp(0x80000000)
 )
 
-func NewTree4() *Tree4 {
-	tree := new(Tree4)
-	tree.root = &Node{}
-	tree.list = []entry{}
+func NewTree4[T any]() *Tree4[T] {
+	tree := new(Tree4[T])
+	tree.root = &Node[T]{}
+	tree.list = []entry[T]{}
 	return tree
 }
 
-func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
+func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
 	bit := startbit
 	node := tree.root
 	next := tree.root
@@ -68,14 +69,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
 			}
 		}
 
-		tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
+		tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
 		node.value = val
+		node.hasValue = true
 		return
 	}
 
 	// Build up the rest of the tree we don't already have
 	for bit&mask != 0 {
-		next = &Node{}
+		next = &Node[T]{}
 		next.parent = node
 
 		if ip&bit != 0 {
@@ -90,17 +92,18 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 	// Final node marks our cidr, set the value
 	node.value = val
-	tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
+	node.hasValue = true
+	tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
 }
 
 // Contains finds the first match, which may be the least specific
-func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
+func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
 	bit := startbit
 	node := tree.root
 
 	for node != nil {
-		if node.value != nil {
-			return node.value
+		if node.hasValue {
+			return true, node.value
 		}
 
 		if ip&bit != 0 {
@@ -113,17 +116,18 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
 
 	}
 
-	return value
+	return false, value
 }
 
 // MostSpecificContains finds the most specific match
-func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
+func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
 	bit := startbit
 	node := tree.root
 
 	for node != nil {
-		if node.value != nil {
+		if node.hasValue {
 			value = node.value
+			ok = true
 		}
 
 		if ip&bit != 0 {
@@ -135,11 +139,12 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
 		bit >>= 1
 	}
 
-	return value
+	return ok, value
 }
 
 // Match finds the most specific match
-func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
+// TODO this is exact match
+func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
 	bit := startbit
 	node := tree.root
 	lastNode := node
@@ -157,11 +162,12 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
 
 	if bit == 0 && lastNode != nil {
 		value = lastNode.value
+		ok = true
 	}
-	return value
+	return ok, value
 }
 
 // List will return all CIDRs and their current values. Do not modify the contents!
-func (tree *Tree4) List() []entry {
+func (tree *Tree4[T]) List() []entry[T] {
 	return tree.list
 }

+ 71 - 47
cidr/tree4_test.go

@@ -9,7 +9,7 @@ import (
 )
 
 func TestCIDRTree_List(t *testing.T) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.0.0.0/16"), "1")
 	tree.AddCIDR(Parse("1.0.0.0/8"), "2")
 	tree.AddCIDR(Parse("1.0.0.0/16"), "3")
@@ -17,13 +17,13 @@ func TestCIDRTree_List(t *testing.T) {
 	list := tree.List()
 	assert.Len(t, list, 2)
 	assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
-	assert.Equal(t, "2", *list[0].Value)
+	assert.Equal(t, "2", list[0].Value)
 	assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
-	assert.Equal(t, "4", *list[1].Value)
+	assert.Equal(t, "4", list[1].Value)
 }
 
 func TestCIDRTree_Contains(t *testing.T) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
 	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
 	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
@@ -33,35 +33,43 @@ func TestCIDRTree_Contains(t *testing.T) {
 	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"1", "1.0.0.0"},
-		{"1", "1.255.255.255"},
-		{"2", "2.1.0.0"},
-		{"2", "2.1.255.255"},
-		{"3", "3.1.1.0"},
-		{"3", "3.1.1.255"},
-		{"4a", "4.1.1.255"},
-		{"4a", "4.1.1.1"},
-		{"5", "240.0.0.0"},
-		{"5", "255.255.255.255"},
-		{nil, "239.0.0.0"},
-		{nil, "4.1.2.2"},
+		{true, "1", "1.0.0.0"},
+		{true, "1", "1.255.255.255"},
+		{true, "2", "2.1.0.0"},
+		{true, "2", "2.1.255.255"},
+		{true, "3", "3.1.1.0"},
+		{true, "3", "3.1.1.255"},
+		{true, "4a", "4.1.1.255"},
+		{true, "4a", "4.1.1.1"},
+		{true, "5", "240.0.0.0"},
+		{true, "5", "255.255.255.255"},
+		{false, "", "239.0.0.0"},
+		{false, "", "4.1.2.2"},
 	}
 
 	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+		ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 
-	tree = NewTree4()
+	tree = NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+	ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
 }
 
 func TestCIDRTree_MostSpecificContains(t *testing.T) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
 	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
 	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
@@ -71,59 +79,75 @@ func TestCIDRTree_MostSpecificContains(t *testing.T) {
 	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"1", "1.0.0.0"},
-		{"1", "1.255.255.255"},
-		{"2", "2.1.0.0"},
-		{"2", "2.1.255.255"},
-		{"3", "3.1.1.0"},
-		{"3", "3.1.1.255"},
-		{"4a", "4.1.1.255"},
-		{"4b", "4.1.1.2"},
-		{"4c", "4.1.1.1"},
-		{"5", "240.0.0.0"},
-		{"5", "255.255.255.255"},
-		{nil, "239.0.0.0"},
-		{nil, "4.1.2.2"},
+		{true, "1", "1.0.0.0"},
+		{true, "1", "1.255.255.255"},
+		{true, "2", "2.1.0.0"},
+		{true, "2", "2.1.255.255"},
+		{true, "3", "3.1.1.0"},
+		{true, "3", "3.1.1.255"},
+		{true, "4a", "4.1.1.255"},
+		{true, "4b", "4.1.1.2"},
+		{true, "4c", "4.1.1.1"},
+		{true, "5", "240.0.0.0"},
+		{true, "5", "255.255.255.255"},
+		{false, "", "239.0.0.0"},
+		{false, "", "4.1.2.2"},
 	}
 
 	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+		ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 
-	tree = NewTree4()
+	tree = NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+	ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
 }
 
 func TestCIDRTree_Match(t *testing.T) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
 	tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"1a", "4.1.1.0"},
-		{"1b", "4.1.1.1"},
+		{true, "1a", "4.1.1.0"},
+		{true, "1b", "4.1.1.1"},
 	}
 
 	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+		ok, r := tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 
-	tree = NewTree4()
+	tree = NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+	ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
 }
 
 func BenchmarkCIDRTree_Contains(b *testing.B) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
 	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
 	tree.AddCIDR(Parse("192.2.1.1/32"), "1")
@@ -145,7 +169,7 @@ func BenchmarkCIDRTree_Contains(b *testing.B) {
 }
 
 func BenchmarkCIDRTree_Match(b *testing.B) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
 	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
 	tree.AddCIDR(Parse("192.2.1.1/32"), "1")

+ 24 - 20
cidr/tree6.go

@@ -8,20 +8,20 @@ import (
 
 const startbit6 = uint64(1 << 63)
 
-type Tree6 struct {
-	root4 *Node
-	root6 *Node
+type Tree6[T any] struct {
+	root4 *Node[T]
+	root6 *Node[T]
 }
 
-func NewTree6() *Tree6 {
-	tree := new(Tree6)
-	tree.root4 = &Node{}
-	tree.root6 = &Node{}
+func NewTree6[T any]() *Tree6[T] {
+	tree := new(Tree6[T])
+	tree.root4 = &Node[T]{}
+	tree.root6 = &Node[T]{}
 	return tree
 }
 
-func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
-	var node, next *Node
+func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
+	var node, next *Node[T]
 
 	cidrIP, ipv4 := isIPV4(cidr.IP)
 	if ipv4 {
@@ -56,7 +56,7 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 		// Build up the rest of the tree we don't already have
 		for bit&mask != 0 {
-			next = &Node{}
+			next = &Node[T]{}
 			next.parent = node
 
 			if ip&bit != 0 {
@@ -72,11 +72,12 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 	// Final node marks our cidr, set the value
 	node.value = val
+	node.hasValue = true
 }
 
 // Finds the most specific match
-func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
-	var node *Node
+func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
+	var node *Node[T]
 
 	wholeIP, ipv4 := isIPV4(ip)
 	if ipv4 {
@@ -90,8 +91,9 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
 		bit := startbit
 
 		for node != nil {
-			if node.value != nil {
+			if node.hasValue {
 				value = node.value
+				ok = true
 			}
 
 			if bit == 0 {
@@ -108,16 +110,17 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
 		}
 	}
 
-	return value
+	return ok, value
 }
 
-func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) {
+func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
 	bit := startbit
 	node := tree.root4
 
 	for node != nil {
-		if node.value != nil {
+		if node.hasValue {
 			value = node.value
+			ok = true
 		}
 
 		if ip&bit != 0 {
@@ -129,10 +132,10 @@ func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{})
 		bit >>= 1
 	}
 
-	return value
+	return ok, value
 }
 
-func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
+func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
 	ip := hi
 	node := tree.root6
 
@@ -140,8 +143,9 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
 		bit := startbit6
 
 		for node != nil {
-			if node.value != nil {
+			if node.hasValue {
 				value = node.value
+				ok = true
 			}
 
 			if bit == 0 {
@@ -160,7 +164,7 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
 		ip = lo
 	}
 
-	return value
+	return ok, value
 }
 
 func isIPV4(ip net.IP) (net.IP, bool) {

+ 45 - 28
cidr/tree6_test.go

@@ -9,7 +9,7 @@ import (
 )
 
 func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
-	tree := NewTree6()
+	tree := NewTree6[string]()
 	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
 	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
 	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
@@ -22,53 +22,68 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
 	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"1", "1.0.0.0"},
-		{"1", "1.255.255.255"},
-		{"2", "2.1.0.0"},
-		{"2", "2.1.255.255"},
-		{"3", "3.1.1.0"},
-		{"3", "3.1.1.255"},
-		{"4a", "4.1.1.255"},
-		{"4b", "4.1.1.2"},
-		{"4c", "4.1.1.1"},
-		{"5", "240.0.0.0"},
-		{"5", "255.255.255.255"},
-		{"6a", "1:2:0:4:1:1:1:1"},
-		{"6b", "1:2:0:4:5:1:1:1"},
-		{"6c", "1:2:0:4:5:0:0:0"},
-		{nil, "239.0.0.0"},
-		{nil, "4.1.2.2"},
+		{true, "1", "1.0.0.0"},
+		{true, "1", "1.255.255.255"},
+		{true, "2", "2.1.0.0"},
+		{true, "2", "2.1.255.255"},
+		{true, "3", "3.1.1.0"},
+		{true, "3", "3.1.1.255"},
+		{true, "4a", "4.1.1.255"},
+		{true, "4b", "4.1.1.2"},
+		{true, "4c", "4.1.1.1"},
+		{true, "5", "240.0.0.0"},
+		{true, "5", "255.255.255.255"},
+		{true, "6a", "1:2:0:4:1:1:1:1"},
+		{true, "6b", "1:2:0:4:5:1:1:1"},
+		{true, "6c", "1:2:0:4:5:0:0:0"},
+		{false, "", "239.0.0.0"},
+		{false, "", "4.1.2.2"},
 	}
 
 	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP)))
+		ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP))
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 
-	tree = NewTree6()
+	tree = NewTree6[string]()
 	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
 	tree.AddCIDR(Parse("::/0"), "cool6")
-	assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0")))
-	assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255")))
-	assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::")))
-	assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8")))
+	ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0"))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255"))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.MostSpecificContains(net.ParseIP("::"))
+	assert.True(t, ok)
+	assert.Equal(t, "cool6", r)
+
+	ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))
+	assert.True(t, ok)
+	assert.Equal(t, "cool6", r)
 }
 
 func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
-	tree := NewTree6()
+	tree := NewTree6[string]()
 	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
 	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
 	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"6a", "1:2:0:4:1:1:1:1"},
-		{"6b", "1:2:0:4:5:1:1:1"},
-		{"6c", "1:2:0:4:5:0:0:0"},
+		{true, "6a", "1:2:0:4:1:1:1:1"},
+		{true, "6b", "1:2:0:4:5:1:1:1"},
+		{true, "6c", "1:2:0:4:5:0:0:0"},
 	}
 
 	for _, tt := range tests {
@@ -76,6 +91,8 @@ func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
 		hi := binary.BigEndian.Uint64(ip[:8])
 		lo := binary.BigEndian.Uint64(ip[8:])
 
-		assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(hi, lo))
+		ok, r := tree.MostSpecificContainsIpV6(hi, lo)
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 }

+ 7 - 8
cmd/nebula-cert/ca.go

@@ -7,7 +7,6 @@ import (
 	"flag"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"math"
 	"net"
 	"os"
@@ -213,27 +212,27 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 		return fmt.Errorf("error while signing: %s", err)
 	}
 
+	var b []byte
 	if *cf.encryption {
-		b, err := cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams)
+		b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams)
 		if err != nil {
 			return fmt.Errorf("error while encrypting out-key: %s", err)
 		}
-
-		err = ioutil.WriteFile(*cf.outKeyPath, b, 0600)
 	} else {
-		err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalSigningPrivateKey(curve, rawPriv), 0600)
+		b = cert.MarshalSigningPrivateKey(curve, rawPriv)
 	}
 
+	err = os.WriteFile(*cf.outKeyPath, b, 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-key: %s", err)
 	}
 
-	b, err := nc.MarshalToPEM()
+	b, err = nc.MarshalToPEM()
 	if err != nil {
 		return fmt.Errorf("error while marshalling certificate: %s", err)
 	}
 
-	err = ioutil.WriteFile(*cf.outCertPath, b, 0600)
+	err = os.WriteFile(*cf.outCertPath, b, 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-crt: %s", err)
 	}
@@ -244,7 +243,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 			return fmt.Errorf("error while generating qr code: %s", err)
 		}
 
-		err = ioutil.WriteFile(*cf.outQRPath, b, 0600)
+		err = os.WriteFile(*cf.outQRPath, b, 0600)
 		if err != nil {
 			return fmt.Errorf("error while writing out-qr: %s", err)
 		}

+ 5 - 6
cmd/nebula-cert/ca_test.go

@@ -7,7 +7,6 @@ import (
 	"bytes"
 	"encoding/pem"
 	"errors"
-	"io/ioutil"
 	"os"
 	"strings"
 	"testing"
@@ -107,7 +106,7 @@ func Test_ca(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// create temp key file
-	keyF, err := ioutil.TempFile("", "test.key")
+	keyF, err := os.CreateTemp("", "test.key")
 	assert.Nil(t, err)
 	os.Remove(keyF.Name())
 
@@ -120,7 +119,7 @@ func Test_ca(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// create temp cert file
-	crtF, err := ioutil.TempFile("", "test.crt")
+	crtF, err := os.CreateTemp("", "test.crt")
 	assert.Nil(t, err)
 	os.Remove(crtF.Name())
 	os.Remove(keyF.Name())
@@ -134,13 +133,13 @@ func Test_ca(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// read cert and key files
-	rb, _ := ioutil.ReadFile(keyF.Name())
+	rb, _ := os.ReadFile(keyF.Name())
 	lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Len(t, lKey, 64)
 
-	rb, _ = ioutil.ReadFile(crtF.Name())
+	rb, _ = os.ReadFile(crtF.Name())
 	lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
@@ -166,7 +165,7 @@ func Test_ca(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// read encrypted key file and verify default params
-	rb, _ = ioutil.ReadFile(keyF.Name())
+	rb, _ = os.ReadFile(keyF.Name())
 	k, _ := pem.Decode(rb)
 	ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes)
 	assert.Nil(t, err)

+ 2 - 3
cmd/nebula-cert/keygen.go

@@ -4,7 +4,6 @@ import (
 	"flag"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"os"
 
 	"github.com/slackhq/nebula/cert"
@@ -54,12 +53,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
 		return fmt.Errorf("invalid curve: %s", *cf.curve)
 	}
 
-	err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
+	err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-key: %s", err)
 	}
 
-	err = ioutil.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600)
+	err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-pub: %s", err)
 	}

+ 4 - 5
cmd/nebula-cert/keygen_test.go

@@ -2,7 +2,6 @@ package main
 
 import (
 	"bytes"
-	"io/ioutil"
 	"os"
 	"testing"
 
@@ -54,7 +53,7 @@ func Test_keygen(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// create temp key file
-	keyF, err := ioutil.TempFile("", "test.key")
+	keyF, err := os.CreateTemp("", "test.key")
 	assert.Nil(t, err)
 	defer os.Remove(keyF.Name())
 
@@ -67,7 +66,7 @@ func Test_keygen(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// create temp pub file
-	pubF, err := ioutil.TempFile("", "test.pub")
+	pubF, err := os.CreateTemp("", "test.pub")
 	assert.Nil(t, err)
 	defer os.Remove(pubF.Name())
 
@@ -80,13 +79,13 @@ func Test_keygen(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// read cert and key files
-	rb, _ := ioutil.ReadFile(keyF.Name())
+	rb, _ := os.ReadFile(keyF.Name())
 	lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Len(t, lKey, 32)
 
-	rb, _ = ioutil.ReadFile(pubF.Name())
+	rb, _ = os.ReadFile(pubF.Name())
 	lPub, b, err := cert.UnmarshalX25519PublicKey(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)

+ 2 - 3
cmd/nebula-cert/print.go

@@ -5,7 +5,6 @@ import (
 	"flag"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"os"
 	"strings"
 
@@ -41,7 +40,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 		return err
 	}
 
-	rawCert, err := ioutil.ReadFile(*pf.path)
+	rawCert, err := os.ReadFile(*pf.path)
 	if err != nil {
 		return fmt.Errorf("unable to read cert; %s", err)
 	}
@@ -87,7 +86,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 			return fmt.Errorf("error while generating qr code: %s", err)
 		}
 
-		err = ioutil.WriteFile(*pf.outQRPath, b, 0600)
+		err = os.WriteFile(*pf.outQRPath, b, 0600)
 		if err != nil {
 			return fmt.Errorf("error while writing out-qr: %s", err)
 		}

+ 1 - 2
cmd/nebula-cert/print_test.go

@@ -2,7 +2,6 @@ package main
 
 import (
 	"bytes"
-	"io/ioutil"
 	"os"
 	"testing"
 	"time"
@@ -54,7 +53,7 @@ func Test_printCert(t *testing.T) {
 	// invalid cert at path
 	ob.Reset()
 	eb.Reset()
-	tf, err := ioutil.TempFile("", "print-cert")
+	tf, err := os.CreateTemp("", "print-cert")
 	assert.Nil(t, err)
 	defer os.Remove(tf.Name())
 

+ 6 - 7
cmd/nebula-cert/sign.go

@@ -6,7 +6,6 @@ import (
 	"flag"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"net"
 	"os"
 	"strings"
@@ -73,7 +72,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return newHelpErrorf("cannot set both -in-pub and -out-key")
 	}
 
-	rawCAKey, err := ioutil.ReadFile(*sf.caKeyPath)
+	rawCAKey, err := os.ReadFile(*sf.caKeyPath)
 	if err != nil {
 		return fmt.Errorf("error while reading ca-key: %s", err)
 	}
@@ -112,7 +111,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("error while parsing ca-key: %s", err)
 	}
 
-	rawCACert, err := ioutil.ReadFile(*sf.caCertPath)
+	rawCACert, err := os.ReadFile(*sf.caCertPath)
 	if err != nil {
 		return fmt.Errorf("error while reading ca-crt: %s", err)
 	}
@@ -178,7 +177,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 
 	var pub, rawPriv []byte
 	if *sf.inPubPath != "" {
-		rawPub, err := ioutil.ReadFile(*sf.inPubPath)
+		rawPub, err := os.ReadFile(*sf.inPubPath)
 		if err != nil {
 			return fmt.Errorf("error while reading in-pub: %s", err)
 		}
@@ -235,7 +234,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 			return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
 		}
 
-		err = ioutil.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
+		err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
 		if err != nil {
 			return fmt.Errorf("error while writing out-key: %s", err)
 		}
@@ -246,7 +245,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("error while marshalling certificate: %s", err)
 	}
 
-	err = ioutil.WriteFile(*sf.outCertPath, b, 0600)
+	err = os.WriteFile(*sf.outCertPath, b, 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-crt: %s", err)
 	}
@@ -257,7 +256,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 			return fmt.Errorf("error while generating qr code: %s", err)
 		}
 
-		err = ioutil.WriteFile(*sf.outQRPath, b, 0600)
+		err = os.WriteFile(*sf.outQRPath, b, 0600)
 		if err != nil {
 			return fmt.Errorf("error while writing out-qr: %s", err)
 		}

+ 11 - 12
cmd/nebula-cert/sign_test.go

@@ -7,7 +7,6 @@ import (
 	"bytes"
 	"crypto/rand"
 	"errors"
-	"io/ioutil"
 	"os"
 	"testing"
 	"time"
@@ -104,7 +103,7 @@ func Test_signCert(t *testing.T) {
 	// failed to unmarshal key
 	ob.Reset()
 	eb.Reset()
-	caKeyF, err := ioutil.TempFile("", "sign-cert.key")
+	caKeyF, err := os.CreateTemp("", "sign-cert.key")
 	assert.Nil(t, err)
 	defer os.Remove(caKeyF.Name())
 
@@ -128,7 +127,7 @@ func Test_signCert(t *testing.T) {
 	// failed to unmarshal cert
 	ob.Reset()
 	eb.Reset()
-	caCrtF, err := ioutil.TempFile("", "sign-cert.crt")
+	caCrtF, err := os.CreateTemp("", "sign-cert.crt")
 	assert.Nil(t, err)
 	defer os.Remove(caCrtF.Name())
 
@@ -159,7 +158,7 @@ func Test_signCert(t *testing.T) {
 	// failed to unmarshal pub
 	ob.Reset()
 	eb.Reset()
-	inPubF, err := ioutil.TempFile("", "in.pub")
+	inPubF, err := os.CreateTemp("", "in.pub")
 	assert.Nil(t, err)
 	defer os.Remove(inPubF.Name())
 
@@ -206,7 +205,7 @@ func Test_signCert(t *testing.T) {
 
 	// mismatched ca key
 	_, caPriv2, _ := ed25519.GenerateKey(rand.Reader)
-	caKeyF2, err := ioutil.TempFile("", "sign-cert-2.key")
+	caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
 	assert.Nil(t, err)
 	defer os.Remove(caKeyF2.Name())
 	caKeyF2.Write(cert.MarshalEd25519PrivateKey(caPriv2))
@@ -227,7 +226,7 @@ func Test_signCert(t *testing.T) {
 	assert.Empty(t, eb.String())
 
 	// create temp key file
-	keyF, err := ioutil.TempFile("", "test.key")
+	keyF, err := os.CreateTemp("", "test.key")
 	assert.Nil(t, err)
 	os.Remove(keyF.Name())
 
@@ -241,7 +240,7 @@ func Test_signCert(t *testing.T) {
 	os.Remove(keyF.Name())
 
 	// create temp cert file
-	crtF, err := ioutil.TempFile("", "test.crt")
+	crtF, err := os.CreateTemp("", "test.crt")
 	assert.Nil(t, err)
 	os.Remove(crtF.Name())
 
@@ -254,13 +253,13 @@ func Test_signCert(t *testing.T) {
 	assert.Empty(t, eb.String())
 
 	// read cert and key files
-	rb, _ := ioutil.ReadFile(keyF.Name())
+	rb, _ := os.ReadFile(keyF.Name())
 	lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Len(t, lKey, 32)
 
-	rb, _ = ioutil.ReadFile(crtF.Name())
+	rb, _ = os.ReadFile(crtF.Name())
 	lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
@@ -296,7 +295,7 @@ func Test_signCert(t *testing.T) {
 	assert.Empty(t, eb.String())
 
 	// read cert file and check pub key matches in-pub
-	rb, _ = ioutil.ReadFile(crtF.Name())
+	rb, _ = os.ReadFile(crtF.Name())
 	lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
@@ -348,11 +347,11 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 
-	caKeyF, err = ioutil.TempFile("", "sign-cert.key")
+	caKeyF, err = os.CreateTemp("", "sign-cert.key")
 	assert.Nil(t, err)
 	defer os.Remove(caKeyF.Name())
 
-	caCrtF, err = ioutil.TempFile("", "sign-cert.crt")
+	caCrtF, err = os.CreateTemp("", "sign-cert.crt")
 	assert.Nil(t, err)
 	defer os.Remove(caCrtF.Name())
 

+ 2 - 3
cmd/nebula-cert/verify.go

@@ -4,7 +4,6 @@ import (
 	"flag"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"os"
 	"strings"
 	"time"
@@ -40,7 +39,7 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 		return err
 	}
 
-	rawCACert, err := ioutil.ReadFile(*vf.caPath)
+	rawCACert, err := os.ReadFile(*vf.caPath)
 	if err != nil {
 		return fmt.Errorf("error while reading ca: %s", err)
 	}
@@ -57,7 +56,7 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 		}
 	}
 
-	rawCert, err := ioutil.ReadFile(*vf.certPath)
+	rawCert, err := os.ReadFile(*vf.certPath)
 	if err != nil {
 		return fmt.Errorf("unable to read crt; %s", err)
 	}

+ 2 - 3
cmd/nebula-cert/verify_test.go

@@ -3,7 +3,6 @@ package main
 import (
 	"bytes"
 	"crypto/rand"
-	"io/ioutil"
 	"os"
 	"testing"
 	"time"
@@ -56,7 +55,7 @@ func Test_verify(t *testing.T) {
 	// invalid ca at path
 	ob.Reset()
 	eb.Reset()
-	caFile, err := ioutil.TempFile("", "verify-ca")
+	caFile, err := os.CreateTemp("", "verify-ca")
 	assert.Nil(t, err)
 	defer os.Remove(caFile.Name())
 
@@ -92,7 +91,7 @@ func Test_verify(t *testing.T) {
 	// invalid crt at path
 	ob.Reset()
 	eb.Reset()
-	certFile, err := ioutil.TempFile("", "verify-cert")
+	certFile, err := os.CreateTemp("", "verify-cert")
 	assert.Nil(t, err)
 	defer os.Remove(certFile.Name())
 

+ 5 - 2
config/config.go

@@ -4,7 +4,6 @@ import (
 	"context"
 	"errors"
 	"fmt"
-	"io/ioutil"
 	"math"
 	"os"
 	"os/signal"
@@ -122,6 +121,10 @@ func (c *C) HasChanged(k string) bool {
 // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
 // original path provided to Load. The old settings are shallow copied for change detection after the reload.
 func (c *C) CatchHUP(ctx context.Context) {
+	if c.path == "" {
+		return
+	}
+
 	ch := make(chan os.Signal, 1)
 	signal.Notify(ch, syscall.SIGHUP)
 
@@ -358,7 +361,7 @@ func (c *C) parse() error {
 	var m map[interface{}]interface{}
 
 	for _, path := range c.files {
-		b, err := ioutil.ReadFile(path)
+		b, err := os.ReadFile(path)
 		if err != nil {
 			return err
 		}

+ 7 - 8
config/config_test.go

@@ -1,7 +1,6 @@
 package config
 
 import (
-	"io/ioutil"
 	"os"
 	"path/filepath"
 	"testing"
@@ -16,10 +15,10 @@ import (
 
 func TestConfig_Load(t *testing.T) {
 	l := test.NewLogger()
-	dir, err := ioutil.TempDir("", "config-test")
+	dir, err := os.MkdirTemp("", "config-test")
 	// invalid yaml
 	c := NewC(l)
-	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
+	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
 	assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n  line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
 
 	// simple multi config merge
@@ -29,8 +28,8 @@ func TestConfig_Load(t *testing.T) {
 
 	assert.Nil(t, err)
 
-	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
-	ioutil.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n  inner: override\nnew: hi"), 0644)
+	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
+	os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n  inner: override\nnew: hi"), 0644)
 	assert.Nil(t, c.Load(dir))
 	expected := map[interface{}]interface{}{
 		"outer": map[interface{}]interface{}{
@@ -120,9 +119,9 @@ func TestConfig_HasChanged(t *testing.T) {
 func TestConfig_ReloadConfig(t *testing.T) {
 	l := test.NewLogger()
 	done := make(chan bool, 1)
-	dir, err := ioutil.TempDir("", "config-test")
+	dir, err := os.MkdirTemp("", "config-test")
 	assert.Nil(t, err)
-	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
+	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
 
 	c := NewC(l)
 	assert.Nil(t, c.Load(dir))
@@ -131,7 +130,7 @@ func TestConfig_ReloadConfig(t *testing.T) {
 	assert.False(t, c.HasChanged("outer"))
 	assert.False(t, c.HasChanged(""))
 
-	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: ho"), 0644)
+	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: ho"), 0644)
 
 	c.RegisterReloadCallback(func(c *C) {
 		done <- true

+ 6 - 19
connection_manager.go

@@ -231,7 +231,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			index = existing.LocalIndex
 			switch r.Type {
 			case TerminalType:
-				relayFrom = newhostinfo.vpnIp
+				relayFrom = n.intf.myVpnIp
 				relayTo = existing.PeerIp
 			case ForwardingType:
 				relayFrom = existing.PeerIp
@@ -256,7 +256,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			}
 			switch r.Type {
 			case TerminalType:
-				relayFrom = newhostinfo.vpnIp
+				relayFrom = n.intf.myVpnIp
 				relayTo = r.PeerIp
 			case ForwardingType:
 				relayFrom = r.PeerIp
@@ -406,7 +406,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	}
 
 	certState := n.intf.pki.GetCertState()
-	return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature)
+	return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
 }
 
 func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
@@ -432,7 +432,7 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
 		return false
 	}
 
-	if !n.intf.disconnectInvalid && err != cert.ErrBlockListed {
+	if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
 		// Block listed certificates should always be disconnected
 		return false
 	}
@@ -465,7 +465,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 
 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
 	certState := n.intf.pki.GetCertState()
-	if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) {
+	if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
 		return
 	}
 
@@ -473,18 +473,5 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
 		WithField("reason", "local certificate is not current").
 		Info("Re-handshaking with remote")
 
-	//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
-	newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)
-	if !newHostinfo.HandshakeReady {
-		ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
-	}
-
-	//If this is a static host, we don't need to wait for the HostQueryReply
-	//We can trigger the handshake right now
-	if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
-		select {
-		case n.intf.handshakeManager.trigger <- hostinfo.vpnIp:
-		default:
-		}
-	}
+	n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
 }

+ 21 - 20
connection_manager_test.go

@@ -58,7 +58,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		firewall:         &Firewall{},
 		lightHouse:       lh,
 		pki:              &PKI{},
-		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
+		handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
 		l:                l,
 	}
 	ifce.pki.cs.Store(cs)
@@ -79,8 +79,8 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		remoteIndexId: 9901,
 	}
 	hostinfo.ConnectionState = &ConnectionState{
-		certState: cs,
-		H:         &noise.HandshakeState{},
+		myCert: &cert.NebulaCertificate{},
+		H:      &noise.HandshakeState{},
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
@@ -138,7 +138,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		firewall:         &Firewall{},
 		lightHouse:       lh,
 		pki:              &PKI{},
-		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
+		handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
 		l:                l,
 	}
 	ifce.pki.cs.Store(cs)
@@ -159,8 +159,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		remoteIndexId: 9901,
 	}
 	hostinfo.ConnectionState = &ConnectionState{
-		certState: cs,
-		H:         &noise.HandshakeState{},
+		myCert: &cert.NebulaCertificate{},
+		H:      &noise.HandshakeState{},
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
@@ -222,7 +222,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 			PublicKey: pubCA,
 		},
 	}
-	caCert.Sign(cert.Curve_CURVE25519, privCA)
+
+	assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
 	ncp := &cert.NebulaCAPool{
 		CAs: cert.NewCAPool().CAs,
 	}
@@ -241,7 +242,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 			Issuer:    "ca",
 		},
 	}
-	peerCert.Sign(cert.Curve_CURVE25519, privCA)
+	assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
 
 	cs := &CertState{
 		RawCertificate:      []byte{},
@@ -252,18 +253,18 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 
 	lh := newTestLighthouse()
 	ifce := &Interface{
-		hostMap:           hostMap,
-		inside:            &test.NoopTun{},
-		outside:           &udp.NoopConn{},
-		firewall:          &Firewall{},
-		lightHouse:        lh,
-		handshakeManager:  NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
-		l:                 l,
-		disconnectInvalid: true,
-		pki:               &PKI{},
+		hostMap:          hostMap,
+		inside:           &test.NoopTun{},
+		outside:          &udp.NoopConn{},
+		firewall:         &Firewall{},
+		lightHouse:       lh,
+		handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
+		l:                l,
+		pki:              &PKI{},
 	}
 	ifce.pki.cs.Store(cs)
 	ifce.pki.caPool.Store(ncp)
+	ifce.disconnectInvalid.Store(true)
 
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())
@@ -275,9 +276,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	hostinfo := &HostInfo{
 		vpnIp: vpnIp,
 		ConnectionState: &ConnectionState{
-			certState: cs,
-			peerCert:  &peerCert,
-			H:         &noise.HandshakeState{},
+			myCert:   &cert.NebulaCertificate{},
+			peerCert: &peerCert,
+			H:        &noise.HandshakeState{},
 		},
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)

+ 11 - 14
connection_state.go

@@ -18,35 +18,34 @@ type ConnectionState struct {
 	eKey           *NebulaCipherState
 	dKey           *NebulaCipherState
 	H              *noise.HandshakeState
-	certState      *CertState
+	myCert         *cert.NebulaCertificate
 	peerCert       *cert.NebulaCertificate
 	initiator      bool
 	messageCounter atomic.Uint64
 	window         *Bits
-	queueLock      sync.Mutex
 	writeLock      sync.Mutex
-	ready          bool
 }
 
-func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
+func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
 	var dhFunc noise.DHFunc
-	curCertState := f.pki.GetCertState()
-
-	switch curCertState.Certificate.Details.Curve {
+	switch certState.Certificate.Details.Curve {
 	case cert.Curve_CURVE25519:
 		dhFunc = noise.DH25519
 	case cert.Curve_P256:
 		dhFunc = noiseutil.DHP256
 	default:
-		l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve)
+		l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
 		return nil
 	}
-	cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
-	if f.cipher == "chachapoly" {
+
+	var cs noise.CipherSuite
+	if cipher == "chachapoly" {
 		cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
+	} else {
+		cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
 	}
 
-	static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey}
+	static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
 
 	b := NewBits(ReplayWindow)
 	// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
@@ -71,8 +70,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
 		H:         hs,
 		initiator: initiator,
 		window:    b,
-		ready:     false,
-		certState: curCertState,
+		myCert:    certState.Certificate,
 	}
 
 	return ci
@@ -83,6 +81,5 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
 		"certificate":     cs.peerCert,
 		"initiator":       cs.initiator,
 		"message_counter": cs.messageCounter.Load(),
-		"ready":           cs.ready,
 	})
 }

+ 10 - 2
control.go

@@ -11,6 +11,7 @@ import (
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/udp"
 )
 
@@ -29,6 +30,7 @@ type controlHostLister interface {
 type Control struct {
 	f               *Interface
 	l               *logrus.Logger
+	ctx             context.Context
 	cancel          context.CancelFunc
 	sshStart        func()
 	statsStart      func()
@@ -41,7 +43,6 @@ type ControlHostInfo struct {
 	LocalIndex             uint32                  `json:"localIndex"`
 	RemoteIndex            uint32                  `json:"remoteIndex"`
 	RemoteAddrs            []*udp.Addr             `json:"remoteAddrs"`
-	CachedPackets          int                     `json:"cachedPackets"`
 	Cert                   *cert.NebulaCertificate `json:"cert"`
 	MessageCounter         uint64                  `json:"messageCounter"`
 	CurrentRemote          *udp.Addr               `json:"currentRemote"`
@@ -72,6 +73,10 @@ func (c *Control) Start() {
 	c.f.run()
 }
 
+func (c *Control) Context() context.Context {
+	return c.ctx
+}
+
 // Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
 func (c *Control) Stop() {
 	// Stop the handshakeManager (and other services), to prevent new tunnels from
@@ -227,6 +232,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	return
 }
 
+func (c *Control) Device() overlay.Device {
+	return c.f.inside
+}
+
 func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 
 	chi := ControlHostInfo{
@@ -234,7 +243,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 		LocalIndex:             h.localIndexId,
 		RemoteIndex:            h.remoteIndexId,
 		RemoteAddrs:            h.remotes.CopyAddrs(preferredRanges),
-		CachedPackets:          len(h.packetStore),
 		CurrentRelaysToMe:      h.relayState.CopyRelayIps(),
 		CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(),
 	}

+ 1 - 2
control_test.go

@@ -96,7 +96,6 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		LocalIndex:             201,
 		RemoteIndex:            200,
 		RemoteAddrs:            []*udp.Addr{remote2, remote1},
-		CachedPackets:          0,
 		Cert:                   crt.Copy(),
 		MessageCounter:         0,
 		CurrentRemote:          udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
@@ -105,7 +104,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	}
 
 	// Make sure we don't have any unexpected fields
-	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
+	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
 	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 	// Make sure we don't panic if the host info doesn't have a cert yet

+ 1 - 11
control_tester.go

@@ -165,15 +165,5 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
 }
 
 func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
-	hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo)
-	ixHandshakeStage0(c.f, vpnIp, hostinfo)
-
-	// If this is a static host, we don't need to wait for the HostQueryReply
-	// We can trigger the handshake right now
-	if _, ok := c.f.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
-		select {
-		case c.f.handshakeManager.trigger <- hostinfo.vpnIp:
-		default:
-		}
-	}
+	c.f.handshakeManager.StartHandshake(vpnIp, nil)
 }

+ 120 - 16
e2e/handshakes_test.go

@@ -20,7 +20,7 @@ import (
 )
 
 func BenchmarkHotPath(b *testing.B) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -44,7 +44,7 @@ func BenchmarkHotPath(b *testing.B) {
 }
 
 func TestGoodHandshake(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -95,7 +95,7 @@ func TestGoodHandshake(t *testing.T) {
 }
 
 func TestWrongResponderHandshake(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 
 	// The IPs here are chosen on purpose:
 	// The current remote handling will sort by preference, public, and then lexically.
@@ -164,7 +164,7 @@ func TestStage1Race(t *testing.T) {
 	// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
 	// But will eventually collapse down to a single tunnel
 
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -241,7 +241,7 @@ func TestStage1Race(t *testing.T) {
 }
 
 func TestUncleanShutdownRaceLoser(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -290,7 +290,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
 }
 
 func TestUncleanShutdownRaceWinner(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -341,7 +341,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
 }
 
 func TestRelays(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -372,7 +372,7 @@ func TestRelays(t *testing.T) {
 
 func TestStage1RaceRelays(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -421,7 +421,7 @@ func TestStage1RaceRelays(t *testing.T) {
 
 func TestStage1RaceRelays2(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -508,7 +508,7 @@ func TestStage1RaceRelays2(t *testing.T) {
 	////TODO: assert hostmaps
 }
 func TestRehandshakingRelays(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -538,7 +538,111 @@ func TestRehandshakingRelays(t *testing.T) {
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
 	// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
 	r.Log("Renew relay certificate and spin until me and them sees it")
-	_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
+	_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
+
+	caB, err := ca.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	relayConfig.Settings["pki"] = m{
+		"ca":   string(caB),
+		"cert": string(myNextPEM),
+		"key":  string(myNextPrivKey),
+	}
+	rc, err := yaml.Marshal(relayConfig.Settings)
+	assert.NoError(t, err)
+	relayConfig.ReloadConfigString(string(rc))
+
+	for {
+		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
+		assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+		c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		if len(c.Cert.Details.Groups) != 0 {
+			// We have a new certificate now
+			r.Log("Certificate between my and relay is updated!")
+			break
+		}
+
+		time.Sleep(time.Second)
+	}
+
+	for {
+		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
+		assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+		c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		if len(c.Cert.Details.Groups) != 0 {
+			// We have a new certificate now
+			r.Log("Certificate between their and relay is updated!")
+			break
+		}
+
+		time.Sleep(time.Second)
+	}
+
+	r.Log("Assert the relay tunnel still works")
+	assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
+	// We should have two hostinfos on all sides
+	for len(myControl.GetHostmap().Indexes) != 2 {
+		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
+		r.Log("Assert the relay tunnel still works")
+		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		r.Log("yupitdoes")
+		time.Sleep(time.Second)
+	}
+	t.Logf("myControl hostinfos got cleaned up!")
+	for len(theirControl.GetHostmap().Indexes) != 2 {
+		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
+		r.Log("Assert the relay tunnel still works")
+		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		r.Log("yupitdoes")
+		time.Sleep(time.Second)
+	}
+	t.Logf("theirControl hostinfos got cleaned up!")
+	for len(relayControl.GetHostmap().Indexes) != 2 {
+		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
+		r.Log("Assert the relay tunnel still works")
+		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		r.Log("yupitdoes")
+		time.Sleep(time.Second)
+	}
+	t.Logf("relayControl hostinfos got cleaned up!")
+}
+
+func TestRehandshakingRelaysPrimary(t *testing.T) {
+	// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, relayControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	relayControl.Start()
+	theirControl.Start()
+
+	t.Log("Trigger a handshake from me to them via the relay")
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+
+	p := r.RouteForAllUntilTxTun(theirControl)
+	r.Log("Assert the tunnel works")
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
+
+	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
+	// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
+	r.Log("Renew relay certificate and spin until me and them sees it")
+	_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
 
 	caB, err := ca.MarshalToPEM()
 	if err != nil {
@@ -611,7 +715,7 @@ func TestRehandshakingRelays(t *testing.T) {
 }
 
 func TestRehandshaking(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
 
@@ -633,7 +737,7 @@ func TestRehandshaking(t *testing.T) {
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 	r.Log("Renew my certificate and spin until their sees it")
-	_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"})
+	_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"})
 
 	caB, err := ca.MarshalToPEM()
 	if err != nil {
@@ -707,7 +811,7 @@ func TestRehandshaking(t *testing.T) {
 func TestRehandshakingLoser(t *testing.T) {
 	// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
 	// Should be the one with the new certificate
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
 
@@ -733,7 +837,7 @@ func TestRehandshakingLoser(t *testing.T) {
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 	r.Log("Renew their certificate and spin until mine sees it")
-	_, _, theirNextPrivKey, theirNextPEM := newTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"})
+	_, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"})
 
 	caB, err := ca.MarshalToPEM()
 	if err != nil {
@@ -808,7 +912,7 @@ func TestRaceRegression(t *testing.T) {
 	// This test forces stage 1, stage 2, stage 1 to be received by me from them
 	// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
 	// caused a cross-linked hostinfo
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 

+ 118 - 0
e2e/helpers.go

@@ -0,0 +1,118 @@
+package e2e
+
+import (
+	"crypto/rand"
+	"io"
+	"net"
+	"time"
+
+	"github.com/slackhq/nebula/cert"
+	"golang.org/x/crypto/curve25519"
+	"golang.org/x/crypto/ed25519"
+)
+
+// NewTestCaCert will generate a CA cert
+func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+	pub, priv, err := ed25519.GenerateKey(rand.Reader)
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	nc := &cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:           "test ca",
+			NotBefore:      time.Unix(before.Unix(), 0),
+			NotAfter:       time.Unix(after.Unix(), 0),
+			PublicKey:      pub,
+			IsCA:           true,
+			InvertedGroups: make(map[string]struct{}),
+		},
+	}
+
+	if len(ips) > 0 {
+		nc.Details.Ips = ips
+	}
+
+	if len(subnets) > 0 {
+		nc.Details.Subnets = subnets
+	}
+
+	if len(groups) > 0 {
+		nc.Details.Groups = groups
+	}
+
+	err = nc.Sign(cert.Curve_CURVE25519, priv)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := nc.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return nc, pub, priv, pem
+}
+
+// NewTestCert will generate a signed certificate with the provided details.
+// Expiry times are defaulted if you do not pass them in
+func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+	issuer, err := ca.Sha256Sum()
+	if err != nil {
+		panic(err)
+	}
+
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	pub, rawPriv := x25519Keypair()
+
+	nc := &cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:           name,
+			Ips:            []*net.IPNet{ip},
+			Subnets:        subnets,
+			Groups:         groups,
+			NotBefore:      time.Unix(before.Unix(), 0),
+			NotAfter:       time.Unix(after.Unix(), 0),
+			PublicKey:      pub,
+			IsCA:           false,
+			Issuer:         issuer,
+			InvertedGroups: make(map[string]struct{}),
+		},
+	}
+
+	err = nc.Sign(ca.Details.Curve, key)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := nc.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
+}
+
+func x25519Keypair() ([]byte, []byte) {
+	privkey := make([]byte, 32)
+	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
+		panic(err)
+	}
+
+	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
+	if err != nil {
+		panic(err)
+	}
+
+	return pubkey, privkey
+}

+ 1 - 110
e2e/helpers_test.go

@@ -4,7 +4,6 @@
 package e2e
 
 import (
-	"crypto/rand"
 	"fmt"
 	"io"
 	"net"
@@ -22,8 +21,6 @@ import (
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
-	"golang.org/x/crypto/curve25519"
-	"golang.org/x/crypto/ed25519"
 	"gopkg.in/yaml.v2"
 )
 
@@ -40,7 +37,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		IP:   udpIp,
 		Port: 4242,
 	}
-	_, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+	_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
 
 	caB, err := caCrt.MarshalToPEM()
 	if err != nil {
@@ -108,112 +105,6 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 	return control, vpnIpNet, &udpAddr, c
 }
 
-// newTestCaCert will generate a CA cert
-func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	nc := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "test ca",
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           true,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	if len(ips) > 0 {
-		nc.Details.Ips = ips
-	}
-
-	if len(subnets) > 0 {
-		nc.Details.Subnets = subnets
-	}
-
-	if len(groups) > 0 {
-		nc.Details.Groups = groups
-	}
-
-	err = nc.Sign(cert.Curve_CURVE25519, priv)
-	if err != nil {
-		panic(err)
-	}
-
-	pem, err := nc.MarshalToPEM()
-	if err != nil {
-		panic(err)
-	}
-
-	return nc, pub, priv, pem
-}
-
-// newTestCert will generate a signed certificate with the provided details.
-// Expiry times are defaulted if you do not pass them in
-func newTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
-	issuer, err := ca.Sha256Sum()
-	if err != nil {
-		panic(err)
-	}
-
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	pub, rawPriv := x25519Keypair()
-
-	nc := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           name,
-			Ips:            []*net.IPNet{ip},
-			Subnets:        subnets,
-			Groups:         groups,
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           false,
-			Issuer:         issuer,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	err = nc.Sign(ca.Details.Curve, key)
-	if err != nil {
-		panic(err)
-	}
-
-	pem, err := nc.MarshalToPEM()
-	if err != nil {
-		panic(err)
-	}
-
-	return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
-}
-
-func x25519Keypair() ([]byte, []byte) {
-	privkey := make([]byte, 32)
-	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
-		panic(err)
-	}
-
-	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
-	if err != nil {
-		panic(err)
-	}
-
-	return pubkey, privkey
-}
-
 type doneCb func()
 
 func deadline(t *testing.T, seconds time.Duration) doneCb {

+ 3 - 2
examples/config.yml

@@ -11,7 +11,7 @@ pki:
   #blocklist:
   #  - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72
   # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
-  #disconnect_invalid: false
+  #disconnect_invalid: true
 
 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
 # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
@@ -171,7 +171,8 @@ punchy:
 # and has been deprecated for "preferred_ranges"
 #preferred_ranges: ["172.16.0.0/24"]
 
-# sshd can expose informational and administrative functions via ssh this is a
+# sshd can expose informational and administrative functions via ssh. This can expose informational and administrative
+# functions, and allows manual tweaking of various network settings when debugging or testing.
 #sshd:
   # Toggles the feature
   #enabled: true

+ 100 - 0
examples/go_service/main.go

@@ -0,0 +1,100 @@
+package main
+
+import (
+	"bufio"
+	"fmt"
+	"log"
+
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/service"
+)
+
+func main() {
+	if err := run(); err != nil {
+		log.Fatalf("%+v", err)
+	}
+}
+
+func run() error {
+	configStr := `
+tun:
+  user: true
+
+static_host_map:
+  '192.168.100.1': ['localhost:4242']
+
+listen:
+  host: 0.0.0.0
+  port: 4241
+
+lighthouse:
+  am_lighthouse: false
+  interval: 60
+  hosts:
+    - '192.168.100.1'
+
+firewall:
+  outbound:
+    # Allow all outbound traffic from this node
+    - port: any
+      proto: any
+      host: any
+
+  inbound:
+    # Allow icmp between any nebula hosts
+    - port: any
+      proto: icmp
+      host: any
+    - port: any
+      proto: any
+      host: any
+
+pki:
+  ca: /home/rice/Developer/nebula-config/ca.crt
+  cert: /home/rice/Developer/nebula-config/app.crt
+  key: /home/rice/Developer/nebula-config/app.key
+`
+	var config config.C
+	if err := config.LoadString(configStr); err != nil {
+		return err
+	}
+	service, err := service.New(&config)
+	if err != nil {
+		return err
+	}
+
+	ln, err := service.Listen("tcp", ":1234")
+	if err != nil {
+		return err
+	}
+	for {
+		conn, err := ln.Accept()
+		if err != nil {
+			log.Printf("accept error: %s", err)
+			break
+		}
+		defer conn.Close()
+
+		log.Printf("got connection")
+
+		conn.Write([]byte("hello world\n"))
+
+		scanner := bufio.NewScanner(conn)
+		for scanner.Scan() {
+			message := scanner.Text()
+			fmt.Fprintf(conn, "echo: %q\n", message)
+			log.Printf("got message %q", message)
+		}
+
+		if err := scanner.Err(); err != nil {
+			log.Printf("scanner error: %s", err)
+			break
+		}
+	}
+
+	service.Close()
+	if err := service.Wait(); err != nil {
+		return err
+	}
+	return nil
+}

+ 36 - 14
firewall.go

@@ -6,6 +6,7 @@ import (
 	"encoding/hex"
 	"errors"
 	"fmt"
+	"hash/fnv"
 	"net"
 	"reflect"
 	"strconv"
@@ -57,7 +58,7 @@ type Firewall struct {
 	DefaultTimeout time.Duration //linux: 600s
 
 	// Used to ensure we don't emit local packets for ips we don't own
-	localIps *cidr.Tree4
+	localIps *cidr.Tree4[struct{}]
 
 	rules        string
 	rulesVersion uint16
@@ -110,8 +111,8 @@ type FirewallRule struct {
 	Any       bool
 	Hosts     map[string]struct{}
 	Groups    [][]string
-	CIDR      *cidr.Tree4
-	LocalCIDR *cidr.Tree4
+	CIDR      *cidr.Tree4[struct{}]
+	LocalCIDR *cidr.Tree4[struct{}]
 }
 
 // Even though ports are uint16, int32 maps are faster for lookup
@@ -137,7 +138,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 		max = defaultTimeout
 	}
 
-	localIps := cidr.NewTree4()
+	localIps := cidr.NewTree4[struct{}]()
 	for _, ip := range c.Details.Ips {
 		localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
 	}
@@ -278,6 +279,18 @@ func (f *Firewall) GetRuleHash() string {
 	return hex.EncodeToString(sum[:])
 }
 
+// GetRuleHashFNV returns a uint32 FNV-1 hash representation the rules, for use as a metric value
+func (f *Firewall) GetRuleHashFNV() uint32 {
+	h := fnv.New32a()
+	h.Write([]byte(f.rules))
+	return h.Sum32()
+}
+
+// GetRuleHashes returns both the sha256 and FNV-1 hashes, suitable for logging
+func (f *Firewall) GetRuleHashes() string {
+	return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
+}
+
 func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
 	var table string
 	if inbound {
@@ -391,7 +404,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
 
 	// Make sure remote address matches nebula certificate
 	if remoteCidr := h.remoteCidr; remoteCidr != nil {
-		if remoteCidr.Contains(fp.RemoteIP) == nil {
+		ok, _ := remoteCidr.Contains(fp.RemoteIP)
+		if !ok {
 			f.metrics(incoming).droppedRemoteIP.Inc(1)
 			return ErrInvalidRemoteIP
 		}
@@ -404,7 +418,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
 	}
 
 	// Make sure we are supposed to be handling this local ip address
-	if f.localIps.Contains(fp.LocalIP) == nil {
+	ok, _ := f.localIps.Contains(fp.LocalIP)
+	if !ok {
 		f.metrics(incoming).droppedLocalIP.Inc(1)
 		return ErrInvalidLocalIP
 	}
@@ -447,6 +462,7 @@ func (f *Firewall) EmitStats() {
 	conntrack.Unlock()
 	metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
 	metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
+	metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
 }
 
 func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
@@ -657,8 +673,8 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
 		return &FirewallRule{
 			Hosts:     make(map[string]struct{}),
 			Groups:    make([][]string, 0),
-			CIDR:      cidr.NewTree4(),
-			LocalCIDR: cidr.NewTree4(),
+			CIDR:      cidr.NewTree4[struct{}](),
+			LocalCIDR: cidr.NewTree4[struct{}](),
 		}
 	}
 
@@ -726,8 +742,8 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, loc
 		// If it's any we need to wipe out any pre-existing rules to save on memory
 		fr.Groups = make([][]string, 0)
 		fr.Hosts = make(map[string]struct{})
-		fr.CIDR = cidr.NewTree4()
-		fr.LocalCIDR = cidr.NewTree4()
+		fr.CIDR = cidr.NewTree4[struct{}]()
+		fr.LocalCIDR = cidr.NewTree4[struct{}]()
 	} else {
 		if len(groups) > 0 {
 			fr.Groups = append(fr.Groups, groups)
@@ -809,12 +825,18 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 		}
 	}
 
-	if fr.CIDR != nil && fr.CIDR.Contains(p.RemoteIP) != nil {
-		return true
+	if fr.CIDR != nil {
+		ok, _ := fr.CIDR.Contains(p.RemoteIP)
+		if ok {
+			return true
+		}
 	}
 
-	if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil {
-		return true
+	if fr.LocalCIDR != nil {
+		ok, _ := fr.LocalCIDR.Contains(p.LocalIP)
+		if ok {
+			return true
+		}
 	}
 
 	// No host, group, or cidr matched, bye bye

+ 8 - 4
firewall_test.go

@@ -92,14 +92,16 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
-	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
+	ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
-	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
+	ok, _ = fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
@@ -114,8 +116,10 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
 	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
 	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
-	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
-	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
+	ok, _ = fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.True(t, ok)
+	ok, _ = fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.True(t, ok)
 
 	// run twice just to make sure
 	//TODO: these ANY rules should clear the CA firewall portion

+ 16 - 13
go.mod

@@ -11,26 +11,28 @@ require (
 	github.com/gogo/protobuf v1.3.2
 	github.com/google/gopacket v1.1.19
 	github.com/kardianos/service v1.2.2
-	github.com/miekg/dns v1.1.55
+	github.com/miekg/dns v1.1.56
 	github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
-	github.com/prometheus/client_golang v1.16.0
+	github.com/prometheus/client_golang v1.17.0
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
 	github.com/sirupsen/logrus v1.9.3
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
 	github.com/stretchr/testify v1.8.4
 	github.com/timandy/routine v1.1.1
-	github.com/vishvananda/netlink v1.1.0
-	golang.org/x/crypto v0.12.0
+	github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54
+	golang.org/x/crypto v0.16.0
 	golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53
-	golang.org/x/net v0.14.0
-	golang.org/x/sys v0.11.0
-	golang.org/x/term v0.11.0
+	golang.org/x/net v0.19.0
+	golang.org/x/sync v0.5.0
+	golang.org/x/sys v0.15.0
+	golang.org/x/term v0.15.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
 	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
 	golang.zx2c4.com/wireguard/windows v0.5.3
 	google.golang.org/protobuf v1.31.0
 	gopkg.in/yaml.v2 v2.4.0
+	gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f
 )
 
 require (
@@ -38,14 +40,15 @@ require (
 	github.com/cespare/xxhash/v2 v2.2.0 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/golang/protobuf v1.5.3 // indirect
+	github.com/google/btree v1.0.1 // indirect
 	github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
-	github.com/prometheus/client_model v0.4.0 // indirect
-	github.com/prometheus/common v0.42.0 // indirect
-	github.com/prometheus/procfs v0.10.1 // indirect
-	github.com/rogpeppe/go-internal v1.10.0 // indirect
+	github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect
+	github.com/prometheus/common v0.44.0 // indirect
+	github.com/prometheus/procfs v0.11.1 // indirect
 	github.com/vishvananda/netns v0.0.4 // indirect
-	golang.org/x/mod v0.10.0 // indirect
-	golang.org/x/tools v0.8.0 // indirect
+	golang.org/x/mod v0.12.0 // indirect
+	golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
+	golang.org/x/tools v0.13.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )

+ 35 - 28
go.sum

@@ -47,6 +47,8 @@ github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw
 github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
 github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
 github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
+github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
+github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
 github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
 github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
 github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
@@ -78,8 +80,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
 github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
 github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
-github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
-github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
+github.com/miekg/dns v1.1.56 h1:5imZaSeoRNvpM9SzWNhEcP9QliKiz20/dA2QabIGVnE=
+github.com/miekg/dns v1.1.56/go.mod h1:cRm6Oo2C8TY9ZS/TqsSrseAcncm74lfK5G+ikN2SWWY=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
@@ -97,28 +99,27 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
 github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
 github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
 github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
-github.com/prometheus/client_golang v1.16.0 h1:yk/hx9hDbrGHovbci4BY+pRMfSuuat626eFsHb7tmT8=
-github.com/prometheus/client_golang v1.16.0/go.mod h1:Zsulrv/L9oM40tJ7T815tM89lFEugiJ9HzIqaAx4LKc=
+github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q=
+github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY=
 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
 github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
-github.com/prometheus/client_model v0.4.0 h1:5lQXD3cAg1OXBf4Wq03gTrXHeaV0TQvGfUooCfx1yqY=
-github.com/prometheus/client_model v0.4.0/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU=
+github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 h1:v7DLqVdK4VrYkVD5diGdl4sxJurKJEMnODWRJlxV9oM=
+github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU=
 github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
 github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
 github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
-github.com/prometheus/common v0.42.0 h1:EKsfXEYo4JpWMHH5cg+KOUWeuJSov1Id8zGR8eeI1YM=
-github.com/prometheus/common v0.42.0/go.mod h1:xBwqVerjNdUDjgODMpudtOMwlOwf2SaTr1yjz4b7Zbc=
+github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY=
+github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY=
 github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
 github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
 github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
 github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
-github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg=
-github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM=
+github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI=
+github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
 github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
-github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
 github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
 github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
 github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
@@ -138,9 +139,9 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU
 github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
 github.com/timandy/routine v1.1.1 h1:6/Z7qLFZj3GrzuRksBFzIG8YGUh8CLhjnnMePBQTrEI=
 github.com/timandy/routine v1.1.1/go.mod h1:OZHPOKSvqL/ZvqXFkNZyit0xIVelERptYXdAHH00adQ=
-github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0=
-github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
-github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
+github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 h1:8mhqcHPqTMhSPoslhGYihEgSfc77+7La1P6kiB6+9So=
+github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
+github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
 github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
 github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
 github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -150,16 +151,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
-golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
-golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
+golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY=
+golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o=
 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
 golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
 golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
 golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
 golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
-golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
+golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
+golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
 golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -170,8 +171,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
 golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
 golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
-golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
+golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
+golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
 golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -179,38 +180,42 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
+golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
+golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
-golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
+golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.11.0 h1:F9tnn/DA/Im8nCwm+fX+1/eBwi4qFjRT++MhtVC4ZX0=
-golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
+golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4=
+golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
 golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
 golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
 golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y=
-golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4=
+golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
+golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -247,3 +252,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f h1:8GE2MRjGiFmfpon8dekPI08jEuNMQzSffVHgdupcO4E=
+gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f/go.mod h1:pzr6sy8gDLfVmDAg8OYrlKvGEHw5C3PGTiBXBTCx76Q=

+ 0 - 31
handshake.go

@@ -1,31 +0,0 @@
-package nebula
-
-import (
-	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/udp"
-)
-
-func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) {
-	// First remote allow list check before we know the vpnIp
-	if addr != nil {
-		if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
-			f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
-			return
-		}
-	}
-
-	switch h.Subtype {
-	case header.HandshakeIXPSK0:
-		switch h.MessageCounter {
-		case 1:
-			ixHandshakeStage1(f, addr, via, packet, h)
-		case 2:
-			newHostinfo := f.handshakeManager.QueryIndex(h.RemoteIndex)
-			tearDown := ixHandshakeStage2(f, addr, via, newHostinfo, packet, h)
-			if tearDown && newHostinfo != nil {
-				f.handshakeManager.DeleteHostInfo(newHostinfo)
-			}
-		}
-	}
-
-}

+ 63 - 80
handshake_ix.go

@@ -4,6 +4,7 @@ import (
 	"time"
 
 	"github.com/flynn/noise"
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
@@ -13,27 +14,22 @@ import (
 
 // This function constructs a handshake packet, but does not actually send it
 // Sending is done by the handshake manager
-func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
-	// This queries the lighthouse if we don't know a remote for the host
-	// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
-	// more quickly, effect is a quicker handshake.
-	if hostinfo.remote == nil {
-		f.lightHouse.QueryServer(vpnIp, f)
-	}
-
-	err := f.handshakeManager.AddIndexHostInfo(hostinfo)
+func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
+	err := f.handshakeManager.allocateIndex(hh)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", vpnIp).
+		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
-		return
+		return false
 	}
 
-	ci := hostinfo.ConnectionState
+	certState := f.pki.GetCertState()
+	ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
+	hh.hostinfo.ConnectionState = ci
 
 	hsProto := &NebulaHandshakeDetails{
-		InitiatorIndex: hostinfo.localIndexId,
+		InitiatorIndex: hh.hostinfo.localIndexId,
 		Time:           uint64(time.Now().UnixNano()),
-		Cert:           ci.certState.RawCertificateNoKey,
+		Cert:           certState.RawCertificateNoKey,
 	}
 
 	hsBytes := []byte{}
@@ -44,9 +40,9 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
 	hsBytes, err = hs.Marshal()
 
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", vpnIp).
+		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
-		return
+		return false
 	}
 
 	h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
@@ -54,22 +50,23 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
 
 	msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", vpnIp).
+		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
-		return
+		return false
 	}
 
 	// We are sending handshake packet 1, so we don't expect to receive
 	// handshake packet 1 from the responder
 	ci.window.Update(f.l, 1)
 
-	hostinfo.HandshakePacket[0] = msg
-	hostinfo.HandshakeReady = true
-	hostinfo.handshakeStart = time.Now()
+	hh.hostinfo.HandshakePacket[0] = msg
+	hh.ready = true
+	return true
 }
 
 func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
-	ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
+	certState := f.pki.GetCertState()
+	ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
 	// Mark packet 1 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 1)
 
@@ -144,9 +141,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		},
 	}
 
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
-
 	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
@@ -156,7 +150,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		Info("Handshake message received")
 
 	hs.Details.ResponderIndex = myIndex
-	hs.Details.Cert = ci.certState.RawCertificateNoKey
+	hs.Details.Cert = certState.RawCertificateNoKey
 	// Update the time in case their clock is way off from ours
 	hs.Details.Time = uint64(time.Now().UnixNano())
 
@@ -212,19 +206,12 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 	if err != nil {
 		switch err {
 		case ErrAlreadySeen:
-			// Update remote if preferred (Note we have to switch to locking
-			// the existing hostinfo, and then switch back so the defer Unlock
-			// higher in this function still works)
-			hostinfo.Unlock()
-			existing.Lock()
 			// Update remote if preferred
 			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
 				// Send a test packet to ensure the other side has also switched to
 				// the preferred remote
 				f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			}
-			existing.Unlock()
-			hostinfo.Lock()
 
 			msg = existing.HandshakePacket[2]
 			f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
@@ -311,7 +298,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
-				WithField("sentCachedPackets", len(hostinfo.packetStore)).
 				Info("Handshake message sent")
 		}
 	} else {
@@ -327,25 +313,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 			WithField("issuer", issuer).
 			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
-			WithField("sentCachedPackets", len(hostinfo.packetStore)).
 			Info("Handshake message sent")
 	}
 
 	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
-	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
+	hostinfo.ConnectionState.messageCounter.Store(2)
+	hostinfo.remotes.ResetBlockedRemotes()
 
 	return
 }
 
-func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool {
-	if hostinfo == nil {
+func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
+	if hh == nil {
 		// Nothing here to tear down, got a bogus stage 2 packet
 		return true
 	}
 
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
+	hh.Lock()
+	defer hh.Unlock()
 
+	hostinfo := hh.hostinfo
 	if addr != nil {
 		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
 			f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
@@ -354,22 +341,6 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 	}
 
 	ci := hostinfo.ConnectionState
-	if ci.ready {
-		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
-			Info("Handshake is already complete")
-
-		// Update remote if preferred
-		if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) {
-			// Send a test packet to ensure the other side has also switched to
-			// the preferred remote
-			f.SendMessageToVpnIp(header.Test, header.TestRequest, hostinfo.vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
-		}
-
-		// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
-		return false
-	}
-
 	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
 		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
@@ -426,31 +397,27 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 		f.handshakeManager.DeleteHostInfo(hostinfo)
 
 		// Create a new hostinfo/handshake for the intended vpn ip
-		//TODO: this adds it to the timer wheel in a way that aggressively retries
-		newHostInfo := f.getOrHandshake(hostinfo.vpnIp)
-		newHostInfo.Lock()
-
-		// Block the current used address
-		newHostInfo.remotes = hostinfo.remotes
-		newHostInfo.remotes.BlockRemote(addr)
+		f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) {
+			//TODO: this doesnt know if its being added or is being used for caching a packet
+			// Block the current used address
+			newHH.hostinfo.remotes = hostinfo.remotes
+			newHH.hostinfo.remotes.BlockRemote(addr)
 
-		// Get the correct remote list for the host we did handshake with
-		hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
+			// Get the correct remote list for the host we did handshake with
+			hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
 
-		f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
-			WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
-			Info("Blocked addresses for handshakes")
+			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
+				WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
+				Info("Blocked addresses for handshakes")
 
-		// Swap the packet store to benefit the original intended recipient
-		hostinfo.ConnectionState.queueLock.Lock()
-		newHostInfo.packetStore = hostinfo.packetStore
-		hostinfo.packetStore = []*cachedPacket{}
-		hostinfo.ConnectionState.queueLock.Unlock()
+			// Swap the packet store to benefit the original intended recipient
+			newHH.packetStore = hh.packetStore
+			hh.packetStore = []*cachedPacket{}
 
-		// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
-		hostinfo.vpnIp = vpnIp
-		f.sendCloseTunnel(hostinfo)
-		newHostInfo.Unlock()
+			// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
+			hostinfo.vpnIp = vpnIp
+			f.sendCloseTunnel(hostinfo)
+		})
 
 		return true
 	}
@@ -458,7 +425,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 	// Mark packet 2 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 2)
 
-	duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
+	duration := time.Since(hh.startTime).Nanoseconds()
 	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
@@ -466,7 +433,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 		WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 		WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 		WithField("durationNs", duration).
-		WithField("sentCachedPackets", len(hostinfo.packetStore)).
+		WithField("sentCachedPackets", len(hh.packetStore)).
 		Info("Handshake message received")
 
 	hostinfo.remoteIndexId = hs.Details.ResponderIndex
@@ -490,7 +457,23 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
 	f.handshakeManager.Complete(hostinfo, f)
 	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
-	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
+
+	hostinfo.ConnectionState.messageCounter.Store(2)
+
+	if f.l.Level >= logrus.DebugLevel {
+		hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
+	}
+
+	if len(hh.packetStore) > 0 {
+		nb := make([]byte, 12, 12)
+		out := make([]byte, mtu)
+		for _, cp := range hh.packetStore {
+			cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
+		}
+		f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
+	}
+
+	hostinfo.remotes.ResetBlockedRemotes()
 	f.metricHandshakes.Update(duration)
 
 	return false

+ 241 - 125
handshake_manager.go

@@ -8,6 +8,7 @@ import (
 	"errors"
 	"net"
 	"time"
+	"sync"
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
@@ -45,8 +46,8 @@ type HandshakeManager struct {
 	// Mutex for interacting with the vpnIps and indexes maps
 	syncRWMutex
 
-	vpnIps  map[iputil.VpnIp]*HostInfo
-	indexes map[uint32]*HostInfo
+	vpnIps  map[iputil.VpnIp]*HandshakeHostInfo
+	indexes map[uint32]*HandshakeHostInfo
 
 	mainHostMap            *HostMap
 	lightHouse             *LightHouse
@@ -56,17 +57,55 @@ type HandshakeManager struct {
 	messageMetrics         *MessageMetrics
 	metricInitiated        metrics.Counter
 	metricTimedOut         metrics.Counter
+	f                      *Interface
 	l                      *logrus.Logger
 
 	// can be used to trigger outbound handshake for the given vpnIp
 	trigger chan iputil.VpnIp
 }
 
-func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
+type HandshakeHostInfo struct {
+	sync.Mutex
+
+	startTime   time.Time       // Time that we first started trying with this handshake
+	ready       bool            // Is the handshake ready
+	counter     int             // How many attempts have we made so far
+	lastRemotes []*udp.Addr     // Remotes that we sent to during the previous attempt
+	packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
+
+	hostinfo *HostInfo
+}
+
+func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
+	if len(hh.packetStore) < 100 {
+		tempPacket := make([]byte, len(packet))
+		copy(tempPacket, packet)
+
+		hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket})
+		if l.Level >= logrus.DebugLevel {
+			hh.hostinfo.logger(l).
+				WithField("length", len(hh.packetStore)).
+				WithField("stored", true).
+				Debugf("Packet store")
+		}
+
+	} else {
+		m.dropped.Inc(1)
+
+		if l.Level >= logrus.DebugLevel {
+			hh.hostinfo.logger(l).
+				WithField("length", len(hh.packetStore)).
+				WithField("stored", false).
+				Debugf("Packet store")
+		}
+	}
+}
+
+func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
 	return &HandshakeManager{
 		syncRWMutex:            newSyncRWMutex(mutexKey{Type: mutexKeyTypeHandshakeManager}),
-		vpnIps:                 map[iputil.VpnIp]*HostInfo{},
-		indexes:                map[uint32]*HostInfo{},
+		vpnIps:                 map[iputil.VpnIp]*HandshakeHostInfo{},
+		indexes:                map[uint32]*HandshakeHostInfo{},
 		mainHostMap:            mainHostMap,
 		lightHouse:             lightHouse,
 		outside:                outside,
@@ -80,7 +119,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
 	}
 }
 
-func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
+func (c *HandshakeManager) Run(ctx context.Context) {
 	clockSource := time.NewTicker(c.config.tryInterval)
 	defer clockSource.Stop()
 
@@ -89,70 +128,92 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
 		case <-ctx.Done():
 			return
 		case vpnIP := <-c.trigger:
-			c.handleOutbound(vpnIP, f, true)
+			c.handleOutbound(vpnIP, true)
 		case now := <-clockSource.C:
-			c.NextOutboundHandshakeTimerTick(now, f)
+			c.NextOutboundHandshakeTimerTick(now)
+		}
+	}
+}
+
+func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
+	// First remote allow list check before we know the vpnIp
+	if addr != nil {
+		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
+			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+			return
+		}
+	}
+
+	switch h.Subtype {
+	case header.HandshakeIXPSK0:
+		switch h.MessageCounter {
+		case 1:
+			ixHandshakeStage1(hm.f, addr, via, packet, h)
+
+		case 2:
+			newHostinfo := hm.queryIndex(h.RemoteIndex)
+			tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h)
+			if tearDown && newHostinfo != nil {
+				hm.DeleteHostInfo(newHostinfo.hostinfo)
+			}
 		}
 	}
 }
 
-func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
+func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
 	c.OutboundHandshakeTimer.Advance(now)
 	for {
 		vpnIp, has := c.OutboundHandshakeTimer.Purge()
 		if !has {
 			break
 		}
-		c.handleOutbound(vpnIp, f, false)
+		c.handleOutbound(vpnIp, false)
 	}
 }
 
-func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) {
-	hostinfo := c.QueryVpnIp(vpnIp)
-	if hostinfo == nil {
-		return
-	}
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
-
-	// We may have raced to completion but now that we have a lock we should ensure we have not yet completed.
-	if hostinfo.HandshakeComplete {
-		// Ensure we don't exist in the pending hostmap anymore since we have completed
-		c.DeleteHostInfo(hostinfo)
-		return
-	}
-
-	// Check if we have a handshake packet to transmit yet
-	if !hostinfo.HandshakeReady {
-		// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
-		// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
-		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
+	hh := hm.queryVpnIp(vpnIp)
+	if hh == nil {
 		return
 	}
+	hh.Lock()
+	defer hh.Unlock()
 
+	hostinfo := hh.hostinfo
 	// If we are out of time, clean up
-	if hostinfo.HandshakeCounter >= c.config.retries {
-		hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)).
-			WithField("initiatorIndex", hostinfo.localIndexId).
-			WithField("remoteIndex", hostinfo.remoteIndexId).
+	if hh.counter >= hm.config.retries {
+		hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)).
+			WithField("initiatorIndex", hh.hostinfo.localIndexId).
+			WithField("remoteIndex", hh.hostinfo.remoteIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-			WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()).
+			WithField("durationNs", time.Since(hh.startTime).Nanoseconds()).
 			Info("Handshake timed out")
-		c.metricTimedOut.Inc(1)
-		c.DeleteHostInfo(hostinfo)
+		hm.metricTimedOut.Inc(1)
+		hm.DeleteHostInfo(hostinfo)
 		return
 	}
 
+	// Increment the counter to increase our delay, linear backoff
+	hh.counter++
+
+	// Check if we have a handshake packet to transmit yet
+	if !hh.ready {
+		if !ixHandshakeStage0(hm.f, hh) {
+			hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter))
+			return
+		}
+	}
+
 	// Get a remotes object if we don't already have one.
 	// This is mainly to protect us as this should never be the case
 	// NB ^ This comment doesn't jive. It's how the thing gets initialized.
 	// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
 	if hostinfo.remotes == nil {
-		hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
+		hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
 	}
 
-	remotes := hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)
-	remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes)
+	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)
+	remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
 
 	// We only care about a lighthouse trigger if we have new remotes to send to.
 	// This is a very specific optimization for a fast lighthouse reply.
@@ -161,25 +222,25 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 		return
 	}
 
-	hostinfo.HandshakeLastRemotes = remotes
+	hh.lastRemotes = remotes
 
 	// TODO: this will generate a load of queries for hosts with only 1 ip
 	// (such as ones registered to the lighthouse with only a private IP)
 	// So we only do it one time after attempting 5 handshakes already.
-	if len(remotes) <= 1 && hostinfo.HandshakeCounter == 5 {
+	if len(remotes) <= 1 && hh.counter == 5 {
 		// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
 		// Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
 		// the learned public ip for them. Query again to short circuit the promotion counter
-		c.lightHouse.QueryServer(vpnIp, f)
+		hm.lightHouse.QueryServer(vpnIp, hm.f)
 	}
 
 	// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
 	var sentTo []*udp.Addr
-	hostinfo.remotes.ForEach(c.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
-		c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
-		err := c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
+	hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
+		hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
+		err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
 		if err != nil {
-			hostinfo.logger(c.l).WithField("udpAddr", addr).
+			hostinfo.logger(hm.l).WithField("udpAddr", addr).
 				WithField("initiatorIndex", hostinfo.localIndexId).
 				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 				WithError(err).Error("Failed to send handshake message")
@@ -192,63 +253,63 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 	// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
 	// so only log when the list of remotes has changed
 	if remotesHaveChanged {
-		hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
+		hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
 			WithField("initiatorIndex", hostinfo.localIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Info("Handshake message sent")
-	} else if c.l.IsLevelEnabled(logrus.DebugLevel) {
-		hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
+	} else if hm.l.IsLevelEnabled(logrus.DebugLevel) {
+		hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
 			WithField("initiatorIndex", hostinfo.localIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Debug("Handshake message sent")
 	}
 
-	if c.config.useRelays && len(hostinfo.remotes.relays) > 0 {
-		hostinfo.logger(c.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
+	if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 {
+		hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
 		// Send a RelayRequest to all known Relay IP's
 		for _, relay := range hostinfo.remotes.relays {
 			// Don't relay to myself, and don't relay through the host I'm trying to connect to
-			if *relay == vpnIp || *relay == c.lightHouse.myVpnIp {
+			if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp {
 				continue
 			}
-			relayHostInfo := c.mainHostMap.QueryVpnIp(*relay)
+			relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay)
 			if relayHostInfo == nil || relayHostInfo.remote == nil {
-				hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
-				f.Handshake(*relay)
+				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
+				hm.f.Handshake(*relay)
 				continue
 			}
 			// Check the relay HostInfo to see if we already established a relay through it
 			if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok {
 				switch existingRelay.State {
 				case Established:
-					hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Send handshake via relay")
-					f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
+					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
+					hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
 				case Requested:
-					hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
+					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
 					// Re-send the CreateRelay request, in case the previous one was lost.
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: existingRelay.LocalIndex,
-						RelayFromIp:         uint32(c.lightHouse.myVpnIp),
+						RelayFromIp:         uint32(hm.lightHouse.myVpnIp),
 						RelayToIp:           uint32(vpnIp),
 					}
 					msg, err := m.Marshal()
 					if err != nil {
-						hostinfo.logger(c.l).
+						hostinfo.logger(hm.l).
 							WithError(err).
 							Error("Failed to marshal Control message to create relay")
 					} else {
 						// This must send over the hostinfo, not over hm.Hosts[ip]
-						f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
-						c.l.WithFields(logrus.Fields{
-							"relayFrom":           c.lightHouse.myVpnIp,
+						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+						hm.l.WithFields(logrus.Fields{
+							"relayFrom":           hm.lightHouse.myVpnIp,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": existingRelay.LocalIndex,
 							"relay":               *relay}).
 							Info("send CreateRelayRequest")
 					}
 				default:
-					hostinfo.logger(c.l).
+					hostinfo.logger(hm.l).
 						WithField("vpnIp", vpnIp).
 						WithField("state", existingRelay.State).
 						WithField("relay", relayHostInfo.vpnIp).
@@ -257,26 +318,26 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 			} else {
 				// No relays exist or requested yet.
 				if relayHostInfo.remote != nil {
-					idx, err := AddRelay(c.l, relayHostInfo, c.mainHostMap, vpnIp, nil, TerminalType, Requested)
+					idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
 					if err != nil {
-						hostinfo.logger(c.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
+						hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
 					}
 
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: idx,
-						RelayFromIp:         uint32(c.lightHouse.myVpnIp),
+						RelayFromIp:         uint32(hm.lightHouse.myVpnIp),
 						RelayToIp:           uint32(vpnIp),
 					}
 					msg, err := m.Marshal()
 					if err != nil {
-						hostinfo.logger(c.l).
+						hostinfo.logger(hm.l).
 							WithError(err).
 							Error("Failed to marshal Control message to create relay")
 					} else {
-						f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
-						c.l.WithFields(logrus.Fields{
-							"relayFrom":           c.lightHouse.myVpnIp,
+						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+						hm.l.WithFields(logrus.Fields{
+							"relayFrom":           hm.lightHouse.myVpnIp,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": idx,
 							"relay":               *relay}).
@@ -287,24 +348,41 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 		}
 	}
 
-	// Increment the counter to increase our delay, linear backoff
-	hostinfo.HandshakeCounter++
-
 	// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
 	if !lighthouseTriggered {
-		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+		hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter))
 	}
 }
 
-// AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it.
-func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
-	// A write lock is used to avoid having to recheck the map and trading a read lock for a write lock
-	c.Lock()
-	defer c.Unlock()
+// GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
+// The 2nd argument will be true if the hostinfo is ready to transmit traffic
+func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
+	// Check the main hostmap and maintain a read lock if our host is not there
+	hm.mainHostMap.RLock()
+	if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok {
+		hm.mainHostMap.RUnlock()
+		// Do not attempt promotion if you are a lighthouse
+		if !hm.lightHouse.amLighthouse {
+			h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f)
+		}
+		return h, true
+	}
+
+	defer hm.mainHostMap.RUnlock()
+	return hm.StartHandshake(vpnIp, cacheCb), false
+}
+
+// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
+func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo {
+	hm.Lock()
 
-	if hostinfo, ok := c.vpnIps[vpnIp]; ok {
-		// We are already tracking this vpn ip
-		return hostinfo
+	if hh, ok := hm.vpnIps[vpnIp]; ok {
+		// We are already trying to handshake with this vpn ip
+		if cacheCb != nil {
+			cacheCb(hh)
+		}
+		hm.Unlock()
+		return hh.hostinfo
 	}
 
 	hostinfo := &HostInfo{
@@ -318,14 +396,35 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H
 		},
 	}
 
-	if init != nil {
-		init(hostinfo)
+	hh := &HandshakeHostInfo{
+		hostinfo:  hostinfo,
+		startTime: time.Now(),
 	}
+	hm.vpnIps[vpnIp] = hh
+	hm.metricInitiated.Inc(1)
+	hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
 
-	c.vpnIps[vpnIp] = hostinfo
-	c.metricInitiated.Inc(1)
-	c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
+	if cacheCb != nil {
+		cacheCb(hh)
+	}
 
+	// If this is a static host, we don't need to wait for the HostQueryReply
+	// We can trigger the handshake right now
+	_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp]
+	if !doTrigger {
+		// Add any calculated remotes, and trigger early handshake if one found
+		doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp)
+	}
+
+	if doTrigger {
+		select {
+		case hm.trigger <- vpnIp:
+		default:
+		}
+	}
+
+	hm.Unlock()
+	hm.lightHouse.QueryServer(vpnIp, hm.f)
 	return hostinfo
 }
 
@@ -347,10 +446,10 @@ var (
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
 func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
-	c.Lock()
-	defer c.Unlock()
 	c.mainHostMap.Lock()
 	defer c.mainHostMap.Unlock()
+	c.Lock()
+	defer c.Unlock()
 
 	// Check if we already have a tunnel with this vpn ip
 	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
@@ -379,8 +478,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 		return existingIndex, ErrLocalIndexCollision
 	}
 
-	existingIndex, found = c.indexes[hostinfo.localIndexId]
-	if found && existingIndex != hostinfo {
+	existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
+	if found && existingPendingIndex.hostinfo != hostinfo {
 		// We have a collision, but for a different hostinfo
 		return existingIndex, ErrLocalIndexCollision
 	}
@@ -401,47 +500,47 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 // Complete is a simpler version of CheckAndComplete when we already know we
 // won't have a localIndexId collision because we already have an entry in the
 // pendingHostMap. An existing hostinfo is returned if there was one.
-func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
-	c.Lock()
-	defer c.Unlock()
-	c.mainHostMap.Lock()
-	defer c.mainHostMap.Unlock()
+func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
+	hm.mainHostMap.Lock()
+	defer hm.mainHostMap.Unlock()
+	hm.Lock()
+	defer hm.Unlock()
 
-	existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
+	existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
 	if found && existingRemoteIndex != nil {
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
-		hostinfo.logger(c.l).
+		hostinfo.logger(hm.l).
 			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
 			Info("New host shadows existing host remoteIndex")
 	}
 
 	// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
-	c.unlockedDeleteHostInfo(hostinfo)
-	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
+	hm.unlockedDeleteHostInfo(hostinfo)
+	hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
 }
 
-// AddIndexHostInfo generates a unique localIndexId for this HostInfo
+// allocateIndex generates a unique localIndexId for this HostInfo
 // and adds it to the pendingHostMap. Will error if we are unable to generate
 // a unique localIndexId
-func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
-	c.Lock()
-	defer c.Unlock()
-	c.mainHostMap.RLock()
-	defer c.mainHostMap.RUnlock()
+func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
+	hm.mainHostMap.RLock()
+	defer hm.mainHostMap.RUnlock()
+	hm.Lock()
+	defer hm.Unlock()
 
 	for i := 0; i < 32; i++ {
-		index, err := generateIndex(c.l)
+		index, err := generateIndex(hm.l)
 		if err != nil {
 			return err
 		}
 
-		_, inPending := c.indexes[index]
-		_, inMain := c.mainHostMap.Indexes[index]
+		_, inPending := hm.indexes[index]
+		_, inMain := hm.mainHostMap.Indexes[index]
 
 		if !inMain && !inPending {
-			h.localIndexId = index
-			c.indexes[index] = h
+			hh.hostinfo.localIndexId = index
+			hm.indexes[index] = hh
 			return nil
 		}
 	}
@@ -458,12 +557,12 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
 func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	delete(c.vpnIps, hostinfo.vpnIp)
 	if len(c.vpnIps) == 0 {
-		c.vpnIps = map[iputil.VpnIp]*HostInfo{}
+		c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{}
 	}
 
 	delete(c.indexes, hostinfo.localIndexId)
 	if len(c.vpnIps) == 0 {
-		c.indexes = map[uint32]*HostInfo{}
+		c.indexes = map[uint32]*HandshakeHostInfo{}
 	}
 
 	if c.l.Level >= logrus.DebugLevel {
@@ -473,16 +572,33 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	}
 }
 
-func (c *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
-	c.RLock()
-	defer c.RUnlock()
-	return c.vpnIps[vpnIp]
+func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+	hh := hm.queryVpnIp(vpnIp)
+	if hh != nil {
+		return hh.hostinfo
+	}
+	return nil
+
 }
 
-func (c *HandshakeManager) QueryIndex(index uint32) *HostInfo {
-	c.RLock()
-	defer c.RUnlock()
-	return c.indexes[index]
+func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo {
+	hm.RLock()
+	defer hm.RUnlock()
+	return hm.vpnIps[vpnIp]
+}
+
+func (hm *HandshakeManager) QueryIndex(index uint32) *HostInfo {
+	hh := hm.queryIndex(index)
+	if hh != nil {
+		return hh.hostinfo
+	}
+	return nil
+}
+
+func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
+	hm.RLock()
+	defer hm.RUnlock()
+	return hm.indexes[index]
 }
 
 func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
@@ -494,7 +610,7 @@ func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
 	defer c.RUnlock()
 
 	for _, v := range c.vpnIps {
-		f(v)
+		f(v.hostinfo)
 	}
 }
 
@@ -503,7 +619,7 @@ func (c *HandshakeManager) ForEachIndex(f controlEach) {
 	defer c.RUnlock()
 
 	for _, v := range c.indexes {
-		f(v)
+		f(v.hostinfo)
 	}
 }
 

+ 16 - 18
handshake_manager_test.go

@@ -5,6 +5,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
@@ -14,35 +15,32 @@ import (
 
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	l := test.NewLogger()
-	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
-	mw := &mockEncWriter{}
 	mainHM := NewHostMap(l, vpncidr, preferredRanges)
 	lh := newTestLighthouse()
 
-	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
-
-	now := time.Now()
-	blah.NextOutboundHandshakeTimerTick(now, mw)
-
-	var initCalled bool
-	initFunc := func(*HostInfo) {
-		initCalled = true
+	cs := &CertState{
+		RawCertificate:      []byte{},
+		PrivateKey:          []byte{},
+		Certificate:         &cert.NebulaCertificate{},
+		RawCertificateNoKey: []byte{},
 	}
 
-	i := blah.AddVpnIp(ip, initFunc)
-	assert.True(t, initCalled)
+	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
+	blah.f = &Interface{handshakeManager: blah, pki: &PKI{}, l: l}
+	blah.f.pki.cs.Store(cs)
+
+	now := time.Now()
+	blah.NextOutboundHandshakeTimerTick(now)
 
-	initCalled = false
-	i2 := blah.AddVpnIp(ip, initFunc)
-	assert.False(t, initCalled)
+	i := blah.StartHandshake(ip, nil)
+	i2 := blah.StartHandshake(ip, nil)
 	assert.Same(t, i, i2)
 
 	i.remotes = NewRemoteList(nil)
-	i.HandshakeReady = true
 
 	// Adding something to pending should not affect the main hostmap
 	assert.Len(t, mainHM.Hosts, 0)
@@ -53,14 +51,14 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	// Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right
 	for i := 1; i <= DefaultHandshakeRetries+1; i++ {
 		now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval)
-		blah.NextOutboundHandshakeTimerTick(now, mw)
+		blah.NextOutboundHandshakeTimerTick(now)
 	}
 
 	// Confirm they are still in the pending index list
 	assert.Contains(t, blah.vpnIps, ip)
 
 	// Tick 1 more time, a minute will certainly flush it out
-	blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw)
+	blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute))
 
 	// Confirm they have been removed
 	assert.NotContains(t, blah.vpnIps, ip)

+ 18 - 86
hostmap.go

@@ -21,6 +21,7 @@ const defaultPromoteEvery = 1000       // Count of packets sent before we try mo
 const defaultReQueryEvery = 5000       // Count of packets sent before re-querying a hostinfo to the lighthouse
 const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
 const MaxRemotes = 10
+const maxRecvError = 4
 
 // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
 // 5 allows for an initial handshake and each host pair re-handshaking twice
@@ -197,24 +198,20 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
 
 type HostInfo struct {
 	syncRWMutex
-
-	remote               *udp.Addr
-	remotes              *RemoteList
-	promoteCounter       atomic.Uint32
-	ConnectionState      *ConnectionState
-	handshakeStart       time.Time   //todo: this an entry in the handshake manager
-	HandshakeReady       bool        //todo: being in the manager means you are ready
-	HandshakeCounter     int         //todo: another handshake manager entry
-	HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time
-	HandshakeComplete    bool        //todo: this should go away in favor of ConnectionState.ready
-	HandshakePacket      map[uint8][]byte
-	packetStore          []*cachedPacket //todo: this is other handshake manager entry
-	remoteIndexId        uint32
-	localIndexId         uint32
-	vpnIp                iputil.VpnIp
-	recvError            int
-	remoteCidr           *cidr.Tree4
-	relayState           RelayState
+	remote          *udp.Addr
+	remotes         *RemoteList
+	promoteCounter  atomic.Uint32
+	ConnectionState *ConnectionState
+	remoteIndexId   uint32
+	localIndexId    uint32
+	vpnIp           iputil.VpnIp
+	recvError       atomic.Uint32
+	remoteCidr      *cidr.Tree4[struct{}]
+	relayState      RelayState
+
+	// HandshakePacket records the packets used to create this hostinfo
+	// We need these to avoid replayed handshake packets creating new hostinfos which causes churn
+	HandshakePacket map[uint8][]byte
 
 	// nextLHQuery is the earliest we can ask the lighthouse for new information.
 	// This is used to limit lighthouse re-queries in chatty clients
@@ -413,7 +410,6 @@ func (hm *HostMap) QueryIndex(index uint32) *HostInfo {
 }
 
 func (hm *HostMap) QueryRelayIndex(index uint32) *HostInfo {
-	//TODO: we probably just want to return bool instead of error, or at least a static error
 	hm.RLock()
 	if h, ok := hm.Relays[index]; ok {
 		hm.RUnlock()
@@ -457,12 +453,6 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host
 	return nil, nil, errors.New("unable to find host with relay")
 }
 
-// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
-// `PromoteEvery` calls to this function for a given host.
-func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) *HostInfo {
-	return hm.queryVpnIp(vpnIp, ifce)
-}
-
 func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo {
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
@@ -542,10 +532,7 @@ func (hm *HostMap) ForEachIndex(f controlEach) {
 func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
 	c := i.promoteCounter.Add(1)
 	if c%ifce.tryPromoteEvery.Load() == 0 {
-		// The lock here is currently protecting i.remote access
-		i.RLock()
 		remote := i.remote
-		i.RUnlock()
 
 		// return early if we are already on a preferred remote
 		if remote != nil {
@@ -580,60 +567,6 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 	}
 }
 
-func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
-	//TODO: return the error so we can log with more context
-	if len(i.packetStore) < 100 {
-		tempPacket := make([]byte, len(packet))
-		copy(tempPacket, packet)
-		//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
-		i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
-		if l.Level >= logrus.DebugLevel {
-			i.logger(l).
-				WithField("length", len(i.packetStore)).
-				WithField("stored", true).
-				Debugf("Packet store")
-		}
-
-	} else if l.Level >= logrus.DebugLevel {
-		m.dropped.Inc(1)
-		i.logger(l).
-			WithField("length", len(i.packetStore)).
-			WithField("stored", false).
-			Debugf("Packet store")
-	}
-}
-
-// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets
-func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) {
-	//TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because:
-	//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
-	//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
-
-	i.ConnectionState.queueLock.Lock()
-	i.HandshakeComplete = true
-	//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
-	// Clamping it to 2 gets us out of the woods for now
-	i.ConnectionState.messageCounter.Store(2)
-
-	if l.Level >= logrus.DebugLevel {
-		i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore))
-	}
-
-	if len(i.packetStore) > 0 {
-		nb := make([]byte, 12, 12)
-		out := make([]byte, mtu)
-		for _, cp := range i.packetStore {
-			cp.callback(cp.messageType, cp.messageSubType, i, cp.packet, nb, out)
-		}
-		m.sent.Inc(int64(len(i.packetStore)))
-	}
-
-	i.remotes.ResetBlockedRemotes()
-	i.packetStore = make([]*cachedPacket, 0)
-	i.ConnectionState.ready = true
-	i.ConnectionState.queueLock.Unlock()
-}
-
 func (i *HostInfo) GetCert() *cert.NebulaCertificate {
 	if i.ConnectionState != nil {
 		return i.ConnectionState.peerCert
@@ -690,9 +623,8 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
 }
 
 func (i *HostInfo) RecvErrorExceeded() bool {
-	if i.recvError < 3 {
-		i.recvError += 1
-		return false
+	if i.recvError.Add(1) >= maxRecvError {
+		return true
 	}
 	return true
 }
@@ -703,7 +635,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
 		return
 	}
 
-	remoteCidr := cidr.NewTree4()
+	remoteCidr := cidr.NewTree4[struct{}]()
 	for _, ip := range c.Details.Ips {
 		remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
 	}

+ 41 - 87
inside.go

@@ -1,7 +1,6 @@
 package nebula
 
 import (
-	"github.com/flynn/noise"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
@@ -45,7 +44,10 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		return
 	}
 
-	hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
+	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) {
+		hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
+	})
+
 	if hostinfo == nil {
 		f.rejectInside(packet, out, q)
 		if f.l.Level >= logrus.DebugLevel {
@@ -55,23 +57,14 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		}
 		return
 	}
-	ci := hostinfo.ConnectionState
-
-	if !ci.ready {
-		// Because we might be sending stored packets, lock here to stop new things going to
-		// the packet queue.
-		ci.queueLock.Lock()
-		if !ci.ready {
-			hostinfo.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
-			ci.queueLock.Unlock()
-			return
-		}
-		ci.queueLock.Unlock()
+
+	if !ready {
+		return
 	}
 
 	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason == nil {
-		f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, packet, nb, out, q)
+		f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q)
 
 	} else {
 		f.rejectInside(packet, out, q)
@@ -90,6 +83,10 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
 	}
 
 	out = iputil.CreateRejectPacket(packet, out)
+	if len(out) == 0 {
+		return
+	}
+
 	_, err := f.readers[q].Write(out)
 	if err != nil {
 		f.l.WithError(err).Error("Failed to write to tun")
@@ -101,77 +98,39 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 		return
 	}
 
-	// Use some out buffer space to build the packet before encryption
-	// Need 40 bytes for the reject packet (20 byte ipv4 header, 20 byte tcp rst packet)
-	// Leave 100 bytes for the encrypted packet (60 byte Nebula header, 40 byte reject packet)
-	out = out[:140]
-	outPacket := iputil.CreateRejectPacket(packet, out[100:])
-	f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, outPacket, nb, out, q)
+	out = iputil.CreateRejectPacket(packet, out)
+	if len(out) == 0 {
+		return
+	}
+
+	if len(out) > iputil.MaxRejectPacketSize {
+		if f.l.GetLevel() >= logrus.InfoLevel {
+			f.l.
+				WithField("packet", packet).
+				WithField("outPacket", out).
+				Info("rejectOutside: packet too big, not sending")
+		}
+		return
+	}
+
+	f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q)
 }
 
 func (f *Interface) Handshake(vpnIp iputil.VpnIp) {
-	f.getOrHandshake(vpnIp)
+	f.getOrHandshake(vpnIp, nil)
 }
 
-// getOrHandshake returns nil if the vpnIp is not routable
-func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
+// getOrHandshake returns nil if the vpnIp is not routable.
+// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
+func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
 	if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) {
 		vpnIp = f.inside.RouteFor(vpnIp)
 		if vpnIp == 0 {
-			return nil
-		}
-	}
-
-	hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
-	if hostinfo == nil {
-		hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
-	}
-	ci := hostinfo.ConnectionState
-
-	if ci != nil && ci.eKey != nil && ci.ready {
-		return hostinfo
-	}
-
-	// Handshake is not ready, we need to grab the lock now before we start the handshake process
-	//TODO: move this to handshake manager
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
-
-	// Double check, now that we have the lock
-	ci = hostinfo.ConnectionState
-	if ci != nil && ci.eKey != nil && ci.ready {
-		return hostinfo
-	}
-
-	// If we have already created the handshake packet, we don't want to call the function at all.
-	if !hostinfo.HandshakeReady {
-		ixHandshakeStage0(f, vpnIp, hostinfo)
-		// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
-		//xx_handshakeStage0(f, ip, hostinfo)
-
-		// If this is a static host, we don't need to wait for the HostQueryReply
-		// We can trigger the handshake right now
-		_, doTrigger := f.lightHouse.GetStaticHostList()[vpnIp]
-		if !doTrigger {
-			// Add any calculated remotes, and trigger early handshake if one found
-			doTrigger = f.lightHouse.addCalculatedRemotes(vpnIp)
-		}
-
-		if doTrigger {
-			select {
-			case f.handshakeManager.trigger <- vpnIp:
-			default:
-			}
+			return nil, false
 		}
 	}
 
-	return hostinfo
-}
-
-// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that
-// will create the initial Noise ConnectionState
-func (f *Interface) initHostInfo(hostinfo *HostInfo) {
-	hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
+	return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback)
 }
 
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@@ -198,7 +157,10 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
 func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
-	hostInfo := f.getOrHandshake(vpnIp)
+	hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
+		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
+	})
+
 	if hostInfo == nil {
 		if f.l.Level >= logrus.DebugLevel {
 			f.l.WithField("vpnIp", vpnIp).
@@ -207,16 +169,8 @@ func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSu
 		return
 	}
 
-	if !hostInfo.ConnectionState.ready {
-		// Because we might be sending stored packets, lock here to stop new things going to
-		// the packet queue.
-		hostInfo.ConnectionState.queueLock.Lock()
-		if !hostInfo.ConnectionState.ready {
-			hostInfo.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
-			hostInfo.ConnectionState.queueLock.Unlock()
-			return
-		}
-		hostInfo.ConnectionState.queueLock.Unlock()
+	if !ready {
+		return
 	}
 
 	f.SendMessageToHostInfo(t, st, hostInfo, p, nb, out)
@@ -236,7 +190,7 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C
 	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
 }
 
-// sendVia sends a payload through a Relay tunnel. No authentication or encryption is done
+// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
 // to the payload for the ultimate target host, making this a useful method for sending
 // handshake messages to peers through relay tunnels.
 // via is the HostInfo through which the message is relayed.

+ 17 - 7
interface.go

@@ -40,7 +40,6 @@ type InterfaceConfig struct {
 	routines                int
 	MessageMetrics          *MessageMetrics
 	version                 string
-	disconnectInvalid       bool
 	relayManager            *relayManager
 	punchy                  *Punchy
 
@@ -69,7 +68,7 @@ type Interface struct {
 	dropLocalBroadcast bool
 	dropMulticast      bool
 	routines           int
-	disconnectInvalid  bool
+	disconnectInvalid  atomic.Bool
 	closed             atomic.Bool
 	relayManager       *relayManager
 
@@ -176,7 +175,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		version:            c.version,
 		writers:            make([]udp.Conn, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
-		disconnectInvalid:  c.disconnectInvalid,
 		myVpnIp:            myVpnIp,
 		relayManager:       c.relayManager,
 
@@ -294,12 +292,24 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
 	c.RegisterReloadCallback(f.reloadFirewall)
 	c.RegisterReloadCallback(f.reloadSendRecvError)
+	c.RegisterReloadCallback(f.reloadDisconnectInvalid)
 	c.RegisterReloadCallback(f.reloadMisc)
+
 	for _, udpConn := range f.writers {
 		c.RegisterReloadCallback(udpConn.ReloadConfig)
 	}
 }
 
+func (f *Interface) reloadDisconnectInvalid(c *config.C) {
+	initial := c.InitialLoad()
+	if initial || c.HasChanged("pki.disconnect_invalid") {
+		f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
+		if !initial {
+			f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load())
+		}
+	}
+}
+
 func (f *Interface) reloadFirewall(c *config.C) {
 	//TODO: need to trigger/detect if the certificate changed too
 	if c.HasChanged("firewall") == false {
@@ -322,8 +332,8 @@ func (f *Interface) reloadFirewall(c *config.C) {
 	// If rulesVersion is back to zero, we have wrapped all the way around. Be
 	// safe and just reset conntrack in this case.
 	if fw.rulesVersion == 0 {
-		f.l.WithField("firewallHash", fw.GetRuleHash()).
-			WithField("oldFirewallHash", oldFw.GetRuleHash()).
+		f.l.WithField("firewallHashes", fw.GetRuleHashes()).
+			WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
 			WithField("rulesVersion", fw.rulesVersion).
 			Warn("firewall rulesVersion has overflowed, resetting conntrack")
 	} else {
@@ -333,8 +343,8 @@ func (f *Interface) reloadFirewall(c *config.C) {
 	f.firewall = fw
 
 	oldFw.Destroy()
-	f.l.WithField("firewallHash", fw.GetRuleHash()).
-		WithField("oldFirewallHash", oldFw.GetRuleHash()).
+	f.l.WithField("firewallHashes", fw.GetRuleHashes()).
+		WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
 		WithField("rulesVersion", fw.rulesVersion).
 		Info("New firewall has been installed")
 }

+ 34 - 7
iputil/packet.go

@@ -6,8 +6,19 @@ import (
 	"golang.org/x/net/ipv4"
 )
 
+const (
+	// Need 96 bytes for the largest reject packet:
+	// - 20 byte ipv4 header
+	// - 8 byte icmpv4 header
+	// - 68 byte body (60 byte max orig ipv4 header + 8 byte orig icmpv4 header)
+	MaxRejectPacketSize = ipv4.HeaderLen + 8 + 60 + 8
+)
+
 func CreateRejectPacket(packet []byte, out []byte) []byte {
-	// TODO ipv4 only, need to fix when inside supports ipv6
+	if len(packet) < ipv4.HeaderLen || int(packet[0]>>4) != ipv4.Version {
+		return nil
+	}
+
 	switch packet[9] {
 	case 6: // tcp
 		return ipv4CreateRejectTCPPacket(packet, out)
@@ -19,20 +30,28 @@ func CreateRejectPacket(packet []byte, out []byte) []byte {
 func ipv4CreateRejectICMPPacket(packet []byte, out []byte) []byte {
 	ihl := int(packet[0]&0x0f) << 2
 
-	// ICMP reply includes header and first 8 bytes of the packet
+	if len(packet) < ihl {
+		// We need at least this many bytes for this to be a valid packet
+		return nil
+	}
+
+	// ICMP reply includes original header and first 8 bytes of the packet
 	packetLen := len(packet)
 	if packetLen > ihl+8 {
 		packetLen = ihl + 8
 	}
 
 	outLen := ipv4.HeaderLen + 8 + packetLen
+	if outLen > cap(out) {
+		return nil
+	}
 
-	out = out[:(outLen)]
+	out = out[:outLen]
 
 	ipHdr := out[0:ipv4.HeaderLen]
-	ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2)                        // version, ihl
-	ipHdr[1] = 0                                                              // DSCP, ECN
-	binary.BigEndian.PutUint16(ipHdr[2:], uint16(ipv4.HeaderLen+8+packetLen)) // Total Length
+	ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2)    // version, ihl
+	ipHdr[1] = 0                                          // DSCP, ECN
+	binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length
 
 	ipHdr[4] = 0  // id
 	ipHdr[5] = 0  //  .
@@ -76,7 +95,15 @@ func ipv4CreateRejectTCPPacket(packet []byte, out []byte) []byte {
 	ihl := int(packet[0]&0x0f) << 2
 	outLen := ipv4.HeaderLen + tcpLen
 
-	out = out[:(outLen)]
+	if len(packet) < ihl+tcpLen {
+		// We need at least this many bytes for this to be a valid packet
+		return nil
+	}
+	if outLen > cap(out) {
+		return nil
+	}
+
+	out = out[:outLen]
 
 	ipHdr := out[0:ipv4.HeaderLen]
 	ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2)    // version, ihl

+ 73 - 0
iputil/packet_test.go

@@ -0,0 +1,73 @@
+package iputil
+
+import (
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"golang.org/x/net/ipv4"
+)
+
+func Test_CreateRejectPacket(t *testing.T) {
+	h := ipv4.Header{
+		Len:      20,
+		Src:      net.IPv4(10, 0, 0, 1),
+		Dst:      net.IPv4(10, 0, 0, 2),
+		Protocol: 1, // ICMP
+	}
+
+	b, err := h.Marshal()
+	if err != nil {
+		t.Fatalf("h.Marhshal: %v", err)
+	}
+	b = append(b, []byte{0, 3, 0, 4}...)
+
+	expectedLen := ipv4.HeaderLen + 8 + h.Len + 4
+	out := make([]byte, expectedLen)
+	rejectPacket := CreateRejectPacket(b, out)
+	assert.NotNil(t, rejectPacket)
+	assert.Len(t, rejectPacket, expectedLen)
+
+	// ICMP with max header len
+	h = ipv4.Header{
+		Len:      60,
+		Src:      net.IPv4(10, 0, 0, 1),
+		Dst:      net.IPv4(10, 0, 0, 2),
+		Protocol: 1, // ICMP
+		Options:  make([]byte, 40),
+	}
+
+	b, err = h.Marshal()
+	if err != nil {
+		t.Fatalf("h.Marhshal: %v", err)
+	}
+	b = append(b, []byte{0, 3, 0, 4, 0, 0, 0, 0}...)
+
+	expectedLen = MaxRejectPacketSize
+	out = make([]byte, MaxRejectPacketSize)
+	rejectPacket = CreateRejectPacket(b, out)
+	assert.NotNil(t, rejectPacket)
+	assert.Len(t, rejectPacket, expectedLen)
+
+	// TCP with max header len
+	h = ipv4.Header{
+		Len:      60,
+		Src:      net.IPv4(10, 0, 0, 1),
+		Dst:      net.IPv4(10, 0, 0, 2),
+		Protocol: 6, // TCP
+		Options:  make([]byte, 40),
+	}
+
+	b, err = h.Marshal()
+	if err != nil {
+		t.Fatalf("h.Marhshal: %v", err)
+	}
+	b = append(b, []byte{0, 3, 0, 4}...)
+	b = append(b, make([]byte, 16)...)
+
+	expectedLen = ipv4.HeaderLen + 20
+	out = make([]byte, expectedLen)
+	rejectPacket = CreateRejectPacket(b, out)
+	assert.NotNil(t, rejectPacket)
+	assert.Len(t, rejectPacket, expectedLen)
+}

+ 4 - 5
lighthouse.go

@@ -74,7 +74,7 @@ type LightHouse struct {
 	// IP's of relays that can be used by peers to access me
 	relaysForMe atomic.Pointer[[]iputil.VpnIp]
 
-	calculatedRemotes atomic.Pointer[cidr.Tree4] // Maps VpnIp to []*calculatedRemote
+	calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
 
 	metrics           *MessageMetrics
 	metricHolepunchTx metrics.Counter
@@ -166,7 +166,7 @@ func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp {
 	return *lh.relaysForMe.Load()
 }
 
-func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4 {
+func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] {
 	return lh.calculatedRemotes.Load()
 }
 
@@ -594,11 +594,10 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
 	if tree == nil {
 		return false
 	}
-	value := tree.MostSpecificContains(vpnIp)
-	if value == nil {
+	ok, calculatedRemotes := tree.MostSpecificContains(vpnIp)
+	if !ok {
 		return false
 	}
-	calculatedRemotes := value.([]*calculatedRemote)
 
 	var calculated []*Ip4AndPort
 	for _, cr := range calculatedRemotes {

+ 16 - 6
main.go

@@ -18,7 +18,7 @@ import (
 
 type m map[string]interface{}
 
-func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
+func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
 	ctx, cancel := context.WithCancel(context.Background())
 	// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
 	defer func() {
@@ -65,12 +65,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
 	}
-	l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
+	l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
 
 	// TODO: make sure mask is 4 bytes
 	tunCidr := certificate.Details.Ips[0]
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
+	if err != nil {
+		return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
+	}
 	wireSSHReload(l, ssh, c)
 	var sshStart func()
 	if c.GetBool("sshd.enabled", false) {
@@ -125,7 +128,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	if !configTest {
 		c.CatchHUP(ctx)
 
-		tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines)
+		if deviceFactory == nil {
+			deviceFactory = overlay.NewDeviceFromConfig
+		}
+
+		tun, err = deviceFactory(c, l, tunCidr, routines)
 		if err != nil {
 			return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
 		}
@@ -156,6 +163,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 
 		for i := 0; i < routines; i++ {
+			l.Infof("listening %q %d", listenHost.IP, port)
 			udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64))
 			if err != nil {
 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
@@ -235,7 +243,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		messageMetrics: messageMetrics,
 	}
 
-	handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
+	handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
 	lightHouse.handshakeTrigger = handshakeManager.trigger
 
 	serveDns := false
@@ -270,7 +278,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		routines:                routines,
 		MessageMetrics:          messageMetrics,
 		version:                 buildVersion,
-		disconnectInvalid:       c.GetBool("pki.disconnect_invalid", false),
 		relayManager:            NewRelayManager(ctx, l, hostMap, c),
 		punchy:                  punchy,
 
@@ -300,9 +307,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		lightHouse.ifce = ifce
 
 		ifce.RegisterConfigChangeCallbacks(c)
+		ifce.reloadDisconnectInvalid(c)
 		ifce.reloadSendRecvError(c)
 
-		go handshakeManager.Run(ctx, ifce)
+		handshakeManager.f = ifce
+		go handshakeManager.Run(ctx)
 	}
 
 	// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
@@ -331,6 +340,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	return &Control{
 		ifce,
 		l,
+		ctx,
 		cancel,
 		sshStart,
 		statsStart,

+ 22 - 6
mutex_debug.go

@@ -5,7 +5,9 @@ package nebula
 
 import (
 	"fmt"
+	"log"
 	"runtime"
+	"runtime/debug"
 	"sync"
 
 	"github.com/timandy/routine"
@@ -42,24 +44,38 @@ func newSyncRWMutex(key mutexKey) syncRWMutex {
 	}
 }
 
+func alertMutex(err error) {
+	log.Print(err, string(debug.Stack()))
+}
+
 func checkMutex(state map[mutexKey]mutexValue, add mutexKey) {
+	for k := range state {
+		if add == k {
+			alertMutex(fmt.Errorf("re-entrant lock: state=%v add=%v", state, add))
+		}
+	}
+
 	switch add.Type {
 	case mutexKeyTypeHostInfo:
 		// Check for any other hostinfo keys:
 		for k := range state {
 			if k.Type == mutexKeyTypeHostInfo {
-				panic(fmt.Errorf("grabbing hostinfo lock and already have a hostinfo lock: state=%v add=%v", state, add))
+				alertMutex(fmt.Errorf("grabbing hostinfo lock and already have a hostinfo lock: state=%v add=%v", state, add))
 			}
 		}
 		if _, ok := state[mutexKey{Type: mutexKeyTypeHostMap}]; ok {
-			panic(fmt.Errorf("grabbing hostinfo lock and already have hostmap: state=%v add=%v", state, add))
+			alertMutex(fmt.Errorf("grabbing hostinfo lock and already have hostmap: state=%v add=%v", state, add))
 		}
 		if _, ok := state[mutexKey{Type: mutexKeyTypeHandshakeManager}]; ok {
-			panic(fmt.Errorf("grabbing hostinfo lock and already have handshake-manager: state=%v add=%v", state, add))
+			alertMutex(fmt.Errorf("grabbing hostinfo lock and already have handshake-manager: state=%v add=%v", state, add))
 		}
-	case mutexKeyTypeHandshakeManager:
-		if _, ok := state[mutexKey{Type: mutexKeyTypeHostMap}]; ok {
-			panic(fmt.Errorf("grabbing handshake-manager lock and already have hostmap: state=%v add=%v", state, add))
+		// case mutexKeyTypeHandshakeManager:
+		// 	if _, ok := state[mutexKey{Type: mutexKeyTypeHostMap}]; ok {
+		// 		alertMutex(fmt.Errorf("grabbing handshake-manager lock and already have hostmap: state=%v add=%v", state, add))
+		// 	}
+	case mutexKeyTypeHostMap:
+		if _, ok := state[mutexKey{Type: mutexKeyTypeHandshakeManager}]; ok {
+			alertMutex(fmt.Errorf("grabbing hostmap lock and already have handshake-manager: state=%v add=%v", state, add))
 		}
 	}
 }

+ 4 - 5
outside.go

@@ -198,7 +198,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 
 	case header.Handshake:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		HandleIncomingHandshake(f, addr, via, packet, h, hostinfo)
+		f.handshakeManager.HandleIncoming(addr, via, packet, h)
 		return
 
 	case header.RecvError:
@@ -406,7 +406,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 
 	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason != nil {
-		f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, out, q)
+		// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
+		// This gives us a buffer to build the reject packet in
+		f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
 		if f.l.Level >= logrus.DebugLevel {
 			hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
 				WithField("reason", dropReason).
@@ -455,9 +457,6 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
 		return
 	}
 
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
-
 	if !hostinfo.RecvErrorExceeded() {
 		return
 	}

+ 2 - 2
overlay/route.go

@@ -21,8 +21,8 @@ type Route struct {
 	Install bool
 }
 
-func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) {
-	routeTree := cidr.NewTree4()
+func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
+	routeTree := cidr.NewTree4[iputil.VpnIp]()
 	for _, r := range routes {
 		if !allowMTU && r.MTU > 0 {
 			l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)

+ 8 - 10
overlay/route_test.go

@@ -265,18 +265,16 @@ func Test_makeRouteTree(t *testing.T) {
 	assert.NoError(t, err)
 
 	ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
-	r := routeTree.MostSpecificContains(ip)
-	assert.NotNil(t, r)
-	assert.IsType(t, iputil.VpnIp(0), r)
-	assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
+	ok, r := routeTree.MostSpecificContains(ip)
+	assert.True(t, ok)
+	assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
 
 	ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
-	r = routeTree.MostSpecificContains(ip)
-	assert.NotNil(t, r)
-	assert.IsType(t, iputil.VpnIp(0), r)
-	assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
+	ok, r = routeTree.MostSpecificContains(ip)
+	assert.True(t, ok)
+	assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
 
 	ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
-	r = routeTree.MostSpecificContains(ip)
-	assert.Nil(t, r)
+	ok, r = routeTree.MostSpecificContains(ip)
+	assert.False(t, ok)
 }

+ 24 - 8
overlay/tun.go

@@ -10,7 +10,9 @@ import (
 
 const DefaultMTU = 1300
 
-func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *int, routines int) (Device, error) {
+type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error)
+
+func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
 	routes, err := parseRoutes(c, tunCidr)
 	if err != nil {
 		return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
@@ -27,27 +29,41 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
 		tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
 		return tun, nil
 
-	case fd != nil:
-		return newTunFromFd(
+	default:
+		return newTun(
 			l,
-			*fd,
+			c.GetString("tun.dev", ""),
 			tunCidr,
 			c.GetInt("tun.mtu", DefaultMTU),
 			routes,
 			c.GetInt("tun.tx_queue", 500),
+			routines > 1,
 			c.GetBool("tun.use_system_route_table", false),
 		)
+	}
+}
 
-	default:
-		return newTun(
+func NewFdDeviceFromConfig(fd *int) DeviceFactory {
+	return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+		routes, err := parseRoutes(c, tunCidr)
+		if err != nil {
+			return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
+		}
+
+		unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
+		if err != nil {
+			return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
+		}
+		routes = append(routes, unsafeRoutes...)
+		return newTunFromFd(
 			l,
-			c.GetString("tun.dev", ""),
+			*fd,
 			tunCidr,
 			c.GetInt("tun.mtu", DefaultMTU),
 			routes,
 			c.GetInt("tun.tx_queue", 500),
-			routines > 1,
 			c.GetBool("tun.use_system_route_table", false),
 		)
+
 	}
 }

+ 3 - 7
overlay/tun_android.go

@@ -18,7 +18,7 @@ type tun struct {
 	io.ReadWriteCloser
 	fd        int
 	cidr      *net.IPNet
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 }
 
@@ -46,12 +46,8 @@ func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t tun) Activate() error {

+ 4 - 4
overlay/tun_darwin.go

@@ -25,7 +25,7 @@ type tun struct {
 	cidr       *net.IPNet
 	DefaultMTU int
 	Routes     []Route
-	routeTree  *cidr.Tree4
+	routeTree  *cidr.Tree4[iputil.VpnIp]
 	l          *logrus.Logger
 
 	// cache out buffer since we need to prepend 4 bytes for tun metadata
@@ -304,9 +304,9 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
+	ok, r := t.routeTree.MostSpecificContains(ip)
+	if ok {
+		return r
 	}
 
 	return 0

+ 3 - 7
overlay/tun_freebsd.go

@@ -48,7 +48,7 @@ type tun struct {
 	cidr      *net.IPNet
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 
 	io.ReadWriteCloser
@@ -192,12 +192,8 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *tun) Cidr() *net.IPNet {

+ 3 - 7
overlay/tun_ios.go

@@ -20,7 +20,7 @@ import (
 type tun struct {
 	io.ReadWriteCloser
 	cidr      *net.IPNet
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 }
 
 func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
@@ -46,12 +46,8 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 // The following is hoisted up from water, we do this so we can inject our own fd on iOS

+ 5 - 9
overlay/tun_linux.go

@@ -30,7 +30,7 @@ type tun struct {
 	TXQueueLen int
 
 	Routes          []Route
-	routeTree       atomic.Pointer[cidr.Tree4]
+	routeTree       atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
 	routeChan       chan struct{}
 	useSystemRoutes bool
 
@@ -154,12 +154,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.Load().MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.Load().MostSpecificContains(ip)
+	return r
 }
 
 func (t *tun) Write(b []byte) (int, error) {
@@ -380,7 +376,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 		return
 	}
 
-	newTree := cidr.NewTree4()
+	newTree := cidr.NewTree4[iputil.VpnIp]()
 	if r.Type == unix.RTM_NEWROUTE {
 		for _, oldR := range t.routeTree.Load().List() {
 			newTree.AddCIDR(oldR.CIDR, oldR.Value)
@@ -392,7 +388,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 	} else {
 		gw := iputil.Ip2VpnIp(r.Gw)
 		for _, oldR := range t.routeTree.Load().List() {
-			if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw {
+			if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw {
 				// This is the record to delete
 				t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
 				continue

+ 3 - 7
overlay/tun_netbsd.go

@@ -29,7 +29,7 @@ type tun struct {
 	cidr      *net.IPNet
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 
 	io.ReadWriteCloser
@@ -134,12 +134,8 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *tun) Cidr() *net.IPNet {

+ 3 - 7
overlay/tun_openbsd.go

@@ -23,7 +23,7 @@ type tun struct {
 	cidr      *net.IPNet
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 
 	io.ReadWriteCloser
@@ -115,12 +115,8 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *tun) Cidr() *net.IPNet {

+ 3 - 7
overlay/tun_tester.go

@@ -19,7 +19,7 @@ type TestTun struct {
 	Device    string
 	cidr      *net.IPNet
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 
 	closed    atomic.Bool
@@ -83,12 +83,8 @@ func (t *TestTun) Get(block bool) []byte {
 //********************************************************************************************************************//
 
 func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *TestTun) Activate() error {

+ 3 - 7
overlay/tun_water_windows.go

@@ -18,7 +18,7 @@ type waterTun struct {
 	cidr      *net.IPNet
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 
 	*water.Interface
 }
@@ -97,12 +97,8 @@ func (t *waterTun) Activate() error {
 }
 
 func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *waterTun) Cidr() *net.IPNet {

+ 12 - 9
overlay/tun_wintun_windows.go

@@ -24,7 +24,7 @@ type winTun struct {
 	prefix    netip.Prefix
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 
 	tun *wintun.NativeTun
 }
@@ -54,9 +54,16 @@ func newWinTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU
 		return nil, fmt.Errorf("generate GUID failed: %w", err)
 	}
 
-	tunDevice, err := wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
+	var tunDevice wintun.Device
+	tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
 	if err != nil {
-		return nil, fmt.Errorf("create TUN device failed: %w", err)
+		// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
+		// Trying a second time resolves the issue.
+		l.WithError(err).Debug("Failed to create wintun device, retrying")
+		tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
+		if err != nil {
+			return nil, fmt.Errorf("create TUN device failed: %w", err)
+		}
 	}
 
 	routeTree, err := makeRouteTree(l, routes, false)
@@ -139,12 +146,8 @@ func (t *winTun) Activate() error {
 }
 
 func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *winTun) Cidr() *net.IPNet {

+ 63 - 0
overlay/user.go

@@ -0,0 +1,63 @@
+package overlay
+
+import (
+	"io"
+	"net"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/iputil"
+)
+
+func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+	return NewUserDevice(tunCidr)
+}
+
+func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
+	// these pipes guarantee each write/read will match 1:1
+	or, ow := io.Pipe()
+	ir, iw := io.Pipe()
+	return &UserDevice{
+		tunCidr:        tunCidr,
+		outboundReader: or,
+		outboundWriter: ow,
+		inboundReader:  ir,
+		inboundWriter:  iw,
+	}, nil
+}
+
+type UserDevice struct {
+	tunCidr *net.IPNet
+
+	outboundReader *io.PipeReader
+	outboundWriter *io.PipeWriter
+
+	inboundReader *io.PipeReader
+	inboundWriter *io.PipeWriter
+}
+
+func (d *UserDevice) Activate() error {
+	return nil
+}
+func (d *UserDevice) Cidr() *net.IPNet                      { return d.tunCidr }
+func (d *UserDevice) Name() string                          { return "faketun0" }
+func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip }
+func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
+	return d, nil
+}
+
+func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
+	return d.inboundReader, d.outboundWriter
+}
+
+func (d *UserDevice) Read(p []byte) (n int, err error) {
+	return d.outboundReader.Read(p)
+}
+func (d *UserDevice) Write(p []byte) (n int, err error) {
+	return d.inboundWriter.Write(p)
+}
+func (d *UserDevice) Close() error {
+	d.inboundWriter.Close()
+	d.outboundWriter.Close()
+	return nil
+}

+ 7 - 1
relay_manager.go

@@ -179,6 +179,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		"vpnIp":               h.vpnIp})
 
 	logMsg.Info("handleCreateRelayRequest")
+	// Is the source of the relay me? This should never happen, but did happen due to
+	// an issue migrating relays over to newly re-handshaked host info objects.
+	if from == f.myVpnIp {
+		logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself")
+		return
+	}
 	// Is the target of the relay me?
 	if target == f.myVpnIp {
 		existingRelay, ok := h.relayState.QueryRelayForByIp(from)
@@ -244,7 +250,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		if peer == nil {
 			// Try to establish a connection to this host. If we get a future relay request,
 			// we'll be ready!
-			f.getOrHandshake(target)
+			f.Handshake(target)
 			return
 		}
 		if peer.remote == nil {

+ 36 - 0
service/listener.go

@@ -0,0 +1,36 @@
+package service
+
+import (
+	"io"
+	"net"
+)
+
+type tcpListener struct {
+	port   uint16
+	s      *Service
+	addr   *net.TCPAddr
+	accept chan net.Conn
+}
+
+func (l *tcpListener) Accept() (net.Conn, error) {
+	conn, ok := <-l.accept
+	if !ok {
+		return nil, io.EOF
+	}
+	return conn, nil
+}
+
+func (l *tcpListener) Close() error {
+	l.s.mu.Lock()
+	defer l.s.mu.Unlock()
+	delete(l.s.mu.listeners, uint16(l.addr.Port))
+
+	close(l.accept)
+
+	return nil
+}
+
+// Addr returns the listener's network address.
+func (l *tcpListener) Addr() net.Addr {
+	return l.addr
+}

+ 248 - 0
service/service.go

@@ -0,0 +1,248 @@
+package service
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"fmt"
+	"log"
+	"math"
+	"net"
+	"os"
+	"strings"
+	"sync"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/overlay"
+	"golang.org/x/sync/errgroup"
+	"gvisor.dev/gvisor/pkg/bufferv2"
+	"gvisor.dev/gvisor/pkg/tcpip"
+	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+	"gvisor.dev/gvisor/pkg/tcpip/header"
+	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+	"gvisor.dev/gvisor/pkg/tcpip/stack"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+	"gvisor.dev/gvisor/pkg/waiter"
+)
+
+const nicID = 1
+
+type Service struct {
+	eg      *errgroup.Group
+	control *nebula.Control
+	ipstack *stack.Stack
+
+	mu struct {
+		sync.Mutex
+
+		listeners map[uint16]*tcpListener
+	}
+}
+
+func New(config *config.C) (*Service, error) {
+	logger := logrus.New()
+	logger.Out = os.Stdout
+
+	control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
+	if err != nil {
+		return nil, err
+	}
+	control.Start()
+
+	ctx := control.Context()
+	eg, ctx := errgroup.WithContext(ctx)
+	s := Service{
+		eg:      eg,
+		control: control,
+	}
+	s.mu.listeners = map[uint16]*tcpListener{}
+
+	device, ok := control.Device().(*overlay.UserDevice)
+	if !ok {
+		return nil, errors.New("must be using user device")
+	}
+
+	s.ipstack = stack.New(stack.Options{
+		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
+	})
+	sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
+	tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
+	if tcpipErr != nil {
+		return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
+	}
+	linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "")
+	if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
+		return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
+	}
+	ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4)))
+	s.ipstack.SetRouteTable([]tcpip.Route{
+		{
+			Destination: ipv4Subnet,
+			NIC:         nicID,
+		},
+	})
+
+	ipNet := device.Cidr()
+	pa := tcpip.ProtocolAddress{
+		AddressWithPrefix: tcpip.Address(ipNet.IP).WithPrefix(),
+		Protocol:          ipv4.ProtocolNumber,
+	}
+	if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
+		PEB:        stack.CanBePrimaryEndpoint, // zero value default
+		ConfigType: stack.AddressConfigStatic,  // zero value default
+	}); err != nil {
+		return nil, fmt.Errorf("error creating IP: %s", err)
+	}
+
+	const tcpReceiveBufferSize = 0
+	const maxInFlightConnectionAttempts = 1024
+	tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler)
+	s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
+
+	reader, writer := device.Pipe()
+
+	go func() {
+		<-ctx.Done()
+		reader.Close()
+		writer.Close()
+	}()
+
+	// create Goroutines to forward packets between Nebula and Gvisor
+	eg.Go(func() error {
+		buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize)
+		for {
+			// this will read exactly one packet
+			n, err := reader.Read(buf)
+			if err != nil {
+				return err
+			}
+			packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
+				Payload: bufferv2.MakeWithData(bytes.Clone(buf[:n])),
+			})
+			linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
+
+			if err := ctx.Err(); err != nil {
+				return err
+			}
+		}
+	})
+	eg.Go(func() error {
+		for {
+			packet := linkEP.ReadContext(ctx)
+			if packet.IsNil() {
+				if err := ctx.Err(); err != nil {
+					return err
+				}
+				continue
+			}
+			bufView := packet.ToView()
+			if _, err := bufView.WriteTo(writer); err != nil {
+				return err
+			}
+			bufView.Release()
+		}
+	})
+
+	return &s, nil
+}
+
+// DialContext dials the provided address. Currently only TCP is supported.
+func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+	if network != "tcp" && network != "tcp4" {
+		return nil, errors.New("only tcp is supported")
+	}
+
+	addr, err := net.ResolveTCPAddr(network, address)
+	if err != nil {
+		return nil, err
+	}
+
+	fullAddr := tcpip.FullAddress{
+		NIC:  nicID,
+		Addr: tcpip.Address(addr.IP),
+		Port: uint16(addr.Port),
+	}
+
+	return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber)
+}
+
+// Listen listens on the provided address. Currently only TCP with wildcard
+// addresses are supported.
+func (s *Service) Listen(network, address string) (net.Listener, error) {
+	if network != "tcp" && network != "tcp4" {
+		return nil, errors.New("only tcp is supported")
+	}
+	addr, err := net.ResolveTCPAddr(network, address)
+	if err != nil {
+		return nil, err
+	}
+	if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) {
+		return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP)
+	}
+	if addr.Port == 0 {
+		return nil, errors.New("specific port required, got 0")
+	}
+	if addr.Port < 0 || addr.Port >= math.MaxUint16 {
+		return nil, fmt.Errorf("invalid port %d", addr.Port)
+	}
+	port := uint16(addr.Port)
+
+	l := &tcpListener{
+		port:   port,
+		s:      s,
+		addr:   addr,
+		accept: make(chan net.Conn),
+	}
+
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	if _, ok := s.mu.listeners[port]; ok {
+		return nil, fmt.Errorf("already listening on port %d", port)
+	}
+	s.mu.listeners[port] = l
+
+	return l, nil
+}
+
+func (s *Service) Wait() error {
+	return s.eg.Wait()
+}
+
+func (s *Service) Close() error {
+	s.control.Stop()
+	return nil
+}
+
+func (s *Service) tcpHandler(r *tcp.ForwarderRequest) {
+	endpointID := r.ID()
+
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	l, ok := s.mu.listeners[endpointID.LocalPort]
+	if !ok {
+		r.Complete(true)
+		return
+	}
+
+	var wq waiter.Queue
+	ep, err := r.CreateEndpoint(&wq)
+	if err != nil {
+		log.Printf("got error creating endpoint %q", err)
+		r.Complete(true)
+		return
+	}
+	r.Complete(false)
+	ep.SocketOptions().SetKeepAlive(true)
+
+	conn := gonet.NewTCPConn(&wq, ep)
+	l.accept <- conn
+}

+ 165 - 0
service/service_test.go

@@ -0,0 +1,165 @@
+package service
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"net"
+	"testing"
+	"time"
+
+	"dario.cat/mergo"
+	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/e2e"
+	"golang.org/x/sync/errgroup"
+	"gopkg.in/yaml.v2"
+)
+
+type m map[string]interface{}
+
+func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service {
+
+	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
+	copy(vpnIpNet.IP, udpIp)
+
+	_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+	caB, err := caCrt.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	mc := m{
+		"pki": m{
+			"ca":   string(caB),
+			"cert": string(myPEM),
+			"key":  string(myPrivKey),
+		},
+		//"tun": m{"disabled": true},
+		"firewall": m{
+			"outbound": []m{{
+				"proto": "any",
+				"port":  "any",
+				"host":  "any",
+			}},
+			"inbound": []m{{
+				"proto": "any",
+				"port":  "any",
+				"host":  "any",
+			}},
+		},
+		"timers": m{
+			"pending_deletion_interval": 2,
+			"connection_alive_interval": 2,
+		},
+		"handshakes": m{
+			"try_interval": "200ms",
+		},
+	}
+
+	if overrides != nil {
+		err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
+		if err != nil {
+			panic(err)
+		}
+		mc = overrides
+	}
+
+	cb, err := yaml.Marshal(mc)
+	if err != nil {
+		panic(err)
+	}
+
+	var c config.C
+	if err := c.LoadString(string(cb)); err != nil {
+		panic(err)
+	}
+
+	s, err := New(&c)
+	if err != nil {
+		panic(err)
+	}
+	return s
+}
+
+func TestService(t *testing.T) {
+	ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{
+		"static_host_map": m{},
+		"lighthouse": m{
+			"am_lighthouse": true,
+		},
+		"listen": m{
+			"host": "0.0.0.0",
+			"port": 4243,
+		},
+	})
+	b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{
+		"static_host_map": m{
+			"10.0.0.1": []string{"localhost:4243"},
+		},
+		"lighthouse": m{
+			"hosts":    []string{"10.0.0.1"},
+			"interval": 1,
+		},
+	})
+
+	ln, err := a.Listen("tcp", ":1234")
+	if err != nil {
+		t.Fatal(err)
+	}
+	var eg errgroup.Group
+	eg.Go(func() error {
+		conn, err := ln.Accept()
+		if err != nil {
+			return err
+		}
+		defer conn.Close()
+
+		t.Log("accepted connection")
+
+		if _, err := conn.Write([]byte("server msg")); err != nil {
+			return err
+		}
+
+		t.Log("server: wrote message")
+
+		data := make([]byte, 100)
+		n, err := conn.Read(data)
+		if err != nil {
+			return err
+		}
+		data = data[:n]
+		if !bytes.Equal(data, []byte("client msg")) {
+			return errors.New("got invalid message from client")
+		}
+		t.Log("server: read message")
+		return conn.Close()
+	})
+
+	c, err := b.DialContext(context.Background(), "tcp", "10.0.0.1:1234")
+	if err != nil {
+		t.Fatal(err)
+	}
+	if _, err := c.Write([]byte("client msg")); err != nil {
+		t.Fatal(err)
+	}
+
+	data := make([]byte, 100)
+	n, err := c.Read(data)
+	if err != nil {
+		t.Fatal(err)
+	}
+	data = data[:n]
+	if !bytes.Equal(data, []byte("server msg")) {
+		t.Fatal("got invalid message from client")
+	}
+
+	if err := c.Close(); err != nil {
+		t.Fatal(err)
+	}
+
+	if err := eg.Wait(); err != nil {
+		t.Fatal(err)
+	}
+}

+ 2 - 4
ssh.go

@@ -6,7 +6,6 @@ import (
 	"errors"
 	"flag"
 	"fmt"
-	"io/ioutil"
 	"net"
 	"os"
 	"reflect"
@@ -96,7 +95,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
 		return nil, fmt.Errorf("sshd.host_key must be provided")
 	}
 
-	hostKeyBytes, err := ioutil.ReadFile(hostKeyFile)
+	hostKeyBytes, err := os.ReadFile(hostKeyFile)
 	if err != nil {
 		return nil, fmt.Errorf("error while loading sshd.host_key file: %s", err)
 	}
@@ -607,11 +606,10 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		}
 	}
 
-	hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo)
+	hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
 	if addr != nil {
 		hostInfo.SetRemote(addr)
 	}
-	ifce.getOrHandshake(vpnIp)
 
 	return w.WriteLine("Created")
 }

+ 2 - 2
test/logger.go

@@ -1,7 +1,7 @@
 package test
 
 import (
-	"io/ioutil"
+	"io"
 	"os"
 
 	"github.com/sirupsen/logrus"
@@ -12,7 +12,7 @@ func NewLogger() *logrus.Logger {
 
 	v := os.Getenv("TEST_LOGS")
 	if v == "" {
-		l.SetOutput(ioutil.Discard)
+		l.SetOutput(io.Discard)
 		return l
 	}
 

+ 7 - 2
udp/udp_darwin.go

@@ -43,10 +43,15 @@ func NewListenConfig(multi bool) net.ListenConfig {
 }
 
 func (u *GenericConn) Rebind() error {
-	file, err := u.File()
+	rc, err := u.UDPConn.SyscallConn()
 	if err != nil {
 		return err
 	}
 
-	return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
+	return rc.Control(func(fd uintptr) {
+		err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
+		if err != nil {
+			u.l.WithError(err).Error("Failed to rebind udp socket")
+		}
+	})
 }