Explorar el Código

Merge remote-tracking branch 'origin/master' into multiport

Wade Simmons hace 2 años
padre
commit
28ecfcbc03

+ 7 - 0
CHANGELOG.md

@@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ## [Unreleased]
 
+### Added
+- `nebula-cert ca` now supports encrypting the CA's private key with a
+  passphrase. Pass `-encrypt` in order to be prompted for a passphrase.
+  Encryption is performed using AES-256-GCM and Argon2id for KDF. KDF
+  parameters default to RFC recommendations, but can be overridden via CLI
+  flags `-argon-memory`, `-argon-parallelism`, and `-argon-iterations`.
+
 ## [1.6.1] - 2022-09-26
 
 ### Fixed

+ 145 - 6
cert/cert.go

@@ -9,7 +9,9 @@ import (
 	"encoding/hex"
 	"encoding/json"
 	"encoding/pem"
+	"errors"
 	"fmt"
+	"math"
 	"net"
 	"time"
 
@@ -21,11 +23,12 @@ import (
 const publicKeyLen = 32
 
 const (
-	CertBanner              = "NEBULA CERTIFICATE"
-	X25519PrivateKeyBanner  = "NEBULA X25519 PRIVATE KEY"
-	X25519PublicKeyBanner   = "NEBULA X25519 PUBLIC KEY"
-	Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
-	Ed25519PublicKeyBanner  = "NEBULA ED25519 PUBLIC KEY"
+	CertBanner                       = "NEBULA CERTIFICATE"
+	X25519PrivateKeyBanner           = "NEBULA X25519 PRIVATE KEY"
+	X25519PublicKeyBanner            = "NEBULA X25519 PUBLIC KEY"
+	EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
+	Ed25519PrivateKeyBanner          = "NEBULA ED25519 PRIVATE KEY"
+	Ed25519PublicKeyBanner           = "NEBULA ED25519 PUBLIC KEY"
 )
 
 type NebulaCertificate struct {
@@ -48,8 +51,21 @@ type NebulaCertificateDetails struct {
 	InvertedGroups map[string]struct{}
 }
 
+type NebulaEncryptedData struct {
+	EncryptionMetadata NebulaEncryptionMetadata
+	Ciphertext         []byte
+}
+
+type NebulaEncryptionMetadata struct {
+	EncryptionAlgorithm string
+	Argon2Parameters    Argon2Parameters
+}
+
 type m map[string]interface{}
 
+// Returned if we try to unmarshal an encrypted private key without a passphrase
+var ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
+
 // UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert
 func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) {
 	if len(b) == 0 {
@@ -144,6 +160,30 @@ func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte {
 	return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key})
 }
 
+// EncryptAndMarshalX25519PrivateKey is a simple helper to encrypt and PEM encode an X25519 private key
+func EncryptAndMarshalEd25519PrivateKey(b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) {
+	ciphertext, err := aes256Encrypt(passphrase, kdfParams, b)
+	if err != nil {
+		return nil, err
+	}
+
+	b, err = proto.Marshal(&RawNebulaEncryptedData{
+		EncryptionMetadata: &RawNebulaEncryptionMetadata{
+			EncryptionAlgorithm: "AES-256-GCM",
+			Argon2Parameters: &RawNebulaArgon2Parameters{
+				Version:     kdfParams.version,
+				Memory:      kdfParams.Memory,
+				Parallelism: uint32(kdfParams.Parallelism),
+				Iterations:  kdfParams.Iterations,
+				Salt:        kdfParams.salt,
+			},
+		},
+		Ciphertext: ciphertext,
+	})
+
+	return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil
+}
+
 // UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b
 // or an error on failure
 func UnmarshalX25519PrivateKey(b []byte) ([]byte, []byte, error) {
@@ -168,9 +208,13 @@ func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) {
 	if k == nil {
 		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
 	}
-	if k.Type != Ed25519PrivateKeyBanner {
+
+	if k.Type == EncryptedEd25519PrivateKeyBanner {
+		return nil, r, ErrPrivateKeyEncrypted
+	} else if k.Type != Ed25519PrivateKeyBanner {
 		return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 private key banner")
 	}
+
 	if len(k.Bytes) != ed25519.PrivateKeySize {
 		return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
 	}
@@ -178,6 +222,101 @@ func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) {
 	return k.Bytes, r, nil
 }
 
+// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert into its
+// protobuf-generated struct.
+func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
+	if len(b) == 0 {
+		return nil, fmt.Errorf("nil byte array")
+	}
+	var rned RawNebulaEncryptedData
+	err := proto.Unmarshal(b, &rned)
+	if err != nil {
+		return nil, err
+	}
+
+	if rned.EncryptionMetadata == nil {
+		return nil, fmt.Errorf("encoded EncryptionMetadata was nil")
+	}
+
+	if rned.EncryptionMetadata.Argon2Parameters == nil {
+		return nil, fmt.Errorf("encoded Argon2Parameters was nil")
+	}
+
+	params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters)
+	if err != nil {
+		return nil, err
+	}
+
+	ned := NebulaEncryptedData{
+		EncryptionMetadata: NebulaEncryptionMetadata{
+			EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm,
+			Argon2Parameters:    *params,
+		},
+		Ciphertext: rned.Ciphertext,
+	}
+
+	return &ned, nil
+}
+
+func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) {
+	if params.Version < math.MinInt32 || params.Version > math.MaxInt32 {
+		return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32)
+	}
+	if params.Memory <= 0 || params.Memory > math.MaxUint32 {
+		return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32))
+	}
+	if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 {
+		return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8)
+	}
+	if params.Iterations <= 0 || params.Iterations > math.MaxUint32 {
+		return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32))
+	}
+
+	return &Argon2Parameters{
+		version:     rune(params.Version),
+		Memory:      uint32(params.Memory),
+		Parallelism: uint8(params.Parallelism),
+		Iterations:  uint32(params.Iterations),
+		salt:        params.Salt,
+	}, nil
+
+}
+
+// DecryptAndUnmarshalEd25519PrivateKey will try to pem decode and decrypt an Ed25519 private key with
+// the given passphrase, returning any other bytes b or an error on failure
+func DecryptAndUnmarshalEd25519PrivateKey(passphrase, b []byte) (ed25519.PrivateKey, []byte, error) {
+	k, r := pem.Decode(b)
+	if k == nil {
+		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+
+	if k.Type != EncryptedEd25519PrivateKeyBanner {
+		return nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519 private key banner")
+	}
+
+	ned, err := UnmarshalNebulaEncryptedData(k.Bytes)
+	if err != nil {
+		return nil, r, err
+	}
+
+	var bytes []byte
+	switch ned.EncryptionMetadata.EncryptionAlgorithm {
+	case "AES-256-GCM":
+		bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext)
+		if err != nil {
+			return nil, r, err
+		}
+	default:
+		return nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm)
+	}
+
+	if len(bytes) != ed25519.PrivateKeySize {
+		return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
+	}
+
+	return bytes, r, nil
+}
+
 // MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key
 func MarshalX25519PublicKey(b []byte) []byte {
 	return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})

+ 270 - 11
cert/cert.pb.go

@@ -1,7 +1,7 @@
 // Code generated by protoc-gen-go. DO NOT EDIT.
 // versions:
 // 	protoc-gen-go v1.28.0
-// 	protoc        v3.20.0
+// 	protoc        v3.19.4
 // source: cert.proto
 
 package cert
@@ -188,6 +188,195 @@ func (x *RawNebulaCertificateDetails) GetIssuer() []byte {
 	return nil
 }
 
