Browse Source

Enable running testifylint in CI (#1350)

Caleb Jasik 4 months ago
parent
commit
088af8edb2

+ 10 - 0
.github/workflows/test.yml

@@ -31,6 +31,11 @@ jobs:
     - name: Vet
       run: make vet
 
+    - name: golangci-lint
+      uses: golangci/golangci-lint-action@v6
+      with:
+        version: v1.64
+
     - name: Test
       run: make test
 
@@ -109,6 +114,11 @@ jobs:
     - name: Vet
       run: make vet
 
+    - name: golangci-lint
+      uses: golangci/golangci-lint-action@v6
+      with:
+        version: v1.64
+
     - name: Test
       run: make test
 

+ 9 - 0
.golangci.yaml

@@ -0,0 +1,9 @@
+# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json
+linters:
+  # Disable all linters.
+  # Default: false
+  disable-all: true
+  # Enable specific linter
+  # https://golangci-lint.run/usage/linters/#enabled-by-default
+  enable:
+    - testifylint

+ 7 - 6
allow_list_test.go

@@ -9,6 +9,7 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestNewAllowListFromConfig(t *testing.T) {
@@ -18,21 +19,21 @@ func TestNewAllowListFromConfig(t *testing.T) {
 		"192.168.0.0": true,
 	}
 	r, err := newAllowListFromConfig(c, "allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
+	require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
 	assert.Nil(t, r)
 
 	c.Settings["allowlist"] = map[interface{}]interface{}{
 		"192.168.0.0/16": "abc",
 	}
 	r, err = newAllowListFromConfig(c, "allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
+	require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
 
 	c.Settings["allowlist"] = map[interface{}]interface{}{
 		"192.168.0.0/16": true,
 		"10.0.0.0/8":     false,
 	}
 	r, err = newAllowListFromConfig(c, "allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
+	require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
 
 	c.Settings["allowlist"] = map[interface{}]interface{}{
 		"0.0.0.0/0":      true,
@@ -42,7 +43,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
 		"fd00:fd00::/16": false,
 	}
 	r, err = newAllowListFromConfig(c, "allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
+	require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
 
 	c.Settings["allowlist"] = map[interface{}]interface{}{
 		"0.0.0.0/0":     true,
@@ -75,7 +76,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
 		},
 	}
 	lr, err := NewLocalAllowListFromConfig(c, "allowlist")
-	assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
+	require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
 
 	c.Settings["allowlist"] = map[interface{}]interface{}{
 		"interfaces": map[interface{}]interface{}{
@@ -84,7 +85,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
 		},
 	}
 	lr, err = NewLocalAllowListFromConfig(c, "allowlist")
-	assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
+	require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
 
 	c.Settings["allowlist"] = map[interface{}]interface{}{
 		"interfaces": map[interface{}]interface{}{

+ 8 - 8
calculated_remote_test.go

@@ -15,10 +15,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
 	require.NoError(t, err)
 
 	input, err := netip.ParseAddr("10.0.10.182")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	expected, err := netip.ParseAddr("192.168.1.182")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
 
@@ -28,10 +28,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
 	require.NoError(t, err)
 
 	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
 
@@ -41,10 +41,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
 	require.NoError(t, err)
 
 	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
 
@@ -54,10 +54,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
 	require.NoError(t, err)
 
 	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
 }

+ 76 - 75
cert/ca_pool_test.go

@@ -6,6 +6,7 @@ import (
 	"time"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestNewCAPoolFromBytes(t *testing.T) {
@@ -82,12 +83,12 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
 	}
 
 	p, err := NewCAPoolFromPEM([]byte(noNewLines))
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
 	assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
 
 	pp, err := NewCAPoolFromPEM([]byte(withNewLines))
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
 	assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
 
@@ -105,7 +106,7 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
 	assert.Len(t, pppp.CAs, 3)
 
 	ppppp, err := NewCAPoolFromPEM([]byte(p256))
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
 	assert.Len(t, ppppp.CAs, 1)
 }
@@ -115,21 +116,21 @@ func TestCertificateV1_Verify(t *testing.T) {
 	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
 
 	caPool := NewCAPool()
-	assert.NoError(t, caPool.AddCA(ca))
+	require.NoError(t, caPool.AddCA(ca))
 
 	f, err := c.Fingerprint()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	caPool.BlocklistFingerprint(f)
 
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.EqualError(t, err, "certificate is in the block list")
+	require.EqualError(t, err, "certificate is in the block list")
 
 	caPool.ResetCertBlocklist()
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
-	assert.EqualError(t, err, "root certificate is expired")
+	require.EqualError(t, err, "root certificate is expired")
 
 	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
 		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
@@ -138,11 +139,11 @@ func TestCertificateV1_Verify(t *testing.T) {
 	// Test group assertion
 	ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
 	caPem, err := ca.MarshalPEM()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	caPool = NewCAPool()
 	b, err := caPool.AddCAFromPEM(caPem)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, b)
 
 	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
@@ -150,9 +151,9 @@ func TestCertificateV1_Verify(t *testing.T) {
 	})
 
 	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 }
 
 func TestCertificateV1_VerifyP256(t *testing.T) {
@@ -160,21 +161,21 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
 	c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
 
 	caPool := NewCAPool()
-	assert.NoError(t, caPool.AddCA(ca))
+	require.NoError(t, caPool.AddCA(ca))
 
 	f, err := c.Fingerprint()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	caPool.BlocklistFingerprint(f)
 
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.EqualError(t, err, "certificate is in the block list")
+	require.EqualError(t, err, "certificate is in the block list")
 
 	caPool.ResetCertBlocklist()
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
-	assert.EqualError(t, err, "root certificate is expired")
+	require.EqualError(t, err, "root certificate is expired")
 
 	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
 		NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
@@ -183,11 +184,11 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
 	// Test group assertion
 	ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
 	caPem, err := ca.MarshalPEM()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	caPool = NewCAPool()
 	b, err := caPool.AddCAFromPEM(caPem)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, b)
 
 	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
@@ -196,7 +197,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
 
 	c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 }
 
 func TestCertificateV1_Verify_IPs(t *testing.T) {
@@ -205,11 +206,11 @@ func TestCertificateV1_Verify_IPs(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
 
 	caPem, err := ca.MarshalPEM()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	caPool := NewCAPool()
 	b, err := caPool.AddCAFromPEM(caPem)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, b)
 
 	// ip is outside the network
@@ -245,25 +246,25 @@ func TestCertificateV1_Verify_IPs(t *testing.T) {
 	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
 	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Exact matches
 	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Exact matches reversed
 	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Exact matches reversed with just 1
 	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 }
 
 func TestCertificateV1_Verify_Subnets(t *testing.T) {
@@ -272,11 +273,11 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
 
 	caPem, err := ca.MarshalPEM()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	caPool := NewCAPool()
 	b, err := caPool.AddCAFromPEM(caPem)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, b)
 
 	// ip is outside the network
@@ -311,27 +312,27 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) {
 	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
 	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
 	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Exact matches
 	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Exact matches reversed
 	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Exact matches reversed with just 1
 	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 }
 
 func TestCertificateV2_Verify(t *testing.T) {
@@ -339,21 +340,21 @@ func TestCertificateV2_Verify(t *testing.T) {
 	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
 
 	caPool := NewCAPool()
-	assert.NoError(t, caPool.AddCA(ca))
+	require.NoError(t, caPool.AddCA(ca))
 
 	f, err := c.Fingerprint()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	caPool.BlocklistFingerprint(f)
 
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.EqualError(t, err, "certificate is in the block list")
+	require.EqualError(t, err, "certificate is in the block list")
 
 	caPool.ResetCertBlocklist()
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
-	assert.EqualError(t, err, "root certificate is expired")
+	require.EqualError(t, err, "root certificate is expired")
 
 	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
 		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
@@ -362,11 +363,11 @@ func TestCertificateV2_Verify(t *testing.T) {
 	// Test group assertion
 	ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
 	caPem, err := ca.MarshalPEM()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	caPool = NewCAPool()
 	b, err := caPool.AddCAFromPEM(caPem)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, b)
 
 	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
@@ -374,9 +375,9 @@ func TestCertificateV2_Verify(t *testing.T) {
 	})
 
 	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 }
 
 func TestCertificateV2_VerifyP256(t *testing.T) {
@@ -384,21 +385,21 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
 	c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
 
 	caPool := NewCAPool()
-	assert.NoError(t, caPool.AddCA(ca))
+	require.NoError(t, caPool.AddCA(ca))
 
 	f, err := c.Fingerprint()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	caPool.BlocklistFingerprint(f)
 
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.EqualError(t, err, "certificate is in the block list")
+	require.EqualError(t, err, "certificate is in the block list")
 
 	caPool.ResetCertBlocklist()
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
-	assert.EqualError(t, err, "root certificate is expired")
+	require.EqualError(t, err, "root certificate is expired")
 
 	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
 		NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
@@ -407,11 +408,11 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
 	// Test group assertion
 	ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
 	caPem, err := ca.MarshalPEM()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	caPool = NewCAPool()
 	b, err := caPool.AddCAFromPEM(caPem)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, b)
 
 	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
@@ -420,7 +421,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
 
 	c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 }
 
 func TestCertificateV2_Verify_IPs(t *testing.T) {
@@ -429,11 +430,11 @@ func TestCertificateV2_Verify_IPs(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
 
 	caPem, err := ca.MarshalPEM()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	caPool := NewCAPool()
 	b, err := caPool.AddCAFromPEM(caPem)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, b)
 
 	// ip is outside the network
@@ -469,25 +470,25 @@ func TestCertificateV2_Verify_IPs(t *testing.T) {
 	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
 	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Exact matches
 	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Exact matches reversed
 	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Exact matches reversed with just 1
 	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 }
 
 func TestCertificateV2_Verify_Subnets(t *testing.T) {
@@ -496,11 +497,11 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
 
 	caPem, err := ca.MarshalPEM()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	caPool := NewCAPool()
 	b, err := caPool.AddCAFromPEM(caPem)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, b)
 
 	// ip is outside the network
@@ -535,25 +536,25 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
 	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
 	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
 	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Exact matches
 	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Exact matches reversed
 	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Exact matches reversed with just 1
 	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 }

+ 17 - 17
cert/cert_v1_test.go

@@ -39,11 +39,11 @@ func TestCertificateV1_Marshal(t *testing.T) {
 	}
 
 	b, err := nc.Marshal()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	//t.Log("Cert size:", len(b))
 
 	nc2, err := unmarshalCertificateV1(b, nil)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	assert.Equal(t, Version1, nc.Version())
 	assert.Equal(t, Curve_CURVE25519, nc.Curve())
@@ -99,7 +99,7 @@ func TestCertificateV1_MarshalJSON(t *testing.T) {
 	}
 
 	b, err := nc.MarshalJSON()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.JSONEq(
 		t,
 		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
@@ -110,47 +110,47 @@ func TestCertificateV1_MarshalJSON(t *testing.T) {
 func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
 	err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
-	assert.Error(t, err)
+	require.Error(t, err)
 
 	c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
 	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, b)
 	assert.Equal(t, Curve_CURVE25519, curve)
 	err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	_, priv2 := X25519Keypair()
 	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
-	assert.Error(t, err)
+	require.Error(t, err)
 }
 
 func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
 	err := ca.VerifyPrivateKey(Curve_P256, caKey)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
-	assert.Error(t, err)
+	require.Error(t, err)
 
 	c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
 	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, b)
 	assert.Equal(t, Curve_P256, curve)
 	err = c.VerifyPrivateKey(Curve_P256, rawPriv)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	_, priv2 := P256Keypair()
 	err = c.VerifyPrivateKey(Curve_P256, priv2)
-	assert.Error(t, err)
+	require.Error(t, err)
 }
 
 // Ensure that upgrading the protobuf library does not change how certificates
@@ -186,7 +186,7 @@ func TestMarshalingCertificateV1Consistency(t *testing.T) {
 	assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
 
 	b, err = proto.Marshal(nc.getRawDetails())
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
 }
 
@@ -201,7 +201,7 @@ func TestUnmarshalCertificateV1(t *testing.T) {
 	// Test that we don't panic with an invalid certificate (#332)
 	data := []byte("\x98\x00\x00")
 	_, err := unmarshalCertificateV1(data, nil)
-	assert.EqualError(t, err, "encoded Details was nil")
+	require.EqualError(t, err, "encoded Details was nil")
 }
 
 func appendByteSlices(b ...[]byte) []byte {

+ 25 - 25
cert/cert_v2_test.go

@@ -49,7 +49,7 @@ func TestCertificateV2_Marshal(t *testing.T) {
 	//t.Log("Cert size:", len(b))
 
 	nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	assert.Equal(t, Version2, nc.Version())
 	assert.Equal(t, Curve_CURVE25519, nc.Curve())
@@ -114,14 +114,14 @@ func TestCertificateV2_MarshalJSON(t *testing.T) {
 	}
 
 	b, err := nc.MarshalJSON()
-	assert.ErrorIs(t, err, ErrMissingDetails)
+	require.ErrorIs(t, err, ErrMissingDetails)
 
 	rd, err := nc.details.Marshal()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	nc.rawDetails = rd
 	b, err = nc.MarshalJSON()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.JSONEq(
 		t,
 		"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
@@ -132,85 +132,85 @@ func TestCertificateV2_MarshalJSON(t *testing.T) {
 func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
 	err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
-	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+	require.ErrorIs(t, err, ErrInvalidPrivateKey)
 
 	_, caKey2, err := ed25519.GenerateKey(rand.Reader)
 	require.NoError(t, err)
 	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
-	assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
+	require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
 
 	c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
 	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, b)
 	assert.Equal(t, Curve_CURVE25519, curve)
 	err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	_, priv2 := X25519Keypair()
 	err = c.VerifyPrivateKey(Curve_P256, priv2)
-	assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
+	require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
 
 	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
-	assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
+	require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
 
 	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
-	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+	require.ErrorIs(t, err, ErrInvalidPrivateKey)
 
 	ac, ok := c.(*certificateV2)
 	require.True(t, ok)
 	ac.curve = Curve(99)
 	err = c.VerifyPrivateKey(Curve(99), priv2)
-	assert.EqualError(t, err, "invalid curve: 99")
+	require.EqualError(t, err, "invalid curve: 99")
 
 	ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
 	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
-	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+	require.ErrorIs(t, err, ErrInvalidPrivateKey)
 
 	c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
 	rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
 
 	err = c.VerifyPrivateKey(Curve_P256, priv[:16])
-	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+	require.ErrorIs(t, err, ErrInvalidPrivateKey)
 
 	err = c.VerifyPrivateKey(Curve_P256, priv)
-	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+	require.ErrorIs(t, err, ErrInvalidPrivateKey)
 
 	aCa, ok := ca2.(*certificateV2)
 	require.True(t, ok)
 	aCa.curve = Curve(99)
 	err = aCa.VerifyPrivateKey(Curve(99), priv2)
-	assert.EqualError(t, err, "invalid curve: 99")
+	require.EqualError(t, err, "invalid curve: 99")
 
 }
 
 func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
 	err := ca.VerifyPrivateKey(Curve_P256, caKey)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
-	assert.Error(t, err)
+	require.Error(t, err)
 
 	c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
 	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, b)
 	assert.Equal(t, Curve_P256, curve)
 	err = c.VerifyPrivateKey(Curve_P256, rawPriv)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	_, priv2 := P256Keypair()
 	err = c.VerifyPrivateKey(Curve_P256, priv2)
-	assert.Error(t, err)
+	require.Error(t, err)
 }
 
 func TestCertificateV2_Copy(t *testing.T) {
@@ -223,7 +223,7 @@ func TestCertificateV2_Copy(t *testing.T) {
 func TestUnmarshalCertificateV2(t *testing.T) {
 	data := []byte("\x98\x00\x00")
 	_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
-	assert.EqualError(t, err, "bad wire format")
+	require.EqualError(t, err, "bad wire format")
 }
 
 func TestCertificateV2_marshalForSigningStability(t *testing.T) {

+ 8 - 7
cert/crypto_test.go

@@ -4,6 +4,7 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"golang.org/x/crypto/argon2"
 )
 
@@ -61,33 +62,33 @@ qrlJ69wer3ZUHFXA
 
 	// Success test case
 	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, Curve_CURVE25519, curve)
 	assert.Len(t, k, 64)
 	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
 
 	// Fail due to short key
 	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
-	assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
+	require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
 	assert.Nil(t, k)
 	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
 
 	// Fail due to invalid banner
 	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
+	require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
 
 	// Fail due to ivalid PEM format, because
 	// it's missing the requisite pre-encapsulation boundary.
 	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+	require.EqualError(t, err, "input did not contain a valid PEM encoded block")
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
 
 	// Fail due to invalid passphrase
 	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
-	assert.EqualError(t, err, "invalid passphrase or corrupt private key")
+	require.EqualError(t, err, "invalid passphrase or corrupt private key")
 	assert.Nil(t, k)
 	assert.Equal(t, []byte{}, rest)
 }
@@ -99,14 +100,14 @@ func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
 	bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
 	kdfParams := NewArgon2Parameters(64*1024, 4, 3)
 	key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Verify the "key" can be decrypted successfully
 	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
 	assert.Len(t, k, 64)
 	assert.Equal(t, Curve_CURVE25519, curve)
 	assert.Equal(t, []byte{}, rest)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
 }

