Explorar el Código

PSK support for v2

Nate Brown hace 8 meses
padre
commit
2c9cc63c1a
Se han modificado 9 ficheros con 451 adiciones y 33 borrados
  1. 7 8
      connection_state.go
  2. 132 0
      e2e/handshakes_test.go
  3. 3 4
      e2e/router/router.go
  4. 32 1
      examples/config.yml
  5. 39 20
      handshake_ix.go
  6. 5 0
      handshake_manager_test.go
  7. 12 0
      pki.go
  8. 150 0
      psk.go
  9. 71 0
      psk_test.go

+ 7 - 8
connection_state.go

@@ -27,7 +27,7 @@ type ConnectionState struct {
 	writeLock      sync.Mutex
 	writeLock      sync.Mutex
 }
 }
 
 
-func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
+func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern, psk []byte) (*ConnectionState, error) {
 	var dhFunc noise.DHFunc
 	var dhFunc noise.DHFunc
 	switch crt.Curve() {
 	switch crt.Curve() {
 	case cert.Curve_CURVE25519:
 	case cert.Curve_CURVE25519:
@@ -56,13 +56,12 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
 	b.Update(l, 0)
 	b.Update(l, 0)
 
 
 	hs, err := noise.NewHandshakeState(noise.Config{
 	hs, err := noise.NewHandshakeState(noise.Config{
-		CipherSuite:   ncs,
-		Random:        rand.Reader,
-		Pattern:       pattern,
-		Initiator:     initiator,
-		StaticKeypair: static,
-		//NOTE: These should come from CertState (pki.go) when we finally implement it
-		PresharedKey:          []byte{},
+		CipherSuite:           ncs,
+		Random:                rand.Reader,
+		Pattern:               pattern,
+		Initiator:             initiator,
+		StaticKeypair:         static,
+		PresharedKey:          psk,
 		PresharedKeyPlacement: 0,
 		PresharedKeyPlacement: 0,
 	})
 	})
 	if err != nil {
 	if err != nil {

+ 132 - 0
e2e/handshakes_test.go

@@ -1224,3 +1224,135 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
 	myControl.Stop()
 	myControl.Stop()
 	theirControl.Stop()
 	theirControl.Stop()
 }
 }
+
+func TestPSK(t *testing.T) {
+	tests := []struct {
+		name         string
+		myPskMode    nebula.PskMode
+		theirPskMode nebula.PskMode
+	}{
+		// All accepting
+		{
+			name:         "both accepting",
+			myPskMode:    nebula.PskAccepting,
+			theirPskMode: nebula.PskAccepting,
+		},
+
+		// accepting and sending both ways
+		{
+			name:         "accepting to sending",
+			myPskMode:    nebula.PskAccepting,
+			theirPskMode: nebula.PskSending,
+		},
+		{
+			name:         "sending to accepting",
+			myPskMode:    nebula.PskSending,
+			theirPskMode: nebula.PskAccepting,
+		},
+
+		// All sending
+		{
+			name:         "sending to sending",
+			myPskMode:    nebula.PskSending,
+			theirPskMode: nebula.PskSending,
+		},
+
+		// enforced and sending both ways
+		{
+			name:         "enforced to sending",
+			myPskMode:    nebula.PskEnforced,
+			theirPskMode: nebula.PskSending,
+		},
+		{
+			name:         "sending to enforced",
+			myPskMode:    nebula.PskSending,
+			theirPskMode: nebula.PskEnforced,
+		},
+
+		// All enforced
+		{
+			name:         "both enforced",
+			myPskMode:    nebula.PskEnforced,
+			theirPskMode: nebula.PskEnforced,
+		},
+
+		// Enforced can technically handshake with an accepting node, but it is bad to be in this state
+		{
+			name:         "enforced to accepting",
+			myPskMode:    nebula.PskEnforced,
+			theirPskMode: nebula.PskAccepting,
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			var myPskSettings, theirPskSettings m
+
+			switch test.myPskMode {
+			case nebula.PskAccepting:
+				myPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage0", "this is a key"}}}
+			case nebula.PskSending:
+				myPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage1"}}}
+			case nebula.PskEnforced:
+				myPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage2"}}}
+			}
+
+			switch test.theirPskMode {
+			case nebula.PskAccepting:
+				theirPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage3", "this is a key"}}}
+			case nebula.PskSending:
+				theirPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage4"}}}
+			case nebula.PskEnforced:
+				theirPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage5"}}}
+			}
+
+			ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+			myControl, myVpnIp, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.0.0.1/24", myPskSettings)
+			theirControl, theirVpnIp, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.0.0.2/24", theirPskSettings)
+
+			myControl.InjectLightHouseAddr(theirVpnIp[0].Addr(), theirUdpAddr)
+			r := router.NewR(t, myControl, theirControl)
+
+			// Start the servers
+			myControl.Start()
+			theirControl.Start()
+
+			t.Log("Route until we see our cached packet flow")
+			myControl.InjectTunUDPPacket(theirVpnIp[0].Addr(), 80, myVpnIp[0].Addr(), 80, []byte("Hi from me"))
+			r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
+				h := &header.H{}
+				err := h.Parse(p.Data)
+				if err != nil {
+					panic(err)
+				}
+
+				// If this is the stage 1 handshake packet and I am configured to send with a psk, my cert name should
+				// not appear. It would likely be more obvious to unmarshal the payload and check but this works fine for now
+				if test.myPskMode == nebula.PskEnforced || test.myPskMode == nebula.PskSending {
+					if h.Type == 0 && h.MessageCounter == 1 {
+						assert.NotContains(t, string(p.Data), "test me")
+					}
+				}
+
+				if p.To == theirUdpAddr && h.Type == 1 {
+					return router.RouteAndExit
+				}
+
+				return router.KeepRouting
+			})
+
+			t.Log("My cached packet should be received by them")
+			myCachedPacket := theirControl.GetFromTun(true)
+			assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), 80, 80)
+
+			t.Log("Test the tunnel with them")
+			assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
+			assertTunnel(t, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), myControl, theirControl, r)
+
+			myControl.Stop()
+			theirControl.Stop()
+			//TODO: assert hostmaps
+		})
+	}
+
+}

