Browse Source

Add tests for the successful psk mode matrix

Nate Brown 4 years ago
parent
commit
c1ed78ffc7
2 changed files with 138 additions and 9 deletions
  1. 117 7
      e2e/handshakes_test.go
  2. 21 2
      e2e/helpers_test.go

+ 117 - 7
e2e/handshakes_test.go

@@ -18,8 +18,8 @@ import (
 
 func TestGoodHandshake(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1})
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
+	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Put their info in our lighthouse
 	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
@@ -70,9 +70,9 @@ func TestWrongResponderHandshake(t *testing.T) {
 	// The IPs here are chosen on purpose:
 	// The current remote handling will sort by preference, public, and then lexically.
 	// So we need them to have a higher address than evil (we could apply a preference though)
-	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100})
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99})
-	evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2})
+	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
+	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
+	evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
 
 	// Add their real udp addr, which should be tried after evil.
 	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
@@ -130,8 +130,8 @@ func TestWrongResponderHandshake(t *testing.T) {
 
 func Test_Case1_Stage1Race(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1})
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
+	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Put their info in our lighthouse and vice versa
 	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
@@ -183,3 +183,113 @@ func Test_Case1_Stage1Race(t *testing.T) {
 }
 
 //TODO: add a test with many lies
+
+func TestPSK(t *testing.T) {
+	tests := []struct {
+		name         string
+		myPskMode    nebula.PskMode
+		theirPskMode nebula.PskMode
+	}{
+		{
+			name:         "none to transitional",
+			myPskMode:    nebula.PskNone,
+			theirPskMode: nebula.PskTransitional,
+		},
+		{
+			name:         "transitional to none",
+			myPskMode:    nebula.PskTransitional,
+			theirPskMode: nebula.PskNone,
+		},
+		{
+			name:         "both transitional",
+			myPskMode:    nebula.PskTransitional,
+			theirPskMode: nebula.PskTransitional,
+		},
+
+		{
+			name:         "enforced to transitional",
+			myPskMode:    nebula.PskEnforced,
+			theirPskMode: nebula.PskTransitional,
+		},
+		{
+			name:         "transitional to enforced",
+			myPskMode:    nebula.PskTransitional,
+			theirPskMode: nebula.PskEnforced,
+		},
+		{
+			name:         "both enforced",
+			myPskMode:    nebula.PskEnforced,
+			theirPskMode: nebula.PskEnforced,
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			var myPskSettings, theirPskSettings *m
+
+			switch test.myPskMode {
+			case nebula.PskNone:
+				myPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "none"}}}
+			case nebula.PskTransitional:
+				myPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "transitional", "keys": []string{"this is a key"}}}}
+			case nebula.PskEnforced:
+				myPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key"}}}}
+			}
+
+			switch test.theirPskMode {
+			case nebula.PskNone:
+				theirPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "none"}}}
+			case nebula.PskTransitional:
+				theirPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "transitional", "keys": []string{"this is a key"}}}}
+			case nebula.PskEnforced:
+				theirPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key"}}}}
+			}
+
+			ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+			myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, myPskSettings)
+			theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, theirPskSettings)
+
+			myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
+			r := router.NewR(myControl, theirControl)
+
+			// Start the servers
+			myControl.Start()
+			theirControl.Start()
+
+			t.Log("Route until we see our cached packet flow")
+			myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
+			r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
+				h := &header.H{}
+				err := h.Parse(p.Data)
+				if err != nil {
+					panic(err)
+				}
+
+				// If this is the stage 1 handshake packet and I am configured to enforce psk, my cert name should not appear.
+				// It would likely be more obvious to unmarshal the payload
+				if test.myPskMode == nebula.PskEnforced && h.Type == 0 && h.MessageCounter == 1 {
+					assert.NotContains(t, string(p.Data), "test me")
+				}
+
+				if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 {
+					return router.RouteAndExit
+				}
+
+				return router.KeepRouting
+			})
+
+			t.Log("My cached packet should be received by them")
+			myCachedPacket := theirControl.GetFromTun(true)
+			assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
+
+			t.Log("Test the tunnel with them")
+			assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
+			assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
+
+			myControl.Stop()
+			theirControl.Stop()
+			//TODO: assert hostmaps
+		})
+	}
+
+}

+ 21 - 2
e2e/helpers_test.go

@@ -15,6 +15,7 @@ import (
 
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
+	"github.com/imdario/mergo"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/cert"
@@ -30,7 +31,7 @@ import (
 type m map[string]interface{}
 
 // newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP) (*nebula.Control, net.IP, *net.UDPAddr) {
+func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, customConfig *m) (*nebula.Control, net.IP, *net.UDPAddr) {
 	l := NewTestLogger()
 
 	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
@@ -40,7 +41,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		IP:   udpIp,
 		Port: 4242,
 	}
-	_, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+	_, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, "test "+name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
 
 	caB, err := caCrt.MarshalToPEM()
 	if err != nil {
@@ -86,6 +87,24 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 	c := config.NewC(l)
 	c.LoadString(string(cb))
 
+	if customConfig != nil {
+		ccb, err := yaml.Marshal(customConfig)
+		if err != nil {
+			panic(err)
+		}
+
+		ccm := map[interface{}]interface{}{}
+		err = yaml.Unmarshal(ccb, &ccm)
+		if err != nil {
+			panic(err)
+		}
+
+		err = mergo.Merge(&c.Settings, ccm, mergo.WithAppendSlice)
+		if err != nil {
+			panic(err)
+		}
+	}
+
 	control, err := nebula.Main(c, false, "e2e-test", l, nil)
 
 	if err != nil {