+ 23 - 22
cert/pem_test.go

@@ -4,6 +4,7 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestUnmarshalCertificateFromPEM(t *testing.T) {
@@ -35,20 +36,20 @@ bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
 	cert, rest, err := UnmarshalCertificateFromPEM(certBundle)
 	assert.NotNil(t, cert)
 	assert.Equal(t, rest, append(badBanner, invalidPem...))
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Fail due to invalid banner.
 	cert, rest, err = UnmarshalCertificateFromPEM(rest)
 	assert.Nil(t, cert)
 	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper certificate banner")
+	require.EqualError(t, err, "bytes did not contain a proper certificate banner")
 
 	// Fail due to ivalid PEM format, because
 	// it's missing the requisite pre-encapsulation boundary.
 	cert, rest, err = UnmarshalCertificateFromPEM(rest)
 	assert.Nil(t, cert)
 	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+	require.EqualError(t, err, "input did not contain a valid PEM encoded block")
 }
 
 func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) {
@@ -84,33 +85,33 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
 	assert.Len(t, k, 64)
 	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
 	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Success test case
 	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
 	assert.Len(t, k, 32)
 	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
 	assert.Equal(t, Curve_P256, curve)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Fail due to short key
 	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
+	require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
 
 	// Fail due to invalid banner
 	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
+	require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
 
 	// Fail due to ivalid PEM format, because
 	// it's missing the requisite pre-encapsulation boundary.
 	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+	require.EqualError(t, err, "input did not contain a valid PEM encoded block")
 }
 
 func TestUnmarshalPrivateKeyFromPEM(t *testing.T) {
@@ -146,33 +147,33 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 	assert.Len(t, k, 32)
 	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
 	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Success test case
 	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
 	assert.Len(t, k, 32)
 	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
 	assert.Equal(t, Curve_P256, curve)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// Fail due to short key
 	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
+	require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
 
 	// Fail due to invalid banner
 	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper private key banner")
+	require.EqualError(t, err, "bytes did not contain a proper private key banner")
 
 	// Fail due to ivalid PEM format, because
 	// it's missing the requisite pre-encapsulation boundary.
 	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+	require.EqualError(t, err, "input did not contain a valid PEM encoded block")
 }
 
 func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