+type RawNebulaEncryptedData struct {
+	state         protoimpl.MessageState
+	sizeCache     protoimpl.SizeCache
+	unknownFields protoimpl.UnknownFields
+
+	EncryptionMetadata *RawNebulaEncryptionMetadata `protobuf:"bytes,1,opt,name=EncryptionMetadata,proto3" json:"EncryptionMetadata,omitempty"`
+	Ciphertext         []byte                       `protobuf:"bytes,2,opt,name=Ciphertext,proto3" json:"Ciphertext,omitempty"`
+}
+
+func (x *RawNebulaEncryptedData) Reset() {
+	*x = RawNebulaEncryptedData{}
+	if protoimpl.UnsafeEnabled {
+		mi := &file_cert_proto_msgTypes[2]
+		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+		ms.StoreMessageInfo(mi)
+	}
+}
+
+func (x *RawNebulaEncryptedData) String() string {
+	return protoimpl.X.MessageStringOf(x)
+}
+
+func (*RawNebulaEncryptedData) ProtoMessage() {}
+
+func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message {
+	mi := &file_cert_proto_msgTypes[2]
+	if protoimpl.UnsafeEnabled && x != nil {
+		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+		if ms.LoadMessageInfo() == nil {
+			ms.StoreMessageInfo(mi)
+		}
+		return ms
+	}
+	return mi.MessageOf(x)
+}
+
+// Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead.
+func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) {
+	return file_cert_proto_rawDescGZIP(), []int{2}
+}
+
+func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata {
+	if x != nil {
+		return x.EncryptionMetadata
+	}
+	return nil
+}
+
+func (x *RawNebulaEncryptedData) GetCiphertext() []byte {
+	if x != nil {
+		return x.Ciphertext
+	}
+	return nil
+}
+
+type RawNebulaEncryptionMetadata struct {
+	state         protoimpl.MessageState
+	sizeCache     protoimpl.SizeCache
+	unknownFields protoimpl.UnknownFields
+
+	EncryptionAlgorithm string                     `protobuf:"bytes,1,opt,name=EncryptionAlgorithm,proto3" json:"EncryptionAlgorithm,omitempty"`
+	Argon2Parameters    *RawNebulaArgon2Parameters `protobuf:"bytes,2,opt,name=Argon2Parameters,proto3" json:"Argon2Parameters,omitempty"`
+}
+
+func (x *RawNebulaEncryptionMetadata) Reset() {
+	*x = RawNebulaEncryptionMetadata{}
+	if protoimpl.UnsafeEnabled {
+		mi := &file_cert_proto_msgTypes[3]
+		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+		ms.StoreMessageInfo(mi)
+	}
+}
+
+func (x *RawNebulaEncryptionMetadata) String() string {
+	return protoimpl.X.MessageStringOf(x)
+}
+
+func (*RawNebulaEncryptionMetadata) ProtoMessage() {}
+
+func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message {
+	mi := &file_cert_proto_msgTypes[3]
+	if protoimpl.UnsafeEnabled && x != nil {
+		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+		if ms.LoadMessageInfo() == nil {
+			ms.StoreMessageInfo(mi)
+		}
+		return ms
+	}
+	return mi.MessageOf(x)
+}
+
+// Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead.
+func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) {
+	return file_cert_proto_rawDescGZIP(), []int{3}
+}
+
+func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string {
+	if x != nil {
+		return x.EncryptionAlgorithm
+	}
+	return ""
+}
+
+func (x *RawNebulaEncryptionMetadata) GetArgon2Parameters() *RawNebulaArgon2Parameters {
+	if x != nil {
+		return x.Argon2Parameters
+	}
+	return nil
+}
+
+type RawNebulaArgon2Parameters struct {
+	state         protoimpl.MessageState
+	sizeCache     protoimpl.SizeCache
+	unknownFields protoimpl.UnknownFields
+
+	Version     int32  `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"` // rune in Go
+	Memory      uint32 `protobuf:"varint,2,opt,name=memory,proto3" json:"memory,omitempty"`
+	Parallelism uint32 `protobuf:"varint,4,opt,name=parallelism,proto3" json:"parallelism,omitempty"` // uint8 in Go
+	Iterations  uint32 `protobuf:"varint,3,opt,name=iterations,proto3" json:"iterations,omitempty"`
+	Salt        []byte `protobuf:"bytes,5,opt,name=salt,proto3" json:"salt,omitempty"`
+}
+
+func (x *RawNebulaArgon2Parameters) Reset() {
+	*x = RawNebulaArgon2Parameters{}
+	if protoimpl.UnsafeEnabled {
+		mi := &file_cert_proto_msgTypes[4]
+		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+		ms.StoreMessageInfo(mi)
+	}
+}
+
+func (x *RawNebulaArgon2Parameters) String() string {
+	return protoimpl.X.MessageStringOf(x)
+}
+
+func (*RawNebulaArgon2Parameters) ProtoMessage() {}
+
+func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message {
+	mi := &file_cert_proto_msgTypes[4]
+	if protoimpl.UnsafeEnabled && x != nil {
+		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+		if ms.LoadMessageInfo() == nil {
+			ms.StoreMessageInfo(mi)
+		}
+		return ms
+	}
+	return mi.MessageOf(x)
+}
+
+// Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead.
+func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) {
+	return file_cert_proto_rawDescGZIP(), []int{4}
+}
+
+func (x *RawNebulaArgon2Parameters) GetVersion() int32 {
+	if x != nil {
+		return x.Version
+	}
+	return 0
+}
+
+func (x *RawNebulaArgon2Parameters) GetMemory() uint32 {
+	if x != nil {
+		return x.Memory
+	}
+	return 0
+}
+
+func (x *RawNebulaArgon2Parameters) GetParallelism() uint32 {
+	if x != nil {
+		return x.Parallelism
+	}
+	return 0
+}
+
+func (x *RawNebulaArgon2Parameters) GetIterations() uint32 {
+	if x != nil {
+		return x.Iterations
+	}
+	return 0
+}
+
+func (x *RawNebulaArgon2Parameters) GetSalt() []byte {
+	if x != nil {
+		return x.Salt
+	}
+	return nil
+}
+
 var File_cert_proto protoreflect.FileDescriptor
 
 var file_cert_proto_rawDesc = []byte{
@@ -215,9 +404,38 @@ var file_cert_proto_rawDesc = []byte{
 	0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20,
 	0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73,
 	0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65,
-	0x72, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
-	0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63,
-	0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+	0x72, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45,
+	0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12, 0x51, 0x0a, 0x12,
+	0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
+	0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e,
+	0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
+	0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12, 0x45, 0x6e, 0x63,
+	0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12,
+	0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x18, 0x02, 0x20,
+	0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x22,
+	0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63,
+	0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12,
+	0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67,
+	0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x45, 0x6e,
+	0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68,
+	0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d,
+	0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x65,
+	0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f,
+	0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52, 0x10, 0x41, 0x72,
+	0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x22, 0xa3,
+	0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f,
+	0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07,
+	0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76,
+	0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79,
+	0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x12, 0x20,
+	0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x18, 0x04, 0x20,
+	0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d,
+	0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03,
+	0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73,
+	0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04,
+	0x73, 0x61, 0x6c, 0x74, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63,
+	0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c,
+	0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
 }
 
 var (
@@ -232,18 +450,23 @@ func file_cert_proto_rawDescGZIP() []byte {
 	return file_cert_proto_rawDescData
 }
 
-var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
+var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
 var file_cert_proto_goTypes = []interface{}{
 	(*RawNebulaCertificate)(nil),        // 0: cert.RawNebulaCertificate
 	(*RawNebulaCertificateDetails)(nil), // 1: cert.RawNebulaCertificateDetails
+	(*RawNebulaEncryptedData)(nil),      // 2: cert.RawNebulaEncryptedData
+	(*RawNebulaEncryptionMetadata)(nil), // 3: cert.RawNebulaEncryptionMetadata
+	(*RawNebulaArgon2Parameters)(nil),   // 4: cert.RawNebulaArgon2Parameters
 }
 var file_cert_proto_depIdxs = []int32{
 	1, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails
-	1, // [1:1] is the sub-list for method output_type
-	1, // [1:1] is the sub-list for method input_type
-	1, // [1:1] is the sub-list for extension type_name
-	1, // [1:1] is the sub-list for extension extendee
-	0, // [0:1] is the sub-list for field type_name
+	3, // 1: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata
+	4, // 2: cert.RawNebulaEncryptionMetadata.Argon2Parameters:type_name -> cert.RawNebulaArgon2Parameters
+	3, // [3:3] is the sub-list for method output_type
+	3, // [3:3] is the sub-list for method input_type
+	3, // [3:3] is the sub-list for extension type_name
+	3, // [3:3] is the sub-list for extension extendee
+	0, // [0:3] is the sub-list for field type_name
 }
 
 func init() { file_cert_proto_init() }
@@ -276,6 +499,42 @@ func file_cert_proto_init() {
 				return nil
 			}
 		}
+		file_cert_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
+			switch v := v.(*RawNebulaEncryptedData); i {
+			case 0:
+				return &v.state
+			case 1:
+				return &v.sizeCache
+			case 2:
+				return &v.unknownFields
+			default:
+				return nil
+			}
+		}
+		file_cert_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
+			switch v := v.(*RawNebulaEncryptionMetadata); i {
+			case 0:
+				return &v.state
+			case 1:
+				return &v.sizeCache
+			case 2:
+				return &v.unknownFields
+			default:
+				return nil
+			}
+		}
+		file_cert_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
+			switch v := v.(*RawNebulaArgon2Parameters); i {
+			case 0:
+				return &v.state
+			case 1:
+				return &v.sizeCache
+			case 2:
+				return &v.unknownFields
+			default:
+				return nil
+			}
+		}
 	}
 	type x struct{}
 	out := protoimpl.TypeBuilder{
@@ -283,7 +542,7 @@ func file_cert_proto_init() {
 			GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
 			RawDescriptor: file_cert_proto_rawDesc,
 			NumEnums:      0,
-			NumMessages:   2,
+			NumMessages:   5,
 			NumExtensions: 0,
 			NumServices:   0,
 		},

+ 19 - 1
cert/cert.proto

@@ -26,4 +26,22 @@ message RawNebulaCertificateDetails {
 
     // sha-256 of the issuer certificate, if this field is blank the cert is self-signed
     bytes Issuer = 9;
-}
+}
+
+message RawNebulaEncryptedData {
+	RawNebulaEncryptionMetadata EncryptionMetadata = 1;
+	bytes Ciphertext = 2;
+}
+
+message RawNebulaEncryptionMetadata {
+	string EncryptionAlgorithm = 1;
+	RawNebulaArgon2Parameters Argon2Parameters = 2;
+}
+
+message RawNebulaArgon2Parameters {
+	int32 version = 1; // rune in Go
+	uint32 memory = 2;
+	uint32 parallelism = 4; // uint8 in Go
+	uint32 iterations = 3;
+	bytes salt = 5;
+}

+ 85 - 0
cert/cert_test.go

@@ -578,6 +578,91 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
 	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
 }
 
