Browse Source

Don't use a global logger (#423)

Nathan Brown 4 years ago
parent
commit
3ea7e1b75f
45 changed files with 591 additions and 471 deletions
  1. 2 2
      bits.go
  2. 79 75
      bits_test.go
  3. 2 1
      cert.go
  4. 4 3
      cmd/nebula-service/main.go
  5. 4 3
      cmd/nebula-service/service.go
  6. 4 3
      cmd/nebula/main.go
  7. 10 8
      config.go
  8. 17 10
      config_test.go
  9. 11 9
      connection_manager.go
  10. 12 8
      connection_manager_test.go
  11. 3 2
      connection_state.go
  12. 2 1
      control_test.go
  13. 16 11
      dns_server.go
  14. 17 14
      firewall.go
  15. 73 76
      firewall_test.go
  16. 1 1
      handshake.go
  17. 33 33
      handshake_ix.go
  18. 11 9
      handshake_manager.go
  19. 15 11
      handshake_manager_test.go
  20. 29 25
      hostmap.go
  21. 6 3
      hostmap_test.go
  22. 20 20
      inside.go
  23. 22 18
      interface.go
  24. 19 17
      lighthouse.go
  25. 11 8
      lighthouse_test.go
  26. 15 14
      main.go
  27. 29 0
      main_test.go
  28. 25 25
      outside.go
  29. 2 1
      punchy_test.go
  30. 12 8
      ssh.go
  31. 6 5
      stats.go
  32. 4 1
      tun_android.go
  33. 5 3
      tun_darwin.go
  34. 14 15
      tun_disabled.go
  35. 10 6
      tun_freebsd.go
  36. 8 4
      tun_linux.go
  37. 4 2
      tun_test.go
  38. 5 2
      tun_windows.go
  39. 2 0
      udp_android.go
  40. 2 0
      udp_darwin.go
  41. 2 0
      udp_freebsd.go
  42. 8 4
      udp_generic.go
  43. 13 10
      udp_linux.go
  44. 1 0
      udp_linux_32.go
  45. 1 0
      udp_linux_64.go

+ 2 - 2
bits.go

@@ -26,7 +26,7 @@ func NewBits(bits uint64) *Bits {
 	}
 }
 