@@ -202,7 +203,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 	k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
 	assert.Len(t, k, 32)
 	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
 
 	// Fail due to short key
@@ -210,13 +211,13 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 	assert.Nil(t, k)
 	assert.Equal(t, Curve_CURVE25519, curve)
 	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
+	require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
 
 	// Fail due to invalid banner
 	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.EqualError(t, err, "bytes did not contain a proper public key banner")
+	require.EqualError(t, err, "bytes did not contain a proper public key banner")
 	assert.Equal(t, rest, invalidPem)
 
 	// Fail due to ivalid PEM format, because
@@ -225,7 +226,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 	assert.Nil(t, k)
 	assert.Equal(t, Curve_CURVE25519, curve)
 	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+	require.EqualError(t, err, "input did not contain a valid PEM encoded block")
 }
 
 func TestUnmarshalX25519PublicKey(t *testing.T) {
@@ -260,14 +261,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 	// Success test case
 	k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
 	assert.Len(t, k, 32)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
 	assert.Equal(t, Curve_CURVE25519, curve)
 
 	// Success test case
 	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
 	assert.Len(t, k, 65)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
 	assert.Equal(t, Curve_P256, curve)
 
@@ -275,12 +276,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
+	require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
 
 	// Fail due to invalid banner
 	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
 	assert.Nil(t, k)
-	assert.EqualError(t, err, "bytes did not contain a proper public key banner")
+	require.EqualError(t, err, "bytes did not contain a proper public key banner")
 	assert.Equal(t, rest, invalidPem)
 
 	// Fail due to ivalid PEM format, because
@@ -288,5 +289,5 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
 	assert.Nil(t, k)
 	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+	require.EqualError(t, err, "input did not contain a valid PEM encoded block")
 }

+ 8 - 7
cert/sign_test.go

