瀏覽代碼

More like a library (#279)

Nathan Brown 4 年之前
父節點
當前提交
68e3e84fdc
共有 19 個文件被更改,包括 608 次插入153 次删除
  1. 57 0
      cert/cert.go
  2. 34 20
      cert/cert_test.go
  3. 6 1
      cmd/nebula-service/main.go
  4. 10 10
      cmd/nebula-service/service.go
  5. 6 1
      cmd/nebula/main.go
  6. 169 0
      control.go
  7. 111 0
      control_test.go
  8. 8 2
      firewall.go
  9. 1 1
      go.mod
  10. 4 4
      go.sum
  11. 12 5
      interface.go
  12. 8 0
      logger.go
  13. 29 94
      main.go
  14. 2 2
      udp_android.go
  15. 2 2
      udp_freebsd.go
  16. 11 0
      udp_generic.go
  17. 6 9
      udp_linux.go
  18. 2 2
      udp_windows.go
  19. 130 0
      util/assert.go

+ 57 - 0
cert/cert.go

@@ -468,6 +468,63 @@ func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
 	return json.Marshal(jc)
 }
 
+//func (nc *NebulaCertificate) Copy() *NebulaCertificate {
+//	r, err := nc.Marshal()
+//	if err != nil {
+//		//TODO
+//		return nil
+//	}
+//
+//	c, err := UnmarshalNebulaCertificate(r)
+//	return c
+//}
+
+func (nc *NebulaCertificate) Copy() *NebulaCertificate {
+	c := &NebulaCertificate{
+		Details: NebulaCertificateDetails{
+			Name:           nc.Details.Name,
+			Groups:         make([]string, len(nc.Details.Groups)),
+			Ips:            make([]*net.IPNet, len(nc.Details.Ips)),
+			Subnets:        make([]*net.IPNet, len(nc.Details.Subnets)),
+			NotBefore:      nc.Details.NotBefore,
+			NotAfter:       nc.Details.NotAfter,
+			PublicKey:      make([]byte, len(nc.Details.PublicKey)),
+			IsCA:           nc.Details.IsCA,
+			Issuer:         nc.Details.Issuer,
+			InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)),
+		},
+		Signature: make([]byte, len(nc.Signature)),
+	}
+
+	copy(c.Signature, nc.Signature)
+	copy(c.Details.Groups, nc.Details.Groups)
+	copy(c.Details.PublicKey, nc.Details.PublicKey)
+
+	for i, p := range nc.Details.Ips {
+		c.Details.Ips[i] = &net.IPNet{
+			IP:   make(net.IP, len(p.IP)),
+			Mask: make(net.IPMask, len(p.Mask)),
+		}
+		copy(c.Details.Ips[i].IP, p.IP)
+		copy(c.Details.Ips[i].Mask, p.Mask)
+	}
+
+	for i, p := range nc.Details.Subnets {
+		c.Details.Subnets[i] = &net.IPNet{
+			IP:   make(net.IP, len(p.IP)),
+			Mask: make(net.IPMask, len(p.Mask)),
+		}
+		copy(c.Details.Subnets[i].IP, p.IP)
+		copy(c.Details.Subnets[i].Mask, p.Mask)
+	}
+
+	for g := range nc.Details.InvertedGroups {
+		c.Details.InvertedGroups[g] = struct{}{}
+	}
+
+	return c
+}
+
 func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
 	for _, net := range rootIps {
 		if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {

+ 34 - 20
cert/cert_test.go

@@ -9,6 +9,7 @@ import (
 	"time"
 
 	"github.com/golang/protobuf/proto"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/crypto/curve25519"
 	"golang.org/x/crypto/ed25519"
@@ -487,6 +488,17 @@ func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
 	assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
 }
 
+func TestNebulaCertificate_Copy(t *testing.T) {
+	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	assert.Nil(t, err)
+
+	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	assert.Nil(t, err)
+	cc := c.Copy()
+
+	util.AssertDeepCopyEqual(t, c, cc)
+}
+
 func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
 	pub, priv, err := ed25519.GenerateKey(rand.Reader)
 	if before.IsZero() {
@@ -498,11 +510,12 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
 
 	nc := &NebulaCertificate{
 		Details: NebulaCertificateDetails{
-			Name:      "test ca",
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pub,
-			IsCA:      true,
+			Name:           "test ca",
+			NotBefore:      time.Unix(before.Unix(), 0),
+			NotAfter:       time.Unix(after.Unix(), 0),
+			PublicKey:      pub,
+			IsCA:           true,
+			InvertedGroups: make(map[string]struct{}),
 		},
 	}
 
@@ -544,17 +557,17 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
 
 	if len(ips) == 0 {
 		ips = []*net.IPNet{
-			{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-			{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
+			{IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
+			{IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
+			{IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
 		}
 	}
 
 	if len(subnets) == 0 {
 		subnets = []*net.IPNet{
-			{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-			{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
+			{IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
+			{IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
+			{IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
 		}
 	}
 
@@ -562,15 +575,16 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips
 
 	nc := &NebulaCertificate{
 		Details: NebulaCertificateDetails{
-			Name:      "testing",
-			Ips:       ips,
-			Subnets:   subnets,
-			Groups:    groups,
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pub,
-			IsCA:      false,
-			Issuer:    issuer,
+			Name:           "testing",
+			Ips:            ips,
+			Subnets:        subnets,
+			Groups:         groups,
+			NotBefore:      time.Unix(before.Unix(), 0),
+			NotAfter:       time.Unix(after.Unix(), 0),
+			PublicKey:      pub,
+			IsCA:           false,
+			Issuer:         issuer,
+			InvertedGroups: make(map[string]struct{}),
 		},
 	}
 

+ 6 - 1
cmd/nebula-service/main.go

@@ -55,7 +55,7 @@ func main() {
 
 	l := logrus.New()
 	l.Out = os.Stdout
-	err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
+	c, err := nebula.Main(config, *configTest, Build, l, nil)
 
 	switch v := err.(type) {
 	case nebula.ContextualError:
@@ -66,5 +66,10 @@ func main() {
 		os.Exit(1)
 	}
 
+	if !*configTest {
+		c.Start()
+		c.ShutdownBlock()
+	}
+
 	os.Exit(0)
 }

+ 10 - 10
cmd/nebula-service/service.go

@@ -14,21 +14,16 @@ import (
 var logger service.Logger
 
 type program struct {
-	exit       chan struct{}
 	configPath *string
 	configTest *bool
 	build      string
+	control    *nebula.Control
 }
 
 func (p *program) Start(s service.Service) error {
-	logger.Info("Nebula service starting.")
-	p.exit = make(chan struct{})
 	// Start should not block.
-	go p.run()
-	return nil
-}
+	logger.Info("Nebula service starting.")
 
-func (p *program) run() error {
 	config := nebula.NewConfig()
 	err := config.Load(*p.configPath)
 	if err != nil {
@@ -37,17 +32,22 @@ func (p *program) run() error {
 
 	l := logrus.New()
 	l.Out = os.Stdout
-	return nebula.Main(config, *p.configTest, true, Build, l, nil, nil)
+	p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
+	if err != nil {
+		return err
+	}
+
+	p.control.Start()
+	return nil
 }
 
 func (p *program) Stop(s service.Service) error {
 	logger.Info("Nebula service stopping.")
-	close(p.exit)
+	p.control.Stop()
 	return nil
 }
 
 func doService(configPath *string, configTest *bool, build string, serviceFlag *string) {
-
 	if *configPath == "" {
 		ex, err := os.Executable()
 		if err != nil {

+ 6 - 1
cmd/nebula/main.go

@@ -49,7 +49,7 @@ func main() {
 
 	l := logrus.New()
 	l.Out = os.Stdout
-	err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
+	c, err := nebula.Main(config, *configTest, Build, l, nil)
 
 	switch v := err.(type) {
 	case nebula.ContextualError:
@@ -60,5 +60,10 @@ func main() {
 		os.Exit(1)
 	}
 
+	if !*configTest {
+		c.Start()
+		c.ShutdownBlock()
+	}
+
 	os.Exit(0)
 }

+ 169 - 0
control.go

@@ -0,0 +1,169 @@
+package nebula
+
+import (
+	"net"
+	"os"
+	"os/signal"
+	"syscall"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
+)
+
+// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
+// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
+
+type Control struct {
+	f *Interface
+	l *logrus.Logger
+}
+
+type ControlHostInfo struct {
+	VpnIP          net.IP                  `json:"vpnIp"`
+	LocalIndex     uint32                  `json:"localIndex"`
+	RemoteIndex    uint32                  `json:"remoteIndex"`
+	RemoteAddrs    []udpAddr               `json:"remoteAddrs"`
+	CachedPackets  int                     `json:"cachedPackets"`
+	Cert           *cert.NebulaCertificate `json:"cert"`
+	MessageCounter uint64                  `json:"messageCounter"`
+	CurrentRemote  udpAddr                 `json:"currentRemote"`
+}
+
+// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
+func (c *Control) Start() {
+	c.f.run()
+}
+
+// Stop signals nebula to shutdown, returns after the shutdown is complete
+func (c *Control) Stop() {
+	//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
+	//TODO: this is probably better as a function in ConnectionManager or HostMap directly
+	c.f.hostMap.Lock()
+	for _, h := range c.f.hostMap.Hosts {
+		if h.ConnectionState.ready {
+			c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
+			c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
+				Debug("Sending close tunnel message")
+		}
+	}
+	c.f.hostMap.Unlock()
+	c.l.Info("Goodbye")
+}
+
+// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
+func (c *Control) ShutdownBlock() {
+	sigChan := make(chan os.Signal)
+	signal.Notify(sigChan, syscall.SIGTERM)
+	signal.Notify(sigChan, syscall.SIGINT)
+
+	rawSig := <-sigChan
+	sig := rawSig.String()
+	c.l.WithField("signal", sig).Info("Caught signal, shutting down")
+	c.Stop()
+}
+
+// RebindUDPServer asks the UDP listener to rebind it's listener. Mainly used on mobile clients when interfaces change
+func (c *Control) RebindUDPServer() {
+	_ = c.f.outside.Rebind()
+}
+
+// ListHostmap returns details about the actual or pending (handshaking) hostmap
+func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
+	var hm *HostMap
+	if pendingMap {
+		hm = c.f.handshakeManager.pendingHostMap
+	} else {
+		hm = c.f.hostMap
+	}
+
+	hm.RLock()
+	hosts := make([]ControlHostInfo, len(hm.Hosts))
+	i := 0
+	for _, v := range hm.Hosts {
+		hosts[i] = copyHostInfo(v)
+		i++
+	}
+	hm.RUnlock()
+
+	return hosts
+}
+
+// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
+func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo {
+	var hm *HostMap
+	if pending {
+		hm = c.f.handshakeManager.pendingHostMap
+	} else {
+		hm = c.f.hostMap
+	}
+
+	h, err := hm.QueryVpnIP(vpnIP)
+	if err != nil {
+		return nil
+	}
+
+	ch := copyHostInfo(h)
+	return &ch
+}
+
+// SetRemoteForTunnel forces a tunnel to use a specific remote
+func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo {
+	hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
+	if err != nil {
+		return nil
+	}
+
+	hostInfo.SetRemote(addr.Copy())
+	ch := copyHostInfo(hostInfo)
+	return &ch
+}
+
+// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
+func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
+	hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
+	if err != nil {
+		return false
+	}
+
+	if !localOnly {
+		c.f.send(
+			closeTunnel,
+			0,
+			hostInfo.ConnectionState,
+			hostInfo,
+			hostInfo.remote,
+			[]byte{},
+			make([]byte, 12, 12),
+			make([]byte, mtu),
+		)
+	}
+
+	c.f.closeTunnel(hostInfo)
+	return true
+}
+
+func copyHostInfo(h *HostInfo) ControlHostInfo {
+	addrs := h.RemoteUDPAddrs()
+	chi := ControlHostInfo{
+		VpnIP:          int2ip(h.hostId),
+		LocalIndex:     h.localIndexId,
+		RemoteIndex:    h.remoteIndexId,
+		RemoteAddrs:    make([]udpAddr, len(addrs), len(addrs)),
+		CachedPackets:  len(h.packetStore),
+		MessageCounter: *h.ConnectionState.messageCounter,
+	}
+
+	if c := h.GetCert(); c != nil {
+		chi.Cert = c.Copy()
+	}
+
+	if h.remote != nil {
+		chi.CurrentRemote = *h.remote
+	}
+
+	for i, addr := range addrs {
+		chi.RemoteAddrs[i] = addr.Copy()
+	}
+
+	return chi
+}

+ 111 - 0
control_test.go

@@ -0,0 +1,111 @@
+package nebula
+
+import (
+	"net"
+	"reflect"
+	"testing"
+	"time"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/util"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestControl_GetHostInfoByVpnIP(t *testing.T) {
+	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
+	// To properly ensure we are not exposing core memory to the caller
+	hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
+	remote1 := NewUDPAddr(100, 4444)
+	remote2 := NewUDPAddr(101, 4444)
+	ipNet := net.IPNet{
+		IP:   net.IPv4(1, 2, 3, 4),
+		Mask: net.IPMask{255, 255, 255, 0},
+	}
+
+	ipNet2 := net.IPNet{
+		IP:   net.IPv4(1, 2, 3, 5),
+		Mask: net.IPMask{255, 255, 255, 0},
+	}
+
+	crt := &cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:           "test",
+			Ips:            []*net.IPNet{&ipNet},
+			Subnets:        []*net.IPNet{},
+			Groups:         []string{"default-group"},
+			NotBefore:      time.Unix(1, 0),
+			NotAfter:       time.Unix(2, 0),
+			PublicKey:      []byte{5, 6, 7, 8},
+			IsCA:           false,
+			Issuer:         "the-issuer",
+			InvertedGroups: map[string]struct{}{"default-group": {}},
+		},
+		Signature: []byte{1, 2, 1, 2, 1, 3},
+	}
+	counter := uint64(0)
+
+	remotes := []*HostInfoDest{NewHostInfoDest(remote1), NewHostInfoDest(remote2)}
+	hm.Add(ip2int(ipNet.IP), &HostInfo{
+		remote:  remote1,
+		Remotes: remotes,
+		ConnectionState: &ConnectionState{
+			peerCert:       crt,
+			messageCounter: &counter,
+		},
+		remoteIndexId: 200,
+		localIndexId:  201,
+		hostId:        ip2int(ipNet.IP),
+	})
+
+	hm.Add(ip2int(ipNet2.IP), &HostInfo{
+		remote:  remote1,
+		Remotes: remotes,
+		ConnectionState: &ConnectionState{
+			peerCert:       nil,
+			messageCounter: &counter,
+		},
+		remoteIndexId: 200,
+		localIndexId:  201,
+		hostId:        ip2int(ipNet2.IP),
+	})
+
+	c := Control{
+		f: &Interface{
+			hostMap: hm,
+		},
+		l: logrus.New(),
+	}
+
+	thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
+
+	expectedInfo := ControlHostInfo{
+		VpnIP:          net.IPv4(1, 2, 3, 4).To4(),
+		LocalIndex:     201,
+		RemoteIndex:    200,
+		RemoteAddrs:    []udpAddr{*remote1, *remote2},
+		CachedPackets:  0,
+		Cert:           crt.Copy(),
+		MessageCounter: 0,
+		CurrentRemote:  *NewUDPAddr(100, 4444),
+	}
+
+	// Make sure we don't have any unexpected fields
+	assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
+	util.AssertDeepCopyEqual(t, &expectedInfo, thi)
+
+	// Make sure we don't panic if the host info doesn't have a cert yet
+	assert.NotPanics(t, func() {
+		thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
+	})
+}
+
+func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
+	val := reflect.ValueOf(actualStruct).Elem()
+	fields := make([]string, val.NumField())
+	for i := 0; i < val.NumField(); i++ {
+		fields[i] = val.Type().Field(i).Name
+	}
+
+	assert.Equal(t, expected, fields)
+}

+ 8 - 2
firewall.go

@@ -221,11 +221,17 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
 
 // AddRule properly creates the in memory rule structure for a firewall table.
 func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
+	// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
+	// https://github.com/golang/go/issues/14131
+	sIp := ""
+	if ip != nil {
+		sIp = ip.String()
+	}
 
 	// We need this rule string because we generate a hash. Removing this will break firewall reload.
 	ruleString := fmt.Sprintf(
 		"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
-		incoming, proto, startPort, endPort, groups, host, ip, caName, caSha,
+		incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
 	)
 	f.rules += ruleString + "\n"
 
@@ -233,7 +239,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 	if !incoming {
 		direction = "outgoing"
 	}
-	l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": ip, "caName": caName, "caSha": caSha}).
+	l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
 		Info("Firewall rule added")
 
 	var (

+ 1 - 1
go.mod

@@ -22,7 +22,7 @@ require (
 	github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563
 	github.com/sirupsen/logrus v1.4.2
 	github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b
-	github.com/stretchr/testify v1.4.0
+	github.com/stretchr/testify v1.6.1
 	github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a
 	github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
 	golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975

+ 4 - 4
go.sum

@@ -103,8 +103,8 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
 github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
 github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
-github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
-github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
+github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a h1:Bt1IVPhiCDMqwGrc2nnbIN4QKvJGx6SK2NzWBmW00ao=
 github.com/vishvananda/netlink v1.0.1-0.20190522153524-00009fb8606a/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
 github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k=
@@ -112,8 +112,6 @@ github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17
 golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
-golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g=
-golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo=
 golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -154,3 +152,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
 gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
 gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

+ 12 - 5
interface.go

@@ -35,7 +35,10 @@ type InterfaceConfig struct {
 	DropLocalBroadcast      bool
 	DropMulticast           bool
 	UDPBatchSize            int
+	udpQueues               int
+	tunQueues               int
 	MessageMetrics          *MessageMetrics
+	version                 string
 }
 
 type Interface struct {
@@ -54,6 +57,8 @@ type Interface struct {
 	dropLocalBroadcast bool
 	dropMulticast      bool
 	udpBatchSize       int
+	udpQueues          int
+	tunQueues          int
 	version            string
 
 	metricHandshakes metrics.Histogram
@@ -89,6 +94,9 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
 		dropLocalBroadcast: c.DropLocalBroadcast,
 		dropMulticast:      c.DropMulticast,
 		udpBatchSize:       c.UDPBatchSize,
+		udpQueues:          c.udpQueues,
+		tunQueues:          c.tunQueues,
+		version:            c.version,
 
 		metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
 		messageMetrics:   c.MessageMetrics,
@@ -99,29 +107,28 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
 	return ifce, nil
 }
 
-func (f *Interface) Run(tunRoutines, udpRoutines int, buildVersion string) {
+func (f *Interface) run() {
 	// actually turn on tun dev
 	if err := f.inside.Activate(); err != nil {
 		l.Fatal(err)
 	}
 
-	f.version = buildVersion
 	addr, err := f.outside.LocalAddr()
 	if err != nil {
 		l.WithError(err).Error("Failed to get udp listen address")
 	}
 
 	l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
-		WithField("build", buildVersion).WithField("udpAddr", addr).
+		WithField("build", f.version).WithField("udpAddr", addr).
 		Info("Nebula interface is active")
 
 	// Launch n queues to read packets from udp
-	for i := 0; i < udpRoutines; i++ {
+	for i := 0; i < f.udpQueues; i++ {
 		go f.listenOut(i)
 	}
 
 	// Launch n queues to read packets from tun dev
-	for i := 0; i < tunRoutines; i++ {
+	for i := 0; i < f.tunQueues; i++ {
 		go f.listenIn(i)
 	}
 }

+ 8 - 0
logger.go

@@ -1,6 +1,8 @@
 package nebula
 
 import (
+	"errors"
+
 	"github.com/sirupsen/logrus"
 )
 
@@ -15,10 +17,16 @@ func NewContextualError(msg string, fields map[string]interface{}, realError err
 }
 
 func (ce ContextualError) Error() string {
+	if ce.RealError == nil {
+		return ce.Context
+	}
 	return ce.RealError.Error()
 }
 
 func (ce ContextualError) Unwrap() error {
+	if ce.RealError == nil {
+		return errors.New(ce.Context)
+	}
 	return ce.RealError
 }
 

+ 29 - 94
main.go

@@ -4,11 +4,8 @@ import (
 	"encoding/binary"
 	"fmt"
 	"net"
-	"os"
-	"os/signal"
 	"strconv"
 	"strings"
-	"syscall"
 	"time"
 
 	"github.com/sirupsen/logrus"
@@ -21,12 +18,7 @@ var l = logrus.New()
 
 type m map[string]interface{}
 
-type CommandRequest struct {
-	Command  string
-	Callback chan error
-}
-
-func Main(config *Config, configTest bool, block bool, buildVersion string, logger *logrus.Logger, tunFd *int, commandChan <-chan CommandRequest) error {
+func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
 	l = logger
 	l.Formatter = &logrus.TextFormatter{
 		FullTimestamp: true,
@@ -36,7 +28,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 	if configTest {
 		b, err := yaml.Marshal(config.Settings)
 		if err != nil {
-			return err
+			return nil, err
 		}
 
 		// Print the final config
@@ -45,7 +37,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 
 	err := configLogger(config)
 	if err != nil {
-		return NewContextualError("Failed to configure the logger", nil, err)
+		return nil, NewContextualError("Failed to configure the logger", nil, err)
 	}
 
 	config.RegisterReloadCallback(func(c *Config) {
@@ -59,20 +51,20 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 	trustedCAs, err = loadCAFromConfig(config)
 	if err != nil {
 		//The errors coming out of loadCA are already nicely formatted
-		return NewContextualError("Failed to load ca from config", nil, err)
+		return nil, NewContextualError("Failed to load ca from config", nil, err)
 	}
 	l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
 
 	cs, err := NewCertStateFromConfig(config)
 	if err != nil {
 		//The errors coming out of NewCertStateFromConfig are already nicely formatted
-		return NewContextualError("Failed to load certificate from config", nil, err)
+		return nil, NewContextualError("Failed to load certificate from config", nil, err)
 	}
 	l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
 
 	fw, err := NewFirewallFromConfig(cs.certificate, config)
 	if err != nil {
-		return NewContextualError("Error while loading firewall rules", nil, err)
+		return nil, NewContextualError("Error while loading firewall rules", nil, err)
 	}
 	l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
 
@@ -80,11 +72,11 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 	tunCidr := cs.certificate.Details.Ips[0]
 	routes, err := parseRoutes(config, tunCidr)
 	if err != nil {
-		return NewContextualError("Could not parse tun.routes", nil, err)
+		return nil, NewContextualError("Could not parse tun.routes", nil, err)
 	}
 	unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
 	if err != nil {
-		return NewContextualError("Could not parse tun.unsafe_routes", nil, err)
+		return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
 	}
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
@@ -92,7 +84,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 	if config.GetBool("sshd.enabled", false) {
 		err = configSSH(ssh, config)
 		if err != nil {
-			return NewContextualError("Error while configuring the sshd", nil, err)
+			return nil, NewContextualError("Error while configuring the sshd", nil, err)
 		}
 	}
 
@@ -129,7 +121,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 		}
 
 		if err != nil {
-			return NewContextualError("Failed to get a tun/tap device", nil, err)
+			return nil, NewContextualError("Failed to get a tun/tap device", nil, err)
 		}
 	}
 
@@ -140,28 +132,11 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 	if !configTest {
 		udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
 		if err != nil {
-			return NewContextualError("Failed to open udp listener", nil, err)
+			return nil, NewContextualError("Failed to open udp listener", nil, err)
 		}
 		udpServer.reloadConfig(config)
 	}
 
-	sigChan := make(chan os.Signal)
-	killChan := make(chan CommandRequest)
-	if commandChan != nil {
-		go func() {
-			cmd := CommandRequest{}
-			for {
-				cmd = <-commandChan
-				switch cmd.Command {
-				case "rebind":
-					udpServer.Rebind()
-				case "exit":
-					killChan <- cmd
-				}
-			}
-		}()
-	}
-
 	// Set up my internal host map
 	var preferredRanges []*net.IPNet
 	rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
@@ -170,7 +145,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 		for _, rawPreferredRange := range rawPreferredRanges {
 			_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
 			if err != nil {
-				return NewContextualError("Failed to parse preferred ranges", nil, err)
+				return nil, NewContextualError("Failed to parse preferred ranges", nil, err)
 			}
 			preferredRanges = append(preferredRanges, preferredRange)
 		}
@@ -183,7 +158,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 	if rawLocalRange != "" {
 		_, localRange, err := net.ParseCIDR(rawLocalRange)
 		if err != nil {
-			return NewContextualError("Failed to parse local_range", nil, err)
+			return nil, NewContextualError("Failed to parse local_range", nil, err)
 		}
 
 		// Check if the entry for local_range was already specified in
@@ -223,7 +198,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 	if port == 0 && !configTest {
 		uPort, err := udpServer.LocalAddr()
 		if err != nil {
-			return NewContextualError("Failed to get listening port", nil, err)
+			return nil, NewContextualError("Failed to get listening port", nil, err)
 		}
 		port = int(uPort.Port)
 	}
@@ -240,10 +215,10 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 	for i, host := range rawLighthouseHosts {
 		ip := net.ParseIP(host)
 		if ip == nil {
-			return NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
+			return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
 		}
 		if !tunCidr.Contains(ip) {
-			return NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
+			return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
 		}
 		lighthouseHosts[i] = ip2int(ip)
 	}
@@ -263,13 +238,13 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 
 	remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
 	if err != nil {
-		return NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
+		return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
 	}
 	lightHouse.SetRemoteAllowList(remoteAllowList)
 
 	localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
 	if err != nil {
-		return NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
+		return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
 	}
 	lightHouse.SetLocalAllowList(localAllowList)
 
@@ -277,7 +252,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 	for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
 		vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
 		if !tunCidr.Contains(vpnIp) {
-			return NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
+			return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
 		}
 		vals, ok := v.([]interface{})
 		if ok {
@@ -288,7 +263,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 					ip := addr.IP
 					port, err := strconv.Atoi(parts[1])
 					if err != nil {
-						return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
+						return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 					}
 					lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
 				}
@@ -301,7 +276,7 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 				ip := addr.IP
 				port, err := strconv.Atoi(parts[1])
 				if err != nil {
-					return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
+					return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 				}
 				lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
 			}
@@ -354,7 +329,10 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 		DropLocalBroadcast:      config.GetBool("tun.drop_local_broadcast", false),
 		DropMulticast:           config.GetBool("tun.drop_multicast", false),
 		UDPBatchSize:            config.GetInt("listen.batch", 64),
+		udpQueues:               udpQueues,
+		tunQueues:               config.GetInt("tun.routines", 1),
 		MessageMetrics:          messageMetrics,
+		version:                 buildVersion,
 	}
 
 	switch ifConfig.Cipher {
@@ -363,14 +341,14 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 	case "chachapoly":
 		noiseEndianness = binary.LittleEndian
 	default:
-		return fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
+		return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
 	}
 
 	var ifce *Interface
 	if !configTest {
 		ifce, err = NewInterface(ifConfig)
 		if err != nil {
-			return fmt.Errorf("failed to initialize interface: %s", err)
+			return nil, fmt.Errorf("failed to initialize interface: %s", err)
 		}
 
 		ifce.RegisterConfigChangeCallbacks(config)
@@ -381,18 +359,17 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 
 	err = startStats(config, configTest)
 	if err != nil {
-		return NewContextualError("Failed to start stats emitter", nil, err)
+		return nil, NewContextualError("Failed to start stats emitter", nil, err)
 	}
 
 	if configTest {
-		return nil
+		return nil, nil
 	}
 
 	//TODO: check if we _should_ be emitting stats
 	go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
 
 	attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
-	ifce.Run(config.GetInt("tun.routines", 1), udpQueues, buildVersion)
 
 	// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
 	if amLighthouse && serveDns {
@@ -400,47 +377,5 @@ func Main(config *Config, configTest bool, block bool, buildVersion string, logg
 		go dnsMain(hostMap, config)
 	}
 
-	if block {
-		// Just sit here and be friendly, main thread.
-		shutdownBlock(ifce, sigChan, killChan)
-	} else {
-		// Even though we aren't blocking we still want to shutdown gracefully
-		go shutdownBlock(ifce, sigChan, killChan)
-	}
-	return nil
-}
-
-func shutdownBlock(ifce *Interface, sigChan chan os.Signal, killChan chan CommandRequest) {
-	var cmd CommandRequest
-	var sig string
-
-	signal.Notify(sigChan, syscall.SIGTERM)
-	signal.Notify(sigChan, syscall.SIGINT)
-
-	select {
-	case rawSig := <-sigChan:
-		sig = rawSig.String()
-	case cmd = <-killChan:
-		sig = "controlling app"
-	}
-
-	l.WithField("signal", sig).Info("Caught signal, shutting down")
-
-	//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
-	//TODO: this is probably better as a function in ConnectionManager or HostMap directly
-	ifce.hostMap.Lock()
-	for _, h := range ifce.hostMap.Hosts {
-		if h.ConnectionState.ready {
-			ifce.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
-			l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
-				Debug("Sending close tunnel message")
-		}
-	}
-	ifce.hostMap.Unlock()
-
-	l.WithField("signal", sig).Info("Goodbye")
-	select {
-	case cmd.Callback <- nil:
-	default:
-	}
+	return &Control{ifce, l}, nil
 }

+ 2 - 2
udp_android.go

@@ -31,6 +31,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 }
 
-func (u *udpConn) Rebind() {
-	return
+func (u *udpConn) Rebind() error {
+	return nil
 }

+ 2 - 2
udp_freebsd.go

@@ -33,6 +33,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 }
 
-func (u *udpConn) Rebind() {
-	return
+func (u *udpConn) Rebind() error {
+	return nil
 }

+ 11 - 0
udp_generic.go

@@ -65,6 +65,17 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
 	return ua.IP.Equal(t.IP) && ua.Port == t.Port
 }
 
+func (ua *udpAddr) Copy() udpAddr {
+	nu := udpAddr{net.UDPAddr{
+		Port: ua.Port,
+		Zone: ua.Zone,
+		IP:   make(net.IP, len(ua.IP)),
+	}}
+
+	copy(nu.IP, ua.IP)
+	return nu
+}
+
 func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
 	_, err := uc.UDPConn.WriteToUDP(b, &addr.UDPAddr)
 	return err

+ 6 - 9
udp_linux.go

@@ -89,8 +89,12 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
 	return &udpConn{sysFd: fd}, err
 }
 
-func (u *udpConn) Rebind() {
-	return
+func (u *udpConn) Rebind() error {
+	return nil
+}
+
+func (ua *udpAddr) Copy() udpAddr {
+	return *ua
 }
 
 func (u *udpConn) SetRecvBuffer(n int) error {
@@ -282,13 +286,6 @@ func (ua *udpAddr) Equals(t *udpAddr) bool {
 	return ua.IP == t.IP && ua.Port == t.Port
 }
 
-func (ua *udpAddr) Copy() *udpAddr {
-	return &udpAddr{
-		Port: ua.Port,
-		IP:   ua.IP,
-	}
-}
-
 func (ua *udpAddr) String() string {
 	return fmt.Sprintf("%s:%v", int2ip(ua.IP), ua.Port)
 }

+ 2 - 2
udp_windows.go

@@ -21,6 +21,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 }
 
-func (u *udpConn) Rebind() {
-	return
+func (u *udpConn) Rebind() error {
+	return nil
 }

+ 130 - 0
util/assert.go

@@ -0,0 +1,130 @@
+package util
+
+import (
+	"fmt"
+	"reflect"
+	"testing"
+	"time"
+	"unsafe"
+
+	"github.com/stretchr/testify/assert"
+)
+
+// AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
+// There is currently a special case for `time.loc` (as this code traverses into unexported fields)
+func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
+	v1 := reflect.ValueOf(a)
+	v2 := reflect.ValueOf(b)
+
+	if !assert.Equal(t, v1.Type(), v2.Type()) {
+		return
+	}
+
+	traverseDeepCopy(t, v1, v2, v1.Type().String())
+}
+
+func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool {
+	switch v1.Kind() {
+	case reflect.Array:
+		for i := 0; i < v1.Len(); i++ {
+			if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
+				return false
+			}
+		}
+		return true
+
+	case reflect.Slice:
+		if v1.IsNil() || v2.IsNil() {
+			return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil %+v, %+v", name, v1, v2)
+		}
+
+		if !assert.Equal(t, v1.Len(), v2.Len(), "%s did not have the same length", name) {
+			return false
+		}
+
+		// A slice with cap 0
+		if v1.Cap() != 0 && !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same slice %v == %v", name, v1.Pointer(), v2.Pointer()) {
+			return false
+		}
+
+		v1c := v1.Cap()
+		v2c := v2.Cap()
+		if v1c > 0 && v2c > 0 && v1.Slice(0, v1c).Slice(v1c-1, v1c-1).Pointer() == v2.Slice(0, v2c).Slice(v2c-1, v2c-1).Pointer() {
+			return assert.Fail(t, "", "%s share some underlying memory", name)
+		}
+
+		for i := 0; i < v1.Len(); i++ {
+			if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) {
+				return false
+			}
+		}
+		return true
+
+	case reflect.Interface:
+		if v1.IsNil() || v2.IsNil() {
+			return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
+		}
+		return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
+
+	case reflect.Ptr:
+		local := reflect.ValueOf(time.Local).Pointer()
+		if local == v1.Pointer() && local == v2.Pointer() {
+			return true
+		}
+
+		if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s points to the same memory", name) {
+			return false
+		}
+
+		return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name)
+
+	case reflect.Struct:
+		for i, n := 0, v1.NumField(); i < n; i++ {
+			if !traverseDeepCopy(t, v1.Field(i), v2.Field(i), name+"."+v1.Type().Field(i).Name) {
+				return false
+			}
+		}
+		return true
+
+	case reflect.Map:
+		if v1.IsNil() || v2.IsNil() {
+			return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name)
+		}
+
+		if !assert.Equal(t, v1.Len(), v2.Len(), "%s are not the same length", name) {
+			return false
+		}
+
+		if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same memory", name) {
+			return false
+		}
+
+		for _, k := range v1.MapKeys() {
+			val1 := v1.MapIndex(k)
+			val2 := v2.MapIndex(k)
+			if !assert.True(t, val1.IsValid(), "%s is an invalid key in %s", k, name) {
+				return false
+			}
+
+			if !assert.True(t, val2.IsValid(), "%s is an invalid key in %s", k, name) {
+				return false
+			}
+
+			if !traverseDeepCopy(t, val1, val2, name+fmt.Sprintf("%s[%s]", name, k)) {
+				return false
+			}
+		}
+
+		return true
+
+	default:
+		if v1.CanInterface() && v2.CanInterface() {
+			return assert.Equal(t, v1.Interface(), v2.Interface(), "%s was not equal", name)
+		}
+
+		e1 := reflect.NewAt(v1.Type(), unsafe.Pointer(v1.UnsafeAddr())).Elem().Interface()
+		e2 := reflect.NewAt(v2.Type(), unsafe.Pointer(v2.UnsafeAddr())).Elem().Interface()
+
+		return assert.Equal(t, e1, e2, "%s (unexported) was not equal", name)
+	}
+}