Browse Source

add PKCS11 support (#1153)

* add PKCS11 support

* add pkcs11 build option to the makefile, add a stub pkclient to avoid forcing CGO onto people

* don't print the pkcs11 option on nebula-cert keygen if not compiled in

* remove linux-arm64-pkcs11 from the all target to fix CI

* correctly serialize ec keys

* nebula-cert: support PKCS#11 for sign and ca

* fix gofmt lint

* clean up some logic with regard to closing sessions

* pkclient: handle empty correctly for TPM2

* Update Makefile and Actions

---------

Co-authored-by: Morgan Jones <[email protected]>
Co-authored-by: John Maguire <[email protected]>
Jack Doan 10 months ago
parent
commit
35603d1c39

+ 18 - 0
.github/workflows/test.yml

@@ -67,6 +67,24 @@ jobs:
     - name: End 2 end
       run: make e2evv GOEXPERIMENT=boringcrypto CGO_ENABLED=1
 
+  test-linux-pkcs11:
+    name: Build and test on linux with pkcs11
+    runs-on: ubuntu-latest
+    steps:
+
+    - uses: actions/checkout@v4
+
+    - uses: actions/setup-go@v5
+      with:
+        go-version: '1.22'
+        check-latest: true
+
+    - name: Build
+      run: make bin-pkcs11
+
+    - name: Test
+      run: make test-pkcs11
+
   test:
     name: Build and test on ${{ matrix.os }}
     runs-on: ${{ matrix.os }}

+ 1 - 0
.gitignore

@@ -13,5 +13,6 @@
 **.crt
 **.key
 **.pem
+**.pub
 !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key
 !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt

+ 10 - 3
Makefile

@@ -40,7 +40,7 @@ ALL_LINUX = linux-amd64 \
 	linux-mips64le \
 	linux-mips-softfloat \
 	linux-riscv64 \
-        linux-loong64
+	linux-loong64
 
 ALL_FREEBSD = freebsd-amd64 \
 	freebsd-arm64
@@ -63,7 +63,7 @@ ALL = $(ALL_LINUX) \
 e2e:
 	$(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e
 
-e2ev: TEST_FLAGS = -v
+e2ev: TEST_FLAGS += -v
 e2ev: e2e
 
 e2evv: TEST_ENV += TEST_LOGS=1
@@ -96,7 +96,7 @@ release-netbsd: $(ALL_NETBSD:%=build/nebula-%.tar.gz)
 
 release-boringcrypto: build/nebula-linux-$(shell go env GOARCH)-boringcrypto.tar.gz
 
-BUILD_ARGS = -trimpath
+BUILD_ARGS += -trimpath
 
 bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe
 	mv $? .
@@ -116,6 +116,10 @@ bin-freebsd-arm64: build/freebsd-arm64/nebula build/freebsd-arm64/nebula-cert
 bin-boringcrypto: build/linux-$(shell go env GOARCH)-boringcrypto/nebula build/linux-$(shell go env GOARCH)-boringcrypto/nebula-cert
 	mv $? .
 
+bin-pkcs11: BUILD_ARGS += -tags pkcs11
+bin-pkcs11: CGO_ENABLED = 1
+bin-pkcs11: bin
+
 bin:
 	go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula${NEBULA_CMD_SUFFIX} ${NEBULA_CMD_PATH}
 	go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula-cert${NEBULA_CMD_SUFFIX} ./cmd/nebula-cert
@@ -168,6 +172,9 @@ test:
 test-boringcrypto:
 	GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -v ./...
 
+test-pkcs11:
+	CGO_ENABLED=1 go test -v -tags pkcs11 ./...
+
 test-cov-html:
 	go test -coverprofile=coverage.out
 	go tool cover -html=coverage.out

+ 35 - 2
cert/cert.go

@@ -20,6 +20,7 @@ import (
 	"sync/atomic"
 	"time"
 
+	"github.com/slackhq/nebula/pkclient"
 	"golang.org/x/crypto/curve25519"
 	"google.golang.org/protobuf/proto"
 )
@@ -41,8 +42,9 @@ const (
 )
 
 type NebulaCertificate struct {
-	Details   NebulaCertificateDetails
-	Signature []byte
+	Details      NebulaCertificateDetails
+	Pkcs11Backed bool
+	Signature    []byte
 
 	// the cached hex string of the calculated sha256sum
 	// for VerifyWithCache
@@ -555,6 +557,34 @@ func (nc *NebulaCertificate) Sign(curve Curve, key []byte) error {
 	return nil
 }
 
+// SignPkcs11 signs a nebula cert with the provided private key
+func (nc *NebulaCertificate) SignPkcs11(curve Curve, client *pkclient.PKClient) error {
+	if !nc.Pkcs11Backed {
+		return fmt.Errorf("certificate is not PKCS#11 backed")
+	}
+
+	if curve != nc.Details.Curve {
+		return fmt.Errorf("curve in cert and private key supplied don't match")
+	}
+
+	if curve != Curve_P256 {
+		return fmt.Errorf("only P256 is supported by PKCS#11")
+	}
+
+	b, err := proto.Marshal(nc.getRawDetails())
+	if err != nil {
+		return err
+	}
+
+	sig, err := client.SignASN1(b)
+	if err != nil {
+		return err
+	}
+
+	nc.Signature = sig
+	return nil
+}
+
 // CheckSignature verifies the signature against the provided public key
 func (nc *NebulaCertificate) CheckSignature(key []byte) bool {
 	b, err := proto.Marshal(nc.getRawDetails())
@@ -693,6 +723,9 @@ func (nc *NebulaCertificate) CheckRootConstrains(signer *NebulaCertificate) erro
 
 // VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match
 func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error {
+	if nc.Pkcs11Backed {
+		return nil //todo!
+	}
 	if curve != nc.Details.Curve {
 		return fmt.Errorf("curve in cert and private key supplied don't match")
 	}

+ 84 - 37
cmd/nebula-cert/ca.go

@@ -4,6 +4,7 @@ import (
 	"crypto/ecdsa"
 	"crypto/elliptic"
 	"crypto/rand"
+
 	"flag"
 	"fmt"
 	"io"
@@ -15,6 +16,7 @@ import (
 
 	"github.com/skip2/go-qrcode"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/pkclient"
 	"golang.org/x/crypto/ed25519"
 )
 
@@ -33,7 +35,8 @@ type caFlags struct {
 	argonParallelism *uint
 	encryption       *bool
 
-	curve *string
+	curve  *string
+	p11url *string
 }
 
 func newCaFlags() *caFlags {
@@ -52,6 +55,7 @@ func newCaFlags() *caFlags {
 	cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
 	cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
 	cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
+	cf.p11url = p11Flag(cf.set)
 	return &cf
 }
 
@@ -76,17 +80,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 		return err
 	}
 
+	isP11 := len(*cf.p11url) > 0
+
 	if err := mustFlagString("name", cf.name); err != nil {
 		return err
 	}
-	if err := mustFlagString("out-key", cf.outKeyPath); err != nil {
-		return err
+	if !isP11 {
+		if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
+			return err
+		}
 	}
 	if err := mustFlagString("out-crt", cf.outCertPath); err != nil {
 		return err
 	}
 	var kdfParams *cert.Argon2Parameters
-	if *cf.encryption {
+	if !isP11 && *cf.encryption {
 		if kdfParams, err = parseArgonParameters(*cf.argonMemory, *cf.argonParallelism, *cf.argonIterations); err != nil {
 			return err
 		}
@@ -143,7 +151,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 	}
 
 	var passphrase []byte
-	if *cf.encryption {
+	if !isP11 && *cf.encryption {
 		for i := 0; i < 5; i++ {
 			out.Write([]byte("Enter passphrase: "))
 			passphrase, err = pr.ReadPassword()
@@ -166,29 +174,54 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 
 	var curve cert.Curve
 	var pub, rawPriv []byte
-	switch *cf.curve {
-	case "25519", "X25519", "Curve25519", "CURVE25519":
-		curve = cert.Curve_CURVE25519
-		pub, rawPriv, err = ed25519.GenerateKey(rand.Reader)
+	var p11Client *pkclient.PKClient
+
+	if isP11 {
+		switch *cf.curve {
+		case "P256":
+			curve = cert.Curve_P256
+		default:
+			return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve)
+		}
+
+		p11Client, err = pkclient.FromUrl(*cf.p11url)
 		if err != nil {
-			return fmt.Errorf("error while generating ed25519 keys: %s", err)
+			return fmt.Errorf("error while creating PKCS#11 client: %w", err)
 		}
-	case "P256":
-		var key *ecdsa.PrivateKey
-		curve = cert.Curve_P256
-		key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+		defer func(client *pkclient.PKClient) {
+			_ = client.Close()
+		}(p11Client)
+		pub, err = p11Client.GetPubKey()
 		if err != nil {
-			return fmt.Errorf("error while generating ecdsa keys: %s", err)
+			return fmt.Errorf("error while getting public key with PKCS#11: %w", err)
 		}
+	} else {
+		switch *cf.curve {
+		case "25519", "X25519", "Curve25519", "CURVE25519":
+			curve = cert.Curve_CURVE25519
+			pub, rawPriv, err = ed25519.GenerateKey(rand.Reader)
+			if err != nil {
+				return fmt.Errorf("error while generating ed25519 keys: %s", err)
+			}
+		case "P256":
+			var key *ecdsa.PrivateKey
+			curve = cert.Curve_P256
+			key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+			if err != nil {
+				return fmt.Errorf("error while generating ecdsa keys: %s", err)
+			}
 
-		// ecdh.PrivateKey lets us get at the encoded bytes, even though
-		// we aren't using ECDH here.
-		eKey, err := key.ECDH()
-		if err != nil {
-			return fmt.Errorf("error while converting ecdsa key: %s", err)
+			// ecdh.PrivateKey lets us get at the encoded bytes, even though
+			// we aren't using ECDH here.
+			eKey, err := key.ECDH()
+			if err != nil {
+				return fmt.Errorf("error while converting ecdsa key: %s", err)
+			}
+			rawPriv = eKey.Bytes()
+			pub = eKey.PublicKey().Bytes()
+		default:
+			return fmt.Errorf("invalid curve: %s", *cf.curve)
 		}
-		rawPriv = eKey.Bytes()
-		pub = eKey.PublicKey().Bytes()
 	}
 
 	nc := cert.NebulaCertificate{
@@ -203,34 +236,48 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 			IsCA:      true,
 			Curve:     curve,
 		},
+		Pkcs11Backed: isP11,
 	}
 
-	if _, err := os.Stat(*cf.outKeyPath); err == nil {
-		return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
+	if !isP11 {
+		if _, err := os.Stat(*cf.outKeyPath); err == nil {
+			return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
+		}
 	}
 
 	if _, err := os.Stat(*cf.outCertPath); err == nil {
 		return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
 	}
 
-	err = nc.Sign(curve, rawPriv)
-	if err != nil {
-		return fmt.Errorf("error while signing: %s", err)
-	}
-
 	var b []byte
-	if *cf.encryption {
-		b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams)
+
+	if isP11 {
+		err = nc.SignPkcs11(curve, p11Client)
 		if err != nil {
-			return fmt.Errorf("error while encrypting out-key: %s", err)
+			return fmt.Errorf("error while signing with PKCS#11: %w", err)
 		}
 	} else {
-		b = cert.MarshalSigningPrivateKey(curve, rawPriv)
-	}
+		err = nc.Sign(curve, rawPriv)
+		if err != nil {
+			return fmt.Errorf("error while signing: %s", err)
+		}
 
-	err = os.WriteFile(*cf.outKeyPath, b, 0600)
-	if err != nil {
-		return fmt.Errorf("error while writing out-key: %s", err)
+		if *cf.encryption {
+			b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams)
+			if err != nil {
+				return fmt.Errorf("error while encrypting out-key: %s", err)
+			}
+		} else {
+			b = cert.MarshalSigningPrivateKey(curve, rawPriv)
+		}
+
+		err = os.WriteFile(*cf.outKeyPath, b, 0600)
+		if err != nil {
+			return fmt.Errorf("error while writing out-key: %s", err)
+		}
+		if _, err := os.Stat(*cf.outCertPath); err == nil {
+			return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
+		}
 	}
 
 	b, err = nc.MarshalToPEM()

+ 1 - 0
cmd/nebula-cert/ca_test.go

@@ -52,6 +52,7 @@ func Test_caHelp(t *testing.T) {
 			"    \tOptional: path to write the private key to (default \"ca.key\")\n"+
 			"  -out-qr string\n"+
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
+			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			"  -subnets string\n"+
 			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n",
 		ob.String(),

+ 48 - 19
cmd/nebula-cert/keygen.go

@@ -6,6 +6,8 @@ import (
 	"io"
 	"os"
 
+	"github.com/slackhq/nebula/pkclient"
+
 	"github.com/slackhq/nebula/cert"
 )
 
@@ -13,8 +15,8 @@ type keygenFlags struct {
 	set        *flag.FlagSet
 	outKeyPath *string
 	outPubPath *string
-
-	curve *string
+	curve      *string
+	p11url     *string
 }
 
 func newKeygenFlags() *keygenFlags {
@@ -23,6 +25,7 @@ func newKeygenFlags() *keygenFlags {
 	cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to")
 	cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to")
 	cf.curve = cf.set.String("curve", "25519", "ECDH Curve (25519, P256)")
+	cf.p11url = p11Flag(cf.set)
 	return &cf
 }
 
@@ -33,31 +36,57 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
 		return err
 	}
 
-	if err := mustFlagString("out-key", cf.outKeyPath); err != nil {
-		return err
+	isP11 := len(*cf.p11url) > 0
+
+	if !isP11 {
+		if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
+			return err
+		}
 	}
-	if err := mustFlagString("out-pub", cf.outPubPath); err != nil {
+	if err = mustFlagString("out-pub", cf.outPubPath); err != nil {
 		return err
 	}
 
 	var pub, rawPriv []byte
 	var curve cert.Curve
-	switch *cf.curve {
-	case "25519", "X25519", "Curve25519", "CURVE25519":
-		pub, rawPriv = x25519Keypair()
-		curve = cert.Curve_CURVE25519
-	case "P256":
-		pub, rawPriv = p256Keypair()
-		curve = cert.Curve_P256
-	default:
-		return fmt.Errorf("invalid curve: %s", *cf.curve)
+	if isP11 {
+		switch *cf.curve {
+		case "P256":
+			curve = cert.Curve_P256
+		default:
+			return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve)
+		}
+	} else {
+		switch *cf.curve {
+		case "25519", "X25519", "Curve25519", "CURVE25519":
+			pub, rawPriv = x25519Keypair()
+			curve = cert.Curve_CURVE25519
+		case "P256":
+			pub, rawPriv = p256Keypair()
+			curve = cert.Curve_P256
+		default:
+			return fmt.Errorf("invalid curve: %s", *cf.curve)
+		}
 	}
 
-	err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
-	if err != nil {
-		return fmt.Errorf("error while writing out-key: %s", err)
+	if isP11 {
+		p11Client, err := pkclient.FromUrl(*cf.p11url)
+		if err != nil {
+			return fmt.Errorf("error while creating PKCS#11 client: %w", err)
+		}
+		defer func(client *pkclient.PKClient) {
+			_ = client.Close()
+		}(p11Client)
+		pub, err = p11Client.GetPubKey()
+		if err != nil {
+			return fmt.Errorf("error while getting public key: %w", err)
+		}
+	} else {
+		err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
+		if err != nil {
+			return fmt.Errorf("error while writing out-key: %s", err)
+		}
 	}
-
 	err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-pub: %s", err)
@@ -72,7 +101,7 @@ func keygenSummary() string {
 
 func keygenHelp(out io.Writer) {
 	cf := newKeygenFlags()
-	out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
+	_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
 	cf.set.SetOutput(out)
 	cf.set.PrintDefaults()
 }

+ 2 - 1
cmd/nebula-cert/keygen_test.go

@@ -26,7 +26,8 @@ func Test_keygenHelp(t *testing.T) {
 			"  -out-key string\n"+
 			"    \tRequired: path to write the private key to\n"+
 			"  -out-pub string\n"+
-			"    \tRequired: path to write the public key to\n",
+			"    \tRequired: path to write the public key to\n"+
+			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n"),
 		ob.String(),
 	)
 }

+ 10 - 1
cmd/nebula-cert/main_test.go

@@ -3,6 +3,7 @@ package main
 import (
 	"bytes"
 	"errors"
+	"fmt"
 	"io"
 	"os"
 	"testing"
@@ -77,8 +78,16 @@ func assertHelpError(t *testing.T, err error, msg string) {
 	case *helpError:
 		// good
 	default:
-		t.Fatal("err was not a helpError")
+		t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
 	}
 
 	assert.EqualError(t, err, msg)
 }
+
+func optionalPkcs11String(msg string) string {
+	if p11Supported() {
+		return msg
+	} else {
+		return ""
+	}
+}

+ 15 - 0
cmd/nebula-cert/p11_cgo.go

@@ -0,0 +1,15 @@
+//go:build cgo && pkcs11
+
+package main
+
+import (
+	"flag"
+)
+
+func p11Supported() bool {
+	return true
+}
+
+func p11Flag(set *flag.FlagSet) *string {
+	return set.String("pkcs11", "", "Optional: PKCS#11 URI to an existing private key")
+}

+ 16 - 0
cmd/nebula-cert/p11_stub.go

@@ -0,0 +1,16 @@
+//go:build !cgo || !pkcs11
+
+package main
+
+import (
+	"flag"
+)
+
+func p11Supported() bool {
+	return false
+}
+
+func p11Flag(set *flag.FlagSet) *string {
+	var ret = ""
+	return &ret
+}

+ 79 - 43
cmd/nebula-cert/sign.go

@@ -13,6 +13,7 @@ import (
 
 	"github.com/skip2/go-qrcode"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/pkclient"
 	"golang.org/x/crypto/curve25519"
 )
 
@@ -29,6 +30,7 @@ type signFlags struct {
 	outQRPath   *string
 	groups      *string
 	subnets     *string
+	p11url      *string
 }
 
 func newSignFlags() *signFlags {
@@ -45,8 +47,8 @@ func newSignFlags() *signFlags {
 	sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
 	sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
+	sf.p11url = p11Flag(sf.set)
 	return &sf
-
 }
 
 func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error {
@@ -56,8 +58,12 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return err
 	}
 
-	if err := mustFlagString("ca-key", sf.caKeyPath); err != nil {
-		return err
+	isP11 := len(*sf.p11url) > 0
+
+	if !isP11 {
+		if err := mustFlagString("ca-key", sf.caKeyPath); err != nil {
+			return err
+		}
 	}
 	if err := mustFlagString("ca-crt", sf.caCertPath); err != nil {
 		return err
@@ -68,47 +74,49 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 	if err := mustFlagString("ip", sf.ip); err != nil {
 		return err
 	}
-	if *sf.inPubPath != "" && *sf.outKeyPath != "" {
+	if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
 		return newHelpErrorf("cannot set both -in-pub and -out-key")
 	}
 
-	rawCAKey, err := os.ReadFile(*sf.caKeyPath)
-	if err != nil {
-		return fmt.Errorf("error while reading ca-key: %s", err)
-	}
-
 	var curve cert.Curve
 	var caKey []byte
+	if !isP11 {
+		var rawCAKey []byte
+		rawCAKey, err := os.ReadFile(*sf.caKeyPath)
+		if err != nil {
+			return fmt.Errorf("error while reading ca-key: %s", err)
+		}
 
-	// naively attempt to decode the private key as though it is not encrypted
-	caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey)
-	if err == cert.ErrPrivateKeyEncrypted {
-		// ask for a passphrase until we get one
-		var passphrase []byte
-		for i := 0; i < 5; i++ {
-			out.Write([]byte("Enter passphrase: "))
-			passphrase, err = pr.ReadPassword()
-
-			if err == ErrNoTerminal {
-				return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
-			} else if err != nil {
-				return fmt.Errorf("error reading password: %s", err)
-			}
+		// naively attempt to decode the private key as though it is not encrypted
+		caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey)
+		if err == cert.ErrPrivateKeyEncrypted {
+			// ask for a passphrase until we get one
+			var passphrase []byte
+			for i := 0; i < 5; i++ {
+				out.Write([]byte("Enter passphrase: "))
+				passphrase, err = pr.ReadPassword()
+
+				if err == ErrNoTerminal {
+					return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
+				} else if err != nil {
+					return fmt.Errorf("error reading password: %s", err)
+				}
 
-			if len(passphrase) > 0 {
-				break
+				if len(passphrase) > 0 {
+					break
+				}
+			}
+			if len(passphrase) == 0 {
+				return fmt.Errorf("cannot open encrypted ca-key without passphrase")
 			}
-		}
-		if len(passphrase) == 0 {
-			return fmt.Errorf("cannot open encrypted ca-key without passphrase")
-		}
 
-		curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
-		if err != nil {
-			return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
+			curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
+			if err != nil {
+				return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
+			}
+		} else if err != nil {
+			return fmt.Errorf("error while parsing ca-key: %s", err)
 		}
-	} else if err != nil {
-		return fmt.Errorf("error while parsing ca-key: %s", err)
 	}
 
 	rawCACert, err := os.ReadFile(*sf.caCertPath)
@@ -121,8 +129,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("error while parsing ca-crt: %s", err)
 	}
 
-	if err := caCert.VerifyPrivateKey(curve, caKey); err != nil {
-		return fmt.Errorf("refusing to sign, root certificate does not match private key")
+	if !isP11 {
+		if err := caCert.VerifyPrivateKey(curve, caKey); err != nil {
+			return fmt.Errorf("refusing to sign, root certificate does not match private key")
+		}
 	}
 
 	issuer, err := caCert.Sha256Sum()
@@ -176,12 +186,25 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 	}
 
 	var pub, rawPriv []byte
+	var p11Client *pkclient.PKClient
+
+	if isP11 {
+		curve = cert.Curve_P256
+		p11Client, err = pkclient.FromUrl(*sf.p11url)
+		if err != nil {
+			return fmt.Errorf("error while creating PKCS#11 client: %w", err)
+		}
+		defer func(client *pkclient.PKClient) {
+			_ = client.Close()
+		}(p11Client)
+	}
+
 	if *sf.inPubPath != "" {
+		var pubCurve cert.Curve
 		rawPub, err := os.ReadFile(*sf.inPubPath)
 		if err != nil {
 			return fmt.Errorf("error while reading in-pub: %s", err)
 		}
-		var pubCurve cert.Curve
 		pub, _, pubCurve, err = cert.UnmarshalPublicKey(rawPub)
 		if err != nil {
 			return fmt.Errorf("error while parsing in-pub: %s", err)
@@ -189,6 +212,11 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		if pubCurve != curve {
 			return fmt.Errorf("curve of in-pub does not match ca")
 		}
+	} else if isP11 {
+		pub, err = p11Client.GetPubKey()
+		if err != nil {
+			return fmt.Errorf("error while getting public key with PKCS#11: %w", err)
+		}
 	} else {
 		pub, rawPriv = newKeypair(curve)
 	}
@@ -206,6 +234,19 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 			Issuer:    issuer,
 			Curve:     curve,
 		},
+		Pkcs11Backed: isP11,
+	}
+
+	if p11Client == nil {
+		err = nc.Sign(curve, caKey)
+		if err != nil {
+			return fmt.Errorf("error while signing: %w", err)
+		}
+	} else {
+		err = nc.SignPkcs11(curve, p11Client)
+		if err != nil {
+			return fmt.Errorf("error while signing with PKCS#11: %w", err)
+		}
 	}
 
 	if err := nc.CheckRootConstrains(caCert); err != nil {
@@ -224,12 +265,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
 	}
 
-	err = nc.Sign(curve, caKey)
-	if err != nil {
-		return fmt.Errorf("error while signing: %s", err)
-	}
-
-	if *sf.inPubPath == "" {
+	if !isP11 && *sf.inPubPath == "" {
 		if _, err := os.Stat(*sf.outKeyPath); err == nil {
 			return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
 		}

+ 1 - 0
cmd/nebula-cert/sign_test.go

@@ -48,6 +48,7 @@ func Test_signHelp(t *testing.T) {
 			"    \tOptional (if in-pub not set): path to write the private key to\n"+
 			"  -out-qr string\n"+
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
+			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			"  -subnets string\n"+
 			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n",
 		ob.String(),

+ 5 - 1
connection_state.go

@@ -32,7 +32,11 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
 	case cert.Curve_CURVE25519:
 		dhFunc = noise.DH25519
 	case cert.Curve_P256:
-		dhFunc = noiseutil.DHP256
+		if certState.Certificate.Pkcs11Backed {
+			dhFunc = noiseutil.DHP256PKCS11
+		} else {
+			dhFunc = noiseutil.DHP256
+		}
 	default:
 		l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
 		return nil

+ 2 - 0
go.mod

@@ -15,12 +15,14 @@ require (
 	github.com/google/gopacket v1.1.19
 	github.com/kardianos/service v1.2.2
 	github.com/miekg/dns v1.1.61
+	github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
 	github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
 	github.com/prometheus/client_golang v1.19.1
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
 	github.com/sirupsen/logrus v1.9.3
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
+	github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
 	github.com/stretchr/testify v1.9.0
 	github.com/vishvananda/netlink v1.2.1-beta.2
 	golang.org/x/crypto v0.26.0

+ 4 - 0
go.sum

@@ -83,6 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
 github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs=
 github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ=
+github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
+github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
@@ -131,6 +133,8 @@ github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
+github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
+github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=

+ 50 - 0
noiseutil/pkcs11.go

@@ -0,0 +1,50 @@
+package noiseutil
+
+import (
+	"crypto/ecdh"
+	"fmt"
+	"strings"
+
+	"github.com/slackhq/nebula/pkclient"
+
+	"github.com/flynn/noise"
+)
+
+// DHP256PKCS11 is the NIST P-256 ECDH function
+var DHP256PKCS11 noise.DHFunc = newNISTP11Curve("P256", ecdh.P256(), 32)
+
+type nistP11Curve struct {
+	nistCurve
+}
+
+func newNISTP11Curve(name string, curve ecdh.Curve, byteLen int) nistP11Curve {
+	return nistP11Curve{
+		newNISTCurve(name, curve, byteLen),
+	}
+}
+
+func (c nistP11Curve) DH(privkey, pubkey []byte) ([]byte, error) {
+	//for this function "privkey" is actually a pkcs11 URI
+	pkStr := string(privkey)
+
+	//to set up a handshake, we need to also do non-pkcs11-DH. Handle that here.
+	if !strings.HasPrefix(pkStr, "pkcs11:") {
+		return DHP256.DH(privkey, pubkey)
+	}
+	ecdhPubKey, err := c.curve.NewPublicKey(pubkey)
+	if err != nil {
+		return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err)
+	}
+
+	//this is not the most performant way to do this (a long-lived client would be better)
+	//but, it works, and helps avoid problems with stale sessions and HSMs used by multiple users.
+	client, err := pkclient.FromUrl(pkStr)
+	if err != nil {
+		return nil, err
+	}
+	defer func(client *pkclient.PKClient) {
+		_ = client.Close()
+	}(client)
+
+	return client.DeriveNoise(ecdhPubKey.Bytes())
+}

+ 87 - 0
pkclient/pkclient.go

@@ -0,0 +1,87 @@
+package pkclient
+
+import (
+	"crypto/ecdsa"
+	"crypto/x509"
+	"fmt"
+	"io"
+	"strconv"
+
+	"github.com/stefanberger/go-pkcs11uri"
+)
+
+type Client interface {
+	io.Closer
+	GetPubKey() ([]byte, error)
+	DeriveNoise(peerPubKey []byte) ([]byte, error)
+	Test() error
+}
+
+const NoiseKeySize = 32
+
+func FromUrl(pkurl string) (*PKClient, error) {
+	uri := pkcs11uri.New()
+	uri.SetAllowAnyModule(true) //todo
+	err := uri.Parse(pkurl)
+	if err != nil {
+		return nil, err
+	}
+
+	module, err := uri.GetModule()
+	if err != nil {
+		return nil, err
+	}
+
+	slotid := 0
+	slot, ok := uri.GetPathAttribute("slot-id", false)
+	if !ok {
+		slotid = 0
+	} else {
+		slotid, err = strconv.Atoi(slot)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	pin, _ := uri.GetPIN()
+	id, _ := uri.GetPathAttribute("id", false)
+	label, _ := uri.GetPathAttribute("object", false)
+
+	return New(module, uint(slotid), pin, id, label)
+}
+
+func ecKeyToArray(key *ecdsa.PublicKey) []byte {
+	x := make([]byte, 32)
+	y := make([]byte, 32)
+	key.X.FillBytes(x)
+	key.Y.FillBytes(y)
+	return append([]byte{0x04}, append(x, y...)...)
+}
+
+func formatPubkeyFromPublicKeyInfoAttr(d []byte) ([]byte, error) {
+	e, err := x509.ParsePKIXPublicKey(d)
+	if err != nil {
+		return nil, err
+	}
+	switch t := e.(type) {
+	case *ecdsa.PublicKey:
+		return ecKeyToArray(e.(*ecdsa.PublicKey)), nil
+	default:
+		return nil, fmt.Errorf("unknown public key type: %T", t)
+	}
+}
+
+func (c *PKClient) Test() error {
+	pub, err := c.GetPubKey()
+	if err != nil {
+		return fmt.Errorf("failed to get public key: %w", err)
+	}
+	out, err := c.DeriveNoise(pub) //do an ECDH with ourselves as a quick test
+	if err != nil {
+		return err
+	}
+	if len(out) != NoiseKeySize {
+		return fmt.Errorf("got a key of %d bytes, expected %d", len(out), NoiseKeySize)
+	}
+	return nil
+}

+ 229 - 0
pkclient/pkclient_cgo.go

@@ -0,0 +1,229 @@
+//go:build cgo && pkcs11
+
+package pkclient
+
+import (
+	"encoding/asn1"
+	"errors"
+	"fmt"
+	"log"
+	"math/big"
+
+	"github.com/miekg/pkcs11"
+	"github.com/miekg/pkcs11/p11"
+)
+
+type PKClient struct {
+	module     p11.Module
+	session    p11.Session
+	id         []byte
+	label      []byte
+	privKeyObj p11.Object
+	pubKeyObj  p11.Object
+}
+
+type ecdsaSignature struct {
+	R, S *big.Int
+}
+
+// New tries to open a session with the HSM, select the slot and login to it
+func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) {
+	module, err := p11.OpenModule(hsmPath)
+	if err != nil {
+		return nil, fmt.Errorf("failed to load module library: %s", hsmPath)
+	}
+
+	slots, err := module.Slots()
+	if err != nil {
+		module.Destroy()
+		return nil, err
+	}
+
+	// Try to open a session on the slot
+	slotIdx := 0
+	for i, slot := range slots {
+		if slot.ID() == slotId {
+			slotIdx = i
+			break
+		}
+	}
+
+	client := &PKClient{
+		module: module,
+		id:     []byte(id),
+		label:  []byte(label),
+	}
+
+	client.session, err = slots[slotIdx].OpenWriteSession()
+	if err != nil {
+		module.Destroy()
+		return nil, fmt.Errorf("failed to open session on slot %d", slotId)
+	}
+
+	if len(pin) != 0 {
+		err = client.session.Login(pin)
+		if err != nil {
+			// ignore "already logged in"
+			if !errors.Is(err, pkcs11.Error(256)) {
+				_ = client.session.Close()
+				return nil, fmt.Errorf("unable to login. error: %w", err)
+			}
+		}
+	}
+
+	// Make sure the hsm has a private key for deriving
+	client.privKeyObj, err = client.findDeriveKey(client.id, client.label, true)
+	if err != nil {
+		_ = client.Close() //log out, close session, destroy module
+		return nil, fmt.Errorf("failed to find private key for deriving: %w", err)
+	}
+
+	return client, nil
+}
+
+// Close cleans up properly and logs out
+func (c *PKClient) Close() error {
+	var err error = nil
+	if c.session != nil {
+		_ = c.session.Logout() //if logout fails, we still want to close
+		err = c.session.Close()
+	}
+
+	c.module.Destroy()
+	return err
+}
+
+// Try to find a suitable key on the hsm for key derivation
+// parameter GET_PUB_KEY sets the search pattern for a public or private key
+func (c *PKClient) findDeriveKey(id []byte, label []byte, private bool) (key p11.Object, err error) {
+	keyClass := pkcs11.CKO_PRIVATE_KEY
+	if !private {
+		keyClass = pkcs11.CKO_PUBLIC_KEY
+	}
+	keyAttrs := []*pkcs11.Attribute{
+		//todo, not all HSMs seem to report this, even if its true: pkcs11.NewAttribute(pkcs11.CKA_DERIVE, true),
+		pkcs11.NewAttribute(pkcs11.CKA_CLASS, keyClass),
+	}
+
+	if id != nil && len(id) != 0 {
+		keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_ID, id))
+	}
+	if label != nil && len(label) != 0 {
+		keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label))
+	}
+
+	return c.session.FindObject(keyAttrs)
+}
+
+func (c *PKClient) listDeriveKeys(id []byte, label []byte, private bool) {
+	keyClass := pkcs11.CKO_PRIVATE_KEY
+	if !private {
+		keyClass = pkcs11.CKO_PUBLIC_KEY
+	}
+	keyAttrs := []*pkcs11.Attribute{
+		pkcs11.NewAttribute(pkcs11.CKA_CLASS, keyClass),
+	}
+
+	if id != nil && len(id) != 0 {
+		keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_ID, id))
+	}
+	if label != nil && len(label) != 0 {
+		keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label))
+	}
+
+	objects, err := c.session.FindObjects(keyAttrs)
+	if err != nil {
+		return
+	}
+
+	for _, obj := range objects {
+		l, err := obj.Label()
+		log.Printf("%s, %v", l, err)
+		a, err := obj.Attribute(pkcs11.CKA_DERIVE)
+		log.Printf("DERIVE: %s %v, %v", l, a, err)
+	}
+}
+
+// SignASN1 signs some data. Returns the ASN.1 encoded signature.
+func (c *PKClient) SignASN1(data []byte) ([]byte, error) {
+	mech := pkcs11.NewMechanism(pkcs11.CKM_ECDSA_SHA256, nil)
+	sk := p11.PrivateKey(c.privKeyObj)
+	rawSig, err := sk.Sign(*mech, data)
+	if err != nil {
+		return nil, err
+	}
+
+	// PKCS #11 Mechanisms v2.30:
+	// "The signature octets correspond to the concatenation of the ECDSA values r and s,
+	// both represented as an octet string of equal length of at most nLen with the most
+	// significant byte first. If r and s have different octet length, the shorter of both
+	// must be padded with leading zero octets such that both have the same octet length.
+	// Loosely spoken, the first half of the signature is r and the second half is s."
+	r := new(big.Int).SetBytes(rawSig[:len(rawSig)/2])
+	s := new(big.Int).SetBytes(rawSig[len(rawSig)/2:])
+	return asn1.Marshal(ecdsaSignature{r, s})
+}
+
+// DeriveNoise derives a shared secret using the input public key against the private key that was found during setup.
+// Returns a fixed 32 byte array.
+func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
+	// Before we call derive, we need to have an array of attributes which specify the type of
+	// key to be returned, in our case, it's the shared secret key, produced via deriving
+	// This template pulled from OpenSC pkclient-tool.c line 4038
+	attrTemplate := []*pkcs11.Attribute{
+		pkcs11.NewAttribute(pkcs11.CKA_TOKEN, false),
+		pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_SECRET_KEY),
+		pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, pkcs11.CKK_GENERIC_SECRET),
+		pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, false),
+		pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, true),
+		pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, true),
+		pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
+		pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
+		pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
+	}
+
+	// Set up the parameters which include the peer's public key
+	ecdhParams := pkcs11.NewECDH1DeriveParams(pkcs11.CKD_NULL, nil, peerPubKey)
+	mech := pkcs11.NewMechanism(pkcs11.CKM_ECDH1_DERIVE, ecdhParams)
+	sk := p11.PrivateKey(c.privKeyObj)
+
+	tmpKey, err := sk.Derive(*mech, attrTemplate)
+	if err != nil {
+		return nil, err
+	}
+	if tmpKey == nil || len(tmpKey) == 0 {
+		return nil, fmt.Errorf("got an empty secret key")
+	}
+	secret := make([]byte, NoiseKeySize)
+	copy(secret[:], tmpKey[:NoiseKeySize])
+	return secret, nil
+}
+
+func (c *PKClient) GetPubKey() ([]byte, error) {
+	d, err := c.privKeyObj.Attribute(pkcs11.CKA_PUBLIC_KEY_INFO)
+	if err != nil {
+		return nil, err
+	}
+	if d != nil && len(d) > 0 {
+		return formatPubkeyFromPublicKeyInfoAttr(d)
+	}
+	c.pubKeyObj, err = c.findDeriveKey(c.id, c.label, false)
+	if err != nil {
+		return nil, fmt.Errorf("pkcs11 module gave us a nil CKA_PUBLIC_KEY_INFO, and looking up the public key also failed: %w", err)
+	}
+	d, err = c.pubKeyObj.Attribute(pkcs11.CKA_EC_POINT)
+	if err != nil {
+		return nil, fmt.Errorf("pkcs11 module gave us a nil CKA_PUBLIC_KEY_INFO, and reading CKA_EC_POINT also failed: %w", err)
+	}
+	if d == nil || len(d) < 1 {
+		return nil, fmt.Errorf("pkcs11 module gave us a nil or empty CKA_EC_POINT")
+	}
+	switch len(d) {
+	case 65: //length of 0x04 + len(X) + len(Y)
+		return d, nil
+	case 67: //as above, DER-encoded IIRC?
+		return d[2:], nil
+	default:
+		return nil, fmt.Errorf("unknown public key length: %d", len(d))
+	}
+}

+ 30 - 0
pkclient/pkclient_stub.go

@@ -0,0 +1,30 @@
+//go:build !cgo || !pkcs11
+
+package pkclient
+
+import "errors"
+
+type PKClient struct {
+}
+
+var notImplemented = errors.New("not implemented")
+
+func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) {
+	return nil, notImplemented
+}
+
+func (c *PKClient) Close() error {
+	return nil
+}
+
+func (c *PKClient) SignASN1(data []byte) ([]byte, error) {
+	return nil, notImplemented
+}
+
+func (c *PKClient) DeriveNoise(_ []byte) ([]byte, error) {
+	return nil, notImplemented
+}
+
+func (c *PKClient) GetPubKey() ([]byte, error) {
+	return nil, notImplemented
+}

+ 27 - 13
pki.go

@@ -141,29 +141,43 @@ func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*Cert
 	return cs, nil
 }
 
-func newCertStateFromConfig(c *config.C) (*CertState, error) {
+func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
 	var pemPrivateKey []byte
-	var err error
-
-	privPathOrPEM := c.GetString("pki.key", "")
-	if privPathOrPEM == "" {
-		return nil, errors.New("no pki.key path or PEM data provided")
-	}
-
 	if strings.Contains(privPathOrPEM, "-----BEGIN") {
 		pemPrivateKey = []byte(privPathOrPEM)
 		privPathOrPEM = "<inline>"
-
+		rawKey, _, curve, err = cert.UnmarshalPrivateKey(pemPrivateKey)
+		if err != nil {
+			return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
+		}
+	} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
+		rawKey = []byte(privPathOrPEM)
+		return rawKey, cert.Curve_P256, true, nil
 	} else {
 		pemPrivateKey, err = os.ReadFile(privPathOrPEM)
 		if err != nil {
-			return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
+			return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
+		}
+		rawKey, _, curve, err = cert.UnmarshalPrivateKey(pemPrivateKey)
+		if err != nil {
+			return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
 		}
 	}
 
-	rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey)
+	return
+}
+
+func newCertStateFromConfig(c *config.C) (*CertState, error) {
+	var err error
+
+	privPathOrPEM := c.GetString("pki.key", "")
+	if privPathOrPEM == "" {
+		return nil, errors.New("no pki.key path or PEM data provided")
+	}
+
+	rawKey, curve, isPkcs11, err := loadPrivateKey(privPathOrPEM)
 	if err != nil {
-		return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
+		return nil, err
 	}
 
 	var rawCert []byte
@@ -188,7 +202,7 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
 	if err != nil {
 		return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
 	}
-
+	nebulaCert.Pkcs11Backed = isPkcs11
 	if nebulaCert.Expired(time.Now()) {
 		return nil, fmt.Errorf("nebula certificate for this host is expired")
 	}