+ 3 - 4
e2e/router/router.go

@@ -111,10 +111,6 @@ type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType
 func NewR(t testing.TB, controls ...*nebula.Control) *R {
 func NewR(t testing.TB, controls ...*nebula.Control) *R {
 	ctx, cancel := context.WithCancel(context.Background())
 	ctx, cancel := context.WithCancel(context.Background())
 
 
-	if err := os.MkdirAll("mermaid", 0755); err != nil {
-		panic(err)
-	}
-
 	r := &R{
 	r := &R{
 		controls:     make(map[netip.AddrPort]*nebula.Control),
 		controls:     make(map[netip.AddrPort]*nebula.Control),
 		vpnControls:  make(map[netip.Addr]*nebula.Control),
 		vpnControls:  make(map[netip.Addr]*nebula.Control),
@@ -194,6 +190,9 @@ func (r *R) renderFlow() {
 		return
 		return
 	}
 	}
 
 
+	if err := os.MkdirAll(filepath.Dir(r.fn), 0755); err != nil {
+		panic(err)
+	}
 	f, err := os.OpenFile(r.fn, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0644)
 	f, err := os.OpenFile(r.fn, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0644)
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)

+ 32 - 1
examples/config.yml

@@ -19,6 +19,38 @@ pki:
   # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
   # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
   # default_version: 1
   # default_version: 1
 
 
+  # psk can be used to mask the contents of handshakes.
+  psk:
+    # `mode` defines how the pre shared keys can be used in a handshake.
+    # `accepting` (the default) will initiate handshakes using an empty key and will try to use any keys provided when
+    # receiving handshakes, including an empty key.
+    # `sending` will initiate handshakes with the first key provided and will try to use any keys provided when
+    # receiving handshakes, including an empty key.
+    # `enforced` will initiate handshakes with the first psk key provided and will try to use any keys provided when
+    # responding to handshakes. An empty key will not be allowed.
+    #
+    # To change a mesh from not using a psk to enforcing psk:
+    # 1. Leave `mode` as `accepting` and configure `psk.keys` to match on all nodes in the mesh and reload.
+    # 2. Change `mode` to `sending` on all nodes in the mesh and reload.
+    # 3. Change `mode` to `enforced` on all nodes in the mesh and reload.
+    #mode: accepting
+
+    # The keys provided are sent through hkdf to ensure the shared secret used in the noise protocol is the
+    # correct byte length.
+    #
+    # Only the first key is used for outbound handshakes but all keys provided will be tried in the order specified, on
+    # incoming handshakes. This is to allow for psk rotation.
+    #
+    # To rotate a primary key:
+    # 1. Put the new key in the 2nd slot on every node in the mesh and reload.
+    # 2. Move the key from the 2nd slot to the 1st slot, the old primary key is now in the 2nd slot, reload.
+    # 3. Remove the old primary key once it is no longer in use on every node in the mesh and reload.
+    #keys:
+    # - shared secret string, this one is used in all outbound handshakes # This is the primary key used when sending handshakes
+    # - this is a fallback key, received handshakes can use this
+    # - another fallback, received handshakes can use this one too
+    # - "\x68\x65\x6c\x6c\x6f\x20\x66\x72\x69\x65\x6e\x64\x73" # for raw bytes if you desire
+
 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
 # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
 # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
 # The syntax is:
 # The syntax is:
@@ -313,7 +345,6 @@ logging:
   # after receiving the response for lighthouse queries
   # after receiving the response for lighthouse queries
   #trigger_buffer: 64
   #trigger_buffer: 64
 
 
-
 # Nebula security group configuration
 # Nebula security group configuration
 firewall:
 firewall:
   # Action to take when a packet is not allowed by the firewall rules.
   # Action to take when a packet is not allowed by the firewall rules.

+ 39 - 20
handshake_ix.go

@@ -50,7 +50,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 			Error("Unable to handshake with host because no certificate handshake bytes is available")
 			Error("Unable to handshake with host because no certificate handshake bytes is available")
 	}
 	}
 
 