@@ -10,6 +10,7 @@ import (
 	"time"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestCertificateV1_Sign(t *testing.T) {
@@ -37,14 +38,14 @@ func TestCertificateV1_Sign(t *testing.T) {
 
 	pub, priv, err := ed25519.GenerateKey(rand.Reader)
 	c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.NotNil(t, c)
 	assert.True(t, c.CheckSignature(pub))
 
 	b, err := c.Marshal()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	uc, err := unmarshalCertificateV1(b, nil)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.NotNil(t, uc)
 }
 
@@ -73,18 +74,18 @@ func TestCertificateV1_SignP256(t *testing.T) {
 	}
 
 	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
 	rawPriv := priv.D.FillBytes(make([]byte, 32))
 
 	c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.NotNil(t, c)
 	assert.True(t, c.CheckSignature(pub))
 
 	b, err := c.Marshal()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	uc, err := unmarshalCertificateV1(b, nil)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.NotNil(t, uc)
 }

+ 19 - 18
cmd/nebula-cert/ca_test.go

@@ -14,6 +14,7 @@ import (
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func Test_caSummary(t *testing.T) {
@@ -106,34 +107,34 @@ func Test_ca(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
-	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
+	require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
 	// create temp key file
 	keyF, err := os.CreateTemp("", "test.key")
-	assert.NoError(t, err)
-	assert.NoError(t, os.Remove(keyF.Name()))
+	require.NoError(t, err)
+	require.NoError(t, os.Remove(keyF.Name()))
 
 	// failed cert write
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
-	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
+	require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
 	// create temp cert file
 	crtF, err := os.CreateTemp("", "test.crt")
-	assert.NoError(t, err)
-	assert.NoError(t, os.Remove(crtF.Name()))
-	assert.NoError(t, os.Remove(keyF.Name()))
+	require.NoError(t, err)
+	require.NoError(t, os.Remove(crtF.Name()))
+	require.NoError(t, os.Remove(keyF.Name()))
 
 	// test proper cert with removed empty groups and subnets
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.NoError(t, ca(args, ob, eb, nopw))
+	require.NoError(t, ca(args, ob, eb, nopw))
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
@@ -142,13 +143,13 @@ func Test_ca(t *testing.T) {
 	lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb)
 	assert.Equal(t, cert.Curve_CURVE25519, c)
 	assert.Empty(t, b)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Len(t, lKey, 64)
 
 	rb, _ = os.ReadFile(crtF.Name())
 	lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
 	assert.Empty(t, b)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	assert.Equal(t, "test", lCrt.Name())
 	assert.Empty(t, lCrt.Networks())
@@ -166,7 +167,7 @@ func Test_ca(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.NoError(t, ca(args, ob, eb, testpw))
+	require.NoError(t, ca(args, ob, eb, testpw))
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, "", eb.String())
 
@@ -174,7 +175,7 @@ func Test_ca(t *testing.T) {
 	rb, _ = os.ReadFile(keyF.Name())
 	k, _ := pem.Decode(rb)
 	ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	// we won't know salt in advance, so just check start of string
 	assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory)
 	assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism)
@@ -184,7 +185,7 @@ func Test_ca(t *testing.T) {
 	var curve cert.Curve
 	curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb)
 	assert.Equal(t, cert.Curve_CURVE25519, curve)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, b)
 	assert.Len(t, lKey, 64)
 
@@ -194,7 +195,7 @@ func Test_ca(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.Error(t, ca(args, ob, eb, errpw))
+	require.Error(t, ca(args, ob, eb, errpw))
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, "", eb.String())
 
@@ -204,7 +205,7 @@ func Test_ca(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
+	require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
 	assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
 	assert.Equal(t, "", eb.String())
 
@@ -214,13 +215,13 @@ func Test_ca(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.NoError(t, ca(args, ob, eb, nopw))
+	require.NoError(t, ca(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing certificate file
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
+	require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
@@ -229,7 +230,7 @@ func Test_ca(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
+	require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	os.Remove(keyF.Name())

+ 8 - 7
cmd/nebula-cert/keygen_test.go

@@ -7,6 +7,7 @@ import (
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func Test_keygenSummary(t *testing.T) {
@@ -47,33 +48,33 @@ func Test_keygen(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
-	assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
+	require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
 	// create temp key file
 	keyF, err := os.CreateTemp("", "test.key")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	defer os.Remove(keyF.Name())
 
 	// failed pub write
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
-	assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
+	require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
 	// create temp pub file
 	pubF, err := os.CreateTemp("", "test.pub")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	defer os.Remove(pubF.Name())
 
 	// test proper keygen
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
-	assert.NoError(t, keygen(args, ob, eb))
+	require.NoError(t, keygen(args, ob, eb))
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
@@ -82,13 +83,13 @@ func Test_keygen(t *testing.T) {
 	lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
 	assert.Equal(t, cert.Curve_CURVE25519, curve)
 	assert.Empty(t, b)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Len(t, lKey, 32)
 
 	rb, _ = os.ReadFile(pubF.Name())
 	lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb)
 	assert.Equal(t, cert.Curve_CURVE25519, curve)
 	assert.Empty(t, b)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Len(t, lPub, 32)
 }

+ 2 - 1
cmd/nebula-cert/main_test.go

@@ -9,6 +9,7 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func Test_help(t *testing.T) {
@@ -79,7 +80,7 @@ func assertHelpError(t *testing.T, err error, msg string) {
 		t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
 	}
 
-	assert.EqualError(t, err, msg)
+	require.EqualError(t, err, msg)
 }
 
 func optionalPkcs11String(msg string) string {

+ 6 - 5
cmd/nebula-cert/print_test.go

@@ -12,6 +12,7 @@ import (
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func Test_printSummary(t *testing.T) {
@@ -52,20 +53,20 @@ func Test_printCert(t *testing.T) {
 	err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
+	require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
 
 	// invalid cert at path
 	ob.Reset()
 	eb.Reset()
 	tf, err := os.CreateTemp("", "print-cert")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	defer os.Remove(tf.Name())
 
 	tf.WriteString("-----BEGIN NOPE-----")
 	err = printCert([]string{"-path", tf.Name()}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
+	require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
 
 	// test multiple certs
 	ob.Reset()
@@ -84,7 +85,7 @@ func Test_printCert(t *testing.T) {
 	fp, _ := c.Fingerprint()
 	pk := hex.EncodeToString(c.PublicKey())
 	sig := hex.EncodeToString(c.Signature())
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(
 		t,
 		//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
@@ -169,7 +170,7 @@ func Test_printCert(t *testing.T) {
 	fp, _ = c.Fingerprint()
 	pk = hex.EncodeToString(c.PublicKey())
 	sig = hex.EncodeToString(c.Signature())
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(
 		t,
 		`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]

+ 32 - 31
cmd/nebula-cert/sign_test.go

@@ -13,6 +13,7 @@ import (
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"golang.org/x/crypto/ed25519"
 )
 
@@ -103,17 +104,17 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
 
 	// failed to unmarshal key
 	ob.Reset()
 	eb.Reset()
 	caKeyF, err := os.CreateTemp("", "sign-cert.key")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	defer os.Remove(caKeyF.Name())
 
 	args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -125,7 +126,7 @@ func Test_signCert(t *testing.T) {
 
 	// failed to read cert
 	args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -133,11 +134,11 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	caCrtF, err := os.CreateTemp("", "sign-cert.crt")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	defer os.Remove(caCrtF.Name())
 
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -148,7 +149,7 @@ func Test_signCert(t *testing.T) {
 
 	// failed to read pub
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -156,11 +157,11 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	inPubF, err := os.CreateTemp("", "in.pub")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	defer os.Remove(inPubF.Name())
 
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -210,14 +211,14 @@ func Test_signCert(t *testing.T) {
 	// mismatched ca key
 	_, caPriv2, _ := ed25519.GenerateKey(rand.Reader)
 	caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	defer os.Remove(caKeyF2.Name())
 	caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2))
 
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
+	require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -225,34 +226,34 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// create temp key file
 	keyF, err := os.CreateTemp("", "test.key")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	os.Remove(keyF.Name())
 
 	// failed cert write
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	os.Remove(keyF.Name())
 
 	// create temp cert file
 	crtF, err := os.CreateTemp("", "test.crt")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	os.Remove(crtF.Name())
 
 	// test proper cert with removed empty groups and subnets
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.NoError(t, signCert(args, ob, eb, nopw))
+	require.NoError(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -261,13 +262,13 @@ func Test_signCert(t *testing.T) {
 	lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
 	assert.Equal(t, cert.Curve_CURVE25519, curve)
 	assert.Empty(t, b)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Len(t, lKey, 32)
 
 	rb, _ = os.ReadFile(crtF.Name())
 	lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
 	assert.Empty(t, b)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	assert.Equal(t, "test", lCrt.Name())
 	assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String())
@@ -295,7 +296,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
-	assert.NoError(t, signCert(args, ob, eb, nopw))
+	require.NoError(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -303,7 +304,7 @@ func Test_signCert(t *testing.T) {
 	rb, _ = os.ReadFile(crtF.Name())
 	lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb)
 	assert.Empty(t, b)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, lCrt.PublicKey(), inPub)
 
 	// test refuse to sign cert with duration beyond root
@@ -312,7 +313,7 @@ func Test_signCert(t *testing.T) {
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -320,14 +321,14 @@ func Test_signCert(t *testing.T) {
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.NoError(t, signCert(args, ob, eb, nopw))
+	require.NoError(t, signCert(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing key file
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
+	require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -335,14 +336,14 @@ func Test_signCert(t *testing.T) {
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.NoError(t, signCert(args, ob, eb, nopw))
+	require.NoError(t, signCert(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing certificate file
 	os.Remove(keyF.Name())
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
+	require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -355,11 +356,11 @@ func Test_signCert(t *testing.T) {
 	eb.Reset()
 
 	caKeyF, err = os.CreateTemp("", "sign-cert.key")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	defer os.Remove(caKeyF.Name())
 
 	caCrtF, err = os.CreateTemp("", "sign-cert.crt")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	defer os.Remove(caCrtF.Name())
 
 	// generate the encrypted key
@@ -374,7 +375,7 @@ func Test_signCert(t *testing.T) {
 
 	// test with the proper password
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.NoError(t, signCert(args, ob, eb, testpw))
+	require.NoError(t, signCert(args, ob, eb, testpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 
@@ -384,7 +385,7 @@ func Test_signCert(t *testing.T) {
 
 	testpw.password = []byte("invalid password")
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.Error(t, signCert(args, ob, eb, testpw))
+	require.Error(t, signCert(args, ob, eb, testpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 
@@ -393,7 +394,7 @@ func Test_signCert(t *testing.T) {
 	eb.Reset()
 
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.Error(t, signCert(args, ob, eb, nopw))
+	require.Error(t, signCert(args, ob, eb, nopw))
 	// normally the user hitting enter on the prompt would add newlines between these
 	assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
@@ -403,7 +404,7 @@ func Test_signCert(t *testing.T) {
 	eb.Reset()
 
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.Error(t, signCert(args, ob, eb, errpw))
+	require.Error(t, signCert(args, ob, eb, errpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 }

+ 9 - 8
cmd/nebula-cert/verify_test.go

@@ -9,6 +9,7 @@ import (
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"golang.org/x/crypto/ed25519"
 )
 
@@ -50,20 +51,20 @@ func Test_verify(t *testing.T) {
 	err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
+	require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
 
 	// invalid ca at path
 	ob.Reset()
 	eb.Reset()
 	caFile, err := os.CreateTemp("", "verify-ca")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	defer os.Remove(caFile.Name())
 
 	caFile.WriteString("-----BEGIN NOPE-----")
 	err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
+	require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
 
 	// make a ca for later
 	caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
@@ -77,20 +78,20 @@ func Test_verify(t *testing.T) {
 	err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
+	require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
 
 	// invalid crt at path
 	ob.Reset()
 	eb.Reset()
 	certFile, err := os.CreateTemp("", "verify-cert")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	defer os.Remove(certFile.Name())
 
 	certFile.WriteString("-----BEGIN NOPE-----")
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
+	require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
 
 	// unverifiable cert at path
 	crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
@@ -107,7 +108,7 @@ func Test_verify(t *testing.T) {
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.ErrorIs(t, err, cert.ErrSignatureMismatch)
+	require.ErrorIs(t, err, cert.ErrSignatureMismatch)
 
 	// verified cert at path
 	crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
@@ -119,5 +120,5 @@ func Test_verify(t *testing.T) {
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.NoError(t, err)
+	require.NoError(t, err)
 }

+ 5 - 5
config/config_test.go

@@ -19,18 +19,18 @@ func TestConfig_Load(t *testing.T) {
 	// invalid yaml
 	c := NewC(l)
 	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
-	assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n  line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
+	require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n  line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
 
 	// simple multi config merge
 	c = NewC(l)
 	os.RemoveAll(dir)
 	os.Mkdir(dir, 0755)
 
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
 	os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n  inner: override\nnew: hi"), 0644)
-	assert.NoError(t, c.Load(dir))
+	require.NoError(t, c.Load(dir))
 	expected := map[interface{}]interface{}{
 		"outer": map[interface{}]interface{}{
 			"inner": "override",
@@ -117,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) {
 	l := test.NewLogger()
 	done := make(chan bool, 1)
 	dir, err := os.MkdirTemp("", "config-test")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
 
 	c := NewC(l)
-	assert.NoError(t, c.Load(dir))
+	require.NoError(t, c.Load(dir))
 
 	assert.False(t, c.HasChanged("outer.inner"))
 	assert.False(t, c.HasChanged("outer"))

+ 4 - 3
connection_manager_test.go

@@ -14,6 +14,7 @@ import (
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func newTestLighthouse() *LightHouse {
@@ -223,9 +224,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	}
 
 	caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	ncp := cert.NewCAPool()
-	assert.NoError(t, ncp.AddCA(caCert))
+	require.NoError(t, ncp.AddCA(caCert))
 
 	pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
 	tbs = &cert.TBSCertificate{
@@ -237,7 +238,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 		PublicKey: pubCrt,
 	}
 	peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
 

+ 11 - 10
e2e/handshakes_test.go

@@ -19,6 +19,7 @@ import (
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"gopkg.in/yaml.v2"
 )
 
@@ -771,7 +772,7 @@ func TestRehandshakingRelays(t *testing.T) {
 		"key":  string(myNextPrivKey),
 	}
 	rc, err := yaml.Marshal(relayConfig.Settings)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	relayConfig.ReloadConfigString(string(rc))
 
 	for {
@@ -875,7 +876,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 		"key":  string(myNextPrivKey),
 	}
 	rc, err := yaml.Marshal(relayConfig.Settings)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	relayConfig.ReloadConfigString(string(rc))
 
 	for {
@@ -970,7 +971,7 @@ func TestRehandshaking(t *testing.T) {
 		"key":  string(myNextPrivKey),
 	}
 	rc, err := yaml.Marshal(myConfig.Settings)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	myConfig.ReloadConfigString(string(rc))
 
 	for {
@@ -987,9 +988,9 @@ func TestRehandshaking(t *testing.T) {
 	r.Log("Got the new cert")
 	// Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly
 	rc, err = yaml.Marshal(theirConfig.Settings)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	var theirNewConfig m
-	assert.NoError(t, yaml.Unmarshal(rc, &theirNewConfig))
+	require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig))
 	theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{})
 	theirFirewall["inbound"] = []m{{
 		"proto": "any",
@@ -997,7 +998,7 @@ func TestRehandshaking(t *testing.T) {
 		"group": "new group",
 	}}
 	rc, err = yaml.Marshal(theirNewConfig)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	theirConfig.ReloadConfigString(string(rc))
 
 	r.Log("Spin until there is only 1 tunnel")
@@ -1067,7 +1068,7 @@ func TestRehandshakingLoser(t *testing.T) {
 		"key":  string(theirNextPrivKey),
 	}
 	rc, err := yaml.Marshal(theirConfig.Settings)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	theirConfig.ReloadConfigString(string(rc))
 
 	for {
@@ -1083,9 +1084,9 @@ func TestRehandshakingLoser(t *testing.T) {
 
 	// Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly
 	rc, err = yaml.Marshal(myConfig.Settings)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	var myNewConfig m
-	assert.NoError(t, yaml.Unmarshal(rc, &myNewConfig))
+	require.NoError(t, yaml.Unmarshal(rc, &myNewConfig))
 	theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{})
 	theirFirewall["inbound"] = []m{{
 		"proto": "any",
@@ -1093,7 +1094,7 @@ func TestRehandshakingLoser(t *testing.T) {
 		"group": "their new group",
 	}}
 	rc, err = yaml.Marshal(myNewConfig)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	myConfig.ReloadConfigString(string(rc))
 
 	r.Log("Spin until there is only 1 tunnel")

+ 75 - 75
firewall_test.go

@@ -66,61 +66,61 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.NotNil(t, fw.OutRules)
 
 	ti, err := netip.ParsePrefix("1.2.3.4/32")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	// An empty rule is any
 	assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.InRules.UDP[1].Any.Any)
 	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
 	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
 	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
 	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
 	_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
 	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
 	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
 	_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
 	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
 	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	anyIp, err := netip.ParsePrefix("0.0.0.0/0")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
-	assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 
 	// Test error conditions
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
-	assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 }
 
 func TestFirewall_Drop(t *testing.T) {
@@ -155,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) {
 	h.buildNetworks(c.networks, c.unsafeNetworks)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	cp := cert.NewCAPool()
 
 	// Drop outbound
 	assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
 	// Allow inbound
 	resetConntrack(fw)
-	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
 	// Allow outbound because conntrack
-	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
+	require.NoError(t, fw.Drop(p, false, &h, cp, nil))
 
 	// test remote mismatch
 	oldRemote := p.RemoteAddr
@@ -174,29 +174,29 @@ func TestFirewall_Drop(t *testing.T) {
 
 	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
 	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caSha doesn't drop on match
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
-	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
 
 	// ensure ca name doesn't get in the way of group checks
 	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
 	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caName doesn't drop on match
 	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
-	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
 }
 
 func BenchmarkFirewallTable_match(b *testing.B) {
@@ -350,14 +350,14 @@ func TestFirewall_Drop2(t *testing.T) {
 	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	cp := cert.NewCAPool()
 
 	// h1/c1 lacks the proper groups
-	assert.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
+	require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
 	// c has the proper groups
 	resetConntrack(fw)
-	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
 }
 
 func TestFirewall_Drop3(t *testing.T) {
@@ -428,23 +428,23 @@ func TestFirewall_Drop3(t *testing.T) {
 	h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
 	cp := cert.NewCAPool()
 
 	// c1 should pass because host match
-	assert.NoError(t, fw.Drop(p, true, &h1, cp, nil))
+	require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
 	// c2 should pass because ca sha match
 	resetConntrack(fw)
-	assert.NoError(t, fw.Drop(p, true, &h2, cp, nil))
+	require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
 	// c3 should fail because no match
 	resetConntrack(fw)
 	assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
 
 	// Test a remote address match
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
-	assert.NoError(t, fw.Drop(p, true, &h1, cp, nil))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
+	require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
 }
 
 func TestFirewall_DropConntrackReload(t *testing.T) {
@@ -480,29 +480,29 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	cp := cert.NewCAPool()
 
 	// Drop outbound
 	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
 	// Allow inbound
 	resetConntrack(fw)
-	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
 	// Allow outbound because conntrack
-	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
+	require.NoError(t, fw.Drop(p, false, &h, cp, nil))
 
 	oldFw := fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
 	// Allow outbound because conntrack and new rules allow port 10
-	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
+	require.NoError(t, fw.Drop(p, false, &h, cp, nil))
 
 	oldFw = fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
-	assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
@@ -585,42 +585,42 @@ func BenchmarkLookup(b *testing.B) {
 
 func Test_parsePort(t *testing.T) {
 	_, _, err := parsePort("")
-	assert.EqualError(t, err, "was not a number; ``")
+	require.EqualError(t, err, "was not a number; ``")
 
 	_, _, err = parsePort("  ")
-	assert.EqualError(t, err, "was not a number; `  `")
+	require.EqualError(t, err, "was not a number; `  `")
 
 	_, _, err = parsePort("-")
-	assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
+	require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
 
 	_, _, err = parsePort(" - ")
-	assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
+	require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
 
 	_, _, err = parsePort("a-b")
-	assert.EqualError(t, err, "beginning range was not a number; `a`")
+	require.EqualError(t, err, "beginning range was not a number; `a`")
 
 	_, _, err = parsePort("1-b")
-	assert.EqualError(t, err, "ending range was not a number; `b`")
+	require.EqualError(t, err, "ending range was not a number; `b`")
 
 	s, e, err := parsePort(" 1 - 2    ")
 	assert.Equal(t, int32(1), s)
 	assert.Equal(t, int32(2), e)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	s, e, err = parsePort("0-1")
 	assert.Equal(t, int32(0), s)
 	assert.Equal(t, int32(0), e)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	s, e, err = parsePort("9919")
 	assert.Equal(t, int32(9919), s)
 	assert.Equal(t, int32(9919), e)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	s, e, err = parsePort("any")
 	assert.Equal(t, int32(0), s)
 	assert.Equal(t, int32(0), e)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 }
 
 func TestNewFirewallFromConfig(t *testing.T) {
@@ -633,53 +633,53 @@ func TestNewFirewallFromConfig(t *testing.T) {
 	conf := config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
 	_, err = NewFirewallFromConfig(l, cs, conf)
-	assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
+	require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
 
 	// Test both port and code
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
-	assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
+	require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
 
 	// Test missing host, group, cidr, ca_name and ca_sha
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
-	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
+	require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
 
 	// Test code/port error
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
-	assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
+	require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
 
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
-	assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
+	require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
 
 	// Test proto error
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
-	assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
+	require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
 
 	// Test cidr parse error
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
-	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
+	require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 	// Test local_cidr parse error
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
-	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
+	require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 	// Test both group and groups
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
-	assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
+	require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
 }
 
 func TestAddFirewallRulesFromConfig(t *testing.T) {
@@ -688,28 +688,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	conf := config.NewC(l)
 	mf := &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
-	assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
+	require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding udp rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
-	assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
+	require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding icmp rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
-	assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
+	require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding any rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
-	assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding rule with cidr
@@ -717,49 +717,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
-	assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding rule with local_cidr
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
-	assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
 
 	// Test adding rule with ca_sha
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
-	assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
 
 	// Test adding rule with ca_name
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
-	assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
 
 	// Test single group
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
-	assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test single groups
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
-	assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test multiple AND groups
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
-	assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test Add error
@@ -767,7 +767,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	mf = &mockFirewall{}
 	mf.nextCallReturn = errors.New("test error")
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
-	assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
+	require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
 }
 
 func TestFirewall_convertRule(t *testing.T) {
@@ -782,7 +782,7 @@ func TestFirewall_convertRule(t *testing.T) {
 
 	r, err := convertRule(l, c, "test", 1)
 	assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, "group1", r.Group)
 
 	// Ensure group array of > 1 is errord
@@ -793,7 +793,7 @@ func TestFirewall_convertRule(t *testing.T) {
 
 	r, err = convertRule(l, c, "test", 1)
 	assert.Equal(t, "", ob.String())
-	assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
+	require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
 
 	// Make sure a well formed group is alright
 	ob.Reset()
@@ -802,7 +802,7 @@ func TestFirewall_convertRule(t *testing.T) {
 	}
 
 	r, err = convertRule(l, c, "test", 1)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, "group1", r.Group)
 }
 

+ 2 - 1
header/header_test.go

@@ -5,6 +5,7 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 type headerTest struct {
@@ -111,7 +112,7 @@ func TestHeader_String(t *testing.T) {
 
 func TestHeader_MarshalJSON(t *testing.T) {
 	b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(
 		t,
 		"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",

+ 15 - 16
lighthouse_test.go

@@ -13,6 +13,7 @@ import (
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"gopkg.in/yaml.v2"
 )
 
@@ -21,7 +22,7 @@ func TestOldIPv4Only(t *testing.T) {
 	b := []byte{8, 129, 130, 132, 80, 16, 10}
 	var m V4AddrPort
 	err := m.Unmarshal(b)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	ip := netip.MustParseAddr("10.1.1.1")
 	bp := ip.As4()
 	assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
@@ -42,14 +43,14 @@ func Test_lhStaticMapping(t *testing.T) {
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
 	_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	lh2 := "10.128.0.3"
 	c = config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
 	_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
-	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
+	require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 
 func TestReloadLighthouseInterval(t *testing.T) {
@@ -71,19 +72,19 @@ func TestReloadLighthouseInterval(t *testing.T) {
 
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
 	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	lh.ifce = &mockEncWriter{}
 
 	// The first one routine is kicked off by main.go currently, lets make sure that one dies
-	assert.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 5"))
+	require.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 5"))
 	assert.Equal(t, int64(5), lh.interval.Load())
 
 	// Subsequent calls are killed off by the LightHouse.Reload function
-	assert.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 10"))
+	require.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 10"))
 	assert.Equal(t, int64(10), lh.interval.Load())
 
 	// If this completes then nothing is stealing our reload routine
-	assert.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 11"))
+	require.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 11"))
 	assert.Equal(t, int64(11), lh.interval.Load())
 }
 
@@ -99,9 +100,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 
 	c := config.NewC(l)
 	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
-	if !assert.NoError(b, err) {
-		b.Fatal()
-	}
+	require.NoError(b, err)
 
 	hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
 	hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
@@ -145,7 +144,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 			},
 		}
 		p, err := req.Marshal()
-		assert.NoError(b, err)
+		require.NoError(b, err)
 		for n := 0; n < b.N; n++ {
 			lhh.HandleRequest(rAddr, hi, p, mw)
 		}
@@ -160,7 +159,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 			},
 		}
 		p, err := req.Marshal()
-		assert.NoError(b, err)
+		require.NoError(b, err)
 
 		for n := 0; n < b.N; n++ {
 			lhh.HandleRequest(rAddr, hi, p, mw)
@@ -205,7 +204,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	}
 	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	lh.ifce = &mockEncWriter{}
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	lhh := lh.NewRequestHandler()
 
 	// Test that my first update responds with just that
@@ -290,7 +289,7 @@ func TestLighthouse_reload(t *testing.T) {
 	}
 
 	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	nc := map[interface{}]interface{}{
 		"static_host_map": map[interface{}]interface{}{
@@ -298,11 +297,11 @@ func TestLighthouse_reload(t *testing.T) {
 		},
 	}
 	rc, err := yaml.Marshal(nc)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	c.ReloadConfigString(string(rc))
 
 	err = lh.reload(c, false)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 }
 
 func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {

+ 29 - 28
outside_test.go

@@ -12,6 +12,7 @@ import (
 
 	"github.com/slackhq/nebula/firewall"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"golang.org/x/net/ipv4"
 )
 
@@ -20,13 +21,13 @@ func Test_newPacket(t *testing.T) {
 
 	// length fails
 	err := newPacket([]byte{}, true, p)
-	assert.ErrorIs(t, err, ErrPacketTooShort)
+	require.ErrorIs(t, err, ErrPacketTooShort)
 
 	err = newPacket([]byte{0x40}, true, p)
-	assert.ErrorIs(t, err, ErrIPv4PacketTooShort)
+	require.ErrorIs(t, err, ErrIPv4PacketTooShort)
 
 	err = newPacket([]byte{0x60}, true, p)
-	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+	require.ErrorIs(t, err, ErrIPv6PacketTooShort)
 
 	// length fail with ip options
 	h := ipv4.Header{
@@ -39,15 +40,15 @@ func Test_newPacket(t *testing.T) {
 
 	b, _ := h.Marshal()
 	err = newPacket(b, true, p)
-	assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
+	require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
 
 	// not an ipv4 packet
 	err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
-	assert.ErrorIs(t, err, ErrUnknownIPVersion)
+	require.ErrorIs(t, err, ErrUnknownIPVersion)
 
 	// invalid ihl
 	err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
-	assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
+	require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
 
 	// account for variable ip header length - incoming
 	h = ipv4.Header{
@@ -63,7 +64,7 @@ func Test_newPacket(t *testing.T) {
 	b = append(b, []byte{0, 3, 0, 4}...)
 	err = newPacket(b, true, p)
 
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
 	assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
 	assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
@@ -85,7 +86,7 @@ func Test_newPacket(t *testing.T) {
 	b = append(b, []byte{0, 5, 0, 6}...)
 	err = newPacket(b, false, p)
 
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, uint8(2), p.Protocol)
 	assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
 	assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
@@ -111,10 +112,10 @@ func Test_newPacket_v6(t *testing.T) {
 		FixLengths:       false,
 	}
 	err := gopacket.SerializeLayers(buffer, opt, &ip)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	err = newPacket(buffer.Bytes(), true, p)
-	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+	require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
 
 	// A good ICMP packet
 	ip = layers.IPv6{
@@ -134,7 +135,7 @@ func Test_newPacket_v6(t *testing.T) {
 	}
 
 	err = newPacket(buffer.Bytes(), true, p)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
 	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
 	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@@ -146,7 +147,7 @@ func Test_newPacket_v6(t *testing.T) {
 	b := buffer.Bytes()
 	b[6] = byte(layers.IPProtocolESP)
 	err = newPacket(b, true, p)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
 	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
 	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@@ -158,7 +159,7 @@ func Test_newPacket_v6(t *testing.T) {
 	b = buffer.Bytes()
 	b[6] = byte(layers.IPProtocolNoNextHeader)
 	err = newPacket(b, true, p)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
 	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
 	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@@ -170,7 +171,7 @@ func Test_newPacket_v6(t *testing.T) {
 	b = buffer.Bytes()
 	b[6] = 255 // 255 is a reserved protocol number
 	err = newPacket(b, true, p)
-	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+	require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
 
 	// A good UDP packet
 	ip = layers.IPv6{
@@ -186,7 +187,7 @@ func Test_newPacket_v6(t *testing.T) {
 		DstPort: layers.UDPPort(22),
 	}
 	err = udp.SetNetworkLayerForChecksum(&ip)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	buffer.Clear()
 	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
@@ -197,7 +198,7 @@ func Test_newPacket_v6(t *testing.T) {
 
 	// incoming
 	err = newPacket(b, true, p)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
 	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
 	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@@ -207,7 +208,7 @@ func Test_newPacket_v6(t *testing.T) {
 
 	// outgoing
 	err = newPacket(b, false, p)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
 	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
 	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
@@ -217,14 +218,14 @@ func Test_newPacket_v6(t *testing.T) {
 
 	// Too short UDP packet
 	err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
-	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+	require.ErrorIs(t, err, ErrIPv6PacketTooShort)
 
 	// A good TCP packet
 	b[6] = byte(layers.IPProtocolTCP)
 
 	// incoming
 	err = newPacket(b, true, p)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
 	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
 	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@@ -234,7 +235,7 @@ func Test_newPacket_v6(t *testing.T) {
 
 	// outgoing
 	err = newPacket(b, false, p)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
 	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
 	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
@@ -244,7 +245,7 @@ func Test_newPacket_v6(t *testing.T) {
 
 	// Too short TCP packet
 	err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
-	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+	require.ErrorIs(t, err, ErrIPv6PacketTooShort)
 
 	// A good UDP packet with an AH header
 	ip = layers.IPv6{
@@ -279,7 +280,7 @@ func Test_newPacket_v6(t *testing.T) {
 	b = append(b, udpHeader...)
 
 	err = newPacket(b, true, p)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
 	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
 	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@@ -290,7 +291,7 @@ func Test_newPacket_v6(t *testing.T) {
 	// Invalid AH header
 	b = buffer.Bytes()
 	err = newPacket(b, true, p)
-	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+	require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
 }
 
 func Test_newPacket_ipv6Fragment(t *testing.T) {
@@ -338,7 +339,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
 
 	// Test first fragment incoming
 	err = newPacket(firstFrag, true, p)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
 	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
 	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
@@ -348,7 +349,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
 
 	// Test first fragment outgoing
 	err = newPacket(firstFrag, false, p)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
 	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
 	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
@@ -377,7 +378,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
 
 	// Test second fragment incoming
 	err = newPacket(secondFrag, true, p)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
 	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
 	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
@@ -387,7 +388,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
 
 	// Test second fragment outgoing
 	err = newPacket(secondFrag, false, p)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
 	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
 	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
@@ -397,7 +398,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
 
 	// Too short of a fragment packet
 	err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
-	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+	require.ErrorIs(t, err, ErrIPv6PacketTooShort)
 }
 
 func BenchmarkParseV6(b *testing.B) {

+ 40 - 39
overlay/route_test.go

@@ -8,84 +8,85 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func Test_parseRoutes(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
 	n, err := netip.ParsePrefix("10.0.0.0/24")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// test no routes config
 	routes, err := parseRoutes(c, []netip.Prefix{n})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, routes)
 
 	// not an array
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "tun.routes is not an array")
+	require.EqualError(t, err, "tun.routes is not an array")
 
 	// no routes
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, routes)
 
 	// weird route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
+	require.EqualError(t, err, "entry 1 in tun.routes is invalid")
 
 	// no mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
+	require.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
 
 	// bad mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
+	require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 	// low mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
+	require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
 
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
+	require.EqualError(t, err, "entry 1.route in tun.routes is not present")
 
 	// unparsable route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
+	require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
+	require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
 
 	// above network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
+	require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
 
 	// Not in multiple ranges
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
+	require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
 
 	// happy case
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
@@ -93,7 +94,7 @@ func Test_parseRoutes(t *testing.T) {
 		map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
 	}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Len(t, routes, 2)
 
 	tested := 0
@@ -119,36 +120,36 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
 	n, err := netip.ParsePrefix("10.0.0.0/24")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// test no routes config
 	routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, routes)
 
 	// not an array
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "tun.unsafe_routes is not an array")
+	require.EqualError(t, err, "tun.unsafe_routes is not an array")
 
 	// no routes
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Empty(t, routes)
 
 	// weird route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
+	require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
 
 	// no via
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
+	require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
 
 	// invalid via
 	for _, invalidValue := range []interface{}{
@@ -157,44 +158,44 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 		c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
 		routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 		assert.Nil(t, routes)
-		assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
+		require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
 	}
 
 	// unparsable via
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
+	require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
 
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
+	require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
 
 	// unparsable route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
+	require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 	// within network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
+	require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
 
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// above network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	// no mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
@@ -206,19 +207,19 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
+	require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 	// low mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
+	require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
 
 	// bad install
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
+	require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
 
 	// happy case
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
@@ -228,7 +229,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
 	}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Len(t, routes, 4)
 
 	tested := 0
@@ -260,38 +261,38 @@ func Test_makeRouteTree(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
 	n, err := netip.ParsePrefix("10.0.0.0/24")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
 		map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
 		map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
 	}}
 	routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Len(t, routes, 2)
 	routeTree, err := makeRouteTree(l, routes, true)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 
 	ip, err := netip.ParseAddr("1.0.0.2")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	r, ok := routeTree.Lookup(ip)
 	assert.True(t, ok)
 
 	nip, err := netip.ParseAddr("192.168.0.1")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, nip, r)
 
 	ip, err = netip.ParseAddr("1.0.0.1")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	r, ok = routeTree.Lookup(ip)
 	assert.True(t, ok)
 
 	nip, err = netip.ParseAddr("192.168.0.2")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	assert.Equal(t, nip, r)
 
 	ip, err = netip.ParseAddr("1.1.0.1")
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	r, ok = routeTree.Lookup(ip)
 	assert.False(t, ok)
 }

+ 3 - 2
punchy_test.go

@@ -7,6 +7,7 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestNewPunchyFromConfig(t *testing.T) {
@@ -56,7 +57,7 @@ func TestPunchy_reload(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
 	delay, _ := time.ParseDuration("1m")
-	assert.NoError(t, c.LoadString(`
+	require.NoError(t, c.LoadString(`
 punchy:
   delay: 1m
   respond: false
@@ -66,7 +67,7 @@ punchy:
 	assert.False(t, p.GetRespond())
 
 	newDelay, _ := time.ParseDuration("10m")
-	assert.NoError(t, c.ReloadConfigString(`
+	require.NoError(t, c.ReloadConfigString(`
 punchy:
   delay: 10m
   respond: true