-func (b *Bits) Check(i uint64) bool {
+func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
 	// If i is the next number, return true.
 	if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
 		return true
@@ -47,7 +47,7 @@ func (b *Bits) Check(i uint64) bool {
 	return false
 }
 
-func (b *Bits) Update(i uint64) bool {
+func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
 	// If i is the next number, return true and update current.
 	if i == b.current+1 {
 		// Report missed packets, we can only understand what was missed after the first window has been gone through

+ 79 - 75
bits_test.go

@@ -7,6 +7,7 @@ import (
 )
 
 func TestBits(t *testing.T) {
+	l := NewTestLogger()
 	b := NewBits(10)
 
 	// make sure it is the right size
@@ -14,46 +15,46 @@ func TestBits(t *testing.T) {
 
 	// This is initialized to zero - receive one. This should work.
 
-	assert.True(t, b.Check(1))
-	u := b.Update(1)
+	assert.True(t, b.Check(l, 1))
+	u := b.Update(l, 1)
 	assert.True(t, u)
 	assert.EqualValues(t, 1, b.current)
 	g := []bool{false, true, false, false, false, false, false, false, false, false}
 	assert.Equal(t, g, b.bits)
 
 	// Receive two
-	assert.True(t, b.Check(2))
-	u = b.Update(2)
+	assert.True(t, b.Check(l, 2))
+	u = b.Update(l, 2)
 	assert.True(t, u)
 	assert.EqualValues(t, 2, b.current)
 	g = []bool{false, true, true, false, false, false, false, false, false, false}
 	assert.Equal(t, g, b.bits)
 
 	// Receive two again - it will fail
-	assert.False(t, b.Check(2))
-	u = b.Update(2)
+	assert.False(t, b.Check(l, 2))
+	u = b.Update(l, 2)
 	assert.False(t, u)
 	assert.EqualValues(t, 2, b.current)
 
 	// Jump ahead to 15, which should clear everything and set the 6th element
-	assert.True(t, b.Check(15))
-	u = b.Update(15)
+	assert.True(t, b.Check(l, 15))
+	u = b.Update(l, 15)
 	assert.True(t, u)
 	assert.EqualValues(t, 15, b.current)
 	g = []bool{false, false, false, false, false, true, false, false, false, false}
 	assert.Equal(t, g, b.bits)
 
 	// Mark 14, which is allowed because it is in the window
-	assert.True(t, b.Check(14))
-	u = b.Update(14)
+	assert.True(t, b.Check(l, 14))
+	u = b.Update(l, 14)
 	assert.True(t, u)
 	assert.EqualValues(t, 15, b.current)
 	g = []bool{false, false, false, false, true, true, false, false, false, false}
 	assert.Equal(t, g, b.bits)
 
 	// Mark 5, which is not allowed because it is not in the window
-	assert.False(t, b.Check(5))
-	u = b.Update(5)
+	assert.False(t, b.Check(l, 5))
+	u = b.Update(l, 5)
 	assert.False(t, u)
 	assert.EqualValues(t, 15, b.current)
 	g = []bool{false, false, false, false, true, true, false, false, false, false}
@@ -61,63 +62,65 @@ func TestBits(t *testing.T) {
 
 	// make sure we handle wrapping around once to the current position
 	b = NewBits(10)
-	assert.True(t, b.Update(1))
-	assert.True(t, b.Update(11))
+	assert.True(t, b.Update(l, 1))
+	assert.True(t, b.Update(l, 11))
 	assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits)
 
 	// Walk through a few windows in order
 	b = NewBits(10)
 	for i := uint64(0); i <= 100; i++ {
-		assert.True(t, b.Check(i), "Error while checking %v", i)
-		assert.True(t, b.Update(i), "Error while updating %v", i)
+		assert.True(t, b.Check(l, i), "Error while checking %v", i)
+		assert.True(t, b.Update(l, i), "Error while updating %v", i)
 	}
 }
 
 func TestBitsDupeCounter(t *testing.T) {
+	l := NewTestLogger()
 	b := NewBits(10)
 	b.lostCounter.Clear()
 	b.dupeCounter.Clear()
 	b.outOfWindowCounter.Clear()
 
-	assert.True(t, b.Update(1))
+	assert.True(t, b.Update(l, 1))
 	assert.Equal(t, int64(0), b.dupeCounter.Count())
 
-	assert.False(t, b.Update(1))
+	assert.False(t, b.Update(l, 1))
 	assert.Equal(t, int64(1), b.dupeCounter.Count())
 
-	assert.True(t, b.Update(2))
+	assert.True(t, b.Update(l, 2))
 	assert.Equal(t, int64(1), b.dupeCounter.Count())
 
-	assert.True(t, b.Update(3))
+	assert.True(t, b.Update(l, 3))
 	assert.Equal(t, int64(1), b.dupeCounter.Count())
 
-	assert.False(t, b.Update(1))
+	assert.False(t, b.Update(l, 1))
 	assert.Equal(t, int64(0), b.lostCounter.Count())
 	assert.Equal(t, int64(2), b.dupeCounter.Count())
 	assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
 }
 
 func TestBitsOutOfWindowCounter(t *testing.T) {
+	l := NewTestLogger()
 	b := NewBits(10)
 	b.lostCounter.Clear()
 	b.dupeCounter.Clear()
 	b.outOfWindowCounter.Clear()
 
-	assert.True(t, b.Update(20))
+	assert.True(t, b.Update(l, 20))
 	assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
 
-	assert.True(t, b.Update(21))
-	assert.True(t, b.Update(22))
-	assert.True(t, b.Update(23))
-	assert.True(t, b.Update(24))
-	assert.True(t, b.Update(25))
-	assert.True(t, b.Update(26))
-	assert.True(t, b.Update(27))
-	assert.True(t, b.Update(28))
-	assert.True(t, b.Update(29))
+	assert.True(t, b.Update(l, 21))
+	assert.True(t, b.Update(l, 22))
+	assert.True(t, b.Update(l, 23))
+	assert.True(t, b.Update(l, 24))
+	assert.True(t, b.Update(l, 25))
+	assert.True(t, b.Update(l, 26))
+	assert.True(t, b.Update(l, 27))
+	assert.True(t, b.Update(l, 28))
+	assert.True(t, b.Update(l, 29))
 	assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
 
-	assert.False(t, b.Update(0))
+	assert.False(t, b.Update(l, 0))
 	assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
 
 	//tODO: make sure lostcounter doesn't increase in orderly increment
@@ -127,23 +130,24 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
 }
 
 func TestBitsLostCounter(t *testing.T) {
+	l := NewTestLogger()
 	b := NewBits(10)
 	b.lostCounter.Clear()
 	b.dupeCounter.Clear()
 	b.outOfWindowCounter.Clear()
 
 	//assert.True(t, b.Update(0))
-	assert.True(t, b.Update(0))
-	assert.True(t, b.Update(20))
-	assert.True(t, b.Update(21))
-	assert.True(t, b.Update(22))
-	assert.True(t, b.Update(23))
-	assert.True(t, b.Update(24))
-	assert.True(t, b.Update(25))
-	assert.True(t, b.Update(26))
-	assert.True(t, b.Update(27))
-	assert.True(t, b.Update(28))
-	assert.True(t, b.Update(29))
+	assert.True(t, b.Update(l, 0))
+	assert.True(t, b.Update(l, 20))
+	assert.True(t, b.Update(l, 21))
+	assert.True(t, b.Update(l, 22))
+	assert.True(t, b.Update(l, 23))
+	assert.True(t, b.Update(l, 24))
+	assert.True(t, b.Update(l, 25))
+	assert.True(t, b.Update(l, 26))
+	assert.True(t, b.Update(l, 27))
+	assert.True(t, b.Update(l, 28))
+	assert.True(t, b.Update(l, 29))
 	assert.Equal(t, int64(20), b.lostCounter.Count())
 	assert.Equal(t, int64(0), b.dupeCounter.Count())
 	assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
@@ -153,56 +157,56 @@ func TestBitsLostCounter(t *testing.T) {
 	b.dupeCounter.Clear()
 	b.outOfWindowCounter.Clear()
 
-	assert.True(t, b.Update(0))
+	assert.True(t, b.Update(l, 0))
 	assert.Equal(t, int64(0), b.lostCounter.Count())
-	assert.True(t, b.Update(9))
+	assert.True(t, b.Update(l, 9))
 	assert.Equal(t, int64(0), b.lostCounter.Count())
 	// 10 will set 0 index, 0 was already set, no lost packets
-	assert.True(t, b.Update(10))
+	assert.True(t, b.Update(l, 10))
 	assert.Equal(t, int64(0), b.lostCounter.Count())
 	// 11 will set 1 index, 1 was missed, we should see 1 packet lost
-	assert.True(t, b.Update(11))
+	assert.True(t, b.Update(l, 11))
 	assert.Equal(t, int64(1), b.lostCounter.Count())
 	// Now let's fill in the window, should end up with 8 lost packets
-	assert.True(t, b.Update(12))
-	assert.True(t, b.Update(13))
-	assert.True(t, b.Update(14))
-	assert.True(t, b.Update(15))
-	assert.True(t, b.Update(16))
-	assert.True(t, b.Update(17))
-	assert.True(t, b.Update(18))
-	assert.True(t, b.Update(19))
+	assert.True(t, b.Update(l, 12))
+	assert.True(t, b.Update(l, 13))
+	assert.True(t, b.Update(l, 14))
+	assert.True(t, b.Update(l, 15))
+	assert.True(t, b.Update(l, 16))
+	assert.True(t, b.Update(l, 17))
+	assert.True(t, b.Update(l, 18))
+	assert.True(t, b.Update(l, 19))
 	assert.Equal(t, int64(8), b.lostCounter.Count())
 
 	// Jump ahead by a window size
-	assert.True(t, b.Update(29))
+	assert.True(t, b.Update(l, 29))
 	assert.Equal(t, int64(8), b.lostCounter.Count())
 	// Now lets walk ahead normally through the window, the missed packets should fill in
-	assert.True(t, b.Update(30))
-	assert.True(t, b.Update(31))
-	assert.True(t, b.Update(32))
-	assert.True(t, b.Update(33))
-	assert.True(t, b.Update(34))
-	assert.True(t, b.Update(35))
-	assert.True(t, b.Update(36))
-	assert.True(t, b.Update(37))
-	assert.True(t, b.Update(38))
+	assert.True(t, b.Update(l, 30))
+	assert.True(t, b.Update(l, 31))
+	assert.True(t, b.Update(l, 32))
+	assert.True(t, b.Update(l, 33))
+	assert.True(t, b.Update(l, 34))
+	assert.True(t, b.Update(l, 35))
+	assert.True(t, b.Update(l, 36))
+	assert.True(t, b.Update(l, 37))
+	assert.True(t, b.Update(l, 38))
 	// 39 packets tracked, 22 seen, 17 lost
 	assert.Equal(t, int64(17), b.lostCounter.Count())
 
 	// Jump ahead by 2 windows, should have recording 1 full window missing
-	assert.True(t, b.Update(58))
+	assert.True(t, b.Update(l, 58))
 	assert.Equal(t, int64(27), b.lostCounter.Count())
 	// Now lets walk ahead normally through the window, the missed packets should fill in from this window
-	assert.True(t, b.Update(59))
-	assert.True(t, b.Update(60))
-	assert.True(t, b.Update(61))
-	assert.True(t, b.Update(62))
-	assert.True(t, b.Update(63))
-	assert.True(t, b.Update(64))
-	assert.True(t, b.Update(65))
-	assert.True(t, b.Update(66))
-	assert.True(t, b.Update(67))
+	assert.True(t, b.Update(l, 59))
+	assert.True(t, b.Update(l, 60))
+	assert.True(t, b.Update(l, 61))
+	assert.True(t, b.Update(l, 62))
+	assert.True(t, b.Update(l, 63))
+	assert.True(t, b.Update(l, 64))
+	assert.True(t, b.Update(l, 65))
+	assert.True(t, b.Update(l, 66))
+	assert.True(t, b.Update(l, 67))
 	// 68 packets tracked, 32 seen, 36 missed
 	assert.Equal(t, int64(36), b.lostCounter.Count())
 	assert.Equal(t, int64(0), b.dupeCounter.Count())

+ 2 - 1
cert.go

@@ -7,6 +7,7 @@ import (
 	"strings"
 	"time"
 
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 )
 
@@ -119,7 +120,7 @@ func NewCertStateFromConfig(c *Config) (*CertState, error) {
 	return NewCertState(nebulaCert, rawKey)
 }
 
-func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) {
+func loadCAFromConfig(l *logrus.Logger, c *Config) (*cert.NebulaCAPool, error) {
 	var rawCA []byte
 	var err error
 

+ 4 - 3
cmd/nebula-service/main.go

@@ -46,15 +46,16 @@ func main() {
 		os.Exit(1)
 	}
 
-	config := nebula.NewConfig()
+	l := logrus.New()
+	l.Out = os.Stdout
+
+	config := nebula.NewConfig(l)
 	err := config.Load(*configPath)
 	if err != nil {
 		fmt.Printf("failed to load config: %s", err)
 		os.Exit(1)
 	}
 
-	l := logrus.New()
-	l.Out = os.Stdout
 	c, err := nebula.Main(config, *configTest, Build, l, nil)
 
 	switch v := err.(type) {

+ 4 - 3
cmd/nebula-service/service.go

@@ -24,14 +24,15 @@ func (p *program) Start(s service.Service) error {
 	// Start should not block.
 	logger.Info("Nebula service starting.")
 
-	config := nebula.NewConfig()
+	l := logrus.New()
+	l.Out = os.Stdout
+
+	config := nebula.NewConfig(l)
 	err := config.Load(*p.configPath)
 	if err != nil {
 		return fmt.Errorf("failed to load config: %s", err)
 	}
 
-	l := logrus.New()
-	l.Out = os.Stdout
 	p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
 	if err != nil {
 		return err

+ 4 - 3
cmd/nebula/main.go

@@ -40,15 +40,16 @@ func main() {
 		os.Exit(1)
 	}
 
-	config := nebula.NewConfig()
+	l := logrus.New()
+	l.Out = os.Stdout
+
+	config := nebula.NewConfig(l)
 	err := config.Load(*configPath)
 	if err != nil {
 		fmt.Printf("failed to load config: %s", err)
 		os.Exit(1)
 	}
 
-	l := logrus.New()
-	l.Out = os.Stdout
 	c, err := nebula.Main(config, *configTest, Build, l, nil)
 
 	switch v := err.(type) {

+ 10 - 8
config.go

@@ -26,11 +26,13 @@ type Config struct {
 	Settings    map[interface{}]interface{}
 	oldSettings map[interface{}]interface{}
 	callbacks   []func(*Config)
+	l           *logrus.Logger
 }
 
-func NewConfig() *Config {
+func NewConfig(l *logrus.Logger) *Config {
 	return &Config{
 		Settings: make(map[interface{}]interface{}),
+		l:        l,
 	}
 }
 
@@ -99,12 +101,12 @@ func (c *Config) HasChanged(k string) bool {
 
 	newVals, err := yaml.Marshal(nv)
 	if err != nil {
-		l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
+		c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
 	}
 
 	oldVals, err := yaml.Marshal(ov)
 	if err != nil {
-		l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
+		c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
 	}
 
 	return string(newVals) != string(oldVals)
@@ -118,7 +120,7 @@ func (c *Config) CatchHUP() {
 
 	go func() {
 		for range ch {
-			l.Info("Caught HUP, reloading config")
+			c.l.Info("Caught HUP, reloading config")
 			c.ReloadConfig()
 		}
 	}()
@@ -132,7 +134,7 @@ func (c *Config) ReloadConfig() {
 
 	err := c.Load(c.path)
 	if err != nil {
-		l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
+		c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
 		return
 	}
 
@@ -500,7 +502,7 @@ func configLogger(c *Config) error {
 	if err != nil {
 		return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
 	}
-	l.SetLevel(logLevel)
+	c.l.SetLevel(logLevel)
 
 	disableTimestamp := c.GetBool("logging.disable_timestamp", false)
 	timestampFormat := c.GetString("logging.timestamp_format", "")
@@ -512,13 +514,13 @@ func configLogger(c *Config) error {
 	logFormat := strings.ToLower(c.GetString("logging.format", "text"))
 	switch logFormat {
 	case "text":
-		l.Formatter = &logrus.TextFormatter{
+		c.l.Formatter = &logrus.TextFormatter{
 			TimestampFormat:  timestampFormat,
 			FullTimestamp:    fullTimestamp,
 			DisableTimestamp: disableTimestamp,
 		}
 	case "json":
-		l.Formatter = &logrus.JSONFormatter{
+		c.l.Formatter = &logrus.JSONFormatter{
 			TimestampFormat:  timestampFormat,
 			DisableTimestamp: disableTimestamp,
 		}

+ 17 - 10
config_test.go

@@ -11,14 +11,15 @@ import (
 )
 
 func TestConfig_Load(t *testing.T) {
+	l := NewTestLogger()
 	dir, err := ioutil.TempDir("", "config-test")
 	// invalid yaml
-	c := NewConfig()
+	c := NewConfig(l)
 	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
 	assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n  line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
 
 	// simple multi config merge
-	c = NewConfig()
+	c = NewConfig(l)
 	os.RemoveAll(dir)
 	os.Mkdir(dir, 0755)
 
@@ -40,8 +41,9 @@ func TestConfig_Load(t *testing.T) {
 }
 
 func TestConfig_Get(t *testing.T) {
+	l := NewTestLogger()
 	// test simple type
-	c := NewConfig()
+	c := NewConfig(l)
 	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
 	assert.Equal(t, "hi", c.Get("firewall.outbound"))
 
@@ -55,13 +57,15 @@ func TestConfig_Get(t *testing.T) {
 }
 
 func TestConfig_GetStringSlice(t *testing.T) {
-	c := NewConfig()
+	l := NewTestLogger()
+	c := NewConfig(l)
 	c.Settings["slice"] = []interface{}{"one", "two"}
 	assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
 }
 
 func TestConfig_GetBool(t *testing.T) {
-	c := NewConfig()
+	l := NewTestLogger()
+	c := NewConfig(l)
 	c.Settings["bool"] = true
 	assert.Equal(t, true, c.GetBool("bool", false))
 
@@ -88,7 +92,8 @@ func TestConfig_GetBool(t *testing.T) {
 }
 
 func TestConfig_GetAllowList(t *testing.T) {
-	c := NewConfig()
+	l := NewTestLogger()
+	c := NewConfig(l)
 	c.Settings["allowlist"] = map[interface{}]interface{}{
 		"192.168.0.0": true,
 	}
@@ -181,20 +186,21 @@ func TestConfig_GetAllowList(t *testing.T) {
 }
 
 func TestConfig_HasChanged(t *testing.T) {
+	l := NewTestLogger()
 	// No reload has occurred, return false
-	c := NewConfig()
+	c := NewConfig(l)
 	c.Settings["test"] = "hi"
 	assert.False(t, c.HasChanged(""))
 
 	// Test key change
-	c = NewConfig()
+	c = NewConfig(l)
 	c.Settings["test"] = "hi"
 	c.oldSettings = map[interface{}]interface{}{"test": "no"}
 	assert.True(t, c.HasChanged("test"))
 	assert.True(t, c.HasChanged(""))
 
 	// No key change
-	c = NewConfig()
+	c = NewConfig(l)
 	c.Settings["test"] = "hi"
 	c.oldSettings = map[interface{}]interface{}{"test": "hi"}
 	assert.False(t, c.HasChanged("test"))
@@ -202,12 +208,13 @@ func TestConfig_HasChanged(t *testing.T) {
 }
 
 func TestConfig_ReloadConfig(t *testing.T) {
+	l := NewTestLogger()
 	done := make(chan bool, 1)
 	dir, err := ioutil.TempDir("", "config-test")
 	assert.Nil(t, err)
 	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
 
-	c := NewConfig()
+	c := NewConfig(l)
 	assert.Nil(t, c.Load(dir))
 
 	assert.False(t, c.HasChanged("outer.inner"))

+ 11 - 9
connection_manager.go

@@ -28,10 +28,11 @@ type connectionManager struct {
 	checkInterval           int
 	pendingDeletionInterval int
 
+	l *logrus.Logger
 	// I wanted to call one matLock
 }
 
-func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
+func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
 	nc := &connectionManager{
 		hostMap:                 intf.hostMap,
 		in:                      make(map[uint32]struct{}),
@@ -47,6 +48,7 @@ func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterva
 		pendingDeletionTimer:    NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
 		checkInterval:           checkInterval,
 		pendingDeletionInterval: pendingDeletionInterval,
+		l:                       l,
 	}
 	nc.Start()
 	return nc
@@ -166,8 +168,8 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 
 		// If we saw incoming packets from this ip, just return
 		if traf {
-			if l.Level >= logrus.DebugLevel {
-				l.WithField("vpnIp", IntIp(vpnIP)).
+			if n.l.Level >= logrus.DebugLevel {
+				n.l.WithField("vpnIp", IntIp(vpnIP)).
 					WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
 					Debug("Tunnel status")
 			}
@@ -179,13 +181,13 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 		// If we didn't we may need to probe or destroy the conn
 		hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
 		if err != nil {
-			l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
+			n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
 			n.ClearIP(vpnIP)
 			n.ClearPendingDeletion(vpnIP)
 			continue
 		}
 
-		hostinfo.logger().
+		hostinfo.logger(n.l).
 			WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
 			Debug("Tunnel status")
 
@@ -194,7 +196,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 			n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out)
 
 		} else {
-			hostinfo.logger().Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
+			hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
 		}
 		n.AddPendingDeletion(vpnIP)
 	}
@@ -214,7 +216,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
 		// If we saw incoming packets from this ip, just return
 		traf := n.CheckIn(vpnIP)
 		if traf {
-			l.WithField("vpnIp", IntIp(vpnIP)).
+			n.l.WithField("vpnIp", IntIp(vpnIP)).
 				WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
 				Debug("Tunnel status")
 			n.ClearIP(vpnIP)
@@ -226,7 +228,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
 		if err != nil {
 			n.ClearIP(vpnIP)
 			n.ClearPendingDeletion(vpnIP)
-			l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
+			n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
 			continue
 		}
 
@@ -236,7 +238,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
 			if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
 				cn = hostinfo.ConnectionState.peerCert.Details.Name
 			}
-			hostinfo.logger().
+			hostinfo.logger(n.l).
 				WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
 				WithField("certName", cn).
 				Info("Tunnel status")

+ 12 - 8
connection_manager_test.go

@@ -13,6 +13,7 @@ import (
 var vpnIP uint32
 
 func Test_NewConnectionManagerTest(t *testing.T) {
+	l := NewTestLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@@ -20,7 +21,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	preferredRanges := []*net.IPNet{localrange}
 
 	// Very incomplete mock objects
-	hostMap := NewHostMap("test", vpncidr, preferredRanges)
+	hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
 	cs := &CertState{
 		rawCertificate:      []byte{},
 		privateKey:          []byte{},
@@ -28,7 +29,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		rawCertificateNoKey: []byte{},
 	}
 
-	lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
+	lh := NewLightHouse(l, false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
 	ifce := &Interface{
 		hostMap:          hostMap,
 		inside:           &Tun{},
@@ -36,12 +37,13 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		certState:        cs,
 		firewall:         &Firewall{},
 		lightHouse:       lh,
-		handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
+		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
+		l:                l,
 	}
 	now := time.Now()
 
 	// Create manager
-	nc := newConnectionManager(ifce, 5, 10)
+	nc := newConnectionManager(l, ifce, 5, 10)
 	p := []byte("")
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
@@ -79,13 +81,14 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 }
 
 func Test_NewConnectionManagerTest2(t *testing.T) {
+	l := NewTestLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	preferredRanges := []*net.IPNet{localrange}
 
 	// Very incomplete mock objects
-	hostMap := NewHostMap("test", vpncidr, preferredRanges)
+	hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
 	cs := &CertState{
 		rawCertificate:      []byte{},
 		privateKey:          []byte{},
@@ -93,7 +96,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		rawCertificateNoKey: []byte{},
 	}
 
-	lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
+	lh := NewLightHouse(l, false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
 	ifce := &Interface{
 		hostMap:          hostMap,
 		inside:           &Tun{},
@@ -101,12 +104,13 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		certState:        cs,
 		firewall:         &Firewall{},
 		lightHouse:       lh,
-		handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
+		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
+		l:                l,
 	}
 	now := time.Now()
 
 	// Create manager
-	nc := newConnectionManager(ifce, 5, 10)
+	nc := newConnectionManager(l, ifce, 5, 10)
 	p := []byte("")
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)

+ 3 - 2
connection_state.go

@@ -7,6 +7,7 @@ import (
 	"sync/atomic"
 
 	"github.com/flynn/noise"
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 )
 
@@ -26,7 +27,7 @@ type ConnectionState struct {
 	ready                bool
 }
 
-func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
+func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
 	cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256)
 	if f.cipher == "chachapoly" {
 		cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
@@ -37,7 +38,7 @@ func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePa
 
 	b := NewBits(ReplayWindow)
 	// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
-	b.Update(0)
+	b.Update(l, 0)
 
 	hs, err := noise.NewHandshakeState(noise.Config{
 		CipherSuite:           cs,

+ 2 - 1
control_test.go

@@ -13,9 +13,10 @@ import (
 )
 
 func TestControl_GetHostInfoByVpnIP(t *testing.T) {
+	l := NewTestLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
-	hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
+	hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
 	remote1 := NewUDPAddr(int2ip(100), 4444)
 	remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
 	ipNet := net.IPNet{

+ 16 - 11
dns_server.go

@@ -7,6 +7,7 @@ import (
 	"sync"
 
 	"github.com/miekg/dns"
+	"github.com/sirupsen/logrus"
 )
 
 // This whole thing should be rewritten to use context
@@ -63,7 +64,7 @@ func (d *dnsRecords) Add(host, data string) {
 	d.Unlock()
 }
 
-func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
+func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
 	for _, q := range m.Question {
 		switch q.Qtype {
 		case dns.TypeA:
@@ -95,34 +96,38 @@ func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
 	}
 }
 
-func handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
+func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
 	m := new(dns.Msg)
 	m.SetReply(r)
 	m.Compress = false
 
 	switch r.Opcode {
 	case dns.OpcodeQuery:
-		parseQuery(m, w)
+		parseQuery(l, m, w)
 	}
 
 	w.WriteMsg(m)
 }
 
-func dnsMain(hostMap *HostMap, c *Config) {
+func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) {
 	dnsR = newDnsRecords(hostMap)
 
 	// attach request handler func
-	dns.HandleFunc(".", handleDnsRequest)
-
-	c.RegisterReloadCallback(reloadDns)
-	startDns(c)
+	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
+		handleDnsRequest(l, w, r)
+	})
+
+	c.RegisterReloadCallback(func(c *Config) {
+		reloadDns(l, c)
+	})
+	startDns(l, c)
 }
 
 func getDnsServerAddr(c *Config) string {
 	return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
 }
 
-func startDns(c *Config) {
+func startDns(l *logrus.Logger, c *Config) {
 	dnsAddr = getDnsServerAddr(c)
 	dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
 	l.Debugf("Starting DNS responder at %s\n", dnsAddr)
@@ -133,7 +138,7 @@ func startDns(c *Config) {
 	}
 }
 
-func reloadDns(c *Config) {
+func reloadDns(l *logrus.Logger, c *Config) {
 	if dnsAddr == getDnsServerAddr(c) {
 		l.Debug("No DNS server config change detected")
 		return
@@ -141,5 +146,5 @@ func reloadDns(c *Config) {
 
 	l.Debug("Restarting DNS server")
 	dnsServer.Shutdown()
-	go startDns(c)
+	go startDns(l, c)
 }

+ 17 - 14
firewall.go

@@ -70,6 +70,7 @@ type Firewall struct {
 
 	trackTCPRTT  bool
 	metricTCPRTT metrics.Histogram
+	l            *logrus.Logger
 }
 
 type FirewallConntrack struct {
@@ -156,7 +157,7 @@ func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
 }
 
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
-func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
+func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
 	//TODO: error on 0 duration
 	var min, max time.Duration
 
@@ -195,11 +196,13 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
 		DefaultTimeout: defaultTimeout,
 		localIps:       localIps,
 		metricTCPRTT:   metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
+		l:              l,
 	}
 }
 
-func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
+func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
 	fw := NewFirewall(
+		l,
 		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
 		c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
 		c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
@@ -207,12 +210,12 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
 		//TODO: max_connections
 	)
 
-	err := AddFirewallRulesFromConfig(false, c, fw)
+	err := AddFirewallRulesFromConfig(l, false, c, fw)
 	if err != nil {
 		return nil, err
 	}
 
-	err = AddFirewallRulesFromConfig(true, c, fw)
+	err = AddFirewallRulesFromConfig(l, true, c, fw)
 	if err != nil {
 		return nil, err
 	}
@@ -240,7 +243,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 	if !incoming {
 		direction = "outgoing"
 	}
-	l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
+	f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
 		Info("Firewall rule added")
 
 	var (
@@ -276,7 +279,7 @@ func (f *Firewall) GetRuleHash() string {
 	return hex.EncodeToString(sum[:])
 }
 
-func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterface) error {
+func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
 	var table string
 	if inbound {
 		table = "firewall.inbound"
@@ -296,7 +299,7 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
 
 	for i, t := range rs {
 		var groups []string
-		r, err := convertRule(t, table, i)
+		r, err := convertRule(l, t, table, i)
 		if err != nil {
 			return fmt.Errorf("%s rule #%v; %s", table, i, err)
 		}
@@ -459,8 +462,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
 
 		// We now know which firewall table to check against
 		if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
-			if l.Level >= logrus.DebugLevel {
-				h.logger().
+			if f.l.Level >= logrus.DebugLevel {
+				h.logger(f.l).
 					WithField("fwPacket", fp).
 					WithField("incoming", c.incoming).
 					WithField("rulesVersion", f.rulesVersion).
@@ -472,8 +475,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
 			return false
 		}
 
-		if l.Level >= logrus.DebugLevel {
-			h.logger().
+		if f.l.Level >= logrus.DebugLevel {
+			h.logger(f.l).
 				WithField("fwPacket", fp).
 				WithField("incoming", c.incoming).
 				WithField("rulesVersion", f.rulesVersion).
@@ -795,7 +798,7 @@ type rule struct {
 	CASha  string
 }
 
-func convertRule(p interface{}, table string, i int) (rule, error) {
+func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
 	r := rule{}
 
 	m, ok := p.(map[interface{}]interface{})
@@ -968,14 +971,14 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
 
 // Get checks if the cache ticker has moved to the next version before returning
 // the map. If it has moved, we reset the map.
-func (c *ConntrackCacheTicker) Get() ConntrackCache {
+func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
 	if c == nil {
 		return nil
 	}
 	if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
 		c.cacheV = tick
 		if ll := len(c.cache); ll > 0 {
-			if l.GetLevel() == logrus.DebugLevel {
+			if l.Level == logrus.DebugLevel {
 				l.WithField("len", ll).Debug("resetting conntrack cache")
 			}
 			c.cache = make(ConntrackCache, ll)

+ 73 - 76
firewall_test.go

@@ -15,8 +15,9 @@ import (
 )
 
 func TestNewFirewall(t *testing.T) {
+	l := NewTestLogger()
 	c := &cert.NebulaCertificate{}
-	fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	conntrack := fw.Conntrack
 	assert.NotNil(t, conntrack)
 	assert.NotNil(t, conntrack.Conns)
@@ -31,35 +32,34 @@ func TestNewFirewall(t *testing.T) {
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
-	fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
+	fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
-	fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
+	fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
-	fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
+	fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
-	fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
+	fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
-	fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
+	fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 }
 
 func TestFirewall_AddRule(t *testing.T) {
+	l := NewTestLogger()
 	ob := &bytes.Buffer{}
-	out := l.Out
 	l.SetOutput(ob)
-	defer l.SetOutput(out)
 
 	c := &cert.NebulaCertificate{}
-	fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.NotNil(t, fw.InRules)
 	assert.NotNil(t, fw.OutRules)
 
@@ -74,7 +74,7 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
 	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
 
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
 	assert.False(t, fw.InRules.UDP[1].Any.Any)
 	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
@@ -83,7 +83,7 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
 	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
 
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
 	assert.False(t, fw.InRules.ICMP[1].Any.Any)
 	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
@@ -92,23 +92,23 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
 	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
 
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
 	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
 
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
 
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
 	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
 
 	// Set any and clear fields
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
 	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
 	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
@@ -125,26 +125,25 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
 	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
 
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
 	// Test error conditions
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
 	assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
 }
 
 func TestFirewall_Drop(t *testing.T) {
+	l := NewTestLogger()
 	ob := &bytes.Buffer{}
-	out := l.Out
 	l.SetOutput(ob)
-	defer l.SetOutput(out)
 
 	p := FirewallPacket{
 		ip2int(net.IPv4(1, 2, 3, 4)),
@@ -177,7 +176,7 @@ func TestFirewall_Drop(t *testing.T) {
 	}
 	h.CreateRemoteCIDR(&c)
 
-	fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
 	cp := cert.NewCAPool()
 
@@ -196,27 +195,27 @@ func TestFirewall_Drop(t *testing.T) {
 	p.RemoteIP = oldRemote
 
 	// ensure signer doesn't get in the way of group checks
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caSha doesn't drop on match
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 
 	// ensure ca name doesn't get in the way of group checks
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caName doesn't drop on match
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
@@ -317,10 +316,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 }
 
 func TestFirewall_Drop2(t *testing.T) {
+	l := NewTestLogger()
 	ob := &bytes.Buffer{}
-	out := l.Out
 	l.SetOutput(ob)
-	defer l.SetOutput(out)
 
 	p := FirewallPacket{
 		ip2int(net.IPv4(1, 2, 3, 4)),
@@ -365,7 +363,7 @@ func TestFirewall_Drop2(t *testing.T) {
 	}
 	h1.CreateRemoteCIDR(&c1)
 
-	fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
 	cp := cert.NewCAPool()
 
@@ -377,10 +375,9 @@ func TestFirewall_Drop2(t *testing.T) {
 }
 
 func TestFirewall_Drop3(t *testing.T) {
+	l := NewTestLogger()
 	ob := &bytes.Buffer{}
-	out := l.Out
 	l.SetOutput(ob)
-	defer l.SetOutput(out)
 
 	p := FirewallPacket{
 		ip2int(net.IPv4(1, 2, 3, 4)),
@@ -448,7 +445,7 @@ func TestFirewall_Drop3(t *testing.T) {
 	}
 	h3.CreateRemoteCIDR(&c3)
 
-	fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
 	cp := cert.NewCAPool()
@@ -464,10 +461,9 @@ func TestFirewall_Drop3(t *testing.T) {
 }
 
 func TestFirewall_DropConntrackReload(t *testing.T) {
+	l := NewTestLogger()
 	ob := &bytes.Buffer{}
-	out := l.Out
 	l.SetOutput(ob)
-	defer l.SetOutput(out)
 
 	p := FirewallPacket{
 		ip2int(net.IPv4(1, 2, 3, 4)),
@@ -500,7 +496,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	}
 	h.CreateRemoteCIDR(&c)
 
-	fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
 	cp := cert.NewCAPool()
 
@@ -513,7 +509,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
 
 	oldFw := fw
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
@@ -522,7 +518,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
 
 	oldFw = fw
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
@@ -647,124 +643,126 @@ func Test_parsePort(t *testing.T) {
 }
 
 func TestNewFirewallFromConfig(t *testing.T) {
+	l := NewTestLogger()
 	// Test a bad rule definition
 	c := &cert.NebulaCertificate{}
-	conf := NewConfig()
+	conf := NewConfig(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
-	_, err := NewFirewallFromConfig(c, conf)
+	_, err := NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
 
 	// Test both port and code
-	conf = NewConfig()
+	conf = NewConfig(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
-	_, err = NewFirewallFromConfig(c, conf)
+	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
 
 	// Test missing host, group, cidr, ca_name and ca_sha
-	conf = NewConfig()
+	conf = NewConfig(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
-	_, err = NewFirewallFromConfig(c, conf)
+	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
 
 	// Test code/port error
-	conf = NewConfig()
+	conf = NewConfig(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(c, conf)
+	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
 
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(c, conf)
+	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
 
 	// Test proto error
-	conf = NewConfig()
+	conf = NewConfig(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(c, conf)
+	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
 
 	// Test cidr parse error
-	conf = NewConfig()
+	conf = NewConfig(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
-	_, err = NewFirewallFromConfig(c, conf)
+	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
 
 	// Test both group and groups
-	conf = NewConfig()
+	conf = NewConfig(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
-	_, err = NewFirewallFromConfig(c, conf)
+	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
 }
 
 func TestAddFirewallRulesFromConfig(t *testing.T) {
+	l := NewTestLogger()
 	// Test adding tcp rule
-	conf := NewConfig()
+	conf := NewConfig(l)
 	mf := &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
-	assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
+	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
 
 	// Test adding udp rule
-	conf = NewConfig()
+	conf = NewConfig(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
-	assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
+	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
 
 	// Test adding icmp rule
-	conf = NewConfig()
+	conf = NewConfig(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
-	assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
+	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
 
 	// Test adding any rule
-	conf = NewConfig()
+	conf = NewConfig(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
-	assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
+	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
 
 	// Test adding rule with ca_sha
-	conf = NewConfig()
+	conf = NewConfig(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
-	assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
+	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
 
 	// Test adding rule with ca_name
-	conf = NewConfig()
+	conf = NewConfig(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
-	assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
+	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
 
 	// Test single group
-	conf = NewConfig()
+	conf = NewConfig(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
-	assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
+	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
 
 	// Test single groups
-	conf = NewConfig()
+	conf = NewConfig(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
-	assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
+	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
 
 	// Test multiple AND groups
-	conf = NewConfig()
+	conf = NewConfig(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
-	assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
+	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
 
 	// Test Add error
-	conf = NewConfig()
+	conf = NewConfig(l)
 	mf = &mockFirewall{}
 	mf.nextCallReturn = errors.New("test error")
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
-	assert.EqualError(t, AddFirewallRulesFromConfig(true, conf, mf), "firewall.inbound rule #0; `test error`")
+	assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
 }
 
 func TestTCPRTTTracking(t *testing.T) {
@@ -859,17 +857,16 @@ func TestTCPRTTTracking(t *testing.T) {
 }
 
 func TestFirewall_convertRule(t *testing.T) {
+	l := NewTestLogger()
 	ob := &bytes.Buffer{}
-	out := l.Out
 	l.SetOutput(ob)
-	defer l.SetOutput(out)
 
 	// Ensure group array of 1 is converted and a warning is printed
 	c := map[interface{}]interface{}{
 		"group": []interface{}{"group1"},
 	}
 
-	r, err := convertRule(c, "test", 1)
+	r, err := convertRule(l, c, "test", 1)
 	assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
 	assert.Nil(t, err)
 	assert.Equal(t, "group1", r.Group)
@@ -880,7 +877,7 @@ func TestFirewall_convertRule(t *testing.T) {
 		"group": []interface{}{"group1", "group2"},
 	}
 
-	r, err = convertRule(c, "test", 1)
+	r, err = convertRule(l, c, "test", 1)
 	assert.Equal(t, "", ob.String())
 	assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
 
@@ -890,7 +887,7 @@ func TestFirewall_convertRule(t *testing.T) {
 		"group": "group1",
 	}
 
-	r, err = convertRule(c, "test", 1)
+	r, err = convertRule(l, c, "test", 1)
 	assert.Nil(t, err)
 	assert.Equal(t, "group1", r.Group)
 }

+ 1 - 1
handshake.go

@@ -7,7 +7,7 @@ const (
 
 func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
 	if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
-		l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+		f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 		return
 	}
 

+ 33 - 33
handshake_ix.go

@@ -27,7 +27,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 
 	err := f.handshakeManager.AddIndexHostInfo(hostinfo)
 	if err != nil {
-		l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
+		f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
 		return
 	}
@@ -48,7 +48,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 	hsBytes, err = proto.Marshal(hs)
 
 	if err != nil {
-		l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
+		f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 		return
 	}
@@ -58,14 +58,14 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 
 	msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
 	if err != nil {
-		l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
+		f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return
 	}
 
 	// We are sending handshake packet 1, so we don't expect to receive
 	// handshake packet 1 from the responder
-	ci.window.Update(1)
+	ci.window.Update(f.l, 1)
 
 	hostinfo.HandshakePacket[0] = msg
 	hostinfo.HandshakeReady = true
@@ -74,13 +74,13 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 }
 
 func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
-	ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0)
+	ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
 	// Mark packet 1 as seen so it doesn't show up as missed
-	ci.window.Update(1)
+	ci.window.Update(f.l, 1)
 
 	msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
 	if err != nil {
-		l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
 		return
 	}
@@ -91,14 +91,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
 	*/
 	if err != nil || hs.Details == nil {
-		l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 		return
 	}
 
 	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
 	if err != nil {
-		l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
 			Info("Invalid certificate from host")
 		return
@@ -108,16 +108,16 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	fingerprint, _ := remoteCert.Sha256Sum()
 
 	if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) {
-		l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
 		return
 	}
 
-	myIndex, err := generateIndex()
+	myIndex, err := generateIndex(f.l)
 	if err != nil {
-		l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
@@ -133,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		HandshakePacket: make(map[uint8][]byte, 0),
 	}
 
-	l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+	f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -145,7 +145,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 
 	hsBytes, err := proto.Marshal(hs)
 	if err != nil {
-		l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
@@ -155,13 +155,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
 	msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
 	if err != nil {
-		l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return
 	} else if dKey == nil || eKey == nil {
-		l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
@@ -178,7 +178,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 
 	// We are sending handshake packet 2, so we don't expect to receive
 	// handshake packet 2 from the initiator.
-	ci.window.Update(2)
+	ci.window.Update(f.l, 2)
 
 	ci.peerCert = remoteCert
 	ci.dKey = NewNebulaCipherState(dKey)
@@ -203,11 +203,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 			f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
 			err := f.outside.WriteTo(msg, addr)
 			if err != nil {
-				l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
+				f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					WithError(err).Error("Failed to send handshake message")
 			} else {
-				l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
+				f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					Info("Handshake message sent")
 			}
@@ -215,7 +215,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		case ErrExistingHostInfo:
 			// This means there was an existing tunnel and we didn't win
 			// handshake avoidance
-			l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -227,7 +227,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 			return
 		case ErrLocalIndexCollision:
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
-			l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -238,7 +238,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
-			l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+			f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -252,14 +252,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
 	err = f.outside.WriteTo(msg, addr)
 	if err != nil {
-		l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithError(err).Error("Failed to send handshake")
 	} else {
-		l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -267,7 +267,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 			Info("Handshake message sent")
 	}
 
-	hostinfo.handshakeComplete()
+	hostinfo.handshakeComplete(f.l)
 
 	return
 }
@@ -280,7 +280,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 	defer hostinfo.Unlock()
 
 	if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) {
-		l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Info("Already seen this handshake packet")
 		return false
@@ -288,14 +288,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 
 	ci := hostinfo.ConnectionState
 	// Mark packet 2 as seen so it doesn't show up as missed
-	ci.window.Update(2)
+	ci.window.Update(f.l, 2)
 
 	hostinfo.HandshakePacket[2] = make([]byte, len(packet[HeaderLen:]))
 	copy(hostinfo.HandshakePacket[2], packet[HeaderLen:])
 
 	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
 	if err != nil {
-		l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Error("Failed to call noise.ReadMessage")
 
@@ -304,7 +304,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		// near future
 		return false
 	} else if dKey == nil || eKey == nil {
-		l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Error("Noise did not arrive at a key")
 		return true
@@ -313,14 +313,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 	hs := &NebulaHandshake{}
 	err = proto.Unmarshal(msg, hs)
 	if err != nil || hs.Details == nil {
-		l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 		return true
 	}
 
 	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
 	if err != nil {
-		l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
 			WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Error("Invalid certificate from host")
 		return true
@@ -330,7 +330,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 	fingerprint, _ := remoteCert.Sha256Sum()
 
 	duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
-	l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+	f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -362,7 +362,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 	hostinfo.CreateRemoteCIDR(remoteCert)
 
 	f.handshakeManager.Complete(hostinfo, f)
-	hostinfo.handshakeComplete()
+	hostinfo.handshakeComplete(f.l)
 	f.metricHandshakes.Update(duration)
 
 	return false

+ 11 - 9
handshake_manager.go

@@ -53,11 +53,12 @@ type HandshakeManager struct {
 	InboundHandshakeTimer  *SystemTimerWheel
 
 	messageMetrics *MessageMetrics
+	l              *logrus.Logger
 }
 
-func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
+func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
 	return &HandshakeManager{
-		pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges),
+		pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
 		mainHostMap:    mainHostMap,
 		lightHouse:     lightHouse,
 		outside:        outside,
@@ -70,6 +71,7 @@ func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainH
 		InboundHandshakeTimer:  NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
 
 		messageMetrics: config.messageMetrics,
+		l:              l,
 	}
 }
 
@@ -78,7 +80,7 @@ func (c *HandshakeManager) Run(f EncWriter) {
 	for {
 		select {
 		case vpnIP := <-c.trigger:
-			l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
+			c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
 			c.handleOutbound(vpnIP, f, true)
 		case now := <-clockSource:
 			c.NextOutboundHandshakeTimerTick(now, f)
@@ -149,7 +151,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
 			c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
 			err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
 			if err != nil {
-				hostinfo.logger().WithField("udpAddr", hostinfo.remote).
+				hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
 					WithField("initiatorIndex", hostinfo.localIndexId).
 					WithField("remoteIndex", hostinfo.remoteIndexId).
 					WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@@ -157,7 +159,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
 			} else {
 				//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
 				// keep the real packet struct around for logging purposes
-				hostinfo.logger().WithField("udpAddr", hostinfo.remote).
+				hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
 					WithField("initiatorIndex", hostinfo.localIndexId).
 					WithField("remoteIndex", hostinfo.remoteIndexId).
 					WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@@ -245,7 +247,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 	if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
-		hostinfo.logger().
+		hostinfo.logger(c.l).
 			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
 			Info("New host shadows existing host remoteIndex")
 	}
@@ -280,7 +282,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 	if found && existingRemoteIndex != nil {
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
-		hostinfo.logger().
+		hostinfo.logger(c.l).
 			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
 			Info("New host shadows existing host remoteIndex")
 	}
@@ -298,7 +300,7 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
 	defer c.mainHostMap.RUnlock()
 
 	for i := 0; i < 32; i++ {
-		index, err := generateIndex()
+		index, err := generateIndex(c.l)
 		if err != nil {
 			return err
 		}
@@ -336,7 +338,7 @@ func (c *HandshakeManager) EmitStats() {
 
 // Utility functions below
 
-func generateIndex() (uint32, error) {
+func generateIndex(l *logrus.Logger) (uint32, error) {
 	b := make([]byte, 4)
 
 	// Let zero mean we don't know the ID, so don't generate zero

+ 15 - 11
handshake_manager_test.go

@@ -12,15 +12,15 @@ import (
 var ips []uint32
 
 func Test_NewHandshakeManagerIndex(t *testing.T) {
-
+	l := NewTestLogger()
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
 	preferredRanges := []*net.IPNet{localrange}
-	mainHM := NewHostMap("test", vpncidr, preferredRanges)
+	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
 
-	blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
+	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
 
 	now := time.Now()
 	blah.NextInboundHandshakeTimerTick(now)
@@ -63,15 +63,16 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
 }
 
 func Test_NewHandshakeManagerVpnIP(t *testing.T) {
+	l := NewTestLogger()
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
-	mainHM := NewHostMap("test", vpncidr, preferredRanges)
+	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
 
-	blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
+	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
 
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now, mw)
@@ -112,16 +113,17 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
 }
 
 func Test_NewHandshakeManagerTrigger(t *testing.T) {
+	l := NewTestLogger()
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	ip := ip2int(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
-	mainHM := NewHostMap("test", vpncidr, preferredRanges)
+	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
 	lh := &LightHouse{}
 
-	blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
+	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
 
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now, mw)
@@ -162,15 +164,16 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
 }
 
 func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
+	l := NewTestLogger()
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	vpnIP = ip2int(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
-	mainHM := NewHostMap("test", vpncidr, preferredRanges)
+	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
 
-	blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
+	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
 
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now, mw)
@@ -216,13 +219,14 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
 }
 
 func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
+	l := NewTestLogger()
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	preferredRanges := []*net.IPNet{localrange}
-	mainHM := NewHostMap("test", vpncidr, preferredRanges)
+	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
 
-	blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
+	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
 
 	now := time.Now()
 	blah.NextInboundHandshakeTimerTick(now)

+ 29 - 25
hostmap.go

@@ -33,6 +33,7 @@ type HostMap struct {
 	defaultRoute    uint32
 	unsafeRoutes    *CIDRTree
 	metricsEnabled  bool
+	l               *logrus.Logger
 }
 
 type HostInfo struct {
@@ -83,7 +84,7 @@ type Probe struct {
 	Counter int
 }
 
-func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
+func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
 	h := map[uint32]*HostInfo{}
 	i := map[uint32]*HostInfo{}
 	r := map[uint32]*HostInfo{}
@@ -96,6 +97,7 @@ func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *
 		vpnCIDR:         vpnCIDR,
 		defaultRoute:    0,
 		unsafeRoutes:    NewCIDRTree(),
+		l:               l,
 	}
 	return &m
 }
@@ -160,8 +162,8 @@ func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
 	}
 	hm.Unlock()
 
-	if l.Level >= logrus.DebugLevel {
-		l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
+	if hm.l.Level >= logrus.DebugLevel {
+		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
 			Debug("Hostmap vpnIp deleted")
 	}
 }
@@ -173,8 +175,8 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
 	hm.RemoteIndexes[index] = h
 	hm.Unlock()
 
-	if l.Level > logrus.DebugLevel {
-		l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
+	if hm.l.Level > logrus.DebugLevel {
+		hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
 			"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
 			Debug("Hostmap remoteIndex added")
 	}
@@ -188,8 +190,8 @@ func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
 	hm.RemoteIndexes[h.remoteIndexId] = h
 	hm.Unlock()
 
-	if l.Level > logrus.DebugLevel {
-		l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
+	if hm.l.Level > logrus.DebugLevel {
+		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
 			"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
 			Debug("Hostmap vpnIp added")
 	}
@@ -212,8 +214,8 @@ func (hm *HostMap) DeleteIndex(index uint32) {
 	}
 	hm.Unlock()
 
-	if l.Level >= logrus.DebugLevel {
-		l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
+	if hm.l.Level >= logrus.DebugLevel {
+		hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
 			Debug("Hostmap index deleted")
 	}
 }
@@ -236,8 +238,8 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
 	}
 	hm.Unlock()
 
-	if l.Level >= logrus.DebugLevel {
-		l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
+	if hm.l.Level >= logrus.DebugLevel {
+		hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
 			Debug("Hostmap remote index deleted")
 	}
 }
@@ -269,8 +271,8 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
 	}
 	hm.Unlock()
 
-	if l.Level >= logrus.DebugLevel {
-		l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
+	if hm.l.Level >= logrus.DebugLevel {
+		hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
 			"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Hostmap hostInfo deleted")
 	}
@@ -313,8 +315,10 @@ func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo {
 		}
 		i.remote = i.Remotes[0].addr
 		hm.Hosts[vpnIp] = i
-		l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
-			Debug("Hostmap remote ip added")
+		if hm.l.Level >= logrus.DebugLevel {
+			hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
+				Debug("Hostmap remote ip added")
+		}
 	}
 	i.ForcePromoteBest(hm.preferredRanges)
 	hm.Unlock()
@@ -377,8 +381,8 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 
-	if l.Level >= logrus.DebugLevel {
-		l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
+	if hm.l.Level >= logrus.DebugLevel {
+		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
 			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
 			Debug("Hostmap vpnIp added")
 	}
@@ -436,7 +440,7 @@ func (hm *HostMap) Punchy(conn *udpConn) {
 
 func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
 	for _, r := range *routes {
-		l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
+		hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
 		hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
 	}
 }
@@ -566,7 +570,7 @@ func (i *HostInfo) rotateRemote() {
 	i.remote = i.Remotes[0].addr
 }
 
-func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
+func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
 	//TODO: return the error so we can log with more context
 	if len(i.packetStore) < 100 {
 		tempPacket := make([]byte, len(packet))
@@ -574,14 +578,14 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
 		//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
 		i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
 		if l.Level >= logrus.DebugLevel {
-			i.logger().
+			i.logger(l).
 				WithField("length", len(i.packetStore)).
 				WithField("stored", true).
 				Debugf("Packet store")
 		}
 
 	} else if l.Level >= logrus.DebugLevel {
-		i.logger().
+		i.logger(l).
 			WithField("length", len(i.packetStore)).
 			WithField("stored", false).
 			Debugf("Packet store")
@@ -589,7 +593,7 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
 }
 
 // handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets
-func (i *HostInfo) handshakeComplete() {
+func (i *HostInfo) handshakeComplete(l *logrus.Logger) {
 	//TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because:
 	//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
 	//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
@@ -601,7 +605,7 @@ func (i *HostInfo) handshakeComplete() {
 	atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2)
 
 	if l.Level >= logrus.DebugLevel {
-		i.logger().Debugf("Sending %d stored packets", len(i.packetStore))
+		i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore))
 	}
 
 	if len(i.packetStore) > 0 {
@@ -689,7 +693,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
 	i.remoteCidr = remoteCidr
 }
 
-func (i *HostInfo) logger() *logrus.Entry {
+func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 	if i == nil {
 		return logrus.NewEntry(l)
 	}
@@ -804,7 +808,7 @@ func (d *HostInfoDest) ProbeReceived(probeCount int) {
 
 // Utility functions
 
-func localIps(allowList *AllowList) *[]net.IP {
+func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP {
 	//FIXME: This function is pretty garbage
 	var ips []net.IP
 	ifaces, _ := net.Interfaces()

+ 6 - 3
hostmap_test.go

@@ -64,12 +64,13 @@ func TestHostInfoDestProbe(t *testing.T) {
 */
 
 func TestHostmap(t *testing.T) {
+	l := NewTestLogger()
 	_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
 	_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
 	myNets := []*net.IPNet{myNet}
 	preferredRanges := []*net.IPNet{localToMe}
 
-	m := NewHostMap("test", myNet, preferredRanges)
+	m := NewHostMap(l, "test", myNet, preferredRanges)
 
 	a := NewUDPAddrFromString("10.127.0.3:11111")
 	b := NewUDPAddrFromString("1.0.0.1:22222")
@@ -103,10 +104,11 @@ func TestHostmap(t *testing.T) {
 }
 
 func TestHostmapdebug(t *testing.T) {
+	l := NewTestLogger()
 	_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
 	_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
 	preferredRanges := []*net.IPNet{localToMe}
-	m := NewHostMap("test", myNet, preferredRanges)
+	m := NewHostMap(l, "test", myNet, preferredRanges)
 
 	a := NewUDPAddrFromString("10.127.0.3:11111")
 	b := NewUDPAddrFromString("1.0.0.1:22222")
@@ -151,11 +153,12 @@ func TestHostMap_rotateRemote(t *testing.T) {
 }
 
 func BenchmarkHostmappromote2(b *testing.B) {
+	l := NewTestLogger()
 	for n := 0; n < b.N; n++ {
 		_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
 		_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
 		preferredRanges := []*net.IPNet{localToMe}
-		m := NewHostMap("test", myNet, preferredRanges)
+		m := NewHostMap(l, "test", myNet, preferredRanges)
 		y := NewUDPAddrFromString("10.128.0.3:11111")
 		a := NewUDPAddrFromString("10.127.0.3:11111")
 		g := NewUDPAddrFromString("1.0.0.1:22222")

+ 20 - 20
inside.go

@@ -10,7 +10,7 @@ import (
 func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
 	err := newPacket(packet, false, fwPacket)
 	if err != nil {
-		l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
+		f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
 		return
 	}
 
@@ -31,8 +31,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 
 	hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
 	if hostinfo == nil {
-		if l.Level >= logrus.DebugLevel {
-			l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
+		if f.l.Level >= logrus.DebugLevel {
+			f.l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
 				WithField("fwPacket", fwPacket).
 				Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
 		}
@@ -45,7 +45,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 		// the packet queue.
 		ci.queueLock.Lock()
 		if !ci.ready {
-			hostinfo.cachePacket(message, 0, packet, f.sendMessageNow)
+			hostinfo.cachePacket(f.l, message, 0, packet, f.sendMessageNow)
 			ci.queueLock.Unlock()
 			return
 		}
@@ -59,8 +59,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 			f.lightHouse.Query(fwPacket.RemoteIP, f)
 		}
 
-	} else if l.Level >= logrus.DebugLevel {
-		hostinfo.logger().
+	} else if f.l.Level >= logrus.DebugLevel {
+		hostinfo.logger(f.l).
 			WithField("fwPacket", fwPacket).
 			WithField("reason", dropReason).
 			Debugln("dropping outbound packet")
@@ -104,7 +104,7 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
 
 	if ci == nil {
 		// if we don't have a connection state, then send a handshake initiation
-		ci = f.newConnectionState(true, noise.HandshakeIX, []byte{}, 0)
+		ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
 		// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
 		//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
 		hostinfo.ConnectionState = ci
@@ -135,15 +135,15 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
 	fp := &FirewallPacket{}
 	err := newPacket(p, false, fp)
 	if err != nil {
-		l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
+		f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
 		return
 	}
 
 	// check if packet is in outbound fw rules
 	dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil)
 	if dropReason != nil {
-		if l.Level >= logrus.DebugLevel {
-			l.WithField("fwPacket", fp).
+		if f.l.Level >= logrus.DebugLevel {
+			f.l.WithField("fwPacket", fp).
 				WithField("reason", dropReason).
 				Debugln("dropping cached packet")
 		}
@@ -160,8 +160,8 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
 func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
 	hostInfo := f.getOrHandshake(vpnIp)
 	if hostInfo == nil {
-		if l.Level >= logrus.DebugLevel {
-			l.WithField("vpnIp", IntIp(vpnIp)).
+		if f.l.Level >= logrus.DebugLevel {
+			f.l.WithField("vpnIp", IntIp(vpnIp)).
 				Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
 		}
 		return
@@ -172,7 +172,7 @@ func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
 		// the packet queue.
 		hostInfo.ConnectionState.queueLock.Lock()
 		if !hostInfo.ConnectionState.ready {
-			hostInfo.cachePacket(t, st, p, f.sendMessageToVpnIp)
+			hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToVpnIp)
 			hostInfo.ConnectionState.queueLock.Unlock()
 			return
 		}
@@ -191,8 +191,8 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
 func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
 	hostInfo := f.getOrHandshake(vpnIp)
 	if hostInfo == nil {
-		if l.Level >= logrus.DebugLevel {
-			l.WithField("vpnIp", IntIp(vpnIp)).
+		if f.l.Level >= logrus.DebugLevel {
+			f.l.WithField("vpnIp", IntIp(vpnIp)).
 				Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
 		}
 		return
@@ -203,7 +203,7 @@ func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubTyp
 		// the packet queue.
 		hostInfo.ConnectionState.queueLock.Lock()
 		if !hostInfo.ConnectionState.ready {
-			hostInfo.cachePacket(t, st, p, f.sendMessageToAll)
+			hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToAll)
 			hostInfo.ConnectionState.queueLock.Unlock()
 			return
 		}
@@ -247,8 +247,8 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
 		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
 		f.lightHouse.Query(hostinfo.hostId, f)
 		hostinfo.lastRebindCount = f.rebindCount
-		if l.Level >= logrus.DebugLevel {
-			l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
+		if f.l.Level >= logrus.DebugLevel {
+			f.l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
 		}
 	}
 
@@ -256,7 +256,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
 	//TODO: see above note on lock
 	//ci.writeLock.Unlock()
 	if err != nil {
-		hostinfo.logger().WithError(err).
+		hostinfo.logger(f.l).WithError(err).
 			WithField("udpAddr", remote).WithField("counter", c).
 			WithField("attemptedCounter", c).
 			Error("Failed to encrypt outgoing packet")
@@ -265,7 +265,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
 
 	err = f.writers[q].WriteTo(out, remote)
 	if err != nil {
-		hostinfo.logger().WithError(err).
+		hostinfo.logger(f.l).WithError(err).
 			WithField("udpAddr", remote).Error("Failed to write outgoing packet")
 	}
 	return c

+ 22 - 18
interface.go

@@ -9,6 +9,7 @@ import (
 	"time"
 
 	"github.com/rcrowley/go-metrics"
+	"github.com/sirupsen/logrus"
 )
 
 const mtu = 9001
@@ -42,6 +43,7 @@ type InterfaceConfig struct {
 	version                 string
 
 	ConntrackCacheTimeout time.Duration
+	l                     *logrus.Logger
 }
 
 type Interface struct {
@@ -73,6 +75,7 @@ type Interface struct {
 
 	metricHandshakes metrics.Histogram
 	messageMetrics   *MessageMetrics
+	l                *logrus.Logger
 }
 
 func NewInterface(c *InterfaceConfig) (*Interface, error) {
@@ -113,9 +116,10 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
 
 		metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
 		messageMetrics:   c.MessageMetrics,
+		l:                c.l,
 	}
 
-	ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval)
+	ifce.connectionManager = newConnectionManager(c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
 
 	return ifce, nil
 }
@@ -125,10 +129,10 @@ func (f *Interface) run() {
 
 	addr, err := f.outside.LocalAddr()
 	if err != nil {
-		l.WithError(err).Error("Failed to get udp listen address")
+		f.l.WithError(err).Error("Failed to get udp listen address")
 	}
 
-	l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
+	f.l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
 		WithField("build", f.version).WithField("udpAddr", addr).
 		Info("Nebula interface is active")
 
@@ -140,14 +144,14 @@ func (f *Interface) run() {
 		if i > 0 {
 			reader, err = f.inside.NewMultiQueueReader()
 			if err != nil {
-				l.Fatal(err)
+				f.l.Fatal(err)
 			}
 		}
 		f.readers[i] = reader
 	}
 
 	if err := f.inside.Activate(); err != nil {
-		l.Fatal(err)
+		f.l.Fatal(err)
 	}
 
 	// Launch n queues to read packets from udp
@@ -187,12 +191,12 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 	for {
 		n, err := reader.Read(packet)
 		if err != nil {
-			l.WithError(err).Error("Error while reading outbound packet")
+			f.l.WithError(err).Error("Error while reading outbound packet")
 			// This only seems to happen when something fatal happens to the fd, so exit.
 			os.Exit(2)
 		}
 
-		f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
+		f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
 	}
 }
 
@@ -208,21 +212,21 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *Config) {
 func (f *Interface) reloadCA(c *Config) {
 	// reload and check regardless
 	// todo: need mutex?
-	newCAs, err := loadCAFromConfig(c)
+	newCAs, err := loadCAFromConfig(f.l, c)
 	if err != nil {
-		l.WithError(err).Error("Could not refresh trusted CA certificates")
+		f.l.WithError(err).Error("Could not refresh trusted CA certificates")
 		return
 	}
 
 	trustedCAs = newCAs
-	l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
+	f.l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
 }
 
 func (f *Interface) reloadCertKey(c *Config) {
 	// reload and check in all cases
 	cs, err := NewCertStateFromConfig(c)
 	if err != nil {
-		l.WithError(err).Error("Could not refresh client cert")
+		f.l.WithError(err).Error("Could not refresh client cert")
 		return
 	}
 
@@ -230,24 +234,24 @@ func (f *Interface) reloadCertKey(c *Config) {
 	oldIPs := f.certState.certificate.Details.Ips
 	newIPs := cs.certificate.Details.Ips
 	if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
-		l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
+		f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
 		return
 	}
 
 	f.certState = cs
-	l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
+	f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
 }
 
 func (f *Interface) reloadFirewall(c *Config) {
 	//TODO: need to trigger/detect if the certificate changed too
 	if c.HasChanged("firewall") == false {
-		l.Debug("No firewall config change detected")
+		f.l.Debug("No firewall config change detected")
 		return
 	}
 
-	fw, err := NewFirewallFromConfig(f.certState.certificate, c)
+	fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c)
 	if err != nil {
-		l.WithError(err).Error("Error while creating firewall during reload")
+		f.l.WithError(err).Error("Error while creating firewall during reload")
 		return
 	}
 
@@ -260,7 +264,7 @@ func (f *Interface) reloadFirewall(c *Config) {
 	// If rulesVersion is back to zero, we have wrapped all the way around. Be
 	// safe and just reset conntrack in this case.
 	if fw.rulesVersion == 0 {
-		l.WithField("firewallHash", fw.GetRuleHash()).
+		f.l.WithField("firewallHash", fw.GetRuleHash()).
 			WithField("oldFirewallHash", oldFw.GetRuleHash()).
 			WithField("rulesVersion", fw.rulesVersion).
 			Warn("firewall rulesVersion has overflowed, resetting conntrack")
@@ -271,7 +275,7 @@ func (f *Interface) reloadFirewall(c *Config) {
 	f.firewall = fw
 
 	oldFw.Destroy()
-	l.WithField("firewallHash", fw.GetRuleHash()).
+	f.l.WithField("firewallHash", fw.GetRuleHash()).
 		WithField("oldFirewallHash", oldFw.GetRuleHash()).
 		WithField("rulesVersion", fw.rulesVersion).
 		Info("New firewall has been installed")

+ 19 - 17
lighthouse.go

@@ -48,6 +48,7 @@ type LightHouse struct {
 
 	metrics           *MessageMetrics
 	metricHolepunchTx metrics.Counter
+	l                 *logrus.Logger
 }
 
 type EncWriter interface {
@@ -55,7 +56,7 @@ type EncWriter interface {
 	SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
 }
 
-func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
+func NewLightHouse(l *logrus.Logger, amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
 	h := LightHouse{
 		amLighthouse: amLighthouse,
 		myIp:         myIp,
@@ -67,6 +68,7 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
 		punchConn:    pc,
 		punchBack:    punchBack,
 		punchDelay:   punchDelay,
+		l:            l,
 	}
 
 	if metricsEnabled {
@@ -126,7 +128,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
 		// Send a query to the lighthouses and hope for the best next time
 		query, err := proto.Marshal(NewLhQueryByInt(ip))
 		if err != nil {
-			l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
+			lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
 			return
 		}
 
@@ -159,7 +161,7 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
 	lh.Lock()
 	//l.Debugln(lh.addrMap)
 	delete(lh.addrMap, vpnIP)
-	l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
+	lh.l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
 	lh.Unlock()
 }
 
@@ -181,7 +183,7 @@ func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) {
 	}
 
 	allow := lh.remoteAllowList.Allow(toIp.IP)
-	l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow")
+	lh.l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow")
 	if !allow {
 		return
 	}
@@ -270,7 +272,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
 	var v4 []*IpAndPort
 	var v6 []*Ip6AndPort
 
-	for _, e := range *localIps(lh.localAllowList) {
+	for _, e := range *localIps(lh.l, lh.localAllowList) {
 		// Only add IPs that aren't my VPN/tun IP
 		if ip2int(e) != lh.myIp {
 			ipp := NewIpAndPort(e, lh.nebulaPort)
@@ -297,7 +299,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
 	for vpnIp := range lh.lighthouses {
 		mm, err := proto.Marshal(m)
 		if err != nil {
-			l.Debugf("Invalid marshal to update")
+			lh.l.Debugf("Invalid marshal to update")
 		}
 		//l.Error("LIGHTHOUSE PACKET SEND", mm)
 		f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
@@ -368,14 +370,14 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
 	n := lhh.resetMeta()
 	err := proto.UnmarshalMerge(p, n)
 	if err != nil {
-		l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
+		lh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
 			Error("Failed to unmarshal lighthouse packet")
 		//TODO: send recv_error?
 		return
 	}
 
 	if n.Details == nil {
-		l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
+		lh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
 			Error("Invalid lighthouse update")
 		//TODO: send recv_error?
 		return
@@ -387,7 +389,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
 	case NebulaMeta_HostQuery:
 		// Exit if we don't answer queries
 		if !lh.amLighthouse {
-			l.Debugln("I don't answer queries, but received from: ", rAddr)
+			lh.l.Debugln("I don't answer queries, but received from: ", rAddr)
 			return
 		}
 
@@ -422,7 +424,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
 
 			reply, err := proto.Marshal(n)
 			if err != nil {
-				l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
+				lh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
 				return
 			}
 			lh.metricTx(NebulaMeta_HostQueryReply, 1)
@@ -431,7 +433,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
 			// This signals the other side to punch some zero byte udp packets
 			ips, err = lh.Query(vpnIp, f)
 			if err != nil {
-				l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch")
+				lh.l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch")
 				return
 			} else {
 				//l.Debugln("Notify host to punch", iap)
@@ -492,7 +494,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
 	case NebulaMeta_HostUpdateNotification:
 		//Simple check that the host sent this not someone else
 		if n.Details.VpnIp != vpnIp {
-			l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
+			lh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
 			return
 		}
 
@@ -530,9 +532,9 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
 
 			}()
 
-			if l.Level >= logrus.DebugLevel {
+			if lh.l.Level >= logrus.DebugLevel {
 				//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
-				l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
+				lh.l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
 			}
 		}
 
@@ -549,9 +551,9 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
 
 			}()
 
-			if l.Level >= logrus.DebugLevel {
+			if lh.l.Level >= logrus.DebugLevel {
 				//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
-				l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
+				lh.l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
 			}
 		}
 
@@ -561,7 +563,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
 		if lh.punchBack {
 			go func() {
 				time.Sleep(time.Second * 5)
-				l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
+				lh.l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
 				// TODO we have to allocate a new output buffer here since we are spawning a new goroutine
 				// for each punchBack packet. We should move this into a timerwheel or a single goroutine
 				// managed by a channel.

+ 11 - 8
lighthouse_test.go

@@ -65,12 +65,13 @@ func TestSetipandportsfromudpaddrs(t *testing.T) {
 }
 
 func Test_lhStaticMapping(t *testing.T) {
+	l := NewTestLogger()
 	lh1 := "10.128.0.2"
 	lh1IP := net.ParseIP(lh1)
 
-	udpServer, _ := NewListener("0.0.0.0", 0, true)
+	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
 
-	meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
+	meh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
 	meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
 	err := meh.ValidateLHStaticEntries()
 	assert.Nil(t, err)
@@ -78,19 +79,20 @@ func Test_lhStaticMapping(t *testing.T) {
 	lh2 := "10.128.0.3"
 	lh2IP := net.ParseIP(lh2)
 
-	meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
+	meh = NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
 	meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
 	err = meh.ValidateLHStaticEntries()
 	assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
+	l := NewTestLogger()
 	lh1 := "10.128.0.2"
 	lh1IP := net.ParseIP(lh1)
 
-	udpServer, _ := NewListener("0.0.0.0", 0, true)
+	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
 
-	lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
+	lh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
 
 	hAddr := NewUDPAddrFromString("4.5.6.7:12345")
 	hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
@@ -136,7 +138,8 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 }
 
 func Test_lhRemoteAllowList(t *testing.T) {
-	c := NewConfig()
+	l := NewTestLogger()
+	c := NewConfig(l)
 	c.Settings["remoteallowlist"] = map[interface{}]interface{}{
 		"10.20.0.0/12": false,
 	}
@@ -146,9 +149,9 @@ func Test_lhRemoteAllowList(t *testing.T) {
 	lh1 := "10.128.0.2"
 	lh1IP := net.ParseIP(lh1)
 
-	udpServer, _ := NewListener("0.0.0.0", 0, true)
+	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
 
-	lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
+	lh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
 	lh.SetRemoteAllowList(allowList)
 
 	remote1 := "10.20.0.3"

+ 15 - 14
main.go

@@ -11,13 +11,10 @@ import (
 	"gopkg.in/yaml.v2"
 )
 
-// The caller should provide a real logger, we have one just in case
-var l = logrus.New()
-
 type m map[string]interface{}
 
 func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
-	l = logger
+	l := logger
 	l.Formatter = &logrus.TextFormatter{
 		FullTimestamp: true,
 	}
@@ -46,7 +43,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	})
 
 	// trustedCAs is currently a global, so loadCA operates on that global directly
-	trustedCAs, err = loadCAFromConfig(config)
+	trustedCAs, err = loadCAFromConfig(l, config)
 	if err != nil {
 		//The errors coming out of loadCA are already nicely formatted
 		return nil, NewContextualError("Failed to load ca from config", nil, err)
@@ -60,7 +57,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	}
 	l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
 
-	fw, err := NewFirewallFromConfig(cs.certificate, config)
+	fw, err := NewFirewallFromConfig(l, cs.certificate, config)
 	if err != nil {
 		return nil, NewContextualError("Error while loading firewall rules", nil, err)
 	}
@@ -78,9 +75,9 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	}
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
-	wireSSHReload(ssh, config)
+	wireSSHReload(l, ssh, config)
 	if config.GetBool("sshd.enabled", false) {
-		err = configSSH(ssh, config)
+		err = configSSH(l, ssh, config)
 		if err != nil {
 			return nil, NewContextualError("Error while configuring the sshd", nil, err)
 		}
@@ -136,6 +133,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 			tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l)
 		case tunFd != nil:
 			tun, err = newTunFromFd(
+				l,
 				*tunFd,
 				tunCidr,
 				config.GetInt("tun.mtu", DEFAULT_MTU),
@@ -145,6 +143,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 			)
 		default:
 			tun, err = newTun(
+				l,
 				config.GetString("tun.dev", ""),
 				tunCidr,
 				config.GetInt("tun.mtu", DEFAULT_MTU),
@@ -166,7 +165,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 
 	if !configTest {
 		for i := 0; i < routines; i++ {
-			udpServer, err := NewListener(config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
+			udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
 			if err != nil {
 				return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 			}
@@ -222,7 +221,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		}
 	}
 
-	hostMap := NewHostMap("main", tunCidr, preferredRanges)
+	hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
 	hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
 	hostMap.addUnsafeRoutes(&unsafeRoutes)
 	hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
@@ -266,6 +265,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	}
 
 	lightHouse := NewLightHouse(
+		l,
 		amLighthouse,
 		ip2int(tunCidr.IP),
 		lighthouseHosts,
@@ -337,7 +337,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		messageMetrics: messageMetrics,
 	}
 
-	handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
+	handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
 	lightHouse.handshakeTrigger = handshakeManager.trigger
 
 	//TODO: These will be reused for psk
@@ -367,6 +367,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		version:                 buildVersion,
 
 		ConntrackCacheTimeout: conntrackCacheTimeout,
+		l:                     l,
 	}
 
 	switch ifConfig.Cipher {
@@ -395,7 +396,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		go lightHouse.LhUpdateWorker(ifce)
 	}
 
-	err = startStats(config, configTest)
+	err = startStats(l, config, configTest)
 	if err != nil {
 		return nil, NewContextualError("Failed to start stats emitter", nil, err)
 	}
@@ -407,12 +408,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	//TODO: check if we _should_ be emitting stats
 	go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
 
-	attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
+	attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
 
 	// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
 	if amLighthouse && serveDns {
 		l.Debugln("Starting dns server")
-		go dnsMain(hostMap, config)
+		go dnsMain(l, hostMap, config)
 	}
 
 	return &Control{ifce, l}, nil

+ 29 - 0
main_test.go

@@ -1 +1,30 @@
 package nebula
+
+import (
+	"io/ioutil"
+	"os"
+
+	"github.com/sirupsen/logrus"
+)
+
+func NewTestLogger() *logrus.Logger {
+	l := logrus.New()
+
+	v := os.Getenv("TEST_LOGS")
+	if v == "" {
+		l.SetOutput(ioutil.Discard)
+		return l
+	}
+
+	switch v {
+	case "1":
+		// This is the default level but we are being explicit
+		l.SetLevel(logrus.InfoLevel)
+	case "2":
+		l.SetLevel(logrus.DebugLevel)
+	case "3":
+		l.SetLevel(logrus.TraceLevel)
+	}
+
+	return l
+}

+ 25 - 25
outside.go

@@ -24,7 +24,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 		// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
 		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
 		if len(packet) > 1 {
-			l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
+			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
 		}
 		return
 	}
@@ -57,7 +57,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 
 		d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
 		if err != nil {
-			hostinfo.logger().WithError(err).WithField("udpAddr", addr).
+			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
 				WithField("packet", packet).
 				Error("Failed to decrypt lighthouse packet")
 
@@ -78,7 +78,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 
 		d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
 		if err != nil {
-			hostinfo.logger().WithError(err).WithField("udpAddr", addr).
+			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
 				WithField("packet", packet).
 				Error("Failed to decrypt test packet")
 
@@ -115,7 +115,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 			return
 		}
 
-		hostinfo.logger().WithField("udpAddr", addr).
+		hostinfo.logger(f.l).WithField("udpAddr", addr).
 			Info("Close tunnel received, tearing down.")
 
 		f.closeTunnel(hostinfo)
@@ -123,7 +123,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 
 	default:
 		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-		hostinfo.logger().Debugf("Unexpected packet received from %s", addr)
+		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
 		return
 	}
 
@@ -143,18 +143,18 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
 	if hostDidRoam(hostinfo.remote, addr) {
 		if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
-			hostinfo.logger().WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
+			hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
 			return
 		}
 		if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
-			if l.Level >= logrus.DebugLevel {
-				hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
+			if f.l.Level >= logrus.DebugLevel {
+				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
 					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
 			}
 			return
 		}
 
-		hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
+		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
 			Info("Host roamed to new udp ip/port.")
 		hostinfo.lastRoam = time.Now()
 		remoteCopy := *hostinfo.remote
@@ -170,7 +170,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
 func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool {
 	// If connectionstate exists and the replay protector allows, process packet
 	// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
-	if ci == nil || !ci.window.Check(header.MessageCounter) {
+	if ci == nil || !ci.window.Check(f.l, header.MessageCounter) {
 		f.sendRecvError(addr, header.RemoteIndex)
 		return false
 	}
@@ -247,8 +247,8 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
 		return nil, err
 	}
 
-	if !hostinfo.ConnectionState.window.Update(mc) {
-		hostinfo.logger().WithField("header", header).
+	if !hostinfo.ConnectionState.window.Update(f.l, mc) {
+		hostinfo.logger(f.l).WithField("header", header).
 			Debugln("dropping out of window packet")
 		return nil, errors.New("out of window packet")
 	}
@@ -261,7 +261,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
 	if err != nil {
-		hostinfo.logger().WithError(err).Error("Failed to decrypt packet")
+		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
 		//TODO: maybe after build 64 is out? 06/14/2018 - NB
 		//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
 		return
@@ -269,21 +269,21 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 
 	err = newPacket(out, true, fwPacket)
 	if err != nil {
-		hostinfo.logger().WithError(err).WithField("packet", out).
+		hostinfo.logger(f.l).WithError(err).WithField("packet", out).
 			Warnf("Error while validating inbound packet")
 		return
 	}
 
-	if !hostinfo.ConnectionState.window.Update(messageCounter) {
-		hostinfo.logger().WithField("fwPacket", fwPacket).
+	if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
+		hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
 			Debugln("dropping out of window packet")
 		return
 	}
 
 	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache)
 	if dropReason != nil {
-		if l.Level >= logrus.DebugLevel {
-			hostinfo.logger().WithField("fwPacket", fwPacket).
+		if f.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
 				WithField("reason", dropReason).
 				Debugln("dropping inbound packet")
 		}
@@ -293,7 +293,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 	f.connectionManager.In(hostinfo.hostId)
 	_, err = f.readers[q].Write(out)
 	if err != nil {
-		l.WithError(err).Error("Failed to write to tun")
+		f.l.WithError(err).Error("Failed to write to tun")
 	}
 }
 
@@ -303,16 +303,16 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
 	//TODO: this should be a signed message so we can trust that we should drop the index
 	b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
 	f.outside.WriteTo(b, endpoint)
-	if l.Level >= logrus.DebugLevel {
-		l.WithField("index", index).
+	if f.l.Level >= logrus.DebugLevel {
+		f.l.WithField("index", index).
 			WithField("udpAddr", endpoint).
 			Debug("Recv error sent")
 	}
 }
 
 func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
-	if l.Level >= logrus.DebugLevel {
-		l.WithField("index", h.RemoteIndex).
+	if f.l.Level >= logrus.DebugLevel {
+		f.l.WithField("index", h.RemoteIndex).
 			WithField("udpAddr", addr).
 			Debug("Recv error received")
 	}
@@ -322,7 +322,7 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
 
 	hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex)
 	if err != nil {
-		l.Debugln(err, ": ", h.RemoteIndex)
+		f.l.Debugln(err, ": ", h.RemoteIndex)
 		return
 	}
 
@@ -333,7 +333,7 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
 		return
 	}
 	if hostinfo.remote != nil && hostinfo.remote.String() != addr.String() {
-		l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
+		f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
 		return
 	}
 

+ 2 - 1
punchy_test.go

@@ -8,7 +8,8 @@ import (
 )
 
 func TestNewPunchyFromConfig(t *testing.T) {
-	c := NewConfig()
+	l := NewTestLogger()
+	c := NewConfig(l)
 
 	// Test defaults
 	p := NewPunchyFromConfig(c)

+ 12 - 8
ssh.go

@@ -44,10 +44,10 @@ type sshCreateTunnelFlags struct {
 	Address string
 }
 
-func wireSSHReload(ssh *sshd.SSHServer, c *Config) {
+func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
 	c.RegisterReloadCallback(func(c *Config) {
 		if c.GetBool("sshd.enabled", false) {
-			err := configSSH(ssh, c)
+			err := configSSH(l, ssh, c)
 			if err != nil {
 				l.WithError(err).Error("Failed to reconfigure the sshd")
 				ssh.Stop()
@@ -58,7 +58,7 @@ func wireSSHReload(ssh *sshd.SSHServer, c *Config) {
 	})
 }
 
-func configSSH(ssh *sshd.SSHServer, c *Config) error {
+func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error {
 	//TODO conntrack list
 	//TODO print firewall rules or hash?
 
@@ -149,7 +149,7 @@ func configSSH(ssh *sshd.SSHServer, c *Config) error {
 	return nil
 }
 
-func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
+func attachCommands(l *logrus.Logger, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "list-hostmap",
 		ShortDescription: "List all known previously connected hosts",
@@ -225,13 +225,17 @@ func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostM
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "log-level",
 		ShortDescription: "Gets or sets the current log level",
-		Callback:         sshLogLevel,
+		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+			return sshLogLevel(l, fs, a, w)
+		},
 	})
 
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "log-format",
 		ShortDescription: "Gets or sets the current log format",
-		Callback:         sshLogFormat,
+		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+			return sshLogFormat(l, fs, a, w)
+		},
 	})
 
 	ssh.RegisterCommand(&sshd.Command{
@@ -629,7 +633,7 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
 	return err
 }
 
-func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error {
+func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
 	if len(a) == 0 {
 		return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
 	}
@@ -643,7 +647,7 @@ func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error {
 	return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
 }
 
-func sshLogFormat(fs interface{}, a []string, w sshd.StringWriter) error {
+func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
 	if len(a) == 0 {
 		return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
 	}

+ 6 - 5
stats.go

@@ -13,9 +13,10 @@ import (
 	"github.com/prometheus/client_golang/prometheus"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"github.com/rcrowley/go-metrics"
+	"github.com/sirupsen/logrus"
 )
 
-func startStats(c *Config, configTest bool) error {
+func startStats(l *logrus.Logger, c *Config, configTest bool) error {
 	mType := c.GetString("stats.type", "")
 	if mType == "" || mType == "none" {
 		return nil
@@ -28,9 +29,9 @@ func startStats(c *Config, configTest bool) error {
 
 	switch mType {
 	case "graphite":
-		startGraphiteStats(interval, c, configTest)
+		startGraphiteStats(l, interval, c, configTest)
 	case "prometheus":
-		startPrometheusStats(interval, c, configTest)
+		startPrometheusStats(l, interval, c, configTest)
 	default:
 		return fmt.Errorf("stats.type was not understood: %s", mType)
 	}
@@ -44,7 +45,7 @@ func startStats(c *Config, configTest bool) error {
 	return nil
 }
 
-func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
+func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
 	proto := c.GetString("stats.protocol", "tcp")
 	host := c.GetString("stats.host", "")
 	if host == "" {
@@ -64,7 +65,7 @@ func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
 	return nil
 }
 
-func startPrometheusStats(i time.Duration, c *Config, configTest bool) error {
+func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
 	namespace := c.GetString("stats.namespace", "")
 	subsystem := c.GetString("stats.subsystem", "")
 

+ 4 - 1
tun_android.go

@@ -6,6 +6,7 @@ import (
 	"net"
 	"os"
 
+	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 )
 
@@ -19,9 +20,10 @@ type Tun struct {
 	TXQueueLen   int
 	Routes       []route
 	UnsafeRoutes []route
+	l            *logrus.Logger
 }
 
-func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
+func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 
 	ifce = &Tun{
@@ -33,6 +35,7 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route,
 		TXQueueLen:      txQueueLen,
 		Routes:          routes,
 		UnsafeRoutes:    unsafeRoutes,
+		l:               l,
 	}
 	return
 }

+ 5 - 3
tun_darwin.go

@@ -9,6 +9,7 @@ import (
 	"os/exec"
 	"strconv"
 
+	"github.com/sirupsen/logrus"
 	"github.com/songgao/water"
 )
 
@@ -17,11 +18,11 @@ type Tun struct {
 	Cidr         *net.IPNet
 	MTU          int
 	UnsafeRoutes []route
-
+	l            *logrus.Logger
 	*water.Interface
 }
 
-func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
 	if len(routes) > 0 {
 		return nil, fmt.Errorf("route MTU not supported in Darwin")
 	}
@@ -31,10 +32,11 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
 		Cidr:         cidr,
 		MTU:          defaultMTU,
 		UnsafeRoutes: unsafeRoutes,
+		l:            l,
 	}, nil
 }
 
-func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
+func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 }
 

+ 14 - 15
tun_disabled.go

@@ -9,24 +9,23 @@ import (
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
-	log "github.com/sirupsen/logrus"
 )
 
 type disabledTun struct {
-	read   chan []byte
-	cidr   *net.IPNet
-	logger *log.Logger
+	read chan []byte
+	cidr *net.IPNet
 
 	// Track these metrics since we don't have the tun device to do it for us
 	tx metrics.Counter
 	rx metrics.Counter
+	l  *logrus.Logger
 }
 
-func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *log.Logger) *disabledTun {
+func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
 	tun := &disabledTun{
-		cidr:   cidr,
-		read:   make(chan []byte, queueLen),
-		logger: l,
+		cidr: cidr,
+		read: make(chan []byte, queueLen),
+		l:    l,
 	}
 
 	if metricsEnabled {
@@ -63,8 +62,8 @@ func (t *disabledTun) Read(b []byte) (int, error) {
 	}
 
 	t.tx.Inc(1)
-	if l.Level >= logrus.DebugLevel {
-		t.logger.WithField("raw", prettyPacket(r)).Debugf("Write payload")
+	if t.l.Level >= logrus.DebugLevel {
+		t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
 	}
 
 	return copy(b, r), nil
@@ -103,7 +102,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
 	select {
 	case t.read <- buf:
 	default:
-		t.logger.Debugf("tun_disabled: dropped ICMP Echo Reply response")
+		t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response")
 	}
 
 	return true
@@ -114,11 +113,11 @@ func (t *disabledTun) Write(b []byte) (int, error) {
 
 	// Check for ICMP Echo Request before spending time doing the full parsing
 	if t.handleICMPEchoRequest(b) {
-		if l.Level >= logrus.DebugLevel {
-			t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
+		if t.l.Level >= logrus.DebugLevel {
+			t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
 		}
-	} else if l.Level >= logrus.DebugLevel {
-		t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
+	} else if t.l.Level >= logrus.DebugLevel {
+		t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
 	}
 	return len(b), nil
 }

+ 10 - 6
tun_freebsd.go

@@ -9,6 +9,8 @@ import (
 	"regexp"
 	"strconv"
 	"strings"
+
+	"github.com/sirupsen/logrus"
 )
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
@@ -18,15 +20,16 @@ type Tun struct {
 	Cidr         *net.IPNet
 	MTU          int
 	UnsafeRoutes []route
+	l            *logrus.Logger
 
 	io.ReadWriteCloser
 }
 
-func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
+func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 }
 
-func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
 	if len(routes) > 0 {
 		return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
 	}
@@ -41,6 +44,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
 		Cidr:         cidr,
 		MTU:          defaultMTU,
 		UnsafeRoutes: unsafeRoutes,
+		l:            l,
 	}, nil
 }
 
@@ -52,21 +56,21 @@ func (c *Tun) Activate() error {
 	}
 
 	// TODO use syscalls instead of exec.Command
-	l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
+	c.l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
 	if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
-	l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
+	c.l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
 	if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
 	}
-	l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
+	c.l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
 	if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 	// Unsafe path routes
 	for _, r := range c.UnsafeRoutes {
-		l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
+		c.l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
 		if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
 			return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
 		}

+ 8 - 4
tun_linux.go

@@ -10,6 +10,7 @@ import (
 	"strings"
 	"unsafe"
 
+	"github.com/sirupsen/logrus"
 	"github.com/vishvananda/netlink"
 	"golang.org/x/sys/unix"
 )
@@ -24,6 +25,7 @@ type Tun struct {
 	TXQueueLen   int
 	Routes       []route
 	UnsafeRoutes []route
+	l            *logrus.Logger
 }
 
 type ifReq struct {
@@ -78,7 +80,7 @@ type ifreqQLEN struct {
 	pad   [8]byte
 }
 
-func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
+func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
 
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 
@@ -91,11 +93,12 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route,
 		TXQueueLen:      txQueueLen,
 		Routes:          routes,
 		UnsafeRoutes:    unsafeRoutes,
+		l:               l,
 	}
 	return
 }
 
-func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {
 		return nil, err
@@ -131,6 +134,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
 		TXQueueLen:      txQueueLen,
 		Routes:          routes,
 		UnsafeRoutes:    unsafeRoutes,
+		l:               l,
 	}
 	return
 }
@@ -233,14 +237,14 @@ func (c Tun) Activate() error {
 	ifm := ifreqMTU{Name: devName, MTU: int32(c.MaxMTU)}
 	if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
 		// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
-		l.WithError(err).Error("Failed to set tun mtu")
+		c.l.WithError(err).Error("Failed to set tun mtu")
 	}
 
 	// Set the transmit queue length
 	ifrq := ifreqQLEN{Name: devName, Value: int32(c.TXQueueLen)}
 	if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
 		// If we can't set the queue length nebula will still work but it may lead to packet loss
-		l.WithError(err).Error("Failed to set tun tx queue length")
+		c.l.WithError(err).Error("Failed to set tun tx queue length")
 	}
 
 	// Bring up the interface

+ 4 - 2
tun_test.go

@@ -9,7 +9,8 @@ import (
 )
 
 func Test_parseRoutes(t *testing.T) {
-	c := NewConfig()
+	l := NewTestLogger()
+	c := NewConfig(l)
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 
 	// test no routes config
@@ -104,7 +105,8 @@ func Test_parseRoutes(t *testing.T) {
 }
 
 func Test_parseUnsafeRoutes(t *testing.T) {
-	c := NewConfig()
+	l := NewTestLogger()
+	c := NewConfig(l)
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 
 	// test no routes config

+ 5 - 2
tun_windows.go

@@ -7,6 +7,7 @@ import (
 	"os/exec"
 	"strconv"
 
+	"github.com/sirupsen/logrus"
 	"github.com/songgao/water"
 )
 
@@ -15,15 +16,16 @@ type Tun struct {
 	Cidr         *net.IPNet
 	MTU          int
 	UnsafeRoutes []route
+	l            *logrus.Logger
 
 	*water.Interface
 }
 
-func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
+func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Windows")
 }
 
-func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
 	if len(routes) > 0 {
 		return nil, fmt.Errorf("route MTU not supported in Windows")
 	}
@@ -33,6 +35,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
 		Cidr:         cidr,
 		MTU:          defaultMTU,
 		UnsafeRoutes: unsafeRoutes,
+		l:            l,
 	}, nil
 }
 

+ 2 - 0
udp_android.go

@@ -1,3 +1,5 @@
+// +build !e2e_testing
+
 package nebula
 
 import (

+ 2 - 0
udp_darwin.go

@@ -1,3 +1,5 @@
+// +build !e2e_testing
+
 package nebula
 
 // Darwin support is primarily implemented in udp_generic, besides NewListenConfig

+ 2 - 0
udp_freebsd.go

@@ -1,3 +1,5 @@
+// +build !e2e_testing
+
 package nebula
 
 // FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig

+ 8 - 4
udp_generic.go

@@ -1,4 +1,5 @@
 // +build !linux android
+// +build !e2e_testing
 
 // udp_generic implements the nebula UDP interface in pure Go stdlib. This
 // means it can be used on platforms like Darwin and Windows.
@@ -9,20 +10,23 @@ import (
 	"context"
 	"fmt"
 	"net"
+
+	"github.com/sirupsen/logrus"
 )
 
 type udpConn struct {
 	*net.UDPConn
+	l *logrus.Logger
 }
 
-func NewListener(ip string, port int, multi bool) (*udpConn, error) {
+func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
 	lc := NewListenConfig(multi)
 	pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
 	if err != nil {
 		return nil, err
 	}
 	if uc, ok := pc.(*net.UDPConn); ok {
-		return &udpConn{UDPConn: uc}, nil
+		return &udpConn{UDPConn: uc, l: l}, nil
 	}
 	return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
 }
@@ -76,13 +80,13 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
 		// Just read one packet at a time
 		n, rua, err := u.ReadFromUDP(buffer)
 		if err != nil {
-			l.WithError(err).Error("Failed to read packets")
+			f.l.WithError(err).Error("Failed to read packets")
 			continue
 		}
 
 		udpAddr.IP = rua.IP
 		udpAddr.Port = uint16(rua.Port)
-		f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get())
+		f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get(f.l))
 	}
 }
 

+ 13 - 10
udp_linux.go

@@ -1,4 +1,5 @@
 // +build !android
+// +build !e2e_testing
 
 package nebula
 
@@ -10,6 +11,7 @@ import (
 	"unsafe"
 
 	"github.com/rcrowley/go-metrics"
+	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 )
 
@@ -17,6 +19,7 @@ import (
 
 type udpConn struct {
 	sysFd int
+	l     *logrus.Logger
 }
 
 var x int
@@ -38,7 +41,7 @@ const (
 
 type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
 
-func NewListener(ip string, port int, multi bool) (*udpConn, error) {
+func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
 	syscall.ForkLock.RLock()
 	fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
 	if err == nil {
@@ -70,7 +73,7 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
 	//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
 	//l.Println(v, err)
 
-	return &udpConn{sysFd: fd}, err
+	return &udpConn{sysFd: fd, l: l}, err
 }
 
 func (u *udpConn) Rebind() error {
@@ -153,7 +156,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
 	for {
 		n, err := read(msgs)
 		if err != nil {
-			l.WithError(err).Error("Failed to read packets")
+			u.l.WithError(err).Error("Failed to read packets")
 			continue
 		}
 
@@ -161,7 +164,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
 		for i := 0; i < n; i++ {
 			udpAddr.IP = names[i][8:24]
 			udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
-			f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get())
+			f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
 		}
 	}
 }
@@ -244,12 +247,12 @@ func (u *udpConn) reloadConfig(c *Config) {
 		if err == nil {
 			s, err := u.GetRecvBuffer()
 			if err == nil {
-				l.WithField("size", s).Info("listen.read_buffer was set")
+				u.l.WithField("size", s).Info("listen.read_buffer was set")
 			} else {
-				l.WithError(err).Warn("Failed to get listen.read_buffer")
+				u.l.WithError(err).Warn("Failed to get listen.read_buffer")
 			}
 		} else {
-			l.WithError(err).Error("Failed to set listen.read_buffer")
+			u.l.WithError(err).Error("Failed to set listen.read_buffer")
 		}
 	}
 
@@ -259,12 +262,12 @@ func (u *udpConn) reloadConfig(c *Config) {
 		if err == nil {
 			s, err := u.GetSendBuffer()
 			if err == nil {
-				l.WithField("size", s).Info("listen.write_buffer was set")
+				u.l.WithField("size", s).Info("listen.write_buffer was set")
 			} else {
-				l.WithError(err).Warn("Failed to get listen.write_buffer")
+				u.l.WithError(err).Warn("Failed to get listen.write_buffer")
 			}
 		} else {
-			l.WithError(err).Error("Failed to set listen.write_buffer")
+			u.l.WithError(err).Error("Failed to set listen.write_buffer")
 		}
 	}
 }

+ 1 - 0
udp_linux_32.go

@@ -1,6 +1,7 @@
 // +build linux
 // +build 386 amd64p32 arm mips mipsle
 // +build !android
+// +build !e2e_testing
 
 package nebula
 

+ 1 - 0
udp_linux_64.go

@@ -1,6 +1,7 @@
 // +build linux
 // +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
 // +build !android
+// +build !e2e_testing
 
 package nebula