| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267 | package certimport (	"crypto/ed25519"	"crypto/rand"	"encoding/hex"	"net/netip"	"slices"	"testing"	"time"	"github.com/slackhq/nebula/test"	"github.com/stretchr/testify/assert"	"github.com/stretchr/testify/require")func TestCertificateV2_Marshal(t *testing.T) {	before := time.Now().Add(time.Second * -60).Round(time.Second)	after := time.Now().Add(time.Second * 60).Round(time.Second)	pubKey := []byte("1234567890abcedfghij1234567890ab")	nc := certificateV2{		details: detailsV2{			name: "testing",			networks: []netip.Prefix{				mustParsePrefixUnmapped("10.1.1.2/16"),				mustParsePrefixUnmapped("10.1.1.1/24"),			},			unsafeNetworks: []netip.Prefix{				mustParsePrefixUnmapped("9.1.1.3/16"),				mustParsePrefixUnmapped("9.1.1.2/24"),			},			groups:    []string{"test-group1", "test-group2", "test-group3"},			notBefore: before,			notAfter:  after,			isCA:      false,			issuer:    "1234567890abcdef1234567890abcdef",		},		signature: []byte("1234567890abcdef1234567890abcdef"),		publicKey: pubKey,	}	db, err := nc.details.Marshal()	require.NoError(t, err)	nc.rawDetails = db	b, err := nc.Marshal()	require.NoError(t, err)	//t.Log("Cert size:", len(b))	nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)	require.NoError(t, err)	assert.Equal(t, Version2, nc.Version())	assert.Equal(t, Curve_CURVE25519, nc.Curve())	assert.Equal(t, nc.Signature(), nc2.Signature())	assert.Equal(t, nc.Name(), nc2.Name())	assert.Equal(t, nc.NotBefore(), nc2.NotBefore())	assert.Equal(t, nc.NotAfter(), nc2.NotAfter())	assert.Equal(t, nc.PublicKey(), nc2.PublicKey())	assert.Equal(t, nc.IsCA(), nc2.IsCA())	assert.Equal(t, nc.Issuer(), nc2.Issuer())	// unmarshalling will sort networks and unsafeNetworks, we need to do the same	// but first make sure it fails	assert.NotEqual(t, nc.Networks(), nc2.Networks())	assert.NotEqual(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())	slices.SortFunc(nc.details.networks, comparePrefix)	slices.SortFunc(nc.details.unsafeNetworks, comparePrefix)	assert.Equal(t, nc.Networks(), nc2.Networks())	assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())	assert.Equal(t, nc.Groups(), nc2.Groups())}func TestCertificateV2_Expired(t *testing.T) {	nc := certificateV2{		details: detailsV2{			notBefore: time.Now().Add(time.Second * -60).Round(time.Second),			notAfter:  time.Now().Add(time.Second * 60).Round(time.Second),		},	}	assert.True(t, nc.Expired(time.Now().Add(time.Hour)))	assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))	assert.False(t, nc.Expired(time.Now()))}func TestCertificateV2_MarshalJSON(t *testing.T) {	time.Local = time.UTC	pubKey := []byte("1234567890abcedf1234567890abcedf")	nc := certificateV2{		details: detailsV2{			name: "testing",			networks: []netip.Prefix{				mustParsePrefixUnmapped("10.1.1.1/24"),				mustParsePrefixUnmapped("10.1.1.2/16"),			},			unsafeNetworks: []netip.Prefix{				mustParsePrefixUnmapped("9.1.1.2/24"),				mustParsePrefixUnmapped("9.1.1.3/16"),			},			groups:    []string{"test-group1", "test-group2", "test-group3"},			notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),			notAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),			isCA:      false,			issuer:    "1234567890abcedf1234567890abcedf",		},		publicKey: pubKey,		signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"),	}	b, err := nc.MarshalJSON()	require.ErrorIs(t, err, ErrMissingDetails)	rd, err := nc.details.Marshal()	require.NoError(t, err)	nc.rawDetails = rd	b, err = nc.MarshalJSON()	require.NoError(t, err)	assert.JSONEq(		t,		"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",		string(b),	)}func TestCertificateV2_VerifyPrivateKey(t *testing.T) {	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)	err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)	require.NoError(t, err)	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])	require.ErrorIs(t, err, ErrInvalidPrivateKey)	_, caKey2, err := ed25519.GenerateKey(rand.Reader)	require.NoError(t, err)	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)	require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)	c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)	require.NoError(t, err)	assert.Empty(t, b)	assert.Equal(t, Curve_CURVE25519, curve)	err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)	require.NoError(t, err)	_, priv2 := X25519Keypair()	err = c.VerifyPrivateKey(Curve_P256, priv2)	require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)	require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])	require.ErrorIs(t, err, ErrInvalidPrivateKey)	ac, ok := c.(*certificateV2)	require.True(t, ok)	ac.curve = Curve(99)	err = c.VerifyPrivateKey(Curve(99), priv2)	require.EqualError(t, err, "invalid curve: 99")	ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)	require.NoError(t, err)	err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])	require.ErrorIs(t, err, ErrInvalidPrivateKey)	c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)	rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)	err = c.VerifyPrivateKey(Curve_P256, priv[:16])	require.ErrorIs(t, err, ErrInvalidPrivateKey)	err = c.VerifyPrivateKey(Curve_P256, priv)	require.ErrorIs(t, err, ErrInvalidPrivateKey)	aCa, ok := ca2.(*certificateV2)	require.True(t, ok)	aCa.curve = Curve(99)	err = aCa.VerifyPrivateKey(Curve(99), priv2)	require.EqualError(t, err, "invalid curve: 99")}func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)	err := ca.VerifyPrivateKey(Curve_P256, caKey)	require.NoError(t, err)	_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)	require.NoError(t, err)	err = ca.VerifyPrivateKey(Curve_P256, caKey2)	require.Error(t, err)	c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)	require.NoError(t, err)	assert.Empty(t, b)	assert.Equal(t, Curve_P256, curve)	err = c.VerifyPrivateKey(Curve_P256, rawPriv)	require.NoError(t, err)	_, priv2 := P256Keypair()	err = c.VerifyPrivateKey(Curve_P256, priv2)	require.Error(t, err)}func TestCertificateV2_Copy(t *testing.T) {	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)	cc := c.Copy()	test.AssertDeepCopyEqual(t, c, cc)}func TestUnmarshalCertificateV2(t *testing.T) {	data := []byte("\x98\x00\x00")	_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)	require.EqualError(t, err, "bad wire format")}func TestCertificateV2_marshalForSigningStability(t *testing.T) {	before := time.Date(1996, time.May, 5, 0, 0, 0, 0, time.UTC)	after := before.Add(time.Second * 60).Round(time.Second)	pubKey := []byte("1234567890abcedfghij1234567890ab")	nc := certificateV2{		details: detailsV2{			name: "testing",			networks: []netip.Prefix{				mustParsePrefixUnmapped("10.1.1.2/16"),				mustParsePrefixUnmapped("10.1.1.1/24"),			},			unsafeNetworks: []netip.Prefix{				mustParsePrefixUnmapped("9.1.1.3/16"),				mustParsePrefixUnmapped("9.1.1.2/24"),			},			groups:    []string{"test-group1", "test-group2", "test-group3"},			notBefore: before,			notAfter:  after,			isCA:      false,			issuer:    "1234567890abcdef1234567890abcdef",		},		signature: []byte("1234567890abcdef1234567890abcdef"),		publicKey: pubKey,	}	const expectedRawDetailsStr = "a070800774657374696e67a10e04050a0101021004050a01010118a20e0405090101031004050901010218a3270c0b746573742d67726f7570310c0b746573742d67726f7570320c0b746573742d67726f7570338504318bef808604318befbc87101234567890abcdef1234567890abcdef"	expectedRawDetails, err := hex.DecodeString(expectedRawDetailsStr)	require.NoError(t, err)	db, err := nc.details.Marshal()	require.NoError(t, err)	assert.Equal(t, expectedRawDetails, db)	expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162")	b, err := nc.marshalForSigning()	require.NoError(t, err)	assert.Equal(t, expectedForSigning, b)}
 |