-	ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
+	ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX, cs.psk.primary)
 	if err != nil {
 	if err != nil {
 		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
 		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
@@ -104,34 +104,53 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			Error("Unable to handshake with host because no certificate is available")
 			Error("Unable to handshake with host because no certificate is available")
 	}
 	}
 
 
-	ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
-	if err != nil {
-		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-			Error("Failed to create connection state")
-		return
-	}
+	var (
+		err error
+		ci  *ConnectionState
+		msg []byte
+	)
 
 
-	// Mark packet 1 as seen so it doesn't show up as missed
-	ci.window.Update(f.l, 1)
+	hs := &NebulaHandshake{}
 
 
-	msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
-	if err != nil {
-		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-			Error("Failed to call noise.ReadMessage")
-		return
+	for _, psk := range cs.psk.keys {
+		ci, err = NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX, psk)
+		if err != nil {
+			//TODO: should be bother logging this, if we have multiple psks and the error is unrelated it will be verbose.
+			f.l.WithError(err).WithField("udpAddr", addr).
+				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+				Error("Failed to create connection state")
+			continue
+		}
+
+		msg, _, _, err = ci.H.ReadMessage(nil, packet[header.Len:])
+		if err != nil {
+			// Calls to ReadMessage with an incorrect psk should fail, try the next one if we have one
+			continue
+		}
+
+		// Sometimes ReadMessage returns fine with a nil psk even if the handshake is using a psk, ensure our protobuf
+		// comes out clean as well
+		err = hs.Unmarshal(msg)
+		if err == nil {
+			// There was no error, we can continue with this handshake
+			break
+		}
+
+		// The unmarshal failed, try the next psk if we have one
 	}
 	}
 
 
-	hs := &NebulaHandshake{}
-	err = hs.Unmarshal(msg)
+	// We finished with an error, log it and get out
 	if err != nil || hs.Details == nil {
 	if err != nil || hs.Details == nil {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		// We aren't logging the error here because we can't be sure of the failure when using psk
+		f.l.WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-			Error("Failed unmarshal handshake message")
+			Error("Was unable to decrypt the handshake")
 		return
 		return
 	}
 	}
 
 