+func TestDecryptAndUnmarshalEd25519PrivateKey(t *testing.T) {
+	passphrase := []byte("DO NOT USE THIS KEY")
+	privKey := []byte(`# A good key
+-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
+oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
+qrlJ69wer3ZUHFXA
+-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+`)
+	shortKey := []byte(`# A key which, once decrypted, is too short
+-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
+k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
+GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
+rQr3bdH3Oy/WiYU=
+-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner (not encrypted)
+-----BEGIN NEBULA ED25519 PRIVATE KEY-----
+bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG
+XgLvodMXZJuaFPssp+WwtA==
+-----END NEBULA ED25519 PRIVATE KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
+oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
+qrlJ69wer3ZUHFXA
+-END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+`)
+
+	keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, err := DecryptAndUnmarshalEd25519PrivateKey(passphrase, keyBundle)
+	assert.Nil(t, err)
+	assert.Len(t, k, 64)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+
+	// Fail due to short key
+	k, rest, err = DecryptAndUnmarshalEd25519PrivateKey(passphrase, rest)
+	assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+
+	// Fail due to invalid banner
+	k, rest, err = DecryptAndUnmarshalEd25519PrivateKey(passphrase, rest)
+	assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519 private key banner")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	k, rest, err = DecryptAndUnmarshalEd25519PrivateKey(passphrase, rest)
+	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to invalid passphrase
+	k, rest, err = DecryptAndUnmarshalEd25519PrivateKey([]byte("invalid passphrase"), privKey)
+	assert.EqualError(t, err, "invalid passphrase or corrupt private key")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, []byte{})
+}
+
+func TestEncryptAndMarshalEd25519PrivateKey(t *testing.T) {
+	// Having proved that decryption works correctly above, we can test the
+	// encryption function produces a value which can be decrypted
+	passphrase := []byte("passphrase")
+	bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
+	kdfParams := NewArgon2Parameters(64*1024, 4, 3)
+	key, err := EncryptAndMarshalEd25519PrivateKey(bytes, passphrase, kdfParams)
+	assert.Nil(t, err)
+
+	// Verify the "key" can be decrypted successfully
+	k, rest, err := DecryptAndUnmarshalEd25519PrivateKey(passphrase, key)
+	assert.Len(t, k, 64)
+	assert.Equal(t, rest, []byte{})
+	assert.Nil(t, err)
+
+	// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
+}
+
 func TestUnmarshalX25519PrivateKey(t *testing.T) {
 	privKey := []byte(`# A good key
 -----BEGIN NEBULA X25519 PRIVATE KEY-----

+ 140 - 0
cert/crypto.go

@@ -0,0 +1,140 @@
+package cert
+
+import (
+	"crypto/aes"
+	"crypto/cipher"
+	"crypto/rand"
+	"fmt"
+	"io"
+
+	"golang.org/x/crypto/argon2"
+)
+
+// KDF factors
+type Argon2Parameters struct {
+	version     rune
+	Memory      uint32 // KiB
+	Parallelism uint8
+	Iterations  uint32
+	salt        []byte
+}
+
+// Returns a new Argon2Parameters object with current version set
+func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters {
+	return &Argon2Parameters{
+		version:     argon2.Version,
+		Memory:      memory, // KiB
+		Parallelism: parallelism,
+		Iterations:  iterations,
+	}
+}
+
+// Encrypts data using AES-256-GCM and the Argon2id key derivation function
+func aes256Encrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte) ([]byte, error) {
+	key, err := aes256DeriveKey(passphrase, kdfParams)
+	if err != nil {
+		return nil, err
+	}
+
+	// this should never happen, but since this dictates how our calls into the
+	// aes package behave and could be catastraphic, let's sanity check this
+	if len(key) != 32 {
+		return nil, fmt.Errorf("invalid AES-256 key length (%d) - cowardly refusing to encrypt", len(key))
+	}
+
+	block, err := aes.NewCipher(key)
+	if err != nil {
+		return nil, err
+	}
+
+	gcm, err := cipher.NewGCM(block)
+	if err != nil {
+		return nil, err
+	}
+
+	nonce := make([]byte, gcm.NonceSize())
+	if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
+		return nil, err
+	}
+
+	ciphertext := gcm.Seal(nil, nonce, data, nil)
+	blob := joinNonceCiphertext(nonce, ciphertext)
+
+	return blob, nil
+}
+
+// Decrypts data using AES-256-GCM and the Argon2id key derivation function
+// Expects the data to include an Argon2id parameter string before the encrypted data
+func aes256Decrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte) ([]byte, error) {
+	key, err := aes256DeriveKey(passphrase, kdfParams)
+	if err != nil {
+		return nil, err
+	}
+
+	block, err := aes.NewCipher(key)
+	if err != nil {
+		return nil, err
+	}
+
+	gcm, err := cipher.NewGCM(block)
+
+	nonce, ciphertext, err := splitNonceCiphertext(data, gcm.NonceSize())
+	if err != nil {
+		return nil, err
+	}
+
+	plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
+	if err != nil {
+		return nil, fmt.Errorf("invalid passphrase or corrupt private key")
+	}
+
+	return plaintext, nil
+}
+
+func aes256DeriveKey(passphrase []byte, params *Argon2Parameters) ([]byte, error) {
+	if params.salt == nil {
+		params.salt = make([]byte, 32)
+		if _, err := rand.Read(params.salt); err != nil {
+			return nil, err
+		}
+	}
+
+	// keySize of 32 bytes will result in AES-256 encryption
+	key, err := deriveKey(passphrase, 32, params)
+	if err != nil {
+		return nil, err
+	}
+
+	return key, nil
+}
+
+// Derives a key from a passphrase using Argon2id
+func deriveKey(passphrase []byte, keySize uint32, params *Argon2Parameters) ([]byte, error) {
+	if params.version != argon2.Version {
+		return nil, fmt.Errorf("incompatible Argon2 version: %d", params.version)
+	}
+
+	if params.salt == nil {
+		return nil, fmt.Errorf("salt must be set in argon2Parameters")
+	} else if len(params.salt) < 16 {
+		return nil, fmt.Errorf("salt must be at least 128  bits")
+	}
+
+	key := argon2.IDKey(passphrase, params.salt, params.Iterations, params.Memory, params.Parallelism, keySize)
+
+	return key, nil
+}
+
+// Prepends nonce to ciphertext
+func joinNonceCiphertext(nonce []byte, ciphertext []byte) []byte {
+	return append(nonce, ciphertext...)
+}
+
+// Splits nonce from ciphertext
+func splitNonceCiphertext(blob []byte, nonceSize int) ([]byte, []byte, error) {
+	if len(blob) <= nonceSize {
+		return nil, nil, fmt.Errorf("invalid ciphertext blob - blob shorter than nonce length")
+	}
+
+	return blob[:nonceSize], blob[nonceSize:], nil
+}

+ 25 - 0
cert/crypto_test.go

@@ -0,0 +1,25 @@
+package cert
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"golang.org/x/crypto/argon2"
+)
+
+func TestNewArgon2Parameters(t *testing.T) {
+	p := NewArgon2Parameters(64*1024, 4, 3)
+	assert.EqualValues(t, &Argon2Parameters{
+		version:     argon2.Version,
+		Memory:      64 * 1024,
+		Parallelism: 4,
+		Iterations:  3,
+	}, p)
+	p = NewArgon2Parameters(2*1024*1024, 2, 1)
+	assert.EqualValues(t, &Argon2Parameters{
+		version:     argon2.Version,
+		Memory:      2 * 1024 * 1024,
+		Parallelism: 2,
+		Iterations:  1,
+	}, p)
+}

+ 72 - 11
cmd/nebula-cert/ca.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"io"
 	"io/ioutil"
+	"math"
 	"net"
 	"os"
 	"strings"
@@ -17,15 +18,19 @@ import (
 )
 
 type caFlags struct {
-	set         *flag.FlagSet
-	name        *string
-	duration    *time.Duration
-	outKeyPath  *string
-	outCertPath *string
-	outQRPath   *string
-	groups      *string
-	ips         *string
-	subnets     *string
+	set              *flag.FlagSet
+	name             *string
+	duration         *time.Duration
+	outKeyPath       *string
+	outCertPath      *string
+	outQRPath        *string
+	groups           *string
+	ips              *string
+	subnets          *string
+	argonMemory      *uint
+	argonIterations  *uint
+	argonParallelism *uint
+	encryption       *bool
 }
 
 func newCaFlags() *caFlags {
@@ -39,10 +44,28 @@ func newCaFlags() *caFlags {
 	cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
 	cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses")
 	cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets")
+	cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
+	cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
+	cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
+	cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
 	return &cf
 }
 
-func ca(args []string, out io.Writer, errOut io.Writer) error {
+func parseArgonParameters(memory uint, parallelism uint, iterations uint) (*cert.Argon2Parameters, error) {
+	if memory <= 0 || memory > math.MaxUint32 {
+		return nil, newHelpErrorf("-argon-memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32))
+	}
+	if parallelism <= 0 || parallelism > math.MaxUint8 {
+		return nil, newHelpErrorf("-argon-parallelism must be be greater than 0 and no more than %d", math.MaxUint8)
+	}
+	if iterations <= 0 || iterations > math.MaxUint32 {
+		return nil, newHelpErrorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32))
+	}
+
+	return cert.NewArgon2Parameters(uint32(memory), uint8(parallelism), uint32(iterations)), nil
+}
+
+func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error {
 	cf := newCaFlags()
 	err := cf.set.Parse(args)
 	if err != nil {
@@ -58,6 +81,12 @@ func ca(args []string, out io.Writer, errOut io.Writer) error {
 	if err := mustFlagString("out-crt", cf.outCertPath); err != nil {
 		return err
 	}
+	var kdfParams *cert.Argon2Parameters
+	if *cf.encryption {
+		if kdfParams, err = parseArgonParameters(*cf.argonMemory, *cf.argonParallelism, *cf.argonIterations); err != nil {
+			return err
+		}
+	}
 
 	if *cf.duration <= 0 {
 		return &helpError{"-duration must be greater than 0"}
@@ -109,6 +138,28 @@ func ca(args []string, out io.Writer, errOut io.Writer) error {
 		}
 	}
 
+	var passphrase []byte
+	if *cf.encryption {
+		for i := 0; i < 5; i++ {
+			out.Write([]byte("Enter passphrase: "))
+			passphrase, err = pr.ReadPassword()
+
+			if err == ErrNoTerminal {
+				return fmt.Errorf("out-key must be encrypted interactively")
+			} else if err != nil {
+				return fmt.Errorf("error reading passphrase: %s", err)
+			}
+
+			if len(passphrase) > 0 {
+				break
+			}
+		}
+
+		if len(passphrase) == 0 {
+			return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
+		}
+	}
+
 	pub, rawPriv, err := ed25519.GenerateKey(rand.Reader)
 	if err != nil {
 		return fmt.Errorf("error while generating ed25519 keys: %s", err)
@@ -140,7 +191,17 @@ func ca(args []string, out io.Writer, errOut io.Writer) error {
 		return fmt.Errorf("error while signing: %s", err)
 	}
 
-	err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalEd25519PrivateKey(rawPriv), 0600)
+	if *cf.encryption {
+		b, err := cert.EncryptAndMarshalEd25519PrivateKey(rawPriv, passphrase, kdfParams)
+		if err != nil {
+			return fmt.Errorf("error while encrypting out-key: %s", err)
+		}
+
+		err = ioutil.WriteFile(*cf.outKeyPath, b, 0600)
+	} else {
+		err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalEd25519PrivateKey(rawPriv), 0600)
+	}
+
 	if err != nil {
 		return fmt.Errorf("error while writing out-key: %s", err)
 	}

+ 86 - 9
cmd/nebula-cert/ca_test.go

@@ -5,8 +5,11 @@ package main
 
 import (
 	"bytes"
+	"encoding/pem"
+	"errors"
 	"io/ioutil"
 	"os"
+	"strings"
 	"testing"
 	"time"
 
@@ -26,8 +29,16 @@ func Test_caHelp(t *testing.T) {
 	assert.Equal(
 		t,
 		"Usage of "+os.Args[0]+" ca <flags>: create a self signed certificate authority\n"+
+			"  -argon-iterations uint\n"+
+			"    \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+
+			"  -argon-memory uint\n"+
+			"    \tOptional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase (default 2097152)\n"+
+			"  -argon-parallelism uint\n"+
+			"    \tOptional: Argon2 parallelism parameter used for encrypted private key passphrase (default 4)\n"+
 			"  -duration duration\n"+
 			"    \tOptional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\" (default 8760h0m0s)\n"+
+			"  -encrypt\n"+
+			"    \tOptional: prompt for passphrase and write out-key in an encrypted format\n"+
 			"  -groups string\n"+
 			"    \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
 			"  -ips string\n"+
@@ -50,18 +61,38 @@ func Test_ca(t *testing.T) {
 	ob := &bytes.Buffer{}
 	eb := &bytes.Buffer{}
 
+	nopw := &StubPasswordReader{
+		password: []byte(""),
+		err:      nil,
+	}
+
+	errpw := &StubPasswordReader{
+		password: []byte(""),
+		err:      errors.New("stub error"),
+	}
+
+	passphrase := []byte("DO NOT USE THIS KEY")
+	testpw := &StubPasswordReader{
+		password: passphrase,
+		err:      nil,
+	}
+
+	pwPromptOb := "Enter passphrase: "
+
 	// required args
-	assertHelpError(t, ca([]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb), "-name is required")
+	assertHelpError(t, ca(
+		[]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
+	), "-name is required")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
 	// ipv4 only ips
-	assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb), "invalid ip definition: can only be ipv4, have 100::100/100")
+	assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
 	// ipv4 only subnets
-	assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb), "invalid subnet definition: can only be ipv4, have 100::100/100")
+	assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
@@ -69,7 +100,7 @@ func Test_ca(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
-	assert.EqualError(t, ca(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
+	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
@@ -82,7 +113,7 @@ func Test_ca(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
-	assert.EqualError(t, ca(args, ob, eb), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
+	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
@@ -96,7 +127,7 @@ func Test_ca(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.Nil(t, ca(args, ob, eb))
+	assert.Nil(t, ca(args, ob, eb, nopw))
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
@@ -122,19 +153,65 @@ func Test_ca(t *testing.T) {
 	assert.Equal(t, "", lCrt.Details.Issuer)
 	assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey))
 
+	// test encrypted key
+	os.Remove(keyF.Name())
+	os.Remove(crtF.Name())
+	ob.Reset()
+	eb.Reset()
+	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	assert.Nil(t, ca(args, ob, eb, testpw))
+	assert.Equal(t, pwPromptOb, ob.String())
+	assert.Equal(t, "", eb.String())
+
+	// read encrypted key file and verify default params
+	rb, _ = ioutil.ReadFile(keyF.Name())
+	k, _ := pem.Decode(rb)
+	ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes)
+	assert.Nil(t, err)
+	// we won't know salt in advance, so just check start of string
+	assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory)
+	assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism)
+	assert.Equal(t, uint32(1), ned.EncryptionMetadata.Argon2Parameters.Iterations)
+
+	// verify the key is valid and decrypt-able
+	lKey, b, err = cert.DecryptAndUnmarshalEd25519PrivateKey(passphrase, rb)
+	assert.Nil(t, err)
+	assert.Len(t, b, 0)
+	assert.Len(t, lKey, 64)
+
+	// test when reading passsword results in an error
+	os.Remove(keyF.Name())
+	os.Remove(crtF.Name())
+	ob.Reset()
+	eb.Reset()
+	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	assert.Error(t, ca(args, ob, eb, errpw))
+	assert.Equal(t, pwPromptOb, ob.String())
+	assert.Equal(t, "", eb.String())
+
+	// test when user fails to enter a password
+	os.Remove(keyF.Name())
+	os.Remove(crtF.Name())
+	ob.Reset()
+	eb.Reset()
+	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
+	assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
+	assert.Equal(t, "", eb.String())
+
 	// create valid cert/key for overwrite tests
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.Nil(t, ca(args, ob, eb))
+	assert.Nil(t, ca(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing certificate file
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.EqualError(t, ca(args, ob, eb), "refusing to overwrite existing CA key: "+keyF.Name())
+	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
@@ -143,7 +220,7 @@ func Test_ca(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.EqualError(t, ca(args, ob, eb), "refusing to overwrite existing CA cert: "+crtF.Name())
+	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	os.Remove(keyF.Name())

+ 2 - 2
cmd/nebula-cert/main.go

@@ -62,11 +62,11 @@ func main() {
 
 	switch args[0] {
 	case "ca":
-		err = ca(args[1:], os.Stdout, os.Stderr)
+		err = ca(args[1:], os.Stdout, os.Stderr, StdinPasswordReader{})
 	case "keygen":
 		err = keygen(args[1:], os.Stdout, os.Stderr)
 	case "sign":
-		err = signCert(args[1:], os.Stdout, os.Stderr)
+		err = signCert(args[1:], os.Stdout, os.Stderr, StdinPasswordReader{})
 	case "print":
 		err = printCert(args[1:], os.Stdout, os.Stderr)
 	case "verify":

+ 28 - 0
cmd/nebula-cert/passwords.go

@@ -0,0 +1,28 @@
+package main
+
+import (
+	"errors"
+	"fmt"
+	"os"
+
+	"golang.org/x/term"
+)
+
+var ErrNoTerminal = errors.New("cannot read password from nonexistent terminal")
+
+type PasswordReader interface {
+	ReadPassword() ([]byte, error)
+}
+
+type StdinPasswordReader struct{}
+
+func (pr StdinPasswordReader) ReadPassword() ([]byte, error) {
+	if !term.IsTerminal(int(os.Stdin.Fd())) {
+		return nil, ErrNoTerminal
+	}
+
+	password, err := term.ReadPassword(int(os.Stdin.Fd()))
+	fmt.Println()
+
+	return password, err
+}

+ 10 - 0
cmd/nebula-cert/passwords_test.go

@@ -0,0 +1,10 @@
+package main
+
+type StubPasswordReader struct {
+	password []byte
+	err      error
+}
+
+func (pr *StubPasswordReader) ReadPassword() ([]byte, error) {
+	return pr.password, pr.err
+}

+ 32 - 3
cmd/nebula-cert/sign.go

@@ -1,6 +1,7 @@
 package main
 
 import (
+	"crypto/ed25519"
 	"crypto/rand"
 	"flag"
 	"fmt"
@@ -49,7 +50,7 @@ func newSignFlags() *signFlags {
 
 }
 
-func signCert(args []string, out io.Writer, errOut io.Writer) error {
+func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error {
 	sf := newSignFlags()
 	err := sf.set.Parse(args)
 	if err != nil {
@@ -77,8 +78,36 @@ func signCert(args []string, out io.Writer, errOut io.Writer) error {
 		return fmt.Errorf("error while reading ca-key: %s", err)
 	}
 
-	caKey, _, err := cert.UnmarshalEd25519PrivateKey(rawCAKey)
-	if err != nil {
+	var caKey ed25519.PrivateKey
+
+	// naively attempt to decode the private key as though it is not encrypted
+	caKey, _, err = cert.UnmarshalEd25519PrivateKey(rawCAKey)
+	if err == cert.ErrPrivateKeyEncrypted {
+		// ask for a passphrase until we get one
+		var passphrase []byte
+		for i := 0; i < 5; i++ {
+			out.Write([]byte("Enter passphrase: "))
+			passphrase, err = pr.ReadPassword()
+
+			if err == ErrNoTerminal {
+				return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
+			} else if err != nil {
+				return fmt.Errorf("error reading password: %s", err)
+			}
+
+			if len(passphrase) > 0 {
+				break
+			}
+		}
+		if len(passphrase) == 0 {
+			return fmt.Errorf("cannot open encrypted ca-key without passphrase")
+		}
+
+		caKey, _, err = cert.DecryptAndUnmarshalEd25519PrivateKey(passphrase, rawCAKey)
+		if err != nil {
+			return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
+		}
+	} else if err != nil {
 		return fmt.Errorf("error while parsing ca-key: %s", err)
 	}
 

+ 115 - 23
cmd/nebula-cert/sign_test.go

@@ -6,6 +6,7 @@ package main
 import (
 	"bytes"
 	"crypto/rand"
+	"errors"
 	"io/ioutil"
 	"os"
 	"testing"
@@ -58,17 +59,39 @@ func Test_signCert(t *testing.T) {
 	ob := &bytes.Buffer{}
 	eb := &bytes.Buffer{}
 
+	nopw := &StubPasswordReader{
+		password: []byte(""),
+		err:      nil,
+	}
+
+	errpw := &StubPasswordReader{
+		password: []byte(""),
+		err:      errors.New("stub error"),
+	}
+
+	passphrase := []byte("DO NOT USE THIS KEY")
+	testpw := &StubPasswordReader{
+		password: passphrase,
+		err:      nil,
+	}
+
 	// required args
-	assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb), "-name is required")
+	assertHelpError(t, signCert(
+		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
+	), "-name is required")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
-	assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb), "-ip is required")
+	assertHelpError(t, signCert(
+		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
+	), "-ip is required")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// cannot set -in-pub and -out-key
-	assertHelpError(t, signCert([]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb), "cannot set both -in-pub and -out-key")
+	assertHelpError(t, signCert(
+		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
+	), "cannot set both -in-pub and -out-key")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -76,7 +99,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb), "error while reading ca-key: open ./nope: "+NoSuchFileError)
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
 
 	// failed to unmarshal key
 	ob.Reset()
@@ -86,7 +109,7 @@ func Test_signCert(t *testing.T) {
 	defer os.Remove(caKeyF.Name())
 
 	args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb), "error while parsing ca-key: input did not contain a valid PEM encoded block")
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -98,7 +121,7 @@ func Test_signCert(t *testing.T) {
 
 	// failed to read cert
 	args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -110,7 +133,7 @@ func Test_signCert(t *testing.T) {
 	defer os.Remove(caCrtF.Name())
 
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -129,7 +152,7 @@ func Test_signCert(t *testing.T) {
 
 	// failed to read pub
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb), "error while reading in-pub: open ./nope: "+NoSuchFileError)
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -141,7 +164,7 @@ func Test_signCert(t *testing.T) {
 	defer os.Remove(inPubF.Name())
 
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb), "error while parsing in-pub: input did not contain a valid PEM encoded block")
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -155,14 +178,14 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assertHelpError(t, signCert(args, ob, eb), "invalid ip definition: invalid CIDR address: a1.1.1.1/24")
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: invalid CIDR address: a1.1.1.1/24")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assertHelpError(t, signCert(args, ob, eb), "invalid ip definition: can only be ipv4, have 100::100/100")
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -170,14 +193,14 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
-	assertHelpError(t, signCert(args, ob, eb), "invalid subnet definition: invalid CIDR address: a")
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: invalid CIDR address: a")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
-	assertHelpError(t, signCert(args, ob, eb), "invalid subnet definition: can only be ipv4, have 100::100/100")
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -191,7 +214,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
-	assert.EqualError(t, signCert(args, ob, eb), "refusing to sign, root certificate does not match private key")
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -199,7 +222,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
-	assert.EqualError(t, signCert(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -212,7 +235,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
-	assert.EqualError(t, signCert(args, ob, eb), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	os.Remove(keyF.Name())
@@ -226,7 +249,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.Nil(t, signCert(args, ob, eb))
+	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -268,7 +291,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
-	assert.Nil(t, signCert(args, ob, eb))
+	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -283,7 +306,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.EqualError(t, signCert(args, ob, eb), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate")
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -291,14 +314,14 @@ func Test_signCert(t *testing.T) {
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.Nil(t, signCert(args, ob, eb))
+	assert.Nil(t, signCert(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing key file
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing key: "+keyF.Name())
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -306,14 +329,83 @@ func Test_signCert(t *testing.T) {
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.Nil(t, signCert(args, ob, eb))
+	assert.Nil(t, signCert(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing certificate file
 	os.Remove(keyF.Name())
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.EqualError(t, signCert(args, ob, eb), "refusing to overwrite existing cert: "+crtF.Name())
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
+
+	// create valid cert/key using encrypted CA key
+	os.Remove(caKeyF.Name())
+	os.Remove(caCrtF.Name())
+	os.Remove(keyF.Name())
+	os.Remove(crtF.Name())
+	ob.Reset()
+	eb.Reset()
+
+	caKeyF, err = ioutil.TempFile("", "sign-cert.key")
+	assert.Nil(t, err)
+	defer os.Remove(caKeyF.Name())
+
+	caCrtF, err = ioutil.TempFile("", "sign-cert.crt")
+	assert.Nil(t, err)
+	defer os.Remove(caCrtF.Name())
+
+	// generate the encrypted key
+	caPub, caPriv, _ = ed25519.GenerateKey(rand.Reader)
+	kdfParams := cert.NewArgon2Parameters(64*1024, 4, 3)
+	b, _ = cert.EncryptAndMarshalEd25519PrivateKey(caPriv, passphrase, kdfParams)
+	caKeyF.Write(b)
+
+	ca = cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:      "ca",
+			NotBefore: time.Now(),
+			NotAfter:  time.Now().Add(time.Minute * 200),
+			PublicKey: caPub,
+			IsCA:      true,
+		},
+	}
+	b, _ = ca.MarshalToPEM()
+	caCrtF.Write(b)
+
+	// test with the proper password
+	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	assert.Nil(t, signCert(args, ob, eb, testpw))
+	assert.Equal(t, "Enter passphrase: ", ob.String())
+	assert.Empty(t, eb.String())
+
+	// test with the wrong password
+	ob.Reset()
+	eb.Reset()
+
+	testpw.password = []byte("invalid password")
+	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	assert.Error(t, signCert(args, ob, eb, testpw))
+	assert.Equal(t, "Enter passphrase: ", ob.String())
+	assert.Empty(t, eb.String())
+
+	// test with the user not entering a password
+	ob.Reset()
+	eb.Reset()
+
+	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	assert.Error(t, signCert(args, ob, eb, nopw))
+	// normally the user hitting enter on the prompt would add newlines between these
+	assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
+	assert.Empty(t, eb.String())
+
+	// test an error condition
+	ob.Reset()
+	eb.Reset()
+
+	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	assert.Error(t, signCert(args, ob, eb, errpw))
+	assert.Equal(t, "Enter passphrase: ", ob.String())
+	assert.Empty(t, eb.String())
 }

+ 18 - 8
connection_manager.go

@@ -183,12 +183,6 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 		return
 	}
 
-	if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
-		// We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel
-		n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
-		return
-	}
-
 	if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
 		// We have already sent a test packet and nothing was returned, this hostinfo is dead
 		hostinfo.logger(n.l).
@@ -205,12 +199,28 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 		Debug("Tunnel status")
 
 	if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
+		if !outTraffic {
+			// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
+			// Just maintain NAT state if configured to do so.
+			n.sendPunch(hostinfo)
+			n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
+			return
+
+		}
+
 		if n.punchy.GetTargetEverything() {
-			// Maybe the remote is sending us packets but our NAT is blocking it and since we are configured to punch to all
-			// known remotes, go ahead and do that AND send a test packet
+			// This is similar to the old punchy behavior with a slight optimization.
+			// We aren't receiving traffic but we are sending it, punch on all known
+			// ips in case we need to re-prime NAT state
 			n.sendPunch(hostinfo)
 		}
 
+		if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
+			// We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel
+			n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
+			return
+		}
+
 		// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
 		n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
 

+ 2 - 0
connection_manager_test.go

@@ -98,6 +98,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 
 	// Do another traffic check tick, this host should be pending deletion now
+	nc.Out(hostinfo.localIndexId)
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
@@ -175,6 +176,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 
 	// Do another traffic check tick, this host should be pending deletion now
+	nc.Out(hostinfo.localIndexId)
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)

+ 2 - 1
connection_state.go

@@ -9,6 +9,7 @@ import (
 	"github.com/flynn/noise"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/noiseutil"
 )
 
 const ReplayWindow = 1024
@@ -28,7 +29,7 @@ type ConnectionState struct {
 }
 
 func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
-	cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256)
+	cs := noise.NewCipherSuite(noise.DH25519, noiseutil.CipherAESGCM, noise.HashSHA256)
 	if f.cipher == "chachapoly" {
 		cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
 	}

+ 7 - 3
examples/config.yml

@@ -107,7 +107,7 @@ lighthouse:
 # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
 # however using port 0 will dynamically assign a port and is recommended for roaming nodes.
 listen:
-  # To listen on both any ipv4 and ipv6 use "[::]"
+  # To listen on both any ipv4 and ipv6 use "::"
   host: 0.0.0.0
   port: 4242
   # Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
@@ -204,20 +204,24 @@ tun:
   tx_queue: 500
   # Default MTU for every packet, safe setting is (and the default) 1300 for internet based traffic
   mtu: 1300
+
   # Route based MTU overrides, you have known vpn ip paths that can support larger MTUs you can increase/decrease them here
   routes:
     #- mtu: 8800
     #  route: 10.0.0.0/16
+
   # Unsafe routes allows you to route traffic over nebula to non-nebula nodes
   # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
   # NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate
-  # `mtu` will default to tun mtu if this option is not specified
-  # `metric` will default to 0 if this option is not specified
+  # `mtu`: will default to tun mtu if this option is not specified
+  # `metric`: will default to 0 if this option is not specified
+  # `install`: will default to true, controls whether this route is installed in the systems routing table.
   unsafe_routes:
     #- route: 172.16.1.0/24
     #  via: 192.168.100.99
     #  mtu: 1300
     #  metric: 100
+    #  install: true
 
   # EXPERIMENTAL: This option may change or disappear in the future.
   # Multiport spreads outgoing UDP packets across multiple UDP send ports,

+ 1 - 1
go.mod

@@ -24,6 +24,7 @@ require (
 	golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0
 	golang.org/x/net v0.8.0
 	golang.org/x/sys v0.6.0
+	golang.org/x/term v0.6.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
 	golang.zx2c4.com/wireguard/windows v0.5.3
 	google.golang.org/protobuf v1.29.0
@@ -42,7 +43,6 @@ require (
 	github.com/prometheus/procfs v0.9.0 // indirect
 	github.com/vishvananda/netns v0.0.4 // indirect
 	golang.org/x/mod v0.9.0 // indirect
-	golang.org/x/term v0.6.0 // indirect
 	golang.org/x/tools v0.7.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )

+ 0 - 2
go.sum

@@ -153,8 +153,6 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
 golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
 golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
-golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 h1:Jvc7gsqn21cJHCmAWx0LiimpP18LZmUxkT5Mp7EZ1mI=
-golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
 golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 h1:LGJsf5LRplCck6jUCH3dBL2dmycNruWNF5xugkSlfXw=
 golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
 golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=

+ 1 - 1
handshake.go

@@ -5,7 +5,7 @@ import (
 	"github.com/slackhq/nebula/udp"
 )
 
-func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H, hostinfo *HostInfo) {
+func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) {
 	// First remote allow list check before we know the vpnIp
 	if addr != nil {
 		if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {

+ 11 - 14
handshake_ix.go

@@ -77,7 +77,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
 	hostinfo.handshakeStart = time.Now()
 }
 
-func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H) {
+func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
 	ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
 	// Mark packet 1 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 1)
@@ -282,14 +282,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 				}
 				return
 			} else {
-				via2 := via.(*ViaSender)
-				if via2 == nil {
+				if via == nil {
 					f.l.Error("Handshake send failed: both addr and via are nil.")
 					return
 				}
-				hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
-				f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-				f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via2.relayHI.vpnIp).
+				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+				f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
+				f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					Info("Handshake message sent")
 				return
@@ -364,14 +363,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 				Info("Handshake message sent")
 		}
 	} else {
-		via2 := via.(*ViaSender)
-		if via2 == nil {
+		if via == nil {
 			f.l.Error("Handshake send failed: both addr and via are nil.")
 			return
 		}
-		hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
-		f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-		f.l.WithField("vpnIp", vpnIp).WithField("relay", via2.relayHI.vpnIp).
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+		f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
+		f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -387,7 +385,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 	return
 }
 
-func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *HostInfo, packet []byte, h *header.H) bool {
+func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool {
 	if hostinfo == nil {
 		// Nothing here to tear down, got a bogus stage 2 packet
 		return true
@@ -551,8 +549,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *
 	if addr != nil {
 		hostinfo.SetRemote(addr)
 	} else {
-		via2 := via.(*ViaSender)
-		hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
 	}
 
 	// Build up the radix for the firewall if we have subnets in the cert

+ 4 - 13
handshake_manager.go

@@ -56,10 +56,6 @@ type HandshakeManager struct {
 	multiPort MultiPortConfig
 	udpRaw    *udp.RawConn
 
-	// vpnIps is another map similar to the pending hostmap but tracks entries in the wheel instead
-	// this is to avoid situations where the same vpn ip enters the wheel and causes rapid fire handshaking
-	vpnIps map[iputil.VpnIp]struct{}
-
 	// can be used to trigger outbound handshake for the given vpnIp
 	trigger chan iputil.VpnIp
 }
@@ -73,7 +69,6 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
 		config:                 config,
 		trigger:                make(chan iputil.VpnIp, config.triggerBuffer),
 		OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
-		vpnIps:                 map[iputil.VpnIp]struct{}{},
 		messageMetrics:         config.messageMetrics,
 		metricInitiated:        metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
 		metricTimedOut:         metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
@@ -81,7 +76,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
 	}
 }
 
-func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
+func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
 	clockSource := time.NewTicker(c.config.tryInterval)
 	defer clockSource.Stop()
 
@@ -97,7 +92,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
 	}
 }
 
-func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) {
+func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
 	c.OutboundHandshakeTimer.Advance(now)
 	for {
 		vpnIp, has := c.OutboundHandshakeTimer.Purge()
@@ -108,10 +103,9 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E
 	}
 }
 
-func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
+func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) {
 	hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
 	if err != nil {
-		delete(c.vpnIps, vpnIp)
 		return
 	}
 	hostinfo.Lock()
@@ -324,10 +318,7 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H
 	hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init)
 
 	if created {
-		if _, ok := c.vpnIps[vpnIp]; !ok {
-			c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
-		}
-		c.vpnIps[vpnIp] = struct{}{}
+		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
 		c.metricInitiated.Inc(1)
 	}
 

+ 1 - 1
handshake_manager_test.go

@@ -84,7 +84,7 @@ func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
 	return
 }
 
-func (mw *mockEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) {
+func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
 	return
 }
 

+ 20 - 8
inside.go

@@ -6,6 +6,7 @@ import (
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/noiseutil"
 	"github.com/slackhq/nebula/udp"
 )
 
@@ -247,15 +248,17 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C
 // nb is a buffer used to store the nonce value, re-used for performance reasons.
 // out is a buffer used to store the result of the Encrypt operation
 // q indicates which writer to use to send the packet.
-func (f *Interface) SendVia(viaIfc interface{},
-	relayIfc interface{},
+func (f *Interface) SendVia(via *HostInfo,
+	relay *Relay,
 	ad,
 	nb,
 	out []byte,
 	nocopy bool,
 ) {
-	via := viaIfc.(*HostInfo)
-	relay := relayIfc.(*Relay)
+	if noiseutil.EncryptLockNeeded {
+		// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
+		via.ConnectionState.writeLock.Lock()
+	}
 	c := via.ConnectionState.messageCounter.Add(1)
 
 	out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
@@ -264,6 +267,9 @@ func (f *Interface) SendVia(viaIfc interface{},
 	// Authenticate the header and payload, but do not encrypt for this message type.
 	// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
 	if len(out)+len(ad)+via.ConnectionState.eKey.Overhead() > cap(out) {
+		if noiseutil.EncryptLockNeeded {
+			via.ConnectionState.writeLock.Unlock()
+		}
 		via.logger(f.l).
 			WithField("outCap", cap(out)).
 			WithField("payloadLen", len(ad)).
@@ -285,6 +291,9 @@ func (f *Interface) SendVia(viaIfc interface{},
 
 	var err error
 	out, err = via.ConnectionState.eKey.EncryptDanger(out, out, nil, c, nb)
+	if noiseutil.EncryptLockNeeded {
+		via.ConnectionState.writeLock.Unlock()
+	}
 	if err != nil {
 		via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia")
 		return
@@ -330,8 +339,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 		out = out[header.Len:]
 	}
 
-	//TODO: enable if we do more than 1 tun queue
-	//ci.writeLock.Lock()
+	if noiseutil.EncryptLockNeeded {
+		// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
+		ci.writeLock.Lock()
+	}
 	c := ci.messageCounter.Add(1)
 
 	//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
@@ -352,8 +363,9 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 
 	var err error
 	out, err = ci.eKey.EncryptDanger(out, out, p, c, nb)
-	//TODO: see above note on lock
-	//ci.writeLock.Unlock()
+	if noiseutil.EncryptLockNeeded {
+		ci.writeLock.Unlock()
+	}
 	if err != nil {
 		hostinfo.logger(f.l).WithError(err).
 			WithField("udpAddr", remote).WithField("counter", c).

+ 17 - 1
interface.go

@@ -16,6 +16,7 @@ import (
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/udp"
@@ -101,6 +102,18 @@ type MultiPortConfig struct {
 	TxHandshakeDelay int
 }
 
+type EncWriter interface {
+	SendVia(via *HostInfo,
+		relay *Relay,
+		ad,
+		nb,
+		out []byte,
+		nocopy bool,
+	)
+	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
+	Handshake(vpnIp iputil.VpnIp)
+}
+
 type sendRecvErrorConfig uint8
 
 const (
@@ -252,7 +265,7 @@ func (f *Interface) listenOut(i int) {
 
 	lhh := f.lightHouse.NewRequestHandler()
 	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
-	li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i)
+	li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
 }
 
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@@ -396,6 +409,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 
 	var rawStats func()
 
+	certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
+
 	for {
 		select {
 		case <-ctx.Done():
@@ -410,6 +425,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 				}
 				rawStats()
 			}
+			certExpirationGauge.Update(int64(f.certState.Load().certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
 		}
 	}
 }

+ 14 - 8
lighthouse.go

@@ -65,7 +65,7 @@ type LightHouse struct {
 	interval        atomic.Int64
 	updateCancel    context.CancelFunc
 	updateParentCtx context.Context
-	updateUdp       udp.EncWriter
+	updateUdp       EncWriter
 	nebulaPort      uint32 // 32 bits because protobuf does not have a uint16
 
 	advertiseAddrs atomic.Pointer[[]netIpAndPort]
@@ -382,7 +382,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 	return nil
 }
 
-func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
+func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList {
 	if !lh.IsLighthouseIP(ip) {
 		lh.QueryServer(ip, f)
 	}
@@ -396,7 +396,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
 }
 
 // This is asynchronous so no reply should be expected
-func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) {
+func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) {
 	if lh.amLighthouse {
 		return
 	}
@@ -629,7 +629,7 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
 	return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
 }
 
-func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
+func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
 	lh.updateParentCtx = ctx
 	lh.updateUdp = f
 
@@ -655,7 +655,7 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
 	}
 }
 
-func (lh *LightHouse) SendUpdate(f udp.EncWriter) {
+func (lh *LightHouse) SendUpdate(f EncWriter) {
 	var v4 []*Ip4AndPort
 	var v6 []*Ip6AndPort
 
@@ -760,7 +760,13 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
 	return lhh.meta
 }
 
-func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) {
+func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
+	return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) {
+		lhh.HandleRequest(rAddr, vpnIp, p, f)
+	}
+}
+
+func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) {
 	n := lhh.resetMeta()
 	err := n.Unmarshal(p)
 	if err != nil {
@@ -795,7 +801,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp,
 	}
 }
 
-func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) {
+func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) {
 	// Exit if we don't answer queries
 	if !lhh.lh.amLighthouse {
 		if lhh.l.Level >= logrus.DebugLevel {
@@ -928,7 +934,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	am.Unlock()
 }
 
-func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) {
+func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 		return
 	}

+ 1 - 1
lighthouse_test.go

@@ -372,7 +372,7 @@ type testEncWriter struct {
 	metaFilter *NebulaMeta_MessageType
 }
 
-func (tw *testEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) {
+func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
 }
 func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
 }

+ 14 - 1
main.go

@@ -151,8 +151,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	port := c.GetInt("listen.port", 0)
 
 	if !configTest {
+		rawListenHost := c.GetString("listen.host", "0.0.0.0")
+		var listenHost *net.IPAddr
+		if rawListenHost == "[::]" {
+			// Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve.
+			listenHost = &net.IPAddr{IP: net.IPv6zero}
+
+		} else {
+			listenHost, err = net.ResolveIPAddr("ip", rawListenHost)
+			if err != nil {
+				return nil, util.NewContextualError("Failed to resolve listen.host", nil, err)
+			}
+		}
+
 		for i := 0; i < routines; i++ {
-			udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64))
+			udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64))
 			if err != nil {
 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 			}

+ 80 - 0
noiseutil/boring.go

@@ -0,0 +1,80 @@
+//go:build boringcrypto
+// +build boringcrypto
+
+package noiseutil
+
+import (
+	"crypto/aes"
+	"crypto/cipher"
+	"encoding/binary"
+
+	// unsafe needed for go:linkname
+	_ "unsafe"
+
+	"github.com/flynn/noise"
+)
+
+// EncryptLockNeeded indicates if calls to Encrypt need a lock
+// This is true for boringcrypto because the Seal function verifies that the
+// nonce is strictly increasing.
+const EncryptLockNeeded = true
+
+// NewGCMTLS is no longer exposed in go1.19+, so we need to link it in
+// See: https://github.com/golang/go/issues/56326
+//
+// NewGCMTLS is the internal method used with boringcrypto that provices a
+// validated mode of AES-GCM which enforces the nonce is strictly
+// monotonically increasing.  This is the TLS 1.2 specification for nonce
+// generation (which also matches the method used by the Noise Protocol)
+//
+// - https://github.com/golang/go/blob/go1.19/src/crypto/tls/cipher_suites.go#L520-L522
+// - https://github.com/golang/go/blob/go1.19/src/crypto/internal/boring/aes.go#L235-L237
+// - https://github.com/golang/go/blob/go1.19/src/crypto/internal/boring/aes.go#L250
+// - https://github.com/google/boringssl/blob/ae223d6138807a13006342edfeef32e813246b39/include/openssl/aead.h#L379-L381
+// - https://github.com/google/boringssl/blob/ae223d6138807a13006342edfeef32e813246b39/crypto/fipsmodule/cipher/e_aes.c#L1082-L1093
+//
+//go:linkname newGCMTLS crypto/internal/boring.NewGCMTLS
+func newGCMTLS(c cipher.Block) (cipher.AEAD, error)
+
+type cipherFn struct {
+	fn   func([32]byte) noise.Cipher
+	name string
+}
+
+func (c cipherFn) Cipher(k [32]byte) noise.Cipher { return c.fn(k) }
+func (c cipherFn) CipherName() string             { return c.name }
+
+// CipherAESGCM is the AES256-GCM AEAD cipher (using NewGCMTLS when GoBoring is present)
+var CipherAESGCM noise.CipherFunc = cipherFn{cipherAESGCMBoring, "AESGCM"}
+
+func cipherAESGCMBoring(k [32]byte) noise.Cipher {
+	c, err := aes.NewCipher(k[:])
+	if err != nil {
+		panic(err)
+	}
+	gcm, err := newGCMTLS(c)
+	if err != nil {
+		panic(err)
+	}
+	return aeadCipher{
+		gcm,
+		func(n uint64) []byte {
+			var nonce [12]byte
+			binary.BigEndian.PutUint64(nonce[4:], n)
+			return nonce[:]
+		},
+	}
+}
+
+type aeadCipher struct {
+	cipher.AEAD
+	nonce func(uint64) []byte
+}
+
+func (c aeadCipher) Encrypt(out []byte, n uint64, ad, plaintext []byte) []byte {
+	return c.Seal(out, c.nonce(n), plaintext, ad)
+}
+
+func (c aeadCipher) Decrypt(out []byte, n uint64, ad, ciphertext []byte) ([]byte, error) {
+	return c.Open(out, c.nonce(n), ciphertext, ad)
+}

+ 39 - 0
noiseutil/boring_test.go

@@ -0,0 +1,39 @@
+//go:build boringcrypto
+// +build boringcrypto
+
+package noiseutil
+
+import (
+	"encoding/hex"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+// Ensure NewGCMTLS validates the nonce is non-repeating
+func TestNewGCMTLS(t *testing.T) {
+	// Test Case 16 from GCM Spec:
+	//  - (now dead link): http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-spec.pdf
+	//  - as listed in boringssl tests: https://github.com/google/boringssl/blob/fips-20220613/crypto/cipher_extra/test/cipher_tests.txt#L412-L418
+	key, _ := hex.DecodeString("feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308")
+	iv, _ := hex.DecodeString("cafebabefacedbaddecaf888")
+	plaintext, _ := hex.DecodeString("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39")
+	aad, _ := hex.DecodeString("feedfacedeadbeeffeedfacedeadbeefabaddad2")
+	expected, _ := hex.DecodeString("522dc1f099567d07f47f37a32a84427d643a8cdcbfe5c0c97598a2bd2555d1aa8cb08e48590dbb3da7b08b1056828838c5f61e6393ba7a0abcc9f662")
+	expectedTag, _ := hex.DecodeString("76fc6ece0f4e1768cddf8853bb2d551b")
+
+	expected = append(expected, expectedTag...)
+
+	var keyArray [32]byte
+	copy(keyArray[:], key)
+	c := CipherAESGCM.Cipher(keyArray)
+	aead := c.(aeadCipher).AEAD
+
+	dst := aead.Seal([]byte{}, iv, plaintext, aad)
+	assert.Equal(t, expected, dst)
+
+	// We expect this to fail since we are re-encrypting with a repeat IV
+	assert.PanicsWithError(t, "boringcrypto: EVP_AEAD_CTX_seal failed", func() {
+		dst = aead.Seal([]byte{}, iv, plaintext, aad)
+	})
+}

+ 14 - 0
noiseutil/notboring.go

@@ -0,0 +1,14 @@
+//go:build !boringcrypto
+// +build !boringcrypto
+
+package noiseutil
+
+import (
+	"github.com/flynn/noise"
+)
+
+// EncryptLockNeeded indicates if calls to Encrypt need a lock
+const EncryptLockNeeded = false
+
+// CipherAESGCM is the standard noise.CipherAESGCM when boringcrypto is not enabled
+var CipherAESGCM noise.CipherFunc = noise.CipherAESGCM

+ 15 - 0
noiseutil/notboring_test.go

@@ -0,0 +1,15 @@
+//go:build !boringcrypto
+// +build !boringcrypto
+
+package noiseutil
+
+import (
+	// NOTE: We have to force these imports here or boring_test.go fails to
+	// compile correctly. This seems to be a Go bug:
+	//
+	//     $ GOEXPERIMENT=boringcrypto go test ./noiseutil
+	//     # github.com/slackhq/nebula/noiseutil
+	//     boring_test.go:10:2: cannot find package
+
+	_ "github.com/stretchr/testify/assert"
+)

+ 18 - 2
outside.go

@@ -21,7 +21,23 @@ const (
 	minFwPacketLen = 4
 )
 
-func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
+func readOutsidePackets(f *Interface) udp.EncReader {
+	return func(
+		addr *udp.Addr,
+		out []byte,
+		packet []byte,
+		header *header.H,
+		fwPacket *firewall.Packet,
+		lhh udp.LightHouseHandlerFunc,
+		nb []byte,
+		q int,
+		localCache firewall.ConntrackCache,
+	) {
+		f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
+	}
+}
+
+func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
 	err := h.Parse(packet)
 	if err != nil {
 		// TODO: best if we return this and let caller log
@@ -149,7 +165,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by
 			return
 		}
 
-		lhf(addr, hostinfo.vpnIp, d, f)
+		lhf(addr, hostinfo.vpnIp, d)
 
 		// Fallthrough to the bottom to record incoming traffic
 

+ 20 - 8
overlay/route.go

@@ -14,10 +14,11 @@ import (
 )
 
 type Route struct {
-	MTU    int
-	Metric int
-	Cidr   *net.IPNet
-	Via    *iputil.VpnIp
+	MTU     int
+	Metric  int
+	Cidr    *net.IPNet
+	Via     *iputil.VpnIp
+	Install bool
 }
 
 func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) {
@@ -81,7 +82,8 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 		}
 
 		r := Route{
-			MTU: mtu,
+			Install: true,
+			MTU:     mtu,
 		}
 
 		_, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
@@ -182,10 +184,20 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 
 		viaVpnIp := iputil.Ip2VpnIp(nVia)
 
+		install := true
+		rInstall, ok := m["install"]
+		if ok {
+			install, err = strconv.ParseBool(fmt.Sprintf("%v", rInstall))
+			if err != nil {
+				return nil, fmt.Errorf("entry %v.install in tun.unsafe_routes is not a boolean: %v", i+1, err)
+			}
+		}
+
 		r := Route{
-			Via:    &viaVpnIp,
-			MTU:    mtu,
-			Metric: metric,
+			Via:     &viaVpnIp,
+			MTU:     mtu,
+			Metric:  metric,
+			Install: install,
 		}
 
 		_, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))

+ 17 - 5
overlay/route_test.go

@@ -92,6 +92,8 @@ func Test_parseRoutes(t *testing.T) {
 
 	tested := 0
 	for _, r := range routes {
+		assert.True(t, r.Install)
+
 		if r.MTU == 8000 {
 			assert.Equal(t, "10.0.0.1/32", r.Cidr.String())
 			tested++
@@ -205,35 +207,45 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
 
+	// bad install
+	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
+	routes, err = parseUnsafeRoutes(c, n)
+	assert.Nil(t, routes)
+	assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
+
 	// happy case
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
-		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29"},
-		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32"},
+		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"},
+		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0},
+		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
 	}}
 	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, err)
-	assert.Len(t, routes, 3)
+	assert.Len(t, routes, 4)
 
 	tested := 0
 	for _, r := range routes {
 		if r.MTU == 8000 {
 			assert.Equal(t, "1.0.0.1/32", r.Cidr.String())
+			assert.False(t, r.Install)
 			tested++
 		} else if r.MTU == 9000 {
 			assert.Equal(t, 9000, r.MTU)
 			assert.Equal(t, "1.0.0.0/29", r.Cidr.String())
+			assert.True(t, r.Install)
 			tested++
 		} else {
 			assert.Equal(t, 1500, r.MTU)
 			assert.Equal(t, 1234, r.Metric)
 			assert.Equal(t, "1.0.0.2/32", r.Cidr.String())
+			assert.True(t, r.Install)
 			tested++
 		}
 	}
 
-	if tested != 3 {
-		t.Fatal("Did not see both unsafe_routes")
+	if tested != 4 {
+		t.Fatal("Did not see all unsafe_routes")
 	}
 }
 

+ 1 - 1
overlay/tun_darwin.go

@@ -287,7 +287,7 @@ func (t *tun) Activate() error {
 
 	// Unsafe path routes
 	for _, r := range t.Routes {
-		if r.Via == nil {
+		if r.Via == nil || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}

+ 1 - 1
overlay/tun_freebsd.go

@@ -86,7 +86,7 @@ func (t *tun) Activate() error {
 	}
 	// Unsafe path routes
 	for _, r := range t.Routes {
-		if r.Via == nil {
+		if r.Via == nil || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}

+ 4 - 0
overlay/tun_linux.go

@@ -279,6 +279,10 @@ func (t tun) Activate() error {
 
 	// Path routes
 	for _, r := range t.Routes {
+		if !r.Install {
+			continue
+		}
+
 		nr := netlink.Route{
 			LinkIndex: link.Attrs().Index,
 			Dst:       r.Cidr,

+ 1 - 1
overlay/tun_water_windows.go

@@ -80,7 +80,7 @@ func (t *waterTun) Activate() error {
 	}
 
 	for _, r := range t.Routes {
-		if r.Via == nil {
+		if r.Via == nil || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}

+ 1 - 1
overlay/tun_wintun_windows.go

@@ -92,7 +92,7 @@ func (t *winTun) Activate() error {
 	routes := make([]*winipcfg.RouteData, 0, len(t.Routes)+1)
 
 	for _, r := range t.Routes {
-		if r.Via == nil {
+		if r.Via == nil || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}

+ 0 - 1
udp/conn.go

@@ -9,7 +9,6 @@ const MTU = 9001
 
 type EncReader func(
 	addr *Addr,
-	via interface{},
 	out []byte,
 	packet []byte,
 	header *header.H,

+ 1 - 14
udp/temp.go

@@ -1,22 +1,9 @@
 package udp
 
 import (
-	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 )
 
-type EncWriter interface {
-	SendVia(via interface{},
-		relay interface{},
-		ad,
-		nb,
-		out []byte,
-		nocopy bool,
-	)
-	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
-	Handshake(vpnIp iputil.VpnIp)
-}
-
 //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
 
-type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter)
+type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte)

+ 3 - 3
udp/udp_generic.go

@@ -23,9 +23,9 @@ type Conn struct {
 	l *logrus.Logger
 }
 
-func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) {
+func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) {
 	lc := NewListenConfig(multi)
-	pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
+	pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
 	if err != nil {
 		return nil, err
 	}
@@ -86,6 +86,6 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
 
 		udpAddr.IP = rua.IP
 		udpAddr.Port = uint16(rua.Port)
-		r(udpAddr, nil, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+		r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
 	}
 }

+ 3 - 3
udp/udp_linux.go

@@ -45,7 +45,7 @@ const (
 
 type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
 
-func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) {
+func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) {
 	syscall.ForkLock.RLock()
 	fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
 	if err == nil {
@@ -59,7 +59,7 @@ func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (
 	}
 
 	var lip [16]byte
-	copy(lip[:], net.ParseIP(ip))
+	copy(lip[:], ip.To16())
 
 	if multi {
 		if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
@@ -145,7 +145,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
 		for i := 0; i < n; i++ {
 			udpAddr.IP = names[i][8:24]
 			udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
-			r(udpAddr, nil, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+			r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
 		}
 	}
 }

+ 3 - 3
udp/udp_tester.go

@@ -45,9 +45,9 @@ type Conn struct {
 	l *logrus.Logger
 }
 
-func NewListener(l *logrus.Logger, ip string, port int, _ bool, _ int) (*Conn, error) {
+func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (*Conn, error) {
 	return &Conn{
-		Addr:      &Addr{net.ParseIP(ip), uint16(port)},
+		Addr:      &Addr{ip, uint16(port)},
 		RxPackets: make(chan *Packet, 10),
 		TxPackets: make(chan *Packet, 10),
 		l:         l,
@@ -122,7 +122,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
 		}
 		ua.Port = p.FromPort
 		copy(ua.IP, p.FromIp.To16())
-		r(ua, nil, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
+		r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
 	}
 }