Browse Source

Move util to test, contextual errors to util (#575)

Nate Brown 3 years ago
parent
commit
4453964e34

+ 2 - 2
allow_list_test.go

@@ -7,12 +7,12 @@ import (
 
 	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/util"
+	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestNewAllowListFromConfig(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	c := config.NewC(l)
 	c.Settings["allowlist"] = map[interface{}]interface{}{
 		"192.168.0.0": true,

+ 5 - 5
bits_test.go

@@ -3,12 +3,12 @@ package nebula
 import (
 	"testing"
 
-	"github.com/slackhq/nebula/util"
+	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestBits(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	b := NewBits(10)
 
 	// make sure it is the right size
@@ -76,7 +76,7 @@ func TestBits(t *testing.T) {
 }
 
 func TestBitsDupeCounter(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	b := NewBits(10)
 	b.lostCounter.Clear()
 	b.dupeCounter.Clear()
@@ -101,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) {
 }
 
 func TestBitsOutOfWindowCounter(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	b := NewBits(10)
 	b.lostCounter.Clear()
 	b.dupeCounter.Clear()
@@ -131,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
 }
 
 func TestBitsLostCounter(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	b := NewBits(10)
 	b.lostCounter.Clear()
 	b.dupeCounter.Clear()

+ 2 - 2
cert/cert_test.go

@@ -9,7 +9,7 @@ import (
 	"time"
 
 	"github.com/golang/protobuf/proto"
-	"github.com/slackhq/nebula/util"
+	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/crypto/curve25519"
 	"golang.org/x/crypto/ed25519"
@@ -752,7 +752,7 @@ func TestNebulaCertificate_Copy(t *testing.T) {
 	assert.Nil(t, err)
 	cc := c.Copy()
 
-	util.AssertDeepCopyEqual(t, c, cc)
+	test.AssertDeepCopyEqual(t, c, cc)
 }
 
 func TestUnmarshalNebulaCertificate(t *testing.T) {

+ 2 - 1
cmd/nebula-service/main.go

@@ -8,6 +8,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
 )
 
 // A version string that can be set with
@@ -60,7 +61,7 @@ func main() {
 	ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
 
 	switch v := err.(type) {
-	case nebula.ContextualError:
+	case util.ContextualError:
 		v.Log(l)
 		os.Exit(1)
 	case error:

+ 2 - 1
cmd/nebula/main.go

@@ -8,6 +8,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
 )
 
 // A version string that can be set with
@@ -54,7 +55,7 @@ func main() {
 	ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
 
 	switch v := err.(type) {
-	case nebula.ContextualError:
+	case util.ContextualError:
 		v.Log(l)
 		os.Exit(1)
 	case error:

+ 7 - 7
config/config_test.go

@@ -7,12 +7,12 @@ import (
 	"testing"
 	"time"
 
-	"github.com/slackhq/nebula/util"
+	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestConfig_Load(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	dir, err := ioutil.TempDir("", "config-test")
 	// invalid yaml
 	c := NewC(l)
@@ -42,7 +42,7 @@ func TestConfig_Load(t *testing.T) {
 }
 
 func TestConfig_Get(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	// test simple type
 	c := NewC(l)
 	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
@@ -58,14 +58,14 @@ func TestConfig_Get(t *testing.T) {
 }
 
 func TestConfig_GetStringSlice(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	c := NewC(l)
 	c.Settings["slice"] = []interface{}{"one", "two"}
 	assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
 }
 
 func TestConfig_GetBool(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	c := NewC(l)
 	c.Settings["bool"] = true
 	assert.Equal(t, true, c.GetBool("bool", false))
@@ -93,7 +93,7 @@ func TestConfig_GetBool(t *testing.T) {
 }
 
 func TestConfig_HasChanged(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	// No reload has occurred, return false
 	c := NewC(l)
 	c.Settings["test"] = "hi"
@@ -115,7 +115,7 @@ func TestConfig_HasChanged(t *testing.T) {
 }
 
 func TestConfig_ReloadConfig(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	done := make(chan bool, 1)
 	dir, err := ioutil.TempDir("", "config-test")
 	assert.Nil(t, err)

+ 4 - 4
connection_manager_test.go

@@ -11,15 +11,15 @@ import (
 	"github.com/flynn/noise"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
-	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
 var vpnIp iputil.VpnIp
 
 func Test_NewConnectionManagerTest(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@@ -89,7 +89,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 }
 
 func Test_NewConnectionManagerTest2(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@@ -164,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 // Disconnect only if disconnectInvalid: true is set.
 func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	now := time.Now()
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	ipNet := net.IPNet{
 		IP:   net.IPv4(172, 1, 1, 2),
 		Mask: net.IPMask{255, 255, 255, 0},

+ 3 - 3
control_test.go

@@ -9,13 +9,13 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
-	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestControl_GetHostInfoByVpnIp(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
 	hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
@@ -94,7 +94,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 
 	// Make sure we don't have any unexpected fields
 	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
-	util.AssertDeepCopyEqual(t, &expectedInfo, thi)
+	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	assert.NotPanics(t, func() {

+ 10 - 10
firewall_test.go

@@ -14,12 +14,12 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/iputil"
-	"github.com/slackhq/nebula/util"
+	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestNewFirewall(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	c := &cert.NebulaCertificate{}
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	conntrack := fw.Conntrack
@@ -58,7 +58,7 @@ func TestNewFirewall(t *testing.T) {
 }
 
 func TestFirewall_AddRule(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
@@ -133,7 +133,7 @@ func TestFirewall_AddRule(t *testing.T) {
 }
 
 func TestFirewall_Drop(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
@@ -308,7 +308,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 }
 
 func TestFirewall_Drop2(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
@@ -367,7 +367,7 @@ func TestFirewall_Drop2(t *testing.T) {
 }
 
 func TestFirewall_Drop3(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
@@ -453,7 +453,7 @@ func TestFirewall_Drop3(t *testing.T) {
 }
 
 func TestFirewall_DropConntrackReload(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
@@ -635,7 +635,7 @@ func Test_parsePort(t *testing.T) {
 }
 
 func TestNewFirewallFromConfig(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	// Test a bad rule definition
 	c := &cert.NebulaCertificate{}
 	conf := config.NewC(l)
@@ -685,7 +685,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
 }
 
 func TestAddFirewallRulesFromConfig(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	// Test adding tcp rule
 	conf := config.NewC(l)
 	mf := &mockFirewall{}
@@ -849,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) {
 }
 
 func TestFirewall_convertRule(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 

+ 3 - 3
handshake_manager_test.go

@@ -7,13 +7,13 @@ import (
 
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
-	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 }
 
 func Test_NewHandshakeManagerTrigger(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")

+ 5 - 5
lighthouse_test.go

@@ -8,8 +8,8 @@ import (
 	"github.com/golang/protobuf/proto"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
-	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
@@ -46,7 +46,7 @@ func TestNewLhQuery(t *testing.T) {
 }
 
 func Test_lhStaticMapping(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	lh1 := "10.128.0.2"
 	lh1IP := net.ParseIP(lh1)
 
@@ -67,7 +67,7 @@ func Test_lhStaticMapping(t *testing.T) {
 }
 
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	lh1 := "10.128.0.2"
 	lh1IP := net.ParseIP(lh1)
 
@@ -137,7 +137,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 }
 
 func TestLighthouse_Memory(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 
 	myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
 	myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
@@ -266,7 +266,7 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr,
 
 //TODO: this is a RemoteList test
 //func Test_lhRemoteAllowList(t *testing.T) {
-//	l := NewTestLogger()
+//	l := NewLogger()
 //	c := NewConfig(l)
 //	c.Settings["remoteallowlist"] = map[interface{}]interface{}{
 //		"10.20.0.0/12": false,

+ 0 - 33
logger.go

@@ -1,7 +1,6 @@
 package nebula
 
 import (
-	"errors"
 	"fmt"
 	"strings"
 	"time"
@@ -10,38 +9,6 @@ import (
 	"github.com/slackhq/nebula/config"
 )
 
-type ContextualError struct {
-	RealError error
-	Fields    map[string]interface{}
-	Context   string
-}
-
-func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
-	return ContextualError{Context: msg, Fields: fields, RealError: realError}
-}
-
-func (ce ContextualError) Error() string {
-	if ce.RealError == nil {
-		return ce.Context
-	}
-	return ce.RealError.Error()
-}
-
-func (ce ContextualError) Unwrap() error {
-	if ce.RealError == nil {
-		return errors.New(ce.Context)
-	}
-	return ce.RealError
-}
-
-func (ce *ContextualError) Log(lr *logrus.Logger) {
-	if ce.RealError != nil {
-		lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
-	} else {
-		lr.WithFields(ce.Fields).Error(ce.Context)
-	}
-}
-
 func configLogger(l *logrus.Logger, c *config.C) error {
 	// set up our logging level
 	logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))

+ 22 - 21
main.go

@@ -12,6 +12,7 @@ import (
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/sshd"
 	"github.com/slackhq/nebula/udp"
+	"github.com/slackhq/nebula/util"
 	"gopkg.in/yaml.v2"
 )
 
@@ -44,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	err := configLogger(l, c)
 	if err != nil {
-		return nil, NewContextualError("Failed to configure the logger", nil, err)
+		return nil, util.NewContextualError("Failed to configure the logger", nil, err)
 	}
 
 	c.RegisterReloadCallback(func(c *config.C) {
@@ -57,20 +58,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	caPool, err := loadCAFromConfig(l, c)
 	if err != nil {
 		//The errors coming out of loadCA are already nicely formatted
-		return nil, NewContextualError("Failed to load ca from config", nil, err)
+		return nil, util.NewContextualError("Failed to load ca from config", nil, err)
 	}
 	l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
 
 	cs, err := NewCertStateFromConfig(c)
 	if err != nil {
 		//The errors coming out of NewCertStateFromConfig are already nicely formatted
-		return nil, NewContextualError("Failed to load certificate from config", nil, err)
+		return nil, util.NewContextualError("Failed to load certificate from config", nil, err)
 	}
 	l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
 
 	fw, err := NewFirewallFromConfig(l, cs.certificate, c)
 	if err != nil {
-		return nil, NewContextualError("Error while loading firewall rules", nil, err)
+		return nil, util.NewContextualError("Error while loading firewall rules", nil, err)
 	}
 	l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
 
@@ -78,11 +79,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	tunCidr := cs.certificate.Details.Ips[0]
 	routes, err := parseRoutes(c, tunCidr)
 	if err != nil {
-		return nil, NewContextualError("Could not parse tun.routes", nil, err)
+		return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
 	}
 	unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
 	if err != nil {
-		return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
+		return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
 	}
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
@@ -91,7 +92,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	if c.GetBool("sshd.enabled", false) {
 		sshStart, err = configSSH(l, ssh, c)
 		if err != nil {
-			return nil, NewContextualError("Error while configuring the sshd", nil, err)
+			return nil, util.NewContextualError("Error while configuring the sshd", nil, err)
 		}
 	}
 
@@ -167,7 +168,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 
 		if err != nil {
-			return nil, NewContextualError("Failed to get a tun/tap device", nil, err)
+			return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err)
 		}
 	}
 
@@ -185,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		for i := 0; i < routines; i++ {
 			udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64))
 			if err != nil {
-				return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
+				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 			}
 			udpServer.ReloadConfig(c)
 			udpConns[i] = udpServer
@@ -194,7 +195,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 			if port == 0 {
 				uPort, err := udpServer.LocalAddr()
 				if err != nil {
-					return nil, NewContextualError("Failed to get listening port", nil, err)
+					return nil, util.NewContextualError("Failed to get listening port", nil, err)
 				}
 				port = int(uPort.Port)
 			}
@@ -209,7 +210,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		for _, rawPreferredRange := range rawPreferredRanges {
 			_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
 			if err != nil {
-				return nil, NewContextualError("Failed to parse preferred ranges", nil, err)
+				return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err)
 			}
 			preferredRanges = append(preferredRanges, preferredRange)
 		}
@@ -222,7 +223,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	if rawLocalRange != "" {
 		_, localRange, err := net.ParseCIDR(rawLocalRange)
 		if err != nil {
-			return nil, NewContextualError("Failed to parse local_range", nil, err)
+			return nil, util.NewContextualError("Failed to parse local_range", nil, err)
 		}
 
 		// Check if the entry for local_range was already specified in
@@ -261,7 +262,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	// fatal if am_lighthouse is enabled but we are using an ephemeral port
 	if amLighthouse && (c.GetInt("listen.port", 0) == 0) {
-		return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
+		return nil, util.NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
 	}
 
 	// warn if am_lighthouse is enabled but upstream lighthouses exists
@@ -274,10 +275,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	for i, host := range rawLighthouseHosts {
 		ip := net.ParseIP(host)
 		if ip == nil {
-			return nil, NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
+			return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
 		}
 		if !tunCidr.Contains(ip) {
-			return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
+			return nil, util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
 		}
 		lighthouseHosts[i] = iputil.Ip2VpnIp(ip)
 	}
@@ -298,13 +299,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
 	if err != nil {
-		return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
+		return nil, util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
 	}
 	lightHouse.SetRemoteAllowList(remoteAllowList)
 
 	localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
 	if err != nil {
-		return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
+		return nil, util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
 	}
 	lightHouse.SetLocalAllowList(localAllowList)
 
@@ -313,21 +314,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		ip := net.ParseIP(fmt.Sprintf("%v", k))
 		vpnIp := iputil.Ip2VpnIp(ip)
 		if !tunCidr.Contains(ip) {
-			return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
+			return nil, util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
 		}
 		vals, ok := v.([]interface{})
 		if ok {
 			for _, v := range vals {
 				ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
 				if err != nil {
-					return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
+					return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 				}
 				lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
 			}
 		} else {
 			ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
 			if err != nil {
-				return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
+				return nil, util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 			}
 			lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
 		}
@@ -426,7 +427,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	statsStart, err := startStats(l, c, buildVersion, configTest)
 
 	if err != nil {
-		return nil, NewContextualError("Failed to start stats emitter", nil, err)
+		return nil, util.NewContextualError("Failed to start stats emitter", nil, err)
 	}
 
 	if configTest {

+ 2 - 2
punchy_test.go

@@ -5,12 +5,12 @@ import (
 	"time"
 
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/util"
+	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestNewPunchyFromConfig(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	c := config.NewC(l)
 
 	// Test defaults

+ 1 - 1
util/assert.go → test/assert.go

@@ -1,4 +1,4 @@
-package util
+package test
 
 import (
 	"fmt"

+ 2 - 2
util/main.go → test/logger.go

@@ -1,4 +1,4 @@
-package util
+package test
 
 import (
 	"io/ioutil"
@@ -7,7 +7,7 @@ import (
 	"github.com/sirupsen/logrus"
 )
 
-func NewTestLogger() *logrus.Logger {
+func NewLogger() *logrus.Logger {
 	l := logrus.New()
 
 	v := os.Getenv("TEST_LOGS")

+ 3 - 3
tun_test.go

@@ -6,12 +6,12 @@ import (
 	"testing"
 
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/util"
+	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 )
 
 func Test_parseRoutes(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	c := config.NewC(l)
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 
@@ -107,7 +107,7 @@ func Test_parseRoutes(t *testing.T) {
 }
 
 func Test_parseUnsafeRoutes(t *testing.T) {
-	l := util.NewTestLogger()
+	l := test.NewLogger()
 	c := config.NewC(l)
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 

+ 39 - 0
util/error.go

@@ -0,0 +1,39 @@
+package util
+
+import (
+	"errors"
+
+	"github.com/sirupsen/logrus"
+)
+
+type ContextualError struct {
+	RealError error
+	Fields    map[string]interface{}
+	Context   string
+}
+
+func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
+	return ContextualError{Context: msg, Fields: fields, RealError: realError}
+}
+
+func (ce ContextualError) Error() string {
+	if ce.RealError == nil {
+		return ce.Context
+	}
+	return ce.RealError.Error()
+}
+
+func (ce ContextualError) Unwrap() error {
+	if ce.RealError == nil {
+		return errors.New(ce.Context)
+	}
+	return ce.RealError
+}
+
+func (ce *ContextualError) Log(lr *logrus.Logger) {
+	if ce.RealError != nil {
+		lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
+	} else {
+		lr.WithFields(ce.Fields).Error(ce.Context)
+	}
+}

+ 3 - 1
logger_test.go → util/error_test.go

@@ -1,4 +1,4 @@
-package nebula
+package util
 
 import (
 	"errors"
@@ -8,6 +8,8 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
+type m map[string]interface{}
+
 type TestLogWriter struct {
 	Logs []string
 }