+	// Mark packet 1 as seen so it doesn't show up as missed
+	ci.window.Update(f.l, 1)
+
 	rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
 	rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
 	if err != nil {
 	if err != nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
 		f.l.WithError(err).WithField("udpAddr", addr).

+ 5 - 0
handshake_manager_test.go

@@ -10,6 +10,7 @@ import (
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 )
 
 
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
@@ -23,11 +24,15 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 
 
 	lh := newTestLighthouse()
 	lh := newTestLighthouse()
 
 
+	psk, err := NewPsk(PskAccepting, nil)
+	require.NoError(t, err)
+
 	cs := &CertState{
 	cs := &CertState{
 		defaultVersion:   cert.Version1,
 		defaultVersion:   cert.Version1,
 		privateKey:       []byte{},
 		privateKey:       []byte{},
 		v1Cert:           &dummyCert{version: cert.Version1},
 		v1Cert:           &dummyCert{version: cert.Version1},
 		v1HandshakeBytes: []byte{},
 		v1HandshakeBytes: []byte{},
+		psk:              psk,
 	}
 	}
 
 
 	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
 	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)

+ 12 - 0
pki.go

@@ -38,6 +38,8 @@ type CertState struct {
 	pkcs11Backed   bool
 	pkcs11Backed   bool
 	cipher         string
 	cipher         string
 
 
+	psk *Psk
+
 	myVpnNetworks            []netip.Prefix
 	myVpnNetworks            []netip.Prefix
 	myVpnNetworksTable       *bart.Table[struct{}]
 	myVpnNetworksTable       *bart.Table[struct{}]
 	myVpnAddrs               []netip.Addr
 	myVpnAddrs               []netip.Addr
@@ -171,6 +173,16 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
 		}
 		}
 	}
 	}
 
 
+	psk, err := NewPskFromConfig(c)
+	if err != nil {
+		return util.NewContextualError("Failed to load psk from config", nil, err)
+	}
+	if len(psk.keys) > 0 {
+		p.l.WithField("pskMode", psk.mode).WithField("keysLen", len(psk.keys)).
+			Info("pre shared keys are in use")
+	}
+	newState.psk = psk
+
 	p.cs.Store(newState)
 	p.cs.Store(newState)
 
 
 	//TODO: CERT-V2 newState needs a stringer that does json
 	//TODO: CERT-V2 newState needs a stringer that does json

+ 150 - 0
psk.go

@@ -0,0 +1,150 @@
+package nebula
+
+import (
+	"crypto/sha256"
+	"errors"
+	"fmt"
+	"io"
+
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
+	"golang.org/x/crypto/hkdf"
+)
+
+var ErrNotAPskMode = errors.New("not a psk mode")
+var ErrKeyTooShort = errors.New("key is too short")
+var ErrNotEnoughPskKeys = errors.New("at least 1 key is required")
+
+// MinPskLength is the minimum bytes that we accept for a user defined psk, the choice is arbitrary
+const MinPskLength = 8
+
+type PskMode int
+
+const (
+	PskAccepting PskMode = 0
+	PskSending   PskMode = 1
+	PskEnforced  PskMode = 2
+)
+
+func NewPskMode(m string) (PskMode, error) {
+	switch m {
+	case "accepting":
+		return PskAccepting, nil
+	case "sending":
+		return PskSending, nil
+	case "enforced":
+		return PskEnforced, nil
+	}
+	return PskAccepting, ErrNotAPskMode
+}
+
+func (p PskMode) String() string {
+	switch p {
+	case PskAccepting:
+		return "accepting"
+	case PskSending:
+		return "sending"
+	case PskEnforced:
+		return "enforced"
+	}
+
+	return "unknown"
+}
+
+func (p PskMode) IsValid() bool {
+	switch p {
+	case PskAccepting, PskSending, PskEnforced:
+		return true
+	default:
+		return false
+	}
+}
+
+type Psk struct {
+	// pskMode sets how psk works, ignored, allowed for incoming, or enforced for all
+	mode PskMode
+
+	// primary is the key to use when sending, it may be nil
+	primary []byte
+
+	// keys holds all pre-computed psk hkdfs
+	// Handshakes iterate this directly
+	keys [][]byte
+}
+
+// NewPskFromConfig is a helper for initial boot and config reloading.
+func NewPskFromConfig(c *config.C) (*Psk, error) {
+	sMode := c.GetString("psk.mode", "accepting")
+	mode, err := NewPskMode(sMode)
+	if err != nil {
+		return nil, util.NewContextualError("Could not parse psk.mode", m{"mode": mode}, err)
+	}
+
+	return NewPsk(
+		mode,
+		c.GetStringSlice("psk.keys", nil),
+	)
+}
+
+// NewPsk creates a new Psk object and handles the caching of all accepted keys
+func NewPsk(mode PskMode, keys []string) (*Psk, error) {
+	if !mode.IsValid() {
+		return nil, ErrNotAPskMode
+	}
+
+	psk := &Psk{
+		mode: mode,
+	}
+
+	err := psk.cachePsks(keys)
+	if err != nil {
+		return nil, err
+	}
+
+	return psk, nil
+}
+
+// cachePsks generates all psks we accept and caches them to speed up handshaking
+func (p *Psk) cachePsks(keys []string) error {
+	if p.mode != PskAccepting && len(keys) < 1 {
+		return ErrNotEnoughPskKeys
+	}
+
+	p.keys = [][]byte{}
+
+	for i, rk := range keys {
+		k, err := sha256KdfFromString(rk)
+		if err != nil {
+			return fmt.Errorf("failed to generate key for position %v: %w", i, err)
+		}
+
+		p.keys = append(p.keys, k)
+	}
+
+	if p.mode != PskAccepting {
+		// We are either sending or enforcing, the primary key must the first slot
+		p.primary = p.keys[0]
+	}
+
+	if p.mode != PskEnforced {
+		// If we are not enforcing psk use then a nil psk is allowed
+		p.keys = append(p.keys, nil)
+	}
+
+	return nil
+}
+
+// sha256KdfFromString generates a useful key to use from a provided secret
+func sha256KdfFromString(secret string) ([]byte, error) {
+	if len(secret) < MinPskLength {
+		return nil, ErrKeyTooShort
+	}
+
+	hmacKey := make([]byte, sha256.Size)
+	_, err := io.ReadFull(hkdf.New(sha256.New, []byte(secret), nil, nil), hmacKey)
+	if err != nil {
+		return nil, err
+	}
+
+	return hmacKey, nil
+}

+ 71 - 0
psk_test.go

@@ -0,0 +1,71 @@
+package nebula
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestNewPsk(t *testing.T) {
+	t.Run("mode accepting", func(t *testing.T) {
+		p, err := NewPsk(PskAccepting, nil)
+		assert.NoError(t, err)
+		assert.Equal(t, PskAccepting, p.mode)
+		assert.Nil(t, p.keys[0])
+		assert.Nil(t, p.primary)
+
+		p, err = NewPsk(PskAccepting, []string{"1234567"})
+		assert.Error(t, ErrKeyTooShort)
+
+		p, err = NewPsk(PskAccepting, []string{"hi there friends"})
+		assert.NoError(t, err)
+		assert.Equal(t, PskAccepting, p.mode)
+		assert.Nil(t, p.primary)
+		assert.Len(t, p.keys, 2)
+		assert.Nil(t, p.keys[1])
+
+		expectedCache := []byte{
+			0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
+			0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
+		}
+		assert.Equal(t, expectedCache, p.keys[0])
+	})
+
+	t.Run("mode sending", func(t *testing.T) {
+		p, err := NewPsk(PskSending, nil)
+		assert.Error(t, ErrNotEnoughPskKeys, err)
+
+		p, err = NewPsk(PskSending, []string{"1234567"})
+		assert.Error(t, ErrKeyTooShort)
+
+		p, err = NewPsk(PskSending, []string{"hi there friends"})
+		assert.NoError(t, err)
+		assert.Equal(t, PskSending, p.mode)
+		assert.Len(t, p.keys, 2)
+		assert.Nil(t, p.keys[1])
+
+		expectedCache := []byte{
+			0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
+			0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
+		}
+		assert.Equal(t, expectedCache, p.keys[0])
+		assert.Equal(t, p.keys[0], p.primary)
+	})
+
+	t.Run("mode enforced", func(t *testing.T) {
+		p, err := NewPsk(PskEnforced, nil)
+		assert.Error(t, ErrNotEnoughPskKeys, err)
+
+		p, err = NewPsk(PskEnforced, []string{"hi there friends"})
+		assert.NoError(t, err)
+		assert.Equal(t, PskEnforced, p.mode)
+		assert.Len(t, p.keys, 1)
+
+		expectedCache := []byte{
+			0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65,
+			0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef,
+		}
+		assert.Equal(t, expectedCache, p.keys[0])
+		assert.Equal(t, p.keys[0], p.primary)
+	})
+}