Browse Source

Rework some things into packages (#489)

Nate Brown 3 years ago
parent
commit
bcabcfdaca
73 changed files with 2527 additions and 2375 deletions
  1. 228 7
      allow_list.go
  2. 98 9
      allow_list_test.go
  3. 5 4
      bits_test.go
  4. 3 2
      cert.go
  5. 10 0
      cidr/parse.go
  6. 20 44
      cidr/tree4.go
  7. 153 0
      cidr/tree4_test.go
  8. 20 19
      cidr/tree6.go
  9. 25 21
      cidr/tree6_test.go
  10. 0 157
      cidr_radix_test.go
  11. 6 5
      cmd/nebula-service/main.go
  12. 4 3
      cmd/nebula-service/service.go
  13. 6 5
      cmd/nebula/main.go
  14. 0 611
      config.go
  15. 358 0
      config/config.go
  16. 18 103
      config/config_test.go
  17. 51 49
      connection_manager.go
  18. 39 36
      connection_manager_test.go
  19. 18 15
      control.go
  20. 16 14
      control_test.go
  21. 15 12
      control_tester.go
  22. 9 7
      dns_server.go
  23. 8 5
      e2e/handshakes_test.go
  24. 11 23
      e2e/helpers_test.go
  25. 12 10
      e2e/router/router.go
  26. 45 147
      firewall.go
  27. 59 0
      firewall/cache.go
  28. 62 0
      firewall/packet.go
  29. 101 109
      firewall_test.go
  30. 5 5
      handshake.go
  31. 59 56
      handshake_ix.go
  32. 36 33
      handshake_manager.go
  33. 16 12
      handshake_manager_test.go
  34. 62 66
      header/header.go
  35. 115 0
      header/header_test.go
  36. 0 119
      header_test.go
  37. 64 93
      hostmap.go
  38. 28 23
      inside.go
  39. 26 21
      interface.go
  40. 66 0
      iputil/util.go
  41. 17 0
      iputil/util_test.go
  42. 82 81
      lighthouse.go
  43. 72 68
      lighthouse_test.go
  44. 39 0
      logger.go
  45. 71 68
      main.go
  46. 5 2
      message_metrics.go
  47. 61 57
      outside.go
  48. 9 7
      outside_test.go
  49. 6 2
      punchy.go
  50. 4 2
      punchy_test.go
  51. 31 28
      remote_list.go
  52. 33 32
      remote_list_test.go
  53. 30 26
      ssh.go
  54. 4 3
      stats.go
  55. 7 5
      timeout.go
  56. 4 2
      timeout_system.go
  57. 4 3
      timeout_system_test.go
  58. 4 3
      timeout_test.go
  59. 7 5
      tun_common.go
  60. 6 4
      tun_test.go
  61. 20 0
      udp/conn.go
  62. 14 0
      udp/temp.go
  63. 15 13
      udp/udp_all.go
  64. 2 2
      udp/udp_android.go
  65. 2 2
      udp/udp_darwin.go
  66. 2 2
      udp/udp_freebsd.go
  67. 20 25
      udp/udp_generic.go
  68. 29 33
      udp/udp_linux.go
  69. 3 3
      udp/udp_linux_32.go
  70. 3 3
      udp/udp_linux_64.go
  71. 39 43
      udp/udp_tester.go
  72. 2 2
      udp/udp_windows.go
  73. 3 4
      util/main.go

+ 228 - 7
allow_list.go

@@ -4,11 +4,15 @@ import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"regexp"
 	"regexp"
+
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 type AllowList struct {
 type AllowList struct {
 	// The values of this cidrTree are `bool`, signifying allow/deny
 	// The values of this cidrTree are `bool`, signifying allow/deny
-	cidrTree *CIDR6Tree
+	cidrTree *cidr.Tree6
 }
 }
 
 
 type RemoteAllowList struct {
 type RemoteAllowList struct {
@@ -16,7 +20,7 @@ type RemoteAllowList struct {
 
 
 	// Inside Range Specific, keys of this tree are inside CIDRs and values
 	// Inside Range Specific, keys of this tree are inside CIDRs and values
 	// are *AllowList
 	// are *AllowList
-	insideAllowLists *CIDR6Tree
+	insideAllowLists *cidr.Tree6
 }
 }
 
 
 type LocalAllowList struct {
 type LocalAllowList struct {
@@ -31,6 +35,223 @@ type AllowListNameRule struct {
 	Allow bool
 	Allow bool
 }
 }
 
 
+func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
+	var nameRules []AllowListNameRule
+	handleKey := func(key string, value interface{}) (bool, error) {
+		if key == "interfaces" {
+			var err error
+			nameRules, err = getAllowListInterfaces(k, value)
+			if err != nil {
+				return false, err
+			}
+
+			return true, nil
+		}
+		return false, nil
+	}
+
+	al, err := newAllowListFromConfig(c, k, handleKey)
+	if err != nil {
+		return nil, err
+	}
+	return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil
+}
+
+func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllowList, error) {
+	al, err := newAllowListFromConfig(c, k, nil)
+	if err != nil {
+		return nil, err
+	}
+	remoteAllowRanges, err := getRemoteAllowRanges(c, rangesKey)
+	if err != nil {
+		return nil, err
+	}
+	return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil
+}
+
+// If the handleKey func returns true, the rest of the parsing is skipped
+// for this key. This allows parsing of special values like `interfaces`.
+func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
+	r := c.Get(k)
+	if r == nil {
+		return nil, nil
+	}
+
+	return newAllowList(k, r, handleKey)
+}
+
+// If the handleKey func returns true, the rest of the parsing is skipped
+// for this key. This allows parsing of special values like `interfaces`.
+func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
+	rawMap, ok := raw.(map[interface{}]interface{})
+	if !ok {
+		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
+	}
+
+	tree := cidr.NewTree6()
+
+	// Keep track of the rules we have added for both ipv4 and ipv6
+	type allowListRules struct {
+		firstValue     bool
+		allValuesMatch bool
+		defaultSet     bool
+		allValues      bool
+	}
+
+	rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
+	rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
+
+	for rawKey, rawValue := range rawMap {
+		rawCIDR, ok := rawKey.(string)
+		if !ok {
+			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
+		}
+
+		if handleKey != nil {
+			handled, err := handleKey(rawCIDR, rawValue)
+			if err != nil {
+				return nil, err
+			}
+			if handled {
+				continue
+			}
+		}
+
+		value, ok := rawValue.(bool)
+		if !ok {
+			return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
+		}
+
+		_, ipNet, err := net.ParseCIDR(rawCIDR)
+		if err != nil {
+			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+		}
+
+		// TODO: should we error on duplicate CIDRs in the config?
+		tree.AddCIDR(ipNet, value)
+
+		maskBits, maskSize := ipNet.Mask.Size()
+
+		var rules *allowListRules
+		if maskSize == 32 {
+			rules = &rules4
+		} else {
+			rules = &rules6
+		}
+
+		if rules.firstValue {
+			rules.allValues = value
+			rules.firstValue = false
+		} else {
+			if value != rules.allValues {
+				rules.allValuesMatch = false
+			}
+		}
+
+		// Check if this is 0.0.0.0/0 or ::/0
+		if maskBits == 0 {
+			rules.defaultSet = true
+		}
+	}
+
+	if !rules4.defaultSet {
+		if rules4.allValuesMatch {
+			_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
+			tree.AddCIDR(zeroCIDR, !rules4.allValues)
+		} else {
+			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
+		}
+	}
+
+	if !rules6.defaultSet {
+		if rules6.allValuesMatch {
+			_, zeroCIDR, _ := net.ParseCIDR("::/0")
+			tree.AddCIDR(zeroCIDR, !rules6.allValues)
+		} else {
+			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
+		}
+	}
+
+	return &AllowList{cidrTree: tree}, nil
+}
+
+func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
+	var nameRules []AllowListNameRule
+
+	rawRules, ok := v.(map[interface{}]interface{})
+	if !ok {
+		return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
+	}
+
+	firstEntry := true
+	var allValues bool
+	for rawName, rawAllow := range rawRules {
+		name, ok := rawName.(string)
+		if !ok {
+			return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
+		}
+		allow, ok := rawAllow.(bool)
+		if !ok {
+			return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
+		}
+
+		nameRE, err := regexp.Compile("^" + name + "$")
+		if err != nil {
+			return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
+		}
+
+		nameRules = append(nameRules, AllowListNameRule{
+			Name:  nameRE,
+			Allow: allow,
+		})
+
+		if firstEntry {
+			allValues = allow
+			firstEntry = false
+		} else {
+			if allow != allValues {
+				return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
+			}
+		}
+	}
+
+	return nameRules, nil
+}
+
+func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) {
+	value := c.Get(k)
+	if value == nil {
+		return nil, nil
+	}
+
+	remoteAllowRanges := cidr.NewTree6()
+
+	rawMap, ok := value.(map[interface{}]interface{})
+	if !ok {
+		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
+	}
+	for rawKey, rawValue := range rawMap {
+		rawCIDR, ok := rawKey.(string)
+		if !ok {
+			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
+		}
+
+		allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
+		if err != nil {
+			return nil, err
+		}
+
+		_, ipNet, err := net.ParseCIDR(rawCIDR)
+		if err != nil {
+			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+		}
+
+		remoteAllowRanges.AddCIDR(ipNet, allowList)
+	}
+
+	return remoteAllowRanges, nil
+}
+
 func (al *AllowList) Allow(ip net.IP) bool {
 func (al *AllowList) Allow(ip net.IP) bool {
 	if al == nil {
 	if al == nil {
 		return true
 		return true
@@ -45,7 +266,7 @@ func (al *AllowList) Allow(ip net.IP) bool {
 	}
 	}
 }
 }
 
 
-func (al *AllowList) AllowIpV4(ip uint32) bool {
+func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
 	if al == nil {
 	if al == nil {
 		return true
 		return true
 	}
 	}
@@ -102,14 +323,14 @@ func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool {
 	return al.AllowList.Allow(ip)
 	return al.AllowList.Allow(ip)
 }
 }
 
 
-func (al *RemoteAllowList) Allow(vpnIp uint32, ip net.IP) bool {
+func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool {
 	if !al.getInsideAllowList(vpnIp).Allow(ip) {
 	if !al.getInsideAllowList(vpnIp).Allow(ip) {
 		return false
 		return false
 	}
 	}
 	return al.AllowList.Allow(ip)
 	return al.AllowList.Allow(ip)
 }
 }
 
 
-func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool {
+func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool {
 	if al == nil {
 	if al == nil {
 		return true
 		return true
 	}
 	}
@@ -119,7 +340,7 @@ func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool {
 	return al.AllowList.AllowIpV4(ip)
 	return al.AllowList.AllowIpV4(ip)
 }
 }
 
 
-func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool {
+func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
 	if al == nil {
 	if al == nil {
 		return true
 		return true
 	}
 	}
@@ -129,7 +350,7 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool {
 	return al.AllowList.AllowIpV6(hi, lo)
 	return al.AllowList.AllowIpV6(hi, lo)
 }
 }
 
 
-func (al *RemoteAllowList) getInsideAllowList(vpnIp uint32) *AllowList {
+func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
 	if al.insideAllowLists != nil {
 	if al.insideAllowLists != nil {
 		inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
 		inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
 		if inside != nil {
 		if inside != nil {

+ 98 - 9
allow_list_test.go

@@ -5,21 +5,110 @@ import (
 	"regexp"
 	"regexp"
 	"testing"
 	"testing"
 
 
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
+func TestNewAllowListFromConfig(t *testing.T) {
+	l := util.NewTestLogger()
+	c := config.NewC(l)
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"192.168.0.0": true,
+	}
+	r, err := newAllowListFromConfig(c, "allowlist", nil)
+	assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
+	assert.Nil(t, r)
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"192.168.0.0/16": "abc",
+	}
+	r, err = newAllowListFromConfig(c, "allowlist", nil)
+	assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"192.168.0.0/16": true,
+		"10.0.0.0/8":     false,
+	}
+	r, err = newAllowListFromConfig(c, "allowlist", nil)
+	assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"0.0.0.0/0":      true,
+		"10.0.0.0/8":     false,
+		"10.42.42.0/24":  true,
+		"fd00::/8":       true,
+		"fd00:fd00::/16": false,
+	}
+	r, err = newAllowListFromConfig(c, "allowlist", nil)
+	assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"0.0.0.0/0":     true,
+		"10.0.0.0/8":    false,
+		"10.42.42.0/24": true,
+	}
+	r, err = newAllowListFromConfig(c, "allowlist", nil)
+	if assert.NoError(t, err) {
+		assert.NotNil(t, r)
+	}
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"0.0.0.0/0":      true,
+		"10.0.0.0/8":     false,
+		"10.42.42.0/24":  true,
+		"::/0":           false,
+		"fd00::/8":       true,
+		"fd00:fd00::/16": false,
+	}
+	r, err = newAllowListFromConfig(c, "allowlist", nil)
+	if assert.NoError(t, err) {
+		assert.NotNil(t, r)
+	}
+
+	// Test interface names
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"interfaces": map[interface{}]interface{}{
+			`docker.*`: "foo",
+		},
+	}
+	lr, err := NewLocalAllowListFromConfig(c, "allowlist")
+	assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"interfaces": map[interface{}]interface{}{
+			`docker.*`: false,
+			`eth.*`:    true,
+		},
+	}
+	lr, err = NewLocalAllowListFromConfig(c, "allowlist")
+	assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"interfaces": map[interface{}]interface{}{
+			`docker.*`: false,
+		},
+	}
+	lr, err = NewLocalAllowListFromConfig(c, "allowlist")
+	if assert.NoError(t, err) {
+		assert.NotNil(t, lr)
+	}
+}
+
 func TestAllowList_Allow(t *testing.T) {
 func TestAllowList_Allow(t *testing.T) {
 	assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
 	assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
 
 
-	tree := NewCIDR6Tree()
-	tree.AddCIDR(getCIDR("0.0.0.0/0"), true)
-	tree.AddCIDR(getCIDR("10.0.0.0/8"), false)
-	tree.AddCIDR(getCIDR("10.42.42.42/32"), true)
-	tree.AddCIDR(getCIDR("10.42.0.0/16"), true)
-	tree.AddCIDR(getCIDR("10.42.42.0/24"), true)
-	tree.AddCIDR(getCIDR("10.42.42.0/24"), false)
-	tree.AddCIDR(getCIDR("::1/128"), true)
-	tree.AddCIDR(getCIDR("::2/128"), false)
+	tree := cidr.NewTree6()
+	tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
+	tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
+	tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
+	tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true)
+	tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true)
+	tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false)
+	tree.AddCIDR(cidr.Parse("::1/128"), true)
+	tree.AddCIDR(cidr.Parse("::2/128"), false)
 	al := &AllowList{cidrTree: tree}
 	al := &AllowList{cidrTree: tree}
 
 
 	assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))
 	assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))

+ 5 - 4
bits_test.go

@@ -3,11 +3,12 @@ package nebula
 import (
 import (
 	"testing"
 	"testing"
 
 
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
 func TestBits(t *testing.T) {
 func TestBits(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	b := NewBits(10)
 	b := NewBits(10)
 
 
 	// make sure it is the right size
 	// make sure it is the right size
@@ -75,7 +76,7 @@ func TestBits(t *testing.T) {
 }
 }
 
 
 func TestBitsDupeCounter(t *testing.T) {
 func TestBitsDupeCounter(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	b := NewBits(10)
 	b := NewBits(10)
 	b.lostCounter.Clear()
 	b.lostCounter.Clear()
 	b.dupeCounter.Clear()
 	b.dupeCounter.Clear()
@@ -100,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) {
 }
 }
 
 
 func TestBitsOutOfWindowCounter(t *testing.T) {
 func TestBitsOutOfWindowCounter(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	b := NewBits(10)
 	b := NewBits(10)
 	b.lostCounter.Clear()
 	b.lostCounter.Clear()
 	b.dupeCounter.Clear()
 	b.dupeCounter.Clear()
@@ -130,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
 }
 }
 
 
 func TestBitsLostCounter(t *testing.T) {
 func TestBitsLostCounter(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	b := NewBits(10)
 	b := NewBits(10)
 	b.lostCounter.Clear()
 	b.lostCounter.Clear()
 	b.dupeCounter.Clear()
 	b.dupeCounter.Clear()

+ 3 - 2
cert.go

@@ -9,6 +9,7 @@ import (
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
 )
 )
 
 
 type CertState struct {
 type CertState struct {
@@ -45,7 +46,7 @@ func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*Cert
 	return cs, nil
 	return cs, nil
 }
 }
 
 
-func NewCertStateFromConfig(c *Config) (*CertState, error) {
+func NewCertStateFromConfig(c *config.C) (*CertState, error) {
 	var pemPrivateKey []byte
 	var pemPrivateKey []byte
 	var err error
 	var err error
 
 
@@ -118,7 +119,7 @@ func NewCertStateFromConfig(c *Config) (*CertState, error) {
 	return NewCertState(nebulaCert, rawKey)
 	return NewCertState(nebulaCert, rawKey)
 }
 }
 
 
-func loadCAFromConfig(l *logrus.Logger, c *Config) (*cert.NebulaCAPool, error) {
+func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
 	var rawCA []byte
 	var rawCA []byte
 	var err error
 	var err error
 
 

+ 10 - 0
cidr/parse.go

@@ -0,0 +1,10 @@
+package cidr
+
+import "net"
+
+// Parse is a convenience function that returns only the IPNet
+// This function ignores errors since it is primarily a test helper, the result could be nil
+func Parse(s string) *net.IPNet {
+	_, c, _ := net.ParseCIDR(s)
+	return c
+}

+ 20 - 44
cidr_radix.go → cidr/tree4.go

@@ -1,39 +1,39 @@
-package nebula
+package cidr
 
 
 import (
 import (
-	"encoding/binary"
-	"fmt"
 	"net"
 	"net"
+
+	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
-type CIDRNode struct {
-	left   *CIDRNode
-	right  *CIDRNode
-	parent *CIDRNode
+type Node struct {
+	left   *Node
+	right  *Node
+	parent *Node
 	value  interface{}
 	value  interface{}
 }
 }
 
 
-type CIDRTree struct {
-	root *CIDRNode
+type Tree4 struct {
+	root *Node
 }
 }
 
 
 const (
 const (
-	startbit = uint32(0x80000000)
+	startbit = iputil.VpnIp(0x80000000)
 )
 )
 
 
-func NewCIDRTree() *CIDRTree {
-	tree := new(CIDRTree)
-	tree.root = &CIDRNode{}
+func NewTree4() *Tree4 {
+	tree := new(Tree4)
+	tree.root = &Node{}
 	return tree
 	return tree
 }
 }
 
 
-func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
+func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
 	bit := startbit
 	bit := startbit
 	node := tree.root
 	node := tree.root
 	next := tree.root
 	next := tree.root
 
 
-	ip := ip2int(cidr.IP)
-	mask := ip2int(cidr.Mask)
+	ip := iputil.Ip2VpnIp(cidr.IP)
+	mask := iputil.Ip2VpnIp(cidr.Mask)
 
 
 	// Find our last ancestor in the tree
 	// Find our last ancestor in the tree
 	for bit&mask != 0 {
 	for bit&mask != 0 {
@@ -59,7 +59,7 @@ func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 
 	// Build up the rest of the tree we don't already have
 	// Build up the rest of the tree we don't already have
 	for bit&mask != 0 {
 	for bit&mask != 0 {
-		next = &CIDRNode{}
+		next = &Node{}
 		next.parent = node
 		next.parent = node
 
 
 		if ip&bit != 0 {
 		if ip&bit != 0 {
@@ -77,7 +77,7 @@ func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
 }
 }
 
 
 // Finds the first match, which may be the least specific
 // Finds the first match, which may be the least specific
-func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
+func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	bit := startbit
 	node := tree.root
 	node := tree.root
 
 
@@ -100,7 +100,7 @@ func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
 }
 }
 
 
 // Finds the most specific match
 // Finds the most specific match
-func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) {
+func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	bit := startbit
 	node := tree.root
 	node := tree.root
 
 
@@ -122,7 +122,7 @@ func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) {
 }
 }
 
 
 // Finds the most specific match
 // Finds the most specific match
-func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
+func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	bit := startbit
 	node := tree.root
 	node := tree.root
 	lastNode := node
 	lastNode := node
@@ -143,27 +143,3 @@ func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
 	}
 	}
 	return value
 	return value
 }
 }
-
-// A helper type to avoid converting to IP when logging
-type IntIp uint32
-
-func (ip IntIp) String() string {
-	return fmt.Sprintf("%v", int2ip(uint32(ip)))
-}
-
-func (ip IntIp) MarshalJSON() ([]byte, error) {
-	return []byte(fmt.Sprintf("\"%s\"", int2ip(uint32(ip)).String())), nil
-}
-
-func ip2int(ip []byte) uint32 {
-	if len(ip) == 16 {
-		return binary.BigEndian.Uint32(ip[12:16])
-	}
-	return binary.BigEndian.Uint32(ip)
-}
-
-func int2ip(nn uint32) net.IP {
-	ip := make(net.IP, 4)
-	binary.BigEndian.PutUint32(ip, nn)
-	return ip
-}

+ 153 - 0
cidr/tree4_test.go

@@ -0,0 +1,153 @@
+package cidr
+
+import (
+	"net"
+	"testing"
+
+	"github.com/slackhq/nebula/iputil"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestCIDRTree_Contains(t *testing.T) {
+	tree := NewTree4()
+	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
+	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
+	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
+	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
+	tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
+	tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
+	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
+
+	tests := []struct {
+		Result interface{}
+		IP     string
+	}{
+		{"1", "1.0.0.0"},
+		{"1", "1.255.255.255"},
+		{"2", "2.1.0.0"},
+		{"2", "2.1.255.255"},
+		{"3", "3.1.1.0"},
+		{"3", "3.1.1.255"},
+		{"4a", "4.1.1.255"},
+		{"4a", "4.1.1.1"},
+		{"5", "240.0.0.0"},
+		{"5", "255.255.255.255"},
+		{nil, "239.0.0.0"},
+		{nil, "4.1.2.2"},
+	}
+
+	for _, tt := range tests {
+		assert.Equal(t, tt.Result, tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+	}
+
+	tree = NewTree4()
+	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
+	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
+	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+}
+
+func TestCIDRTree_MostSpecificContains(t *testing.T) {
+	tree := NewTree4()
+	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
+	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
+	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
+	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
+	tree.AddCIDR(Parse("4.1.1.0/30"), "4b")
+	tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
+	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
+
+	tests := []struct {
+		Result interface{}
+		IP     string
+	}{
+		{"1", "1.0.0.0"},
+		{"1", "1.255.255.255"},
+		{"2", "2.1.0.0"},
+		{"2", "2.1.255.255"},
+		{"3", "3.1.1.0"},
+		{"3", "3.1.1.255"},
+		{"4a", "4.1.1.255"},
+		{"4b", "4.1.1.2"},
+		{"4c", "4.1.1.1"},
+		{"5", "240.0.0.0"},
+		{"5", "255.255.255.255"},
+		{nil, "239.0.0.0"},
+		{nil, "4.1.2.2"},
+	}
+
+	for _, tt := range tests {
+		assert.Equal(t, tt.Result, tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+	}
+
+	tree = NewTree4()
+	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
+	assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
+	assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+}
+
+func TestCIDRTree_Match(t *testing.T) {
+	tree := NewTree4()
+	tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
+	tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
+
+	tests := []struct {
+		Result interface{}
+		IP     string
+	}{
+		{"1a", "4.1.1.0"},
+		{"1b", "4.1.1.1"},
+	}
+
+	for _, tt := range tests {
+		assert.Equal(t, tt.Result, tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+	}
+
+	tree = NewTree4()
+	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
+	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
+	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+}
+
+func BenchmarkCIDRTree_Contains(b *testing.B) {
+	tree := NewTree4()
+	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
+	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
+	tree.AddCIDR(Parse("192.2.1.1/32"), "1")
+	tree.AddCIDR(Parse("172.2.1.1/32"), "1")
+
+	ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
+	b.Run("found", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			tree.Contains(ip)
+		}
+	})
+
+	ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
+	b.Run("not found", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			tree.Contains(ip)
+		}
+	})
+}
+
+func BenchmarkCIDRTree_Match(b *testing.B) {
+	tree := NewTree4()
+	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
+	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
+	tree.AddCIDR(Parse("192.2.1.1/32"), "1")
+	tree.AddCIDR(Parse("172.2.1.1/32"), "1")
+
+	ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
+	b.Run("found", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			tree.Match(ip)
+		}
+	})
+
+	ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
+	b.Run("not found", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			tree.Match(ip)
+		}
+	})
+}

+ 20 - 19
cidr6_radix.go → cidr/tree6.go

@@ -1,26 +1,27 @@
-package nebula
+package cidr
 
 
 import (
 import (
-	"encoding/binary"
 	"net"
 	"net"
+
+	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 const startbit6 = uint64(1 << 63)
 const startbit6 = uint64(1 << 63)
 
 
-type CIDR6Tree struct {
-	root4 *CIDRNode
-	root6 *CIDRNode
+type Tree6 struct {
+	root4 *Node
+	root6 *Node
 }
 }
 
 
-func NewCIDR6Tree() *CIDR6Tree {
-	tree := new(CIDR6Tree)
-	tree.root4 = &CIDRNode{}
-	tree.root6 = &CIDRNode{}
+func NewTree6() *Tree6 {
+	tree := new(Tree6)
+	tree.root4 = &Node{}
+	tree.root6 = &Node{}
 	return tree
 	return tree
 }
 }
 
 
-func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
-	var node, next *CIDRNode
+func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
+	var node, next *Node
 
 
 	cidrIP, ipv4 := isIPV4(cidr.IP)
 	cidrIP, ipv4 := isIPV4(cidr.IP)
 	if ipv4 {
 	if ipv4 {
@@ -33,8 +34,8 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
 	}
 	}
 
 
 	for i := 0; i < len(cidrIP); i += 4 {
 	for i := 0; i < len(cidrIP); i += 4 {
-		ip := binary.BigEndian.Uint32(cidrIP[i : i+4])
-		mask := binary.BigEndian.Uint32(cidr.Mask[i : i+4])
+		ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
+		mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
 		bit := startbit
 		bit := startbit
 
 
 		// Find our last ancestor in the tree
 		// Find our last ancestor in the tree
@@ -55,7 +56,7 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 
 		// Build up the rest of the tree we don't already have
 		// Build up the rest of the tree we don't already have
 		for bit&mask != 0 {
 		for bit&mask != 0 {
-			next = &CIDRNode{}
+			next = &Node{}
 			next.parent = node
 			next.parent = node
 
 
 			if ip&bit != 0 {
 			if ip&bit != 0 {
@@ -74,8 +75,8 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
 }
 }
 
 
 // Finds the most specific match
 // Finds the most specific match
-func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
-	var node *CIDRNode
+func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
+	var node *Node
 
 
 	wholeIP, ipv4 := isIPV4(ip)
 	wholeIP, ipv4 := isIPV4(ip)
 	if ipv4 {
 	if ipv4 {
@@ -85,7 +86,7 @@ func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
 	}
 	}
 
 
 	for i := 0; i < len(wholeIP); i += 4 {
 	for i := 0; i < len(wholeIP); i += 4 {
-		ip := ip2int(wholeIP[i : i+4])
+		ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
 		bit := startbit
 		bit := startbit
 
 
 		for node != nil {
 		for node != nil {
@@ -110,7 +111,7 @@ func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
 	return value
 	return value
 }
 }
 
 
-func (tree *CIDR6Tree) MostSpecificContainsIpV4(ip uint32) (value interface{}) {
+func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	bit := startbit
 	node := tree.root4
 	node := tree.root4
 
 
@@ -131,7 +132,7 @@ func (tree *CIDR6Tree) MostSpecificContainsIpV4(ip uint32) (value interface{}) {
 	return value
 	return value
 }
 }
 
 
-func (tree *CIDR6Tree) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
+func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
 	ip := hi
 	ip := hi
 	node := tree.root6
 	node := tree.root6
 
 

+ 25 - 21
cidr6_radix_test.go → cidr/tree6_test.go

@@ -1,6 +1,7 @@
-package nebula
+package cidr
 
 
 import (
 import (
+	"encoding/binary"
 	"net"
 	"net"
 	"testing"
 	"testing"
 
 
@@ -8,17 +9,17 @@ import (
 )
 )
 
 
 func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
 func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
-	tree := NewCIDR6Tree()
-	tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
-	tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
-	tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
-	tree.AddCIDR(getCIDR("4.1.1.1/24"), "4a")
-	tree.AddCIDR(getCIDR("4.1.1.1/30"), "4b")
-	tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c")
-	tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
-	tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a")
-	tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b")
-	tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c")
+	tree := NewTree6()
+	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
+	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
+	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
+	tree.AddCIDR(Parse("4.1.1.1/24"), "4a")
+	tree.AddCIDR(Parse("4.1.1.1/30"), "4b")
+	tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
+	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
+	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
+	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
+	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
 
 
 	tests := []struct {
 	tests := []struct {
 		Result interface{}
 		Result interface{}
@@ -46,9 +47,9 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
 		assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP)))
 		assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP)))
 	}
 	}
 
 
-	tree = NewCIDR6Tree()
-	tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
-	tree.AddCIDR(getCIDR("::/0"), "cool6")
+	tree = NewTree6()
+	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
+	tree.AddCIDR(Parse("::/0"), "cool6")
 	assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0")))
 	assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0")))
 	assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255")))
 	assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255")))
 	assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::")))
 	assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::")))
@@ -56,10 +57,10 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
 }
 }
 
 
 func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
 func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
-	tree := NewCIDR6Tree()
-	tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a")
-	tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b")
-	tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c")
+	tree := NewTree6()
+	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
+	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
+	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
 
 
 	tests := []struct {
 	tests := []struct {
 		Result interface{}
 		Result interface{}
@@ -71,7 +72,10 @@ func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
 	}
 	}
 
 
 	for _, tt := range tests {
 	for _, tt := range tests {
-		ip := NewIp6AndPort(net.ParseIP(tt.IP), 0)
-		assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(ip.Hi, ip.Lo))
+		ip := net.ParseIP(tt.IP)
+		hi := binary.BigEndian.Uint64(ip[:8])
+		lo := binary.BigEndian.Uint64(ip[8:])
+
+		assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(hi, lo))
 	}
 	}
 }
 }

+ 0 - 157
cidr_radix_test.go

@@ -1,157 +0,0 @@
-package nebula
-
-import (
-	"net"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-func TestCIDRTree_Contains(t *testing.T) {
-	tree := NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
-	tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
-	tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
-	tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a")
-	tree.AddCIDR(getCIDR("4.1.1.1/32"), "4b")
-	tree.AddCIDR(getCIDR("4.1.2.1/32"), "4c")
-	tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
-
-	tests := []struct {
-		Result interface{}
-		IP     string
-	}{
-		{"1", "1.0.0.0"},
-		{"1", "1.255.255.255"},
-		{"2", "2.1.0.0"},
-		{"2", "2.1.255.255"},
-		{"3", "3.1.1.0"},
-		{"3", "3.1.1.255"},
-		{"4a", "4.1.1.255"},
-		{"4a", "4.1.1.1"},
-		{"5", "240.0.0.0"},
-		{"5", "255.255.255.255"},
-		{nil, "239.0.0.0"},
-		{nil, "4.1.2.2"},
-	}
-
-	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.Contains(ip2int(net.ParseIP(tt.IP))))
-	}
-
-	tree = NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
-}
-
-func TestCIDRTree_MostSpecificContains(t *testing.T) {
-	tree := NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
-	tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
-	tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
-	tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a")
-	tree.AddCIDR(getCIDR("4.1.1.0/30"), "4b")
-	tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c")
-	tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
-
-	tests := []struct {
-		Result interface{}
-		IP     string
-	}{
-		{"1", "1.0.0.0"},
-		{"1", "1.255.255.255"},
-		{"2", "2.1.0.0"},
-		{"2", "2.1.255.255"},
-		{"3", "3.1.1.0"},
-		{"3", "3.1.1.255"},
-		{"4a", "4.1.1.255"},
-		{"4b", "4.1.1.2"},
-		{"4c", "4.1.1.1"},
-		{"5", "240.0.0.0"},
-		{"5", "255.255.255.255"},
-		{nil, "239.0.0.0"},
-		{nil, "4.1.2.2"},
-	}
-
-	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.MostSpecificContains(ip2int(net.ParseIP(tt.IP))))
-	}
-
-	tree = NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("255.255.255.255"))))
-}
-
-func TestCIDRTree_Match(t *testing.T) {
-	tree := NewCIDRTree()
-	tree.AddCIDR(getCIDR("4.1.1.0/32"), "1a")
-	tree.AddCIDR(getCIDR("4.1.1.1/32"), "1b")
-
-	tests := []struct {
-		Result interface{}
-		IP     string
-	}{
-		{"1a", "4.1.1.0"},
-		{"1b", "4.1.1.1"},
-	}
-
-	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.Match(ip2int(net.ParseIP(tt.IP))))
-	}
-
-	tree = NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
-}
-
-func BenchmarkCIDRTree_Contains(b *testing.B) {
-	tree := NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.1.0.0/16"), "1")
-	tree.AddCIDR(getCIDR("1.2.1.1/32"), "1")
-	tree.AddCIDR(getCIDR("192.2.1.1/32"), "1")
-	tree.AddCIDR(getCIDR("172.2.1.1/32"), "1")
-
-	ip := ip2int(net.ParseIP("1.2.1.1"))
-	b.Run("found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Contains(ip)
-		}
-	})
-
-	ip = ip2int(net.ParseIP("1.2.1.255"))
-	b.Run("not found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Contains(ip)
-		}
-	})
-}
-
-func BenchmarkCIDRTree_Match(b *testing.B) {
-	tree := NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.1.0.0/16"), "1")
-	tree.AddCIDR(getCIDR("1.2.1.1/32"), "1")
-	tree.AddCIDR(getCIDR("192.2.1.1/32"), "1")
-	tree.AddCIDR(getCIDR("172.2.1.1/32"), "1")
-
-	ip := ip2int(net.ParseIP("1.2.1.1"))
-	b.Run("found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Match(ip)
-		}
-	})
-
-	ip = ip2int(net.ParseIP("1.2.1.255"))
-	b.Run("not found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Match(ip)
-		}
-	})
-}
-
-func getCIDR(s string) *net.IPNet {
-	_, c, _ := net.ParseCIDR(s)
-	return c
-}

+ 6 - 5
cmd/nebula-service/main.go

@@ -7,6 +7,7 @@ import (
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/config"
 )
 )
 
 
 // A version string that can be set with
 // A version string that can be set with
@@ -49,14 +50,14 @@ func main() {
 	l := logrus.New()
 	l := logrus.New()
 	l.Out = os.Stdout
 	l.Out = os.Stdout
 
 
-	config := nebula.NewConfig(l)
-	err := config.Load(*configPath)
+	c := config.NewC(l)
+	err := c.Load(*configPath)
 	if err != nil {
 	if err != nil {
 		fmt.Printf("failed to load config: %s", err)
 		fmt.Printf("failed to load config: %s", err)
 		os.Exit(1)
 		os.Exit(1)
 	}
 	}
 
 
-	c, err := nebula.Main(config, *configTest, Build, l, nil)
+	ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
 
 
 	switch v := err.(type) {
 	switch v := err.(type) {
 	case nebula.ContextualError:
 	case nebula.ContextualError:
@@ -68,8 +69,8 @@ func main() {
 	}
 	}
 
 
 	if !*configTest {
 	if !*configTest {
-		c.Start()
-		c.ShutdownBlock()
+		ctrl.Start()
+		ctrl.ShutdownBlock()
 	}
 	}
 
 
 	os.Exit(0)
 	os.Exit(0)

+ 4 - 3
cmd/nebula-service/service.go

@@ -9,6 +9,7 @@ import (
 	"github.com/kardianos/service"
 	"github.com/kardianos/service"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/config"
 )
 )
 
 
 var logger service.Logger
 var logger service.Logger
@@ -27,13 +28,13 @@ func (p *program) Start(s service.Service) error {
 	l := logrus.New()
 	l := logrus.New()
 	HookLogger(l)
 	HookLogger(l)
 
 
-	config := nebula.NewConfig(l)
-	err := config.Load(*p.configPath)
+	c := config.NewC(l)
+	err := c.Load(*p.configPath)
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("failed to load config: %s", err)
 		return fmt.Errorf("failed to load config: %s", err)
 	}
 	}
 
 
-	p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
+	p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 6 - 5
cmd/nebula/main.go

@@ -7,6 +7,7 @@ import (
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/config"
 )
 )
 
 
 // A version string that can be set with
 // A version string that can be set with
@@ -43,14 +44,14 @@ func main() {
 	l := logrus.New()
 	l := logrus.New()
 	l.Out = os.Stdout
 	l.Out = os.Stdout
 
 
-	config := nebula.NewConfig(l)
-	err := config.Load(*configPath)
+	c := config.NewC(l)
+	err := c.Load(*configPath)
 	if err != nil {
 	if err != nil {
 		fmt.Printf("failed to load config: %s", err)
 		fmt.Printf("failed to load config: %s", err)
 		os.Exit(1)
 		os.Exit(1)
 	}
 	}
 
 
-	c, err := nebula.Main(config, *configTest, Build, l, nil)
+	ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
 
 
 	switch v := err.(type) {
 	switch v := err.(type) {
 	case nebula.ContextualError:
 	case nebula.ContextualError:
@@ -62,8 +63,8 @@ func main() {
 	}
 	}
 
 
 	if !*configTest {
 	if !*configTest {
-		c.Start()
-		c.ShutdownBlock()
+		ctrl.Start()
+		ctrl.ShutdownBlock()
 	}
 	}
 
 
 	os.Exit(0)
 	os.Exit(0)

+ 0 - 611
config.go

@@ -1,611 +0,0 @@
-package nebula
-
-import (
-	"context"
-	"errors"
-	"fmt"
-	"io/ioutil"
-	"net"
-	"os"
-	"os/signal"
-	"path/filepath"
-	"regexp"
-	"sort"
-	"strconv"
-	"strings"
-	"syscall"
-	"time"
-
-	"github.com/imdario/mergo"
-	"github.com/sirupsen/logrus"
-	"gopkg.in/yaml.v2"
-)
-
-type Config struct {
-	path        string
-	files       []string
-	Settings    map[interface{}]interface{}
-	oldSettings map[interface{}]interface{}
-	callbacks   []func(*Config)
-	l           *logrus.Logger
-}
-
-func NewConfig(l *logrus.Logger) *Config {
-	return &Config{
-		Settings: make(map[interface{}]interface{}),
-		l:        l,
-	}
-}
-
-// Load will find all yaml files within path and load them in lexical order
-func (c *Config) Load(path string) error {
-	c.path = path
-	c.files = make([]string, 0)
-
-	err := c.resolve(path, true)
-	if err != nil {
-		return err
-	}
-
-	if len(c.files) == 0 {
-		return fmt.Errorf("no config files found at %s", path)
-	}
-
-	sort.Strings(c.files)
-
-	err = c.parse()
-	if err != nil {
-		return err
-	}
-
-	return nil
-}
-
-func (c *Config) LoadString(raw string) error {
-	if raw == "" {
-		return errors.New("Empty configuration")
-	}
-	return c.parseRaw([]byte(raw))
-}
-
-// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
-// here should decide if they need to make a change to the current process before making the change. HasChanged can be
-// used to help decide if a change is necessary.
-// These functions should return quickly or spawn their own go routine if they will take a while
-func (c *Config) RegisterReloadCallback(f func(*Config)) {
-	c.callbacks = append(c.callbacks, f)
-}
-
-// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
-// k in both the old and new settings will be serialized, the result of the string comparison is returned.
-// If k is an empty string the entire config is tested.
-// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
-// there is change when there actually wasn't any.
-func (c *Config) HasChanged(k string) bool {
-	if c.oldSettings == nil {
-		return false
-	}
-
-	var (
-		nv interface{}
-		ov interface{}
-	)
-
-	if k == "" {
-		nv = c.Settings
-		ov = c.oldSettings
-		k = "all settings"
-	} else {
-		nv = c.get(k, c.Settings)
-		ov = c.get(k, c.oldSettings)
-	}
-
-	newVals, err := yaml.Marshal(nv)
-	if err != nil {
-		c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
-	}
-
-	oldVals, err := yaml.Marshal(ov)
-	if err != nil {
-		c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
-	}
-
-	return string(newVals) != string(oldVals)
-}
-
-// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
-// original path provided to Load. The old settings are shallow copied for change detection after the reload.
-func (c *Config) CatchHUP(ctx context.Context) {
-	ch := make(chan os.Signal, 1)
-	signal.Notify(ch, syscall.SIGHUP)
-
-	go func() {
-		for {
-			select {
-			case <-ctx.Done():
-				signal.Stop(ch)
-				close(ch)
-				return
-			case <-ch:
-				c.l.Info("Caught HUP, reloading config")
-				c.ReloadConfig()
-			}
-		}
-	}()
-}
-
-func (c *Config) ReloadConfig() {
-	c.oldSettings = make(map[interface{}]interface{})
-	for k, v := range c.Settings {
-		c.oldSettings[k] = v
-	}
-
-	err := c.Load(c.path)
-	if err != nil {
-		c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
-		return
-	}
-
-	for _, v := range c.callbacks {
-		v(c)
-	}
-}
-
-// GetString will get the string for k or return the default d if not found or invalid
-func (c *Config) GetString(k, d string) string {
-	r := c.Get(k)
-	if r == nil {
-		return d
-	}
-
-	return fmt.Sprintf("%v", r)
-}
-
-// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
-func (c *Config) GetStringSlice(k string, d []string) []string {
-	r := c.Get(k)
-	if r == nil {
-		return d
-	}
-
-	rv, ok := r.([]interface{})
-	if !ok {
-		return d
-	}
-
-	v := make([]string, len(rv))
-	for i := 0; i < len(v); i++ {
-		v[i] = fmt.Sprintf("%v", rv[i])
-	}
-
-	return v
-}
-
-// GetMap will get the map for k or return the default d if not found or invalid
-func (c *Config) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
-	r := c.Get(k)
-	if r == nil {
-		return d
-	}
-
-	v, ok := r.(map[interface{}]interface{})
-	if !ok {
-		return d
-	}
-
-	return v
-}
-
-// GetInt will get the int for k or return the default d if not found or invalid
-func (c *Config) GetInt(k string, d int) int {
-	r := c.GetString(k, strconv.Itoa(d))
-	v, err := strconv.Atoi(r)
-	if err != nil {
-		return d
-	}
-
-	return v
-}
-
-// GetBool will get the bool for k or return the default d if not found or invalid
-func (c *Config) GetBool(k string, d bool) bool {
-	r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
-	v, err := strconv.ParseBool(r)
-	if err != nil {
-		switch r {
-		case "y", "yes":
-			return true
-		case "n", "no":
-			return false
-		}
-		return d
-	}
-
-	return v
-}
-
-// GetDuration will get the duration for k or return the default d if not found or invalid
-func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
-	r := c.GetString(k, "")
-	v, err := time.ParseDuration(r)
-	if err != nil {
-		return d
-	}
-	return v
-}
-
-func (c *Config) GetLocalAllowList(k string) (*LocalAllowList, error) {
-	var nameRules []AllowListNameRule
-	handleKey := func(key string, value interface{}) (bool, error) {
-		if key == "interfaces" {
-			var err error
-			nameRules, err = c.getAllowListInterfaces(k, value)
-			if err != nil {
-				return false, err
-			}
-
-			return true, nil
-		}
-		return false, nil
-	}
-
-	al, err := c.GetAllowList(k, handleKey)
-	if err != nil {
-		return nil, err
-	}
-	return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil
-}
-
-func (c *Config) GetRemoteAllowList(k, rangesKey string) (*RemoteAllowList, error) {
-	al, err := c.GetAllowList(k, nil)
-	if err != nil {
-		return nil, err
-	}
-	remoteAllowRanges, err := c.getRemoteAllowRanges(rangesKey)
-	if err != nil {
-		return nil, err
-	}
-	return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil
-}
-
-func (c *Config) getRemoteAllowRanges(k string) (*CIDR6Tree, error) {
-	value := c.Get(k)
-	if value == nil {
-		return nil, nil
-	}
-
-	remoteAllowRanges := NewCIDR6Tree()
-
-	rawMap, ok := value.(map[interface{}]interface{})
-	if !ok {
-		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
-	}
-	for rawKey, rawValue := range rawMap {
-		rawCIDR, ok := rawKey.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
-		}
-
-		allowList, err := c.getAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
-		if err != nil {
-			return nil, err
-		}
-
-		_, cidr, err := net.ParseCIDR(rawCIDR)
-		if err != nil {
-			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
-		}
-
-		remoteAllowRanges.AddCIDR(cidr, allowList)
-	}
-
-	return remoteAllowRanges, nil
-}
-
-// If the handleKey func returns true, the rest of the parsing is skipped
-// for this key. This allows parsing of special values like `interfaces`.
-func (c *Config) GetAllowList(k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
-	r := c.Get(k)
-	if r == nil {
-		return nil, nil
-	}
-
-	return c.getAllowList(k, r, handleKey)
-}
-
-// If the handleKey func returns true, the rest of the parsing is skipped
-// for this key. This allows parsing of special values like `interfaces`.
-func (c *Config) getAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
-	rawMap, ok := raw.(map[interface{}]interface{})
-	if !ok {
-		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
-	}
-
-	tree := NewCIDR6Tree()
-
-	// Keep track of the rules we have added for both ipv4 and ipv6
-	type allowListRules struct {
-		firstValue     bool
-		allValuesMatch bool
-		defaultSet     bool
-		allValues      bool
-	}
-	rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
-	rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
-
-	for rawKey, rawValue := range rawMap {
-		rawCIDR, ok := rawKey.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
-		}
-
-		if handleKey != nil {
-			handled, err := handleKey(rawCIDR, rawValue)
-			if err != nil {
-				return nil, err
-			}
-			if handled {
-				continue
-			}
-		}
-
-		value, ok := rawValue.(bool)
-		if !ok {
-			return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
-		}
-
-		_, cidr, err := net.ParseCIDR(rawCIDR)
-		if err != nil {
-			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
-		}
-
-		// TODO: should we error on duplicate CIDRs in the config?
-		tree.AddCIDR(cidr, value)
-
-		maskBits, maskSize := cidr.Mask.Size()
-
-		var rules *allowListRules
-		if maskSize == 32 {
-			rules = &rules4
-		} else {
-			rules = &rules6
-		}
-
-		if rules.firstValue {
-			rules.allValues = value
-			rules.firstValue = false
-		} else {
-			if value != rules.allValues {
-				rules.allValuesMatch = false
-			}
-		}
-
-		// Check if this is 0.0.0.0/0 or ::/0
-		if maskBits == 0 {
-			rules.defaultSet = true
-		}
-	}
-
-	if !rules4.defaultSet {
-		if rules4.allValuesMatch {
-			_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
-			tree.AddCIDR(zeroCIDR, !rules4.allValues)
-		} else {
-			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
-		}
-	}
-
-	if !rules6.defaultSet {
-		if rules6.allValuesMatch {
-			_, zeroCIDR, _ := net.ParseCIDR("::/0")
-			tree.AddCIDR(zeroCIDR, !rules6.allValues)
-		} else {
-			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
-		}
-	}
-
-	return &AllowList{cidrTree: tree}, nil
-}
-
-func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
-	var nameRules []AllowListNameRule
-
-	rawRules, ok := v.(map[interface{}]interface{})
-	if !ok {
-		return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
-	}
-
-	firstEntry := true
-	var allValues bool
-	for rawName, rawAllow := range rawRules {
-		name, ok := rawName.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
-		}
-		allow, ok := rawAllow.(bool)
-		if !ok {
-			return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
-		}
-
-		nameRE, err := regexp.Compile("^" + name + "$")
-		if err != nil {
-			return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
-		}
-
-		nameRules = append(nameRules, AllowListNameRule{
-			Name:  nameRE,
-			Allow: allow,
-		})
-
-		if firstEntry {
-			allValues = allow
-			firstEntry = false
-		} else {
-			if allow != allValues {
-				return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
-			}
-		}
-	}
-
-	return nameRules, nil
-}
-
-func (c *Config) Get(k string) interface{} {
-	return c.get(k, c.Settings)
-}
-
-func (c *Config) IsSet(k string) bool {
-	return c.get(k, c.Settings) != nil
-}
-
-func (c *Config) get(k string, v interface{}) interface{} {
-	parts := strings.Split(k, ".")
-	for _, p := range parts {
-		m, ok := v.(map[interface{}]interface{})
-		if !ok {
-			return nil
-		}
-
-		v, ok = m[p]
-		if !ok {
-			return nil
-		}
-	}
-
-	return v
-}
-
-// direct signifies if this is the config path directly specified by the user,
-// versus a file/dir found by recursing into that path
-func (c *Config) resolve(path string, direct bool) error {
-	i, err := os.Stat(path)
-	if err != nil {
-		return nil
-	}
-
-	if !i.IsDir() {
-		c.addFile(path, direct)
-		return nil
-	}
-
-	paths, err := readDirNames(path)
-	if err != nil {
-		return fmt.Errorf("problem while reading directory %s: %s", path, err)
-	}
-
-	for _, p := range paths {
-		err := c.resolve(filepath.Join(path, p), false)
-		if err != nil {
-			return err
-		}
-	}
-
-	return nil
-}
-
-func (c *Config) addFile(path string, direct bool) error {
-	ext := filepath.Ext(path)
-
-	if !direct && ext != ".yaml" && ext != ".yml" {
-		return nil
-	}
-
-	ap, err := filepath.Abs(path)
-	if err != nil {
-		return err
-	}
-
-	c.files = append(c.files, ap)
-	return nil
-}
-
-func (c *Config) parseRaw(b []byte) error {
-	var m map[interface{}]interface{}
-
-	err := yaml.Unmarshal(b, &m)
-	if err != nil {
-		return err
-	}
-
-	c.Settings = m
-	return nil
-}
-
-func (c *Config) parse() error {
-	var m map[interface{}]interface{}
-
-	for _, path := range c.files {
-		b, err := ioutil.ReadFile(path)
-		if err != nil {
-			return err
-		}
-
-		var nm map[interface{}]interface{}
-		err = yaml.Unmarshal(b, &nm)
-		if err != nil {
-			return err
-		}
-
-		// We need to use WithAppendSlice so that firewall rules in separate
-		// files are appended together
-		err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
-		m = nm
-		if err != nil {
-			return err
-		}
-	}
-
-	c.Settings = m
-	return nil
-}
-
-func readDirNames(path string) ([]string, error) {
-	f, err := os.Open(path)
-	if err != nil {
-		return nil, err
-	}
-
-	paths, err := f.Readdirnames(-1)
-	f.Close()
-	if err != nil {
-		return nil, err
-	}
-
-	sort.Strings(paths)
-	return paths, nil
-}
-
-func configLogger(c *Config) error {
-	// set up our logging level
-	logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
-	if err != nil {
-		return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
-	}
-	c.l.SetLevel(logLevel)
-
-	disableTimestamp := c.GetBool("logging.disable_timestamp", false)
-	timestampFormat := c.GetString("logging.timestamp_format", "")
-	fullTimestamp := (timestampFormat != "")
-	if timestampFormat == "" {
-		timestampFormat = time.RFC3339
-	}
-
-	logFormat := strings.ToLower(c.GetString("logging.format", "text"))
-	switch logFormat {
-	case "text":
-		c.l.Formatter = &logrus.TextFormatter{
-			TimestampFormat:  timestampFormat,
-			FullTimestamp:    fullTimestamp,
-			DisableTimestamp: disableTimestamp,
-		}
-	case "json":
-		c.l.Formatter = &logrus.JSONFormatter{
-			TimestampFormat:  timestampFormat,
-			DisableTimestamp: disableTimestamp,
-		}
-	default:
-		return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
-	}
-
-	return nil
-}

+ 358 - 0
config/config.go

@@ -0,0 +1,358 @@
+package config
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"io/ioutil"
+	"os"
+	"os/signal"
+	"path/filepath"
+	"sort"
+	"strconv"
+	"strings"
+	"syscall"
+	"time"
+
+	"github.com/imdario/mergo"
+	"github.com/sirupsen/logrus"
+	"gopkg.in/yaml.v2"
+)
+
+type C struct {
+	path        string
+	files       []string
+	Settings    map[interface{}]interface{}
+	oldSettings map[interface{}]interface{}
+	callbacks   []func(*C)
+	l           *logrus.Logger
+}
+
+func NewC(l *logrus.Logger) *C {
+	return &C{
+		Settings: make(map[interface{}]interface{}),
+		l:        l,
+	}
+}
+
+// Load will find all yaml files within path and load them in lexical order
+func (c *C) Load(path string) error {
+	c.path = path
+	c.files = make([]string, 0)
+
+	err := c.resolve(path, true)
+	if err != nil {
+		return err
+	}
+
+	if len(c.files) == 0 {
+		return fmt.Errorf("no config files found at %s", path)
+	}
+
+	sort.Strings(c.files)
+
+	err = c.parse()
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (c *C) LoadString(raw string) error {
+	if raw == "" {
+		return errors.New("Empty configuration")
+	}
+	return c.parseRaw([]byte(raw))
+}
+
+// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
+// here should decide if they need to make a change to the current process before making the change. HasChanged can be
+// used to help decide if a change is necessary.
+// These functions should return quickly or spawn their own go routine if they will take a while
+func (c *C) RegisterReloadCallback(f func(*C)) {
+	c.callbacks = append(c.callbacks, f)
+}
+
+// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
+// k in both the old and new settings will be serialized, the result of the string comparison is returned.
+// If k is an empty string the entire config is tested.
+// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
+// there is change when there actually wasn't any.
+func (c *C) HasChanged(k string) bool {
+	if c.oldSettings == nil {
+		return false
+	}
+
+	var (
+		nv interface{}
+		ov interface{}
+	)
+
+	if k == "" {
+		nv = c.Settings
+		ov = c.oldSettings
+		k = "all settings"
+	} else {
+		nv = c.get(k, c.Settings)
+		ov = c.get(k, c.oldSettings)
+	}
+
+	newVals, err := yaml.Marshal(nv)
+	if err != nil {
+		c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
+	}
+
+	oldVals, err := yaml.Marshal(ov)
+	if err != nil {
+		c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
+	}
+
+	return string(newVals) != string(oldVals)
+}
+
+// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
+// original path provided to Load. The old settings are shallow copied for change detection after the reload.
+func (c *C) CatchHUP(ctx context.Context) {
+	ch := make(chan os.Signal, 1)
+	signal.Notify(ch, syscall.SIGHUP)
+
+	go func() {
+		for {
+			select {
+			case <-ctx.Done():
+				signal.Stop(ch)
+				close(ch)
+				return
+			case <-ch:
+				c.l.Info("Caught HUP, reloading config")
+				c.ReloadConfig()
+			}
+		}
+	}()
+}
+
+func (c *C) ReloadConfig() {
+	c.oldSettings = make(map[interface{}]interface{})
+	for k, v := range c.Settings {
+		c.oldSettings[k] = v
+	}
+
+	err := c.Load(c.path)
+	if err != nil {
+		c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
+		return
+	}
+
+	for _, v := range c.callbacks {
+		v(c)
+	}
+}
+
+// GetString will get the string for k or return the default d if not found or invalid
+func (c *C) GetString(k, d string) string {
+	r := c.Get(k)
+	if r == nil {
+		return d
+	}
+
+	return fmt.Sprintf("%v", r)
+}
+
+// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
+func (c *C) GetStringSlice(k string, d []string) []string {
+	r := c.Get(k)
+	if r == nil {
+		return d
+	}
+
+	rv, ok := r.([]interface{})
+	if !ok {
+		return d
+	}
+
+	v := make([]string, len(rv))
+	for i := 0; i < len(v); i++ {
+		v[i] = fmt.Sprintf("%v", rv[i])
+	}
+
+	return v
+}
+
+// GetMap will get the map for k or return the default d if not found or invalid
+func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
+	r := c.Get(k)
+	if r == nil {
+		return d
+	}
+
+	v, ok := r.(map[interface{}]interface{})
+	if !ok {
+		return d
+	}
+
+	return v
+}
+
+// GetInt will get the int for k or return the default d if not found or invalid
+func (c *C) GetInt(k string, d int) int {
+	r := c.GetString(k, strconv.Itoa(d))
+	v, err := strconv.Atoi(r)
+	if err != nil {
+		return d
+	}
+
+	return v
+}
+
+// GetBool will get the bool for k or return the default d if not found or invalid
+func (c *C) GetBool(k string, d bool) bool {
+	r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
+	v, err := strconv.ParseBool(r)
+	if err != nil {
+		switch r {
+		case "y", "yes":
+			return true
+		case "n", "no":
+			return false
+		}
+		return d
+	}
+
+	return v
+}
+
+// GetDuration will get the duration for k or return the default d if not found or invalid
+func (c *C) GetDuration(k string, d time.Duration) time.Duration {
+	r := c.GetString(k, "")
+	v, err := time.ParseDuration(r)
+	if err != nil {
+		return d
+	}
+	return v
+}
+
+func (c *C) Get(k string) interface{} {
+	return c.get(k, c.Settings)
+}
+
+func (c *C) IsSet(k string) bool {
+	return c.get(k, c.Settings) != nil
+}
+
+func (c *C) get(k string, v interface{}) interface{} {
+	parts := strings.Split(k, ".")
+	for _, p := range parts {
+		m, ok := v.(map[interface{}]interface{})
+		if !ok {
+			return nil
+		}
+
+		v, ok = m[p]
+		if !ok {
+			return nil
+		}
+	}
+
+	return v
+}
+
+// direct signifies if this is the config path directly specified by the user,
+// versus a file/dir found by recursing into that path
+func (c *C) resolve(path string, direct bool) error {
+	i, err := os.Stat(path)
+	if err != nil {
+		return nil
+	}
+
+	if !i.IsDir() {
+		c.addFile(path, direct)
+		return nil
+	}
+
+	paths, err := readDirNames(path)
+	if err != nil {
+		return fmt.Errorf("problem while reading directory %s: %s", path, err)
+	}
+
+	for _, p := range paths {
+		err := c.resolve(filepath.Join(path, p), false)
+		if err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (c *C) addFile(path string, direct bool) error {
+	ext := filepath.Ext(path)
+
+	if !direct && ext != ".yaml" && ext != ".yml" {
+		return nil
+	}
+
+	ap, err := filepath.Abs(path)
+	if err != nil {
+		return err
+	}
+
+	c.files = append(c.files, ap)
+	return nil
+}
+
+func (c *C) parseRaw(b []byte) error {
+	var m map[interface{}]interface{}
+
+	err := yaml.Unmarshal(b, &m)
+	if err != nil {
+		return err
+	}
+
+	c.Settings = m
+	return nil
+}
+
+func (c *C) parse() error {
+	var m map[interface{}]interface{}
+
+	for _, path := range c.files {
+		b, err := ioutil.ReadFile(path)
+		if err != nil {
+			return err
+		}
+
+		var nm map[interface{}]interface{}
+		err = yaml.Unmarshal(b, &nm)
+		if err != nil {
+			return err
+		}
+
+		// We need to use WithAppendSlice so that firewall rules in separate
+		// files are appended together
+		err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
+		m = nm
+		if err != nil {
+			return err
+		}
+	}
+
+	c.Settings = m
+	return nil
+}
+
+func readDirNames(path string) ([]string, error) {
+	f, err := os.Open(path)
+	if err != nil {
+		return nil, err
+	}
+
+	paths, err := f.Readdirnames(-1)
+	f.Close()
+	if err != nil {
+		return nil, err
+	}
+
+	sort.Strings(paths)
+	return paths, nil
+}

+ 18 - 103
config_test.go → config/config_test.go

@@ -1,4 +1,4 @@
-package nebula
+package config
 
 
 import (
 import (
 	"io/ioutil"
 	"io/ioutil"
@@ -7,19 +7,20 @@ import (
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
 func TestConfig_Load(t *testing.T) {
 func TestConfig_Load(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	dir, err := ioutil.TempDir("", "config-test")
 	dir, err := ioutil.TempDir("", "config-test")
 	// invalid yaml
 	// invalid yaml
-	c := NewConfig(l)
+	c := NewC(l)
 	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
 	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
 	assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n  line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
 	assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n  line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
 
 
 	// simple multi config merge
 	// simple multi config merge
-	c = NewConfig(l)
+	c = NewC(l)
 	os.RemoveAll(dir)
 	os.RemoveAll(dir)
 	os.Mkdir(dir, 0755)
 	os.Mkdir(dir, 0755)
 
 
@@ -41,9 +42,9 @@ func TestConfig_Load(t *testing.T) {
 }
 }
 
 
 func TestConfig_Get(t *testing.T) {
 func TestConfig_Get(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	// test simple type
 	// test simple type
-	c := NewConfig(l)
+	c := NewC(l)
 	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
 	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
 	assert.Equal(t, "hi", c.Get("firewall.outbound"))
 	assert.Equal(t, "hi", c.Get("firewall.outbound"))
 
 
@@ -57,15 +58,15 @@ func TestConfig_Get(t *testing.T) {
 }
 }
 
 
 func TestConfig_GetStringSlice(t *testing.T) {
 func TestConfig_GetStringSlice(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
+	l := util.NewTestLogger()
+	c := NewC(l)
 	c.Settings["slice"] = []interface{}{"one", "two"}
 	c.Settings["slice"] = []interface{}{"one", "two"}
 	assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
 	assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
 }
 }
 
 
 func TestConfig_GetBool(t *testing.T) {
 func TestConfig_GetBool(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
+	l := util.NewTestLogger()
+	c := NewC(l)
 	c.Settings["bool"] = true
 	c.Settings["bool"] = true
 	assert.Equal(t, true, c.GetBool("bool", false))
 	assert.Equal(t, true, c.GetBool("bool", false))
 
 
@@ -91,108 +92,22 @@ func TestConfig_GetBool(t *testing.T) {
 	assert.Equal(t, false, c.GetBool("bool", true))
 	assert.Equal(t, false, c.GetBool("bool", true))
 }
 }
 
 
-func TestConfig_GetAllowList(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"192.168.0.0": true,
-	}
-	r, err := c.GetAllowList("allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
-	assert.Nil(t, r)
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"192.168.0.0/16": "abc",
-	}
-	r, err = c.GetAllowList("allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"192.168.0.0/16": true,
-		"10.0.0.0/8":     false,
-	}
-	r, err = c.GetAllowList("allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"0.0.0.0/0":      true,
-		"10.0.0.0/8":     false,
-		"10.42.42.0/24":  true,
-		"fd00::/8":       true,
-		"fd00:fd00::/16": false,
-	}
-	r, err = c.GetAllowList("allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"0.0.0.0/0":     true,
-		"10.0.0.0/8":    false,
-		"10.42.42.0/24": true,
-	}
-	r, err = c.GetAllowList("allowlist", nil)
-	if assert.NoError(t, err) {
-		assert.NotNil(t, r)
-	}
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"0.0.0.0/0":      true,
-		"10.0.0.0/8":     false,
-		"10.42.42.0/24":  true,
-		"::/0":           false,
-		"fd00::/8":       true,
-		"fd00:fd00::/16": false,
-	}
-	r, err = c.GetAllowList("allowlist", nil)
-	if assert.NoError(t, err) {
-		assert.NotNil(t, r)
-	}
-
-	// Test interface names
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"interfaces": map[interface{}]interface{}{
-			`docker.*`: "foo",
-		},
-	}
-	lr, err := c.GetLocalAllowList("allowlist")
-	assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"interfaces": map[interface{}]interface{}{
-			`docker.*`: false,
-			`eth.*`:    true,
-		},
-	}
-	lr, err = c.GetLocalAllowList("allowlist")
-	assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"interfaces": map[interface{}]interface{}{
-			`docker.*`: false,
-		},
-	}
-	lr, err = c.GetLocalAllowList("allowlist")
-	if assert.NoError(t, err) {
-		assert.NotNil(t, lr)
-	}
-}
-
 func TestConfig_HasChanged(t *testing.T) {
 func TestConfig_HasChanged(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	// No reload has occurred, return false
 	// No reload has occurred, return false
-	c := NewConfig(l)
+	c := NewC(l)
 	c.Settings["test"] = "hi"
 	c.Settings["test"] = "hi"
 	assert.False(t, c.HasChanged(""))
 	assert.False(t, c.HasChanged(""))
 
 
 	// Test key change
 	// Test key change
-	c = NewConfig(l)
+	c = NewC(l)
 	c.Settings["test"] = "hi"
 	c.Settings["test"] = "hi"
 	c.oldSettings = map[interface{}]interface{}{"test": "no"}
 	c.oldSettings = map[interface{}]interface{}{"test": "no"}
 	assert.True(t, c.HasChanged("test"))
 	assert.True(t, c.HasChanged("test"))
 	assert.True(t, c.HasChanged(""))
 	assert.True(t, c.HasChanged(""))
 
 
 	// No key change
 	// No key change
-	c = NewConfig(l)
+	c = NewC(l)
 	c.Settings["test"] = "hi"
 	c.Settings["test"] = "hi"
 	c.oldSettings = map[interface{}]interface{}{"test": "hi"}
 	c.oldSettings = map[interface{}]interface{}{"test": "hi"}
 	assert.False(t, c.HasChanged("test"))
 	assert.False(t, c.HasChanged("test"))
@@ -200,13 +115,13 @@ func TestConfig_HasChanged(t *testing.T) {
 }
 }
 
 
 func TestConfig_ReloadConfig(t *testing.T) {
 func TestConfig_ReloadConfig(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	done := make(chan bool, 1)
 	done := make(chan bool, 1)
 	dir, err := ioutil.TempDir("", "config-test")
 	dir, err := ioutil.TempDir("", "config-test")
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
 	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
 
 
-	c := NewConfig(l)
+	c := NewC(l)
 	assert.Nil(t, c.Load(dir))
 	assert.Nil(t, c.Load(dir))
 
 
 	assert.False(t, c.HasChanged("outer.inner"))
 	assert.False(t, c.HasChanged("outer.inner"))
@@ -215,7 +130,7 @@ func TestConfig_ReloadConfig(t *testing.T) {
 
 
 	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: ho"), 0644)
 	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: ho"), 0644)
 
 
-	c.RegisterReloadCallback(func(c *Config) {
+	c.RegisterReloadCallback(func(c *C) {
 		done <- true
 		done <- true
 	})
 	})
 
 

+ 51 - 49
connection_manager.go

@@ -6,6 +6,8 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 // TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
 // TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
@@ -13,16 +15,16 @@ import (
 
 
 type connectionManager struct {
 type connectionManager struct {
 	hostMap      *HostMap
 	hostMap      *HostMap
-	in           map[uint32]struct{}
+	in           map[iputil.VpnIp]struct{}
 	inLock       *sync.RWMutex
 	inLock       *sync.RWMutex
 	inCount      int
 	inCount      int
-	out          map[uint32]struct{}
+	out          map[iputil.VpnIp]struct{}
 	outLock      *sync.RWMutex
 	outLock      *sync.RWMutex
 	outCount     int
 	outCount     int
 	TrafficTimer *SystemTimerWheel
 	TrafficTimer *SystemTimerWheel
 	intf         *Interface
 	intf         *Interface
 
 
-	pendingDeletion      map[uint32]int
+	pendingDeletion      map[iputil.VpnIp]int
 	pendingDeletionLock  *sync.RWMutex
 	pendingDeletionLock  *sync.RWMutex
 	pendingDeletionTimer *SystemTimerWheel
 	pendingDeletionTimer *SystemTimerWheel
 
 
@@ -36,15 +38,15 @@ type connectionManager struct {
 func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
 func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
 	nc := &connectionManager{
 	nc := &connectionManager{
 		hostMap:                 intf.hostMap,
 		hostMap:                 intf.hostMap,
-		in:                      make(map[uint32]struct{}),
+		in:                      make(map[iputil.VpnIp]struct{}),
 		inLock:                  &sync.RWMutex{},
 		inLock:                  &sync.RWMutex{},
 		inCount:                 0,
 		inCount:                 0,
-		out:                     make(map[uint32]struct{}),
+		out:                     make(map[iputil.VpnIp]struct{}),
 		outLock:                 &sync.RWMutex{},
 		outLock:                 &sync.RWMutex{},
 		outCount:                0,
 		outCount:                0,
 		TrafficTimer:            NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
 		TrafficTimer:            NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
 		intf:                    intf,
 		intf:                    intf,
-		pendingDeletion:         make(map[uint32]int),
+		pendingDeletion:         make(map[iputil.VpnIp]int),
 		pendingDeletionLock:     &sync.RWMutex{},
 		pendingDeletionLock:     &sync.RWMutex{},
 		pendingDeletionTimer:    NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
 		pendingDeletionTimer:    NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
 		checkInterval:           checkInterval,
 		checkInterval:           checkInterval,
@@ -55,7 +57,7 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
 	return nc
 	return nc
 }
 }
 
 
-func (n *connectionManager) In(ip uint32) {
+func (n *connectionManager) In(ip iputil.VpnIp) {
 	n.inLock.RLock()
 	n.inLock.RLock()
 	// If this already exists, return
 	// If this already exists, return
 	if _, ok := n.in[ip]; ok {
 	if _, ok := n.in[ip]; ok {
@@ -68,7 +70,7 @@ func (n *connectionManager) In(ip uint32) {
 	n.inLock.Unlock()
 	n.inLock.Unlock()
 }
 }
 
 
-func (n *connectionManager) Out(ip uint32) {
+func (n *connectionManager) Out(ip iputil.VpnIp) {
 	n.outLock.RLock()
 	n.outLock.RLock()
 	// If this already exists, return
 	// If this already exists, return
 	if _, ok := n.out[ip]; ok {
 	if _, ok := n.out[ip]; ok {
@@ -87,9 +89,9 @@ func (n *connectionManager) Out(ip uint32) {
 	n.outLock.Unlock()
 	n.outLock.Unlock()
 }
 }
 
 
-func (n *connectionManager) CheckIn(vpnIP uint32) bool {
+func (n *connectionManager) CheckIn(vpnIp iputil.VpnIp) bool {
 	n.inLock.RLock()
 	n.inLock.RLock()
-	if _, ok := n.in[vpnIP]; ok {
+	if _, ok := n.in[vpnIp]; ok {
 		n.inLock.RUnlock()
 		n.inLock.RUnlock()
 		return true
 		return true
 	}
 	}
@@ -97,7 +99,7 @@ func (n *connectionManager) CheckIn(vpnIP uint32) bool {
 	return false
 	return false
 }
 }
 
 
-func (n *connectionManager) ClearIP(ip uint32) {
+func (n *connectionManager) ClearIP(ip iputil.VpnIp) {
 	n.inLock.Lock()
 	n.inLock.Lock()
 	n.outLock.Lock()
 	n.outLock.Lock()
 	delete(n.in, ip)
 	delete(n.in, ip)
@@ -106,13 +108,13 @@ func (n *connectionManager) ClearIP(ip uint32) {
 	n.outLock.Unlock()
 	n.outLock.Unlock()
 }
 }
 
 
-func (n *connectionManager) ClearPendingDeletion(ip uint32) {
+func (n *connectionManager) ClearPendingDeletion(ip iputil.VpnIp) {
 	n.pendingDeletionLock.Lock()
 	n.pendingDeletionLock.Lock()
 	delete(n.pendingDeletion, ip)
 	delete(n.pendingDeletion, ip)
 	n.pendingDeletionLock.Unlock()
 	n.pendingDeletionLock.Unlock()
 }
 }
 
 
-func (n *connectionManager) AddPendingDeletion(ip uint32) {
+func (n *connectionManager) AddPendingDeletion(ip iputil.VpnIp) {
 	n.pendingDeletionLock.Lock()
 	n.pendingDeletionLock.Lock()
 	if _, ok := n.pendingDeletion[ip]; ok {
 	if _, ok := n.pendingDeletion[ip]; ok {
 		n.pendingDeletion[ip] += 1
 		n.pendingDeletion[ip] += 1
@@ -123,7 +125,7 @@ func (n *connectionManager) AddPendingDeletion(ip uint32) {
 	n.pendingDeletionLock.Unlock()
 	n.pendingDeletionLock.Unlock()
 }
 }
 
 
-func (n *connectionManager) checkPendingDeletion(ip uint32) bool {
+func (n *connectionManager) checkPendingDeletion(ip iputil.VpnIp) bool {
 	n.pendingDeletionLock.RLock()
 	n.pendingDeletionLock.RLock()
 	if _, ok := n.pendingDeletion[ip]; ok {
 	if _, ok := n.pendingDeletion[ip]; ok {
 
 
@@ -134,8 +136,8 @@ func (n *connectionManager) checkPendingDeletion(ip uint32) bool {
 	return false
 	return false
 }
 }
 
 
-func (n *connectionManager) AddTrafficWatch(vpnIP uint32, seconds int) {
-	n.TrafficTimer.Add(vpnIP, time.Second*time.Duration(seconds))
+func (n *connectionManager) AddTrafficWatch(vpnIp iputil.VpnIp, seconds int) {
+	n.TrafficTimer.Add(vpnIp, time.Second*time.Duration(seconds))
 }
 }
 
 
 func (n *connectionManager) Start(ctx context.Context) {
 func (n *connectionManager) Start(ctx context.Context) {
@@ -169,23 +171,23 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 			break
 			break
 		}
 		}
 
 
-		vpnIP := ep.(uint32)
+		vpnIp := ep.(iputil.VpnIp)
 
 
 		// Check for traffic coming back in from this host.
 		// Check for traffic coming back in from this host.
-		traf := n.CheckIn(vpnIP)
+		traf := n.CheckIn(vpnIp)
 
 
-		hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
+		hostinfo, err := n.hostMap.QueryVpnIp(vpnIp)
 		if err != nil {
 		if err != nil {
-			n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
+			n.l.Debugf("Not found in hostmap: %s", vpnIp)
 
 
 			if !n.intf.disconnectInvalid {
 			if !n.intf.disconnectInvalid {
-				n.ClearIP(vpnIP)
-				n.ClearPendingDeletion(vpnIP)
+				n.ClearIP(vpnIp)
+				n.ClearPendingDeletion(vpnIp)
 				continue
 				continue
 			}
 			}
 		}
 		}
 
 
-		if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
+		if n.handleInvalidCertificate(now, vpnIp, hostinfo) {
 			continue
 			continue
 		}
 		}
 
 
@@ -193,12 +195,12 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 		// expired, just ignore.
 		// expired, just ignore.
 		if traf {
 		if traf {
 			if n.l.Level >= logrus.DebugLevel {
 			if n.l.Level >= logrus.DebugLevel {
-				n.l.WithField("vpnIp", IntIp(vpnIP)).
+				n.l.WithField("vpnIp", vpnIp).
 					WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
 					WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
 					Debug("Tunnel status")
 					Debug("Tunnel status")
 			}
 			}
-			n.ClearIP(vpnIP)
-			n.ClearPendingDeletion(vpnIP)
+			n.ClearIP(vpnIp)
+			n.ClearPendingDeletion(vpnIp)
 			continue
 			continue
 		}
 		}
 
 
@@ -208,12 +210,12 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 
 
 		if hostinfo != nil && hostinfo.ConnectionState != nil {
 		if hostinfo != nil && hostinfo.ConnectionState != nil {
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-			n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out)
+			n.intf.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, p, nb, out)
 
 
 		} else {
 		} else {
-			hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
+			hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", vpnIp)
 		}
 		}
-		n.AddPendingDeletion(vpnIP)
+		n.AddPendingDeletion(vpnIp)
 	}
 	}
 
 
 }
 }
@@ -226,38 +228,38 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
 			break
 			break
 		}
 		}
 
 
-		vpnIP := ep.(uint32)
+		vpnIp := ep.(iputil.VpnIp)
 
 
-		hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
+		hostinfo, err := n.hostMap.QueryVpnIp(vpnIp)
 		if err != nil {
 		if err != nil {
-			n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
+			n.l.Debugf("Not found in hostmap: %s", vpnIp)
 
 
 			if !n.intf.disconnectInvalid {
 			if !n.intf.disconnectInvalid {
-				n.ClearIP(vpnIP)
-				n.ClearPendingDeletion(vpnIP)
+				n.ClearIP(vpnIp)
+				n.ClearPendingDeletion(vpnIp)
 				continue
 				continue
 			}
 			}
 		}
 		}
 
 
-		if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
+		if n.handleInvalidCertificate(now, vpnIp, hostinfo) {
 			continue
 			continue
 		}
 		}
 
 
 		// If we saw an incoming packets from this ip and peer's certificate is not
 		// If we saw an incoming packets from this ip and peer's certificate is not
 		// expired, just ignore.
 		// expired, just ignore.
-		traf := n.CheckIn(vpnIP)
+		traf := n.CheckIn(vpnIp)
 		if traf {
 		if traf {
-			n.l.WithField("vpnIp", IntIp(vpnIP)).
+			n.l.WithField("vpnIp", vpnIp).
 				WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
 				WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
 				Debug("Tunnel status")
 				Debug("Tunnel status")
 
 
-			n.ClearIP(vpnIP)
-			n.ClearPendingDeletion(vpnIP)
+			n.ClearIP(vpnIp)
+			n.ClearPendingDeletion(vpnIp)
 			continue
 			continue
 		}
 		}
 
 
 		// If it comes around on deletion wheel and hasn't resolved itself, delete
 		// If it comes around on deletion wheel and hasn't resolved itself, delete
-		if n.checkPendingDeletion(vpnIP) {
+		if n.checkPendingDeletion(vpnIp) {
 			cn := ""
 			cn := ""
 			if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
 			if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
 				cn = hostinfo.ConnectionState.peerCert.Details.Name
 				cn = hostinfo.ConnectionState.peerCert.Details.Name
@@ -267,22 +269,22 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
 				WithField("certName", cn).
 				WithField("certName", cn).
 				Info("Tunnel status")
 				Info("Tunnel status")
 
 
-			n.ClearIP(vpnIP)
-			n.ClearPendingDeletion(vpnIP)
+			n.ClearIP(vpnIp)
+			n.ClearPendingDeletion(vpnIp)
 			// TODO: This is only here to let tests work. Should do proper mocking
 			// TODO: This is only here to let tests work. Should do proper mocking
 			if n.intf.lightHouse != nil {
 			if n.intf.lightHouse != nil {
-				n.intf.lightHouse.DeleteVpnIP(vpnIP)
+				n.intf.lightHouse.DeleteVpnIp(vpnIp)
 			}
 			}
 			n.hostMap.DeleteHostInfo(hostinfo)
 			n.hostMap.DeleteHostInfo(hostinfo)
 		} else {
 		} else {
-			n.ClearIP(vpnIP)
-			n.ClearPendingDeletion(vpnIP)
+			n.ClearIP(vpnIp)
+			n.ClearPendingDeletion(vpnIp)
 		}
 		}
 	}
 	}
 }
 }
 
 
 // handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
 // handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
-func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32, hostinfo *HostInfo) bool {
+func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIp iputil.VpnIp, hostinfo *HostInfo) bool {
 	if !n.intf.disconnectInvalid {
 	if !n.intf.disconnectInvalid {
 		return false
 		return false
 	}
 	}
@@ -298,7 +300,7 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32
 	}
 	}
 
 
 	fingerprint, _ := remoteCert.Sha256Sum()
 	fingerprint, _ := remoteCert.Sha256Sum()
-	n.l.WithField("vpnIp", IntIp(vpnIP)).WithError(err).
+	n.l.WithField("vpnIp", vpnIp).WithError(err).
 		WithField("certName", remoteCert.Details.Name).
 		WithField("certName", remoteCert.Details.Name).
 		WithField("fingerprint", fingerprint).
 		WithField("fingerprint", fingerprint).
 		Info("Remote certificate is no longer valid, tearing down the tunnel")
 		Info("Remote certificate is no longer valid, tearing down the tunnel")
@@ -307,7 +309,7 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32
 	n.intf.sendCloseTunnel(hostinfo)
 	n.intf.sendCloseTunnel(hostinfo)
 	n.intf.closeTunnel(hostinfo, false)
 	n.intf.closeTunnel(hostinfo, false)
 
 
-	n.ClearIP(vpnIP)
-	n.ClearPendingDeletion(vpnIP)
+	n.ClearIP(vpnIp)
+	n.ClearPendingDeletion(vpnIp)
 	return true
 	return true
 }
 }

+ 39 - 36
connection_manager_test.go

@@ -10,17 +10,20 @@ import (
 
 
 	"github.com/flynn/noise"
 	"github.com/flynn/noise"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
-var vpnIP uint32
+var vpnIp iputil.VpnIp
 
 
 func Test_NewConnectionManagerTest(t *testing.T) {
 func Test_NewConnectionManagerTest(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	vpnIP = ip2int(net.ParseIP("172.1.1.2"))
+	vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
 	preferredRanges := []*net.IPNet{localrange}
 
 
 	// Very incomplete mock objects
 	// Very incomplete mock objects
@@ -32,15 +35,15 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		rawCertificateNoKey: []byte{},
 		rawCertificateNoKey: []byte{},
 	}
 	}
 
 
-	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
+	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
 	ifce := &Interface{
 	ifce := &Interface{
 		hostMap:          hostMap,
 		hostMap:          hostMap,
 		inside:           &Tun{},
 		inside:           &Tun{},
-		outside:          &udpConn{},
+		outside:          &udp.Conn{},
 		certState:        cs,
 		certState:        cs,
 		firewall:         &Firewall{},
 		firewall:         &Firewall{},
 		lightHouse:       lh,
 		lightHouse:       lh,
-		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
+		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
 		l:                l,
 		l:                l,
 	}
 	}
 	now := time.Now()
 	now := time.Now()
@@ -54,16 +57,16 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	out := make([]byte, mtu)
 	out := make([]byte, mtu)
 	nc.HandleMonitorTick(now, p, nb, out)
 	nc.HandleMonitorTick(now, p, nb, out)
 	// Add an ip we have established a connection w/ to hostmap
 	// Add an ip we have established a connection w/ to hostmap
-	hostinfo := nc.hostMap.AddVpnIP(vpnIP)
+	hostinfo := nc.hostMap.AddVpnIp(vpnIp)
 	hostinfo.ConnectionState = &ConnectionState{
 	hostinfo.ConnectionState = &ConnectionState{
 		certState: cs,
 		certState: cs,
 		H:         &noise.HandshakeState{},
 		H:         &noise.HandshakeState{},
 	}
 	}
 
 
-	// We saw traffic out to vpnIP
-	nc.Out(vpnIP)
-	assert.NotContains(t, nc.pendingDeletion, vpnIP)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIP)
+	// We saw traffic out to vpnIp
+	nc.Out(vpnIp)
+	assert.NotContains(t, nc.pendingDeletion, vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
 	// Move ahead 5s. Nothing should happen
 	// Move ahead 5s. Nothing should happen
 	next_tick := now.Add(5 * time.Second)
 	next_tick := now.Add(5 * time.Second)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
@@ -73,20 +76,20 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleDeletionTick(next_tick)
 	nc.HandleDeletionTick(next_tick)
 	// This host should now be up for deletion
 	// This host should now be up for deletion
-	assert.Contains(t, nc.pendingDeletion, vpnIP)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIP)
+	assert.Contains(t, nc.pendingDeletion, vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
 	// Move ahead some more
 	// Move ahead some more
 	next_tick = now.Add(45 * time.Second)
 	next_tick = now.Add(45 * time.Second)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleDeletionTick(next_tick)
 	nc.HandleDeletionTick(next_tick)
 	// The host should be evicted
 	// The host should be evicted
-	assert.NotContains(t, nc.pendingDeletion, vpnIP)
-	assert.NotContains(t, nc.hostMap.Hosts, vpnIP)
+	assert.NotContains(t, nc.pendingDeletion, vpnIp)
+	assert.NotContains(t, nc.hostMap.Hosts, vpnIp)
 
 
 }
 }
 
 
 func Test_NewConnectionManagerTest2(t *testing.T) {
 func Test_NewConnectionManagerTest2(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@@ -101,15 +104,15 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		rawCertificateNoKey: []byte{},
 		rawCertificateNoKey: []byte{},
 	}
 	}
 
 
-	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
+	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
 	ifce := &Interface{
 	ifce := &Interface{
 		hostMap:          hostMap,
 		hostMap:          hostMap,
 		inside:           &Tun{},
 		inside:           &Tun{},
-		outside:          &udpConn{},
+		outside:          &udp.Conn{},
 		certState:        cs,
 		certState:        cs,
 		firewall:         &Firewall{},
 		firewall:         &Firewall{},
 		lightHouse:       lh,
 		lightHouse:       lh,
-		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
+		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
 		l:                l,
 		l:                l,
 	}
 	}
 	now := time.Now()
 	now := time.Now()
@@ -123,16 +126,16 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	out := make([]byte, mtu)
 	out := make([]byte, mtu)
 	nc.HandleMonitorTick(now, p, nb, out)
 	nc.HandleMonitorTick(now, p, nb, out)
 	// Add an ip we have established a connection w/ to hostmap
 	// Add an ip we have established a connection w/ to hostmap
-	hostinfo := nc.hostMap.AddVpnIP(vpnIP)
+	hostinfo := nc.hostMap.AddVpnIp(vpnIp)
 	hostinfo.ConnectionState = &ConnectionState{
 	hostinfo.ConnectionState = &ConnectionState{
 		certState: cs,
 		certState: cs,
 		H:         &noise.HandshakeState{},
 		H:         &noise.HandshakeState{},
 	}
 	}
 
 
-	// We saw traffic out to vpnIP
-	nc.Out(vpnIP)
-	assert.NotContains(t, nc.pendingDeletion, vpnIP)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIP)
+	// We saw traffic out to vpnIp
+	nc.Out(vpnIp)
+	assert.NotContains(t, nc.pendingDeletion, vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
 	// Move ahead 5s. Nothing should happen
 	// Move ahead 5s. Nothing should happen
 	next_tick := now.Add(5 * time.Second)
 	next_tick := now.Add(5 * time.Second)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
@@ -142,17 +145,17 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleDeletionTick(next_tick)
 	nc.HandleDeletionTick(next_tick)
 	// This host should now be up for deletion
 	// This host should now be up for deletion
-	assert.Contains(t, nc.pendingDeletion, vpnIP)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIP)
+	assert.Contains(t, nc.pendingDeletion, vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
 	// We heard back this time
 	// We heard back this time
-	nc.In(vpnIP)
+	nc.In(vpnIp)
 	// Move ahead some more
 	// Move ahead some more
 	next_tick = now.Add(45 * time.Second)
 	next_tick = now.Add(45 * time.Second)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleDeletionTick(next_tick)
 	nc.HandleDeletionTick(next_tick)
 	// The host should be evicted
 	// The host should be evicted
-	assert.NotContains(t, nc.pendingDeletion, vpnIP)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIP)
+	assert.NotContains(t, nc.pendingDeletion, vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
 
 
 }
 }
 
 
@@ -161,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 // Disconnect only if disconnectInvalid: true is set.
 // Disconnect only if disconnectInvalid: true is set.
 func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	now := time.Now()
 	now := time.Now()
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ipNet := net.IPNet{
 	ipNet := net.IPNet{
 		IP:   net.IPv4(172, 1, 1, 2),
 		IP:   net.IPv4(172, 1, 1, 2),
 		Mask: net.IPMask{255, 255, 255, 0},
 		Mask: net.IPMask{255, 255, 255, 0},
@@ -210,15 +213,15 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 		rawCertificateNoKey: []byte{},
 		rawCertificateNoKey: []byte{},
 	}
 	}
 
 
-	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
+	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
 	ifce := &Interface{
 	ifce := &Interface{
 		hostMap:           hostMap,
 		hostMap:           hostMap,
 		inside:            &Tun{},
 		inside:            &Tun{},
-		outside:           &udpConn{},
+		outside:           &udp.Conn{},
 		certState:         cs,
 		certState:         cs,
 		firewall:          &Firewall{},
 		firewall:          &Firewall{},
 		lightHouse:        lh,
 		lightHouse:        lh,
-		handshakeManager:  NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
+		handshakeManager:  NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
 		l:                 l,
 		l:                 l,
 		disconnectInvalid: true,
 		disconnectInvalid: true,
 		caPool:            ncp,
 		caPool:            ncp,
@@ -229,7 +232,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	defer cancel()
 	defer cancel()
 	nc := newConnectionManager(ctx, l, ifce, 5, 10)
 	nc := newConnectionManager(ctx, l, ifce, 5, 10)
 	ifce.connectionManager = nc
 	ifce.connectionManager = nc
-	hostinfo := nc.hostMap.AddVpnIP(vpnIP)
+	hostinfo := nc.hostMap.AddVpnIp(vpnIp)
 	hostinfo.ConnectionState = &ConnectionState{
 	hostinfo.ConnectionState = &ConnectionState{
 		certState: cs,
 		certState: cs,
 		peerCert:  &peerCert,
 		peerCert:  &peerCert,
@@ -240,13 +243,13 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	// Check if to disconnect with invalid certificate.
 	// Check if to disconnect with invalid certificate.
 	// Should be alive.
 	// Should be alive.
 	nextTick := now.Add(45 * time.Second)
 	nextTick := now.Add(45 * time.Second)
-	destroyed := nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
+	destroyed := nc.handleInvalidCertificate(nextTick, vpnIp, hostinfo)
 	assert.False(t, destroyed)
 	assert.False(t, destroyed)
 
 
 	// Move ahead 61s.
 	// Move ahead 61s.
 	// Check if to disconnect with invalid certificate.
 	// Check if to disconnect with invalid certificate.
 	// Should be disconnected.
 	// Should be disconnected.
 	nextTick = now.Add(61 * time.Second)
 	nextTick = now.Add(61 * time.Second)
-	destroyed = nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
+	destroyed = nc.handleInvalidCertificate(nextTick, vpnIp, hostinfo)
 	assert.True(t, destroyed)
 	assert.True(t, destroyed)
 }
 }

+ 18 - 15
control.go

@@ -10,6 +10,9 @@ import (
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
 // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
@@ -25,14 +28,14 @@ type Control struct {
 }
 }
 
 
 type ControlHostInfo struct {
 type ControlHostInfo struct {
-	VpnIP          net.IP                  `json:"vpnIp"`
+	VpnIp          net.IP                  `json:"vpnIp"`
 	LocalIndex     uint32                  `json:"localIndex"`
 	LocalIndex     uint32                  `json:"localIndex"`
 	RemoteIndex    uint32                  `json:"remoteIndex"`
 	RemoteIndex    uint32                  `json:"remoteIndex"`
-	RemoteAddrs    []*udpAddr              `json:"remoteAddrs"`
+	RemoteAddrs    []*udp.Addr             `json:"remoteAddrs"`
 	CachedPackets  int                     `json:"cachedPackets"`
 	CachedPackets  int                     `json:"cachedPackets"`
 	Cert           *cert.NebulaCertificate `json:"cert"`
 	Cert           *cert.NebulaCertificate `json:"cert"`
 	MessageCounter uint64                  `json:"messageCounter"`
 	MessageCounter uint64                  `json:"messageCounter"`
-	CurrentRemote  *udpAddr                `json:"currentRemote"`
+	CurrentRemote  *udp.Addr               `json:"currentRemote"`
 }
 }
 
 
 // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
 // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
@@ -95,8 +98,8 @@ func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
 	}
 	}
 }
 }
 
 
-// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
-func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo {
+// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
+func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
 	var hm *HostMap
 	var hm *HostMap
 	if pending {
 	if pending {
 		hm = c.f.handshakeManager.pendingHostMap
 		hm = c.f.handshakeManager.pendingHostMap
@@ -104,7 +107,7 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
 		hm = c.f.hostMap
 		hm = c.f.hostMap
 	}
 	}
 
 
-	h, err := hm.QueryVpnIP(vpnIP)
+	h, err := hm.QueryVpnIp(vpnIp)
 	if err != nil {
 	if err != nil {
 		return nil
 		return nil
 	}
 	}
@@ -114,8 +117,8 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
 }
 }
 
 
 // SetRemoteForTunnel forces a tunnel to use a specific remote
 // SetRemoteForTunnel forces a tunnel to use a specific remote
-func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo {
-	hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
+func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
+	hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 	if err != nil {
 		return nil
 		return nil
 	}
 	}
@@ -126,15 +129,15 @@ func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInf
 }
 }
 
 
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
-func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
-	hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
+func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
+	hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 	if err != nil {
 		return false
 		return false
 	}
 	}
 
 
 	if !localOnly {
 	if !localOnly {
 		c.f.send(
 		c.f.send(
-			closeTunnel,
+			header.CloseTunnel,
 			0,
 			0,
 			hostInfo.ConnectionState,
 			hostInfo.ConnectionState,
 			hostInfo,
 			hostInfo,
@@ -156,16 +159,16 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	c.f.hostMap.Lock()
 	c.f.hostMap.Lock()
 	for _, h := range c.f.hostMap.Hosts {
 	for _, h := range c.f.hostMap.Hosts {
 		if excludeLighthouses {
 		if excludeLighthouses {
-			if _, ok := c.f.lightHouse.lighthouses[h.hostId]; ok {
+			if _, ok := c.f.lightHouse.lighthouses[h.vpnIp]; ok {
 				continue
 				continue
 			}
 			}
 		}
 		}
 
 
 		if h.ConnectionState.ready {
 		if h.ConnectionState.ready {
-			c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
+			c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 			c.f.closeTunnel(h, true)
 			c.f.closeTunnel(h, true)
 
 
-			c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
+			c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
 				Debug("Sending close tunnel message")
 				Debug("Sending close tunnel message")
 			closed++
 			closed++
 		}
 		}
@@ -176,7 +179,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 
 
 func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 	chi := ControlHostInfo{
 	chi := ControlHostInfo{
-		VpnIP:         int2ip(h.hostId),
+		VpnIp:         h.vpnIp.ToIP(),
 		LocalIndex:    h.localIndexId,
 		LocalIndex:    h.localIndexId,
 		RemoteIndex:   h.remoteIndexId,
 		RemoteIndex:   h.remoteIndexId,
 		RemoteAddrs:   h.remotes.CopyAddrs(preferredRanges),
 		RemoteAddrs:   h.remotes.CopyAddrs(preferredRanges),

+ 16 - 14
control_test.go

@@ -8,17 +8,19 @@ import (
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
-func TestControl_GetHostInfoByVpnIP(t *testing.T) {
-	l := NewTestLogger()
+func TestControl_GetHostInfoByVpnIp(t *testing.T) {
+	l := util.NewTestLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
 	// To properly ensure we are not exposing core memory to the caller
 	hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
 	hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
-	remote1 := NewUDPAddr(int2ip(100), 4444)
-	remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
+	remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
+	remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
 	ipNet := net.IPNet{
 	ipNet := net.IPNet{
 		IP:   net.IPv4(1, 2, 3, 4),
 		IP:   net.IPv4(1, 2, 3, 4),
 		Mask: net.IPMask{255, 255, 255, 0},
 		Mask: net.IPMask{255, 255, 255, 0},
@@ -48,7 +50,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 	remotes := NewRemoteList()
 	remotes := NewRemoteList()
 	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
 	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
 	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
 	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
-	hm.Add(ip2int(ipNet.IP), &HostInfo{
+	hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{
 		remote:  remote1,
 		remote:  remote1,
 		remotes: remotes,
 		remotes: remotes,
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
@@ -56,10 +58,10 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 		},
 		},
 		remoteIndexId: 200,
 		remoteIndexId: 200,
 		localIndexId:  201,
 		localIndexId:  201,
-		hostId:        ip2int(ipNet.IP),
+		vpnIp:         iputil.Ip2VpnIp(ipNet.IP),
 	})
 	})
 
 
-	hm.Add(ip2int(ipNet2.IP), &HostInfo{
+	hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{
 		remote:  remote1,
 		remote:  remote1,
 		remotes: remotes,
 		remotes: remotes,
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
@@ -67,7 +69,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 		},
 		},
 		remoteIndexId: 200,
 		remoteIndexId: 200,
 		localIndexId:  201,
 		localIndexId:  201,
-		hostId:        ip2int(ipNet2.IP),
+		vpnIp:         iputil.Ip2VpnIp(ipNet2.IP),
 	})
 	})
 
 
 	c := Control{
 	c := Control{
@@ -77,26 +79,26 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 		l: logrus.New(),
 		l: logrus.New(),
 	}
 	}
 
 
-	thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
+	thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false)
 
 
 	expectedInfo := ControlHostInfo{
 	expectedInfo := ControlHostInfo{
-		VpnIP:          net.IPv4(1, 2, 3, 4).To4(),
+		VpnIp:          net.IPv4(1, 2, 3, 4).To4(),
 		LocalIndex:     201,
 		LocalIndex:     201,
 		RemoteIndex:    200,
 		RemoteIndex:    200,
-		RemoteAddrs:    []*udpAddr{remote2, remote1},
+		RemoteAddrs:    []*udp.Addr{remote2, remote1},
 		CachedPackets:  0,
 		CachedPackets:  0,
 		Cert:           crt.Copy(),
 		Cert:           crt.Copy(),
 		MessageCounter: 0,
 		MessageCounter: 0,
-		CurrentRemote:  NewUDPAddr(int2ip(100), 4444),
+		CurrentRemote:  udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
 	}
 	}
 
 
 	// Make sure we don't have any unexpected fields
 	// Make sure we don't have any unexpected fields
-	assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
+	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
 	util.AssertDeepCopyEqual(t, &expectedInfo, thi)
 	util.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	assert.NotPanics(t, func() {
 	assert.NotPanics(t, func() {
-		thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
+		thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false)
 	})
 	})
 }
 }
 
 

+ 15 - 12
control_tester.go

@@ -8,12 +8,15 @@ import (
 
 
 	"github.com/google/gopacket"
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
 	"github.com/google/gopacket/layers"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 // WaitForTypeByIndex will pipe all messages from this control device into the pipeTo control device
 // WaitForTypeByIndex will pipe all messages from this control device into the pipeTo control device
 // returning after a message matching the criteria has been piped
 // returning after a message matching the criteria has been piped
-func (c *Control) WaitForType(msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) {
-	h := &Header{}
+func (c *Control) WaitForType(msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
+	h := &header.H{}
 	for {
 	for {
 		p := c.f.outside.Get(true)
 		p := c.f.outside.Get(true)
 		if err := h.Parse(p.Data); err != nil {
 		if err := h.Parse(p.Data); err != nil {
@@ -28,8 +31,8 @@ func (c *Control) WaitForType(msgType NebulaMessageType, subType NebulaMessageSu
 
 
 // WaitForTypeByIndex is similar to WaitForType except it adds an index check
 // WaitForTypeByIndex is similar to WaitForType except it adds an index check
 // Useful if you have many nodes communicating and want to wait to find a specific nodes packet
 // Useful if you have many nodes communicating and want to wait to find a specific nodes packet
-func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) {
-	h := &Header{}
+func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
+	h := &header.H{}
 	for {
 	for {
 		p := c.f.outside.Get(true)
 		p := c.f.outside.Get(true)
 		if err := h.Parse(p.Data); err != nil {
 		if err := h.Parse(p.Data); err != nil {
@@ -46,12 +49,12 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType,
 // This is necessary if you did not configure static hosts or are not running a lighthouse
 // This is necessary if you did not configure static hosts or are not running a lighthouse
 func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
 func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
 	c.f.lightHouse.Lock()
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(ip2int(vpnIp))
+	remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
 	remoteList.Lock()
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 	c.f.lightHouse.Unlock()
 
 
-	iVpnIp := ip2int(vpnIp)
+	iVpnIp := iputil.Ip2VpnIp(vpnIp)
 	if v4 := toAddr.IP.To4(); v4 != nil {
 	if v4 := toAddr.IP.To4(); v4 != nil {
 		remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
 		remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
 	} else {
 	} else {
@@ -65,12 +68,12 @@ func (c *Control) GetFromTun(block bool) []byte {
 }
 }
 
 
 // GetFromUDP will pull a udp packet off the udp side of nebula
 // GetFromUDP will pull a udp packet off the udp side of nebula
-func (c *Control) GetFromUDP(block bool) *UdpPacket {
+func (c *Control) GetFromUDP(block bool) *udp.Packet {
 	return c.f.outside.Get(block)
 	return c.f.outside.Get(block)
 }
 }
 
 
-func (c *Control) GetUDPTxChan() <-chan *UdpPacket {
-	return c.f.outside.txPackets
+func (c *Control) GetUDPTxChan() <-chan *udp.Packet {
+	return c.f.outside.TxPackets
 }
 }
 
 
 func (c *Control) GetTunTxChan() <-chan []byte {
 func (c *Control) GetTunTxChan() <-chan []byte {
@@ -78,7 +81,7 @@ func (c *Control) GetTunTxChan() <-chan []byte {
 }
 }
 
 
 // InjectUDPPacket will inject a packet into the udp side of nebula
 // InjectUDPPacket will inject a packet into the udp side of nebula
-func (c *Control) InjectUDPPacket(p *UdpPacket) {
+func (c *Control) InjectUDPPacket(p *udp.Packet) {
 	c.f.outside.Send(p)
 	c.f.outside.Send(p)
 }
 }
 
 
@@ -115,11 +118,11 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
 }
 }
 
 
 func (c *Control) GetUDPAddr() string {
 func (c *Control) GetUDPAddr() string {
-	return c.f.outside.addr.String()
+	return c.f.outside.Addr.String()
 }
 }
 
 
 func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
 func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
-	hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[ip2int(vpnIp)]
+	hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)]
 	if !ok {
 	if !ok {
 		return false
 		return false
 	}
 	}

+ 9 - 7
dns_server.go

@@ -8,6 +8,8 @@ import (
 
 
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 // This whole thing should be rewritten to use context
 // This whole thing should be rewritten to use context
@@ -44,8 +46,8 @@ func (d *dnsRecords) QueryCert(data string) string {
 	if ip == nil {
 	if ip == nil {
 		return ""
 		return ""
 	}
 	}
-	iip := ip2int(ip)
-	hostinfo, err := d.hostMap.QueryVpnIP(iip)
+	iip := iputil.Ip2VpnIp(ip)
+	hostinfo, err := d.hostMap.QueryVpnIp(iip)
 	if err != nil {
 	if err != nil {
 		return ""
 		return ""
 	}
 	}
@@ -109,7 +111,7 @@ func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
 	w.WriteMsg(m)
 	w.WriteMsg(m)
 }
 }
 
 
-func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
+func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
 	dnsR = newDnsRecords(hostMap)
 	dnsR = newDnsRecords(hostMap)
 
 
 	// attach request handler func
 	// attach request handler func
@@ -117,7 +119,7 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
 		handleDnsRequest(l, w, r)
 		handleDnsRequest(l, w, r)
 	})
 	})
 
 
-	c.RegisterReloadCallback(func(c *Config) {
+	c.RegisterReloadCallback(func(c *config.C) {
 		reloadDns(l, c)
 		reloadDns(l, c)
 	})
 	})
 
 
@@ -126,11 +128,11 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
 	}
 	}
 }
 }
 
 
-func getDnsServerAddr(c *Config) string {
+func getDnsServerAddr(c *config.C) string {
 	return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
 	return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
 }
 }
 
 
-func startDns(l *logrus.Logger, c *Config) {
+func startDns(l *logrus.Logger, c *config.C) {
 	dnsAddr = getDnsServerAddr(c)
 	dnsAddr = getDnsServerAddr(c)
 	dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
 	dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
 	l.WithField("dnsListener", dnsAddr).Infof("Starting DNS responder")
 	l.WithField("dnsListener", dnsAddr).Infof("Starting DNS responder")
@@ -141,7 +143,7 @@ func startDns(l *logrus.Logger, c *Config) {
 	}
 	}
 }
 }
 
 
-func reloadDns(l *logrus.Logger, c *Config) {
+func reloadDns(l *logrus.Logger, c *config.C) {
 	if dnsAddr == getDnsServerAddr(c) {
 	if dnsAddr == getDnsServerAddr(c) {
 		l.Debug("No DNS server config change detected")
 		l.Debug("No DNS server config change detected")
 		return
 		return

+ 8 - 5
e2e/handshakes_test.go

@@ -10,6 +10,9 @@ import (
 
 
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/slackhq/nebula/e2e/router"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
@@ -37,7 +40,7 @@ func TestGoodHandshake(t *testing.T) {
 	t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
 	t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
 	// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
 	// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
 	badPacket := stage1Packet.Copy()
 	badPacket := stage1Packet.Copy()
-	badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen]
+	badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
 	myControl.InjectUDPPacket(badPacket)
 	myControl.InjectUDPPacket(badPacket)
 
 
 	t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
 	t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
@@ -87,8 +90,8 @@ func TestWrongResponderHandshake(t *testing.T) {
 
 
 	t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
 	t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
 	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
 	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
-	r.RouteForAllExitFunc(func(p *nebula.UdpPacket, c *nebula.Control) router.ExitType {
-		h := &nebula.Header{}
+	r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
+		h := &header.H{}
 		err := h.Parse(p.Data)
 		err := h.Parse(p.Data)
 		if err != nil {
 		if err != nil {
 			panic(err)
 			panic(err)
@@ -115,8 +118,8 @@ func TestWrongResponderHandshake(t *testing.T) {
 	r.FlushAll()
 	r.FlushAll()
 
 
 	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
 	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), true), "My pending hostmap should not contain evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false), "My main hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), true), "My pending hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), false), "My main hostmap should not contain evil")
 	//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
 	//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
 
 
 	//TODO: assert hostmaps for everyone
 	//TODO: assert hostmaps for everyone

+ 11 - 23
e2e/helpers_test.go

@@ -5,7 +5,6 @@ package e2e
 
 
 import (
 import (
 	"crypto/rand"
 	"crypto/rand"
-	"encoding/binary"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
@@ -19,7 +18,9 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/slackhq/nebula/e2e/router"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/crypto/curve25519"
 	"golang.org/x/crypto/curve25519"
 	"golang.org/x/crypto/ed25519"
 	"golang.org/x/crypto/ed25519"
@@ -82,10 +83,10 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		panic(err)
 		panic(err)
 	}
 	}
 
 
-	config := nebula.NewConfig(l)
-	config.LoadString(string(cb))
+	c := config.NewC(l)
+	c.LoadString(string(cb))
 
 
-	control, err := nebula.Main(config, false, "e2e-test", l, nil)
+	control, err := nebula.Main(c, false, "e2e-test", l, nil)
 
 
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
@@ -200,19 +201,6 @@ func x25519Keypair() ([]byte, []byte) {
 	return pubkey, privkey
 	return pubkey, privkey
 }
 }
 
 
-func ip2int(ip []byte) uint32 {
-	if len(ip) == 16 {
-		return binary.BigEndian.Uint32(ip[12:16])
-	}
-	return binary.BigEndian.Uint32(ip)
-}
-
-func int2ip(nn uint32) net.IP {
-	ip := make(net.IP, 4)
-	binary.BigEndian.PutUint32(ip, nn)
-	return ip
-}
-
 type doneCb func()
 type doneCb func()
 
 
 func deadline(t *testing.T, seconds time.Duration) doneCb {
 func deadline(t *testing.T, seconds time.Duration) doneCb {
@@ -245,15 +233,15 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul
 
 
 func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
 func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
 	// Get both host infos
 	// Get both host infos
-	hBinA := controlA.GetHostInfoByVpnIP(ip2int(vpnIpB), false)
-	assert.NotNil(t, hBinA, "Host B was not found by vpnIP in controlA")
+	hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false)
+	assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
 
 
-	hAinB := controlB.GetHostInfoByVpnIP(ip2int(vpnIpA), false)
-	assert.NotNil(t, hAinB, "Host A was not found by vpnIP in controlB")
+	hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false)
+	assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
 
 
 	// Check that both vpn and real addr are correct
 	// Check that both vpn and real addr are correct
-	assert.Equal(t, vpnIpB, hBinA.VpnIP, "Host B VpnIp is wrong in control A")
-	assert.Equal(t, vpnIpA, hAinB.VpnIP, "Host A VpnIp is wrong in control B")
+	assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
+	assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
 
 
 	assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
 	assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
 	assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")
 	assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")

+ 12 - 10
e2e/router/router.go

@@ -11,6 +11,8 @@ import (
 	"sync"
 	"sync"
 
 
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 type R struct {
 type R struct {
@@ -41,7 +43,7 @@ const (
 	RouteAndExit ExitType = 2
 	RouteAndExit ExitType = 2
 )
 )
 
 
-type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType
+type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType
 
 
 func NewR(controls ...*nebula.Control) *R {
 func NewR(controls ...*nebula.Control) *R {
 	r := &R{
 	r := &R{
@@ -79,7 +81,7 @@ func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
 // OnceFrom will route a single packet from sender then return
 // OnceFrom will route a single packet from sender then return
 // If the router doesn't have the nebula controller for that address, we panic
 // If the router doesn't have the nebula controller for that address, we panic
 func (r *R) OnceFrom(sender *nebula.Control) {
 func (r *R) OnceFrom(sender *nebula.Control) {
-	r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType {
+	r.RouteExitFunc(sender, func(*udp.Packet, *nebula.Control) ExitType {
 		return RouteAndExit
 		return RouteAndExit
 	})
 	})
 }
 }
@@ -119,7 +121,7 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
 //   - routeAndExit: this call will return immediately after routing the last packet from sender
 //   - routeAndExit: this call will return immediately after routing the last packet from sender
 //   - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
 //   - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
 func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
 func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
-	h := &nebula.Header{}
+	h := &header.H{}
 	for {
 	for {
 		p := sender.GetFromUDP(true)
 		p := sender.GetFromUDP(true)
 		r.Lock()
 		r.Lock()
@@ -159,9 +161,9 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
 
 
 // RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender
 // RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender
 // If the router doesn't have the nebula controller for that address, we panic
 // If the router doesn't have the nebula controller for that address, we panic
-func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) {
-	h := &nebula.Header{}
-	r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
+func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType header.MessageType, subType header.MessageSubType) {
+	h := &header.H{}
+	r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
 		if err := h.Parse(p.Data); err != nil {
 		if err := h.Parse(p.Data); err != nil {
 			panic(err)
 			panic(err)
 		}
 		}
@@ -181,7 +183,7 @@ func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr
 		finish = RouteAndExit
 		finish = RouteAndExit
 	}
 	}
 
 
-	r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
+	r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
 		if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
 		if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
 			return finish
 			return finish
 		}
 		}
@@ -215,7 +217,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
 		x, rx, _ := reflect.Select(sc)
 		x, rx, _ := reflect.Select(sc)
 		r.Lock()
 		r.Lock()
 
 
-		p := rx.Interface().(*nebula.UdpPacket)
+		p := rx.Interface().(*udp.Packet)
 
 
 		outAddr := cm[x].GetUDPAddr()
 		outAddr := cm[x].GetUDPAddr()
 		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
 		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
@@ -277,7 +279,7 @@ func (r *R) FlushAll() {
 		}
 		}
 		r.Lock()
 		r.Lock()
 
 
-		p := rx.Interface().(*nebula.UdpPacket)
+		p := rx.Interface().(*udp.Packet)
 
 
 		outAddr := cm[x].GetUDPAddr()
 		outAddr := cm[x].GetUDPAddr()
 		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
 		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
@@ -292,7 +294,7 @@ func (r *R) FlushAll() {
 
 
 // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
 // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
 // This is an internal router function, the caller must hold the lock
 // This is an internal router function, the caller must hold the lock
-func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control {
+func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control {
 	if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
 	if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
 		p.FromIp = newAddr.IP
 		p.FromIp = newAddr.IP
 		p.FromPort = uint16(newAddr.Port)
 		p.FromPort = uint16(newAddr.Port)

+ 45 - 147
firewall.go

@@ -4,7 +4,6 @@ import (
 	"crypto/sha256"
 	"crypto/sha256"
 	"encoding/binary"
 	"encoding/binary"
 	"encoding/hex"
 	"encoding/hex"
-	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
@@ -12,22 +11,14 @@ import (
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
-	"sync/atomic"
 	"time"
 	"time"
 
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
-)
-
-const (
-	fwProtoAny  = 0 // When we want to handle HOPOPT (0) we can change this, if ever
-	fwProtoTCP  = 6
-	fwProtoUDP  = 17
-	fwProtoICMP = 1
-
-	fwPortAny      = 0  // Special value for matching `port: any`
-	fwPortFragment = -1 // Special value for matching `port: fragment`
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
 )
 )
 
 
 const tcpACK = 0x10
 const tcpACK = 0x10
@@ -63,7 +54,7 @@ type Firewall struct {
 	DefaultTimeout time.Duration //linux: 600s
 	DefaultTimeout time.Duration //linux: 600s
 
 
 	// Used to ensure we don't emit local packets for ips we don't own
 	// Used to ensure we don't emit local packets for ips we don't own
-	localIps *CIDRTree
+	localIps *cidr.Tree4
 
 
 	rules        string
 	rules        string
 	rulesVersion uint16
 	rulesVersion uint16
@@ -85,7 +76,7 @@ type firewallMetrics struct {
 type FirewallConntrack struct {
 type FirewallConntrack struct {
 	sync.Mutex
 	sync.Mutex
 
 
-	Conns      map[FirewallPacket]*conn
+	Conns      map[firewall.Packet]*conn
 	TimerWheel *TimerWheel
 	TimerWheel *TimerWheel
 }
 }
 
 
@@ -116,55 +107,13 @@ type FirewallRule struct {
 	Any    bool
 	Any    bool
 	Hosts  map[string]struct{}
 	Hosts  map[string]struct{}
 	Groups [][]string
 	Groups [][]string
-	CIDR   *CIDRTree
+	CIDR   *cidr.Tree4
 }
 }
 
 
 // Even though ports are uint16, int32 maps are faster for lookup
 // Even though ports are uint16, int32 maps are faster for lookup
 // Plus we can use `-1` for fragment rules
 // Plus we can use `-1` for fragment rules
 type firewallPort map[int32]*FirewallCA
 type firewallPort map[int32]*FirewallCA
 
 
-type FirewallPacket struct {
-	LocalIP    uint32
-	RemoteIP   uint32
-	LocalPort  uint16
-	RemotePort uint16
-	Protocol   uint8
-	Fragment   bool
-}
-
-func (fp *FirewallPacket) Copy() *FirewallPacket {
-	return &FirewallPacket{
-		LocalIP:    fp.LocalIP,
-		RemoteIP:   fp.RemoteIP,
-		LocalPort:  fp.LocalPort,
-		RemotePort: fp.RemotePort,
-		Protocol:   fp.Protocol,
-		Fragment:   fp.Fragment,
-	}
-}
-
-func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
-	var proto string
-	switch fp.Protocol {
-	case fwProtoTCP:
-		proto = "tcp"
-	case fwProtoICMP:
-		proto = "icmp"
-	case fwProtoUDP:
-		proto = "udp"
-	default:
-		proto = fmt.Sprintf("unknown %v", fp.Protocol)
-	}
-	return json.Marshal(m{
-		"LocalIP":    int2ip(fp.LocalIP).String(),
-		"RemoteIP":   int2ip(fp.RemoteIP).String(),
-		"LocalPort":  fp.LocalPort,
-		"RemotePort": fp.RemotePort,
-		"Protocol":   proto,
-		"Fragment":   fp.Fragment,
-	})
-}
-
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
 func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
 func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
 	//TODO: error on 0 duration
 	//TODO: error on 0 duration
@@ -184,7 +133,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 		max = defaultTimeout
 		max = defaultTimeout
 	}
 	}
 
 
-	localIps := NewCIDRTree()
+	localIps := cidr.NewTree4()
 	for _, ip := range c.Details.Ips {
 	for _, ip := range c.Details.Ips {
 		localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
 		localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
 	}
 	}
@@ -195,7 +144,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 
 
 	return &Firewall{
 	return &Firewall{
 		Conntrack: &FirewallConntrack{
 		Conntrack: &FirewallConntrack{
-			Conns:      make(map[FirewallPacket]*conn),
+			Conns:      make(map[firewall.Packet]*conn),
 			TimerWheel: NewTimerWheel(min, max),
 			TimerWheel: NewTimerWheel(min, max),
 		},
 		},
 		InRules:        newFirewallTable(),
 		InRules:        newFirewallTable(),
@@ -220,7 +169,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 	}
 	}
 }
 }
 
 
-func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
+func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) {
 	fw := NewFirewall(
 	fw := NewFirewall(
 		l,
 		l,
 		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
 		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
@@ -278,13 +227,13 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 	}
 	}
 
 
 	switch proto {
 	switch proto {
-	case fwProtoTCP:
+	case firewall.ProtoTCP:
 		fp = ft.TCP
 		fp = ft.TCP
-	case fwProtoUDP:
+	case firewall.ProtoUDP:
 		fp = ft.UDP
 		fp = ft.UDP
-	case fwProtoICMP:
+	case firewall.ProtoICMP:
 		fp = ft.ICMP
 		fp = ft.ICMP
-	case fwProtoAny:
+	case firewall.ProtoAny:
 		fp = ft.AnyProto
 		fp = ft.AnyProto
 	default:
 	default:
 		return fmt.Errorf("unknown protocol %v", proto)
 		return fmt.Errorf("unknown protocol %v", proto)
@@ -299,7 +248,7 @@ func (f *Firewall) GetRuleHash() string {
 	return hex.EncodeToString(sum[:])
 	return hex.EncodeToString(sum[:])
 }
 }
 
 
-func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
+func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
 	var table string
 	var table string
 	if inbound {
 	if inbound {
 		table = "firewall.inbound"
 		table = "firewall.inbound"
@@ -307,7 +256,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
 		table = "firewall.outbound"
 		table = "firewall.outbound"
 	}
 	}
 
 
-	r := config.Get(table)
+	r := c.Get(table)
 	if r == nil {
 	if r == nil {
 		return nil
 		return nil
 	}
 	}
@@ -362,13 +311,13 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
 		var proto uint8
 		var proto uint8
 		switch r.Proto {
 		switch r.Proto {
 		case "any":
 		case "any":
-			proto = fwProtoAny
+			proto = firewall.ProtoAny
 		case "tcp":
 		case "tcp":
-			proto = fwProtoTCP
+			proto = firewall.ProtoTCP
 		case "udp":
 		case "udp":
-			proto = fwProtoUDP
+			proto = firewall.ProtoUDP
 		case "icmp":
 		case "icmp":
-			proto = fwProtoICMP
+			proto = firewall.ProtoICMP
 		default:
 		default:
 			return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
 			return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
 		}
 		}
@@ -396,7 +345,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
 
 
 // Drop returns an error if the packet should be dropped, explaining why. It
 // Drop returns an error if the packet should be dropped, explaining why. It
 // returns nil if the packet should not be dropped.
 // returns nil if the packet should not be dropped.
-func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error {
+func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
 	// Check if we spoke to this tuple, if we did then allow this packet
 	// Check if we spoke to this tuple, if we did then allow this packet
 	if f.inConns(packet, fp, incoming, h, caPool, localCache) {
 	if f.inConns(packet, fp, incoming, h, caPool, localCache) {
 		return nil
 		return nil
@@ -410,7 +359,7 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
 		}
 		}
 	} else {
 	} else {
 		// Simple case: Certificate has one IP and no subnets
 		// Simple case: Certificate has one IP and no subnets
-		if fp.RemoteIP != h.hostId {
+		if fp.RemoteIP != h.vpnIp {
 			f.metrics(incoming).droppedRemoteIP.Inc(1)
 			f.metrics(incoming).droppedRemoteIP.Inc(1)
 			return ErrInvalidRemoteIP
 			return ErrInvalidRemoteIP
 		}
 		}
@@ -462,7 +411,7 @@ func (f *Firewall) EmitStats() {
 	metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
 	metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
 }
 }
 
 
-func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool {
+func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
 	if localCache != nil {
 	if localCache != nil {
 		if _, ok := localCache[fp]; ok {
 		if _, ok := localCache[fp]; ok {
 			return true
 			return true
@@ -520,14 +469,14 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
 	}
 	}
 
 
 	switch fp.Protocol {
 	switch fp.Protocol {
-	case fwProtoTCP:
+	case firewall.ProtoTCP:
 		c.Expires = time.Now().Add(f.TCPTimeout)
 		c.Expires = time.Now().Add(f.TCPTimeout)
 		if incoming {
 		if incoming {
 			f.checkTCPRTT(c, packet)
 			f.checkTCPRTT(c, packet)
 		} else {
 		} else {
 			setTCPRTTTracking(c, packet)
 			setTCPRTTTracking(c, packet)
 		}
 		}
-	case fwProtoUDP:
+	case firewall.ProtoUDP:
 		c.Expires = time.Now().Add(f.UDPTimeout)
 		c.Expires = time.Now().Add(f.UDPTimeout)
 	default:
 	default:
 		c.Expires = time.Now().Add(f.DefaultTimeout)
 		c.Expires = time.Now().Add(f.DefaultTimeout)
@@ -542,17 +491,17 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
 	return true
 	return true
 }
 }
 
 
-func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
+func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
 	var timeout time.Duration
 	var timeout time.Duration
 	c := &conn{}
 	c := &conn{}
 
 
 	switch fp.Protocol {
 	switch fp.Protocol {
-	case fwProtoTCP:
+	case firewall.ProtoTCP:
 		timeout = f.TCPTimeout
 		timeout = f.TCPTimeout
 		if !incoming {
 		if !incoming {
 			setTCPRTTTracking(c, packet)
 			setTCPRTTTracking(c, packet)
 		}
 		}
-	case fwProtoUDP:
+	case firewall.ProtoUDP:
 		timeout = f.UDPTimeout
 		timeout = f.UDPTimeout
 	default:
 	default:
 		timeout = f.DefaultTimeout
 		timeout = f.DefaultTimeout
@@ -575,7 +524,7 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
 
 
 // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
 // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
 // Caller must own the connMutex lock!
 // Caller must own the connMutex lock!
-func (f *Firewall) evict(p FirewallPacket) {
+func (f *Firewall) evict(p firewall.Packet) {
 	//TODO: report a stat if the tcp rtt tracking was never resolved?
 	//TODO: report a stat if the tcp rtt tracking was never resolved?
 	// Are we still tracking this conn?
 	// Are we still tracking this conn?
 	conntrack := f.Conntrack
 	conntrack := f.Conntrack
@@ -596,21 +545,21 @@ func (f *Firewall) evict(p FirewallPacket) {
 	delete(conntrack.Conns, p)
 	delete(conntrack.Conns, p)
 }
 }
 
 
-func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
 	if ft.AnyProto.match(p, incoming, c, caPool) {
 	if ft.AnyProto.match(p, incoming, c, caPool) {
 		return true
 		return true
 	}
 	}
 
 
 	switch p.Protocol {
 	switch p.Protocol {
-	case fwProtoTCP:
+	case firewall.ProtoTCP:
 		if ft.TCP.match(p, incoming, c, caPool) {
 		if ft.TCP.match(p, incoming, c, caPool) {
 			return true
 			return true
 		}
 		}
-	case fwProtoUDP:
+	case firewall.ProtoUDP:
 		if ft.UDP.match(p, incoming, c, caPool) {
 		if ft.UDP.match(p, incoming, c, caPool) {
 			return true
 			return true
 		}
 		}
-	case fwProtoICMP:
+	case firewall.ProtoICMP:
 		if ft.ICMP.match(p, incoming, c, caPool) {
 		if ft.ICMP.match(p, incoming, c, caPool) {
 			return true
 			return true
 		}
 		}
@@ -640,7 +589,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
 	return nil
 	return nil
 }
 }
 
 
-func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
 	// We don't have any allowed ports, bail
 	// We don't have any allowed ports, bail
 	if fp == nil {
 	if fp == nil {
 		return false
 		return false
@@ -649,7 +598,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
 	var port int32
 	var port int32
 
 
 	if p.Fragment {
 	if p.Fragment {
-		port = fwPortFragment
+		port = firewall.PortFragment
 	} else if incoming {
 	} else if incoming {
 		port = int32(p.LocalPort)
 		port = int32(p.LocalPort)
 	} else {
 	} else {
@@ -660,7 +609,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
 		return true
 		return true
 	}
 	}
 
 
-	return fp[fwPortAny].match(p, c, caPool)
+	return fp[firewall.PortAny].match(p, c, caPool)
 }
 }
 
 
 func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
 func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
@@ -668,7 +617,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
 		return &FirewallRule{
 		return &FirewallRule{
 			Hosts:  make(map[string]struct{}),
 			Hosts:  make(map[string]struct{}),
 			Groups: make([][]string, 0),
 			Groups: make([][]string, 0),
-			CIDR:   NewCIDRTree(),
+			CIDR:   cidr.NewTree4(),
 		}
 		}
 	}
 	}
 
 
@@ -703,7 +652,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
 	return nil
 	return nil
 }
 }
 
 
-func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
 	if fc == nil {
 	if fc == nil {
 		return false
 		return false
 	}
 	}
@@ -736,7 +685,7 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
 		// If it's any we need to wipe out any pre-existing rules to save on memory
 		// If it's any we need to wipe out any pre-existing rules to save on memory
 		fr.Groups = make([][]string, 0)
 		fr.Groups = make([][]string, 0)
 		fr.Hosts = make(map[string]struct{})
 		fr.Hosts = make(map[string]struct{})
-		fr.CIDR = NewCIDRTree()
+		fr.CIDR = cidr.NewTree4()
 	} else {
 	} else {
 		if len(groups) > 0 {
 		if len(groups) > 0 {
 			fr.Groups = append(fr.Groups, groups)
 			fr.Groups = append(fr.Groups, groups)
@@ -776,7 +725,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
 	return false
 	return false
 }
 }
 
 
-func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool {
+func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
 	if fr == nil {
 	if fr == nil {
 		return false
 		return false
 	}
 	}
@@ -885,12 +834,12 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
 
 
 func parsePort(s string) (startPort, endPort int32, err error) {
 func parsePort(s string) (startPort, endPort int32, err error) {
 	if s == "any" {
 	if s == "any" {
-		startPort = fwPortAny
-		endPort = fwPortAny
+		startPort = firewall.PortAny
+		endPort = firewall.PortAny
 
 
 	} else if s == "fragment" {
 	} else if s == "fragment" {
-		startPort = fwPortFragment
-		endPort = fwPortFragment
+		startPort = firewall.PortFragment
+		endPort = firewall.PortFragment
 
 
 	} else if strings.Contains(s, `-`) {
 	} else if strings.Contains(s, `-`) {
 		sPorts := strings.SplitN(s, `-`, 2)
 		sPorts := strings.SplitN(s, `-`, 2)
@@ -914,8 +863,8 @@ func parsePort(s string) (startPort, endPort int32, err error) {
 		startPort = int32(rStartPort)
 		startPort = int32(rStartPort)
 		endPort = int32(rEndPort)
 		endPort = int32(rEndPort)
 
 
-		if startPort == fwPortAny {
-			endPort = fwPortAny
+		if startPort == firewall.PortAny {
+			endPort = firewall.PortAny
 		}
 		}
 
 
 	} else {
 	} else {
@@ -968,54 +917,3 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
 	c.Seq = 0
 	c.Seq = 0
 	return true
 	return true
 }
 }
-
-// ConntrackCache is used as a local routine cache to know if a given flow
-// has been seen in the conntrack table.
-type ConntrackCache map[FirewallPacket]struct{}
-
-type ConntrackCacheTicker struct {
-	cacheV    uint64
-	cacheTick uint64
-
-	cache ConntrackCache
-}
-
-func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
-	if d == 0 {
-		return nil
-	}
-
-	c := &ConntrackCacheTicker{
-		cache: ConntrackCache{},
-	}
-
-	go c.tick(d)
-
-	return c
-}
-
-func (c *ConntrackCacheTicker) tick(d time.Duration) {
-	for {
-		time.Sleep(d)
-		atomic.AddUint64(&c.cacheTick, 1)
-	}
-}
-
-// Get checks if the cache ticker has moved to the next version before returning
-// the map. If it has moved, we reset the map.
-func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
-	if c == nil {
-		return nil
-	}
-	if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
-		c.cacheV = tick
-		if ll := len(c.cache); ll > 0 {
-			if l.Level == logrus.DebugLevel {
-				l.WithField("len", ll).Debug("resetting conntrack cache")
-			}
-			c.cache = make(ConntrackCache, ll)
-		}
-	}
-
-	return c.cache
-}

+ 59 - 0
firewall/cache.go

@@ -0,0 +1,59 @@
+package firewall
+
+import (
+	"sync/atomic"
+	"time"
+
+	"github.com/sirupsen/logrus"
+)
+
+// ConntrackCache is used as a local routine cache to know if a given flow
+// has been seen in the conntrack table.
+type ConntrackCache map[Packet]struct{}
+
+type ConntrackCacheTicker struct {
+	cacheV    uint64
+	cacheTick uint64
+
+	cache ConntrackCache
+}
+
+func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
+	if d == 0 {
+		return nil
+	}
+
+	c := &ConntrackCacheTicker{
+		cache: ConntrackCache{},
+	}
+
+	go c.tick(d)
+
+	return c
+}
+
+func (c *ConntrackCacheTicker) tick(d time.Duration) {
+	for {
+		time.Sleep(d)
+		atomic.AddUint64(&c.cacheTick, 1)
+	}
+}
+
+// Get checks if the cache ticker has moved to the next version before returning
+// the map. If it has moved, we reset the map.
+func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
+	if c == nil {
+		return nil
+	}
+	if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
+		c.cacheV = tick
+		if ll := len(c.cache); ll > 0 {
+			if l.Level == logrus.DebugLevel {
+				l.WithField("len", ll).Debug("resetting conntrack cache")
+			}
+			c.cache = make(ConntrackCache, ll)
+		}
+	}
+
+	return c.cache
+}

+ 62 - 0
firewall/packet.go

@@ -0,0 +1,62 @@
+package firewall
+
+import (
+	"encoding/json"
+	"fmt"
+
+	"github.com/slackhq/nebula/iputil"
+)
+
+type m map[string]interface{}
+
+const (
+	ProtoAny  = 0 // When we want to handle HOPOPT (0) we can change this, if ever
+	ProtoTCP  = 6
+	ProtoUDP  = 17
+	ProtoICMP = 1
+
+	PortAny      = 0  // Special value for matching `port: any`
+	PortFragment = -1 // Special value for matching `port: fragment`
+)
+
+type Packet struct {
+	LocalIP    iputil.VpnIp
+	RemoteIP   iputil.VpnIp
+	LocalPort  uint16
+	RemotePort uint16
+	Protocol   uint8
+	Fragment   bool
+}
+
+func (fp *Packet) Copy() *Packet {
+	return &Packet{
+		LocalIP:    fp.LocalIP,
+		RemoteIP:   fp.RemoteIP,
+		LocalPort:  fp.LocalPort,
+		RemotePort: fp.RemotePort,
+		Protocol:   fp.Protocol,
+		Fragment:   fp.Fragment,
+	}
+}
+
+func (fp Packet) MarshalJSON() ([]byte, error) {
+	var proto string
+	switch fp.Protocol {
+	case ProtoTCP:
+		proto = "tcp"
+	case ProtoICMP:
+		proto = "icmp"
+	case ProtoUDP:
+		proto = "udp"
+	default:
+		proto = fmt.Sprintf("unknown %v", fp.Protocol)
+	}
+	return json.Marshal(m{
+		"LocalIP":    fp.LocalIP.String(),
+		"RemoteIP":   fp.RemoteIP.String(),
+		"LocalPort":  fp.LocalPort,
+		"RemotePort": fp.RemotePort,
+		"Protocol":   proto,
+		"Fragment":   fp.Fragment,
+	})
+}

+ 101 - 109
firewall_test.go

@@ -11,11 +11,15 @@ import (
 
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
 func TestNewFirewall(t *testing.T) {
 func TestNewFirewall(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	c := &cert.NebulaCertificate{}
 	c := &cert.NebulaCertificate{}
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	conntrack := fw.Conntrack
 	conntrack := fw.Conntrack
@@ -54,7 +58,7 @@ func TestNewFirewall(t *testing.T) {
 }
 }
 
 
 func TestFirewall_AddRule(t *testing.T) {
 func TestFirewall_AddRule(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ob := &bytes.Buffer{}
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
@@ -65,92 +69,80 @@ func TestFirewall_AddRule(t *testing.T) {
 
 
 	_, ti, _ := net.ParseCIDR("1.2.3.4/32")
 	_, ti, _ := net.ParseCIDR("1.2.3.4/32")
 
 
-	assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", ""))
 	// An empty rule is any
 	// An empty rule is any
 	assert.True(t, fw.InRules.TCP[1].Any.Any)
 	assert.True(t, fw.InRules.TCP[1].Any.Any)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
-	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
-	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
-	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
 	assert.False(t, fw.InRules.UDP[1].Any.Any)
 	assert.False(t, fw.InRules.UDP[1].Any.Any)
 	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
 	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
 	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
 	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
-	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left)
-	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
-	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
 	assert.False(t, fw.InRules.ICMP[1].Any.Any)
 	assert.False(t, fw.InRules.ICMP[1].Any.Any)
 	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
 	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
 	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
-	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left)
-	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
-	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", ""))
 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
-	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
+	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
 	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
 	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
 
 
 	// Set any and clear fields
 	// Set any and clear fields
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
 	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
 	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
 	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
 	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
-	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP)))
+	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
 
 
 	// run twice just to make sure
 	// run twice just to make sure
 	//TODO: these ANY rules should clear the CA firewall portion
 	//TODO: these ANY rules should clear the CA firewall portion
-	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
-	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
-	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left)
-	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
-	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
 	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
-	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
 
 	// Test error conditions
 	// Test error conditions
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
 	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
-	assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
+	assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", ""))
 }
 }
 
 
 func TestFirewall_Drop(t *testing.T) {
 func TestFirewall_Drop(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ob := &bytes.Buffer{}
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
-	p := FirewallPacket{
-		ip2int(net.IPv4(1, 2, 3, 4)),
-		ip2int(net.IPv4(1, 2, 3, 4)),
+	p := firewall.Packet{
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
 		10,
 		10,
 		90,
 		90,
-		fwProtoUDP,
+		firewall.ProtoUDP,
 		false,
 		false,
 	}
 	}
 
 
@@ -172,12 +164,12 @@ func TestFirewall_Drop(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 			peerCert: &c,
 		},
 		},
-		hostId: ip2int(ipNet.IP),
+		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
 	}
 	}
 	h.CreateRemoteCIDR(&c)
 	h.CreateRemoteCIDR(&c)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	// Drop outbound
 	// Drop outbound
@@ -190,34 +182,34 @@ func TestFirewall_Drop(t *testing.T) {
 
 
 	// test remote mismatch
 	// test remote mismatch
 	oldRemote := p.RemoteIP
 	oldRemote := p.RemoteIP
-	p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
+	p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
 	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
 	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
 	p.RemoteIP = oldRemote
 	p.RemoteIP = oldRemote
 
 
 	// ensure signer doesn't get in the way of group checks
 	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 
 
 	// test caSha doesn't drop on match
 	// test caSha doesn't drop on match
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 
 
 	// ensure ca name doesn't get in the way of group checks
 	// ensure ca name doesn't get in the way of group checks
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 
 
 	// test caName doesn't drop on match
 	// test caName doesn't drop on match
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 }
 }
 
 
@@ -237,14 +229,14 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	b.Run("fail on proto", func(b *testing.B) {
 	b.Run("fail on proto", func(b *testing.B) {
 		c := &cert.NebulaCertificate{}
 		c := &cert.NebulaCertificate{}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoUDP}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)
 		}
 		}
 	})
 	})
 
 
 	b.Run("fail on port", func(b *testing.B) {
 	b.Run("fail on port", func(b *testing.B) {
 		c := &cert.NebulaCertificate{}
 		c := &cert.NebulaCertificate{}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 1}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)
 		}
 		}
 	})
 	})
 
 
@@ -258,7 +250,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 			},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
 		}
 		}
 	})
 	})
 
 
@@ -270,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 			},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
 		}
 		}
 	})
 	})
 
 
@@ -282,12 +274,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 			},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
 		}
 		}
 	})
 	})
 
 
 	b.Run("pass on ip", func(b *testing.B) {
 	b.Run("pass on ip", func(b *testing.B) {
-		ip := ip2int(net.IPv4(172, 1, 1, 1))
+		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
 		c := &cert.NebulaCertificate{
 		c := &cert.NebulaCertificate{
 			Details: cert.NebulaCertificateDetails{
 			Details: cert.NebulaCertificateDetails{
 				InvertedGroups: map[string]struct{}{"nope": {}},
 				InvertedGroups: map[string]struct{}{"nope": {}},
@@ -295,14 +287,14 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 			},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
 		}
 		}
 	})
 	})
 
 
 	_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
 	_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
 
 
 	b.Run("pass on ip with any port", func(b *testing.B) {
 	b.Run("pass on ip with any port", func(b *testing.B) {
-		ip := ip2int(net.IPv4(172, 1, 1, 1))
+		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
 		c := &cert.NebulaCertificate{
 		c := &cert.NebulaCertificate{
 			Details: cert.NebulaCertificateDetails{
 			Details: cert.NebulaCertificateDetails{
 				InvertedGroups: map[string]struct{}{"nope": {}},
 				InvertedGroups: map[string]struct{}{"nope": {}},
@@ -310,22 +302,22 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 			},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
 		}
 		}
 	})
 	})
 }
 }
 
 
 func TestFirewall_Drop2(t *testing.T) {
 func TestFirewall_Drop2(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ob := &bytes.Buffer{}
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
-	p := FirewallPacket{
-		ip2int(net.IPv4(1, 2, 3, 4)),
-		ip2int(net.IPv4(1, 2, 3, 4)),
+	p := firewall.Packet{
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
 		10,
 		10,
 		90,
 		90,
-		fwProtoUDP,
+		firewall.ProtoUDP,
 		false,
 		false,
 	}
 	}
 
 
@@ -345,7 +337,7 @@ func TestFirewall_Drop2(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 			peerCert: &c,
 		},
 		},
-		hostId: ip2int(ipNet.IP),
+		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
 	}
 	}
 	h.CreateRemoteCIDR(&c)
 	h.CreateRemoteCIDR(&c)
 
 
@@ -364,7 +356,7 @@ func TestFirewall_Drop2(t *testing.T) {
 	h1.CreateRemoteCIDR(&c1)
 	h1.CreateRemoteCIDR(&c1)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	// h1/c1 lacks the proper groups
 	// h1/c1 lacks the proper groups
@@ -375,16 +367,16 @@ func TestFirewall_Drop2(t *testing.T) {
 }
 }
 
 
 func TestFirewall_Drop3(t *testing.T) {
 func TestFirewall_Drop3(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ob := &bytes.Buffer{}
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
-	p := FirewallPacket{
-		ip2int(net.IPv4(1, 2, 3, 4)),
-		ip2int(net.IPv4(1, 2, 3, 4)),
+	p := firewall.Packet{
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
 		1,
 		1,
 		1,
 		1,
-		fwProtoUDP,
+		firewall.ProtoUDP,
 		false,
 		false,
 	}
 	}
 
 
@@ -411,7 +403,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c1,
 			peerCert: &c1,
 		},
 		},
-		hostId: ip2int(ipNet.IP),
+		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
 	}
 	}
 	h1.CreateRemoteCIDR(&c1)
 	h1.CreateRemoteCIDR(&c1)
 
 
@@ -426,7 +418,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c2,
 			peerCert: &c2,
 		},
 		},
-		hostId: ip2int(ipNet.IP),
+		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
 	}
 	}
 	h2.CreateRemoteCIDR(&c2)
 	h2.CreateRemoteCIDR(&c2)
 
 
@@ -441,13 +433,13 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c3,
 			peerCert: &c3,
 		},
 		},
-		hostId: ip2int(ipNet.IP),
+		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
 	}
 	}
 	h3.CreateRemoteCIDR(&c3)
 	h3.CreateRemoteCIDR(&c3)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	// c1 should pass because host match
 	// c1 should pass because host match
@@ -461,16 +453,16 @@ func TestFirewall_Drop3(t *testing.T) {
 }
 }
 
 
 func TestFirewall_DropConntrackReload(t *testing.T) {
 func TestFirewall_DropConntrackReload(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ob := &bytes.Buffer{}
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
-	p := FirewallPacket{
-		ip2int(net.IPv4(1, 2, 3, 4)),
-		ip2int(net.IPv4(1, 2, 3, 4)),
+	p := firewall.Packet{
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
 		10,
 		10,
 		90,
 		90,
-		fwProtoUDP,
+		firewall.ProtoUDP,
 		false,
 		false,
 	}
 	}
 
 
@@ -492,12 +484,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 			peerCert: &c,
 		},
 		},
-		hostId: ip2int(ipNet.IP),
+		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
 	}
 	}
 	h.CreateRemoteCIDR(&c)
 	h.CreateRemoteCIDR(&c)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	// Drop outbound
 	// Drop outbound
@@ -510,7 +502,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 
 	oldFw := fw
 	oldFw := fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
 
@@ -519,7 +511,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 
 	oldFw = fw
 	oldFw = fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
 
@@ -643,28 +635,28 @@ func Test_parsePort(t *testing.T) {
 }
 }
 
 
 func TestNewFirewallFromConfig(t *testing.T) {
 func TestNewFirewallFromConfig(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	// Test a bad rule definition
 	// Test a bad rule definition
 	c := &cert.NebulaCertificate{}
 	c := &cert.NebulaCertificate{}
-	conf := NewConfig(l)
+	conf := config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
 	_, err := NewFirewallFromConfig(l, c, conf)
 	_, err := NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
 	assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
 
 
 	// Test both port and code
 	// Test both port and code
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
 	assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
 
 
 	// Test missing host, group, cidr, ca_name and ca_sha
 	// Test missing host, group, cidr, ca_name and ca_sha
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
 	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
 
 
 	// Test code/port error
 	// Test code/port error
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
 	assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
@@ -674,91 +666,91 @@ func TestNewFirewallFromConfig(t *testing.T) {
 	assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
 	assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
 
 
 	// Test proto error
 	// Test proto error
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
 	assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
 
 
 	// Test cidr parse error
 	// Test cidr parse error
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
 	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
 
 
 	// Test both group and groups
 	// Test both group and groups
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
 	assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
 }
 }
 
 
 func TestAddFirewallRulesFromConfig(t *testing.T) {
 func TestAddFirewallRulesFromConfig(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	// Test adding tcp rule
 	// Test adding tcp rule
-	conf := NewConfig(l)
+	conf := config.NewC(l)
 	mf := &mockFirewall{}
 	mf := &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
 
 
 	// Test adding udp rule
 	// Test adding udp rule
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
 
 
 	// Test adding icmp rule
 	// Test adding icmp rule
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
 
 
 	// Test adding any rule
 	// Test adding any rule
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
 
 
 	// Test adding rule with ca_sha
 	// Test adding rule with ca_sha
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
 
 
 	// Test adding rule with ca_name
 	// Test adding rule with ca_name
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
 
 
 	// Test single group
 	// Test single group
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
 
 
 	// Test single groups
 	// Test single groups
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
 
 
 	// Test multiple AND groups
 	// Test multiple AND groups
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
 
 
 	// Test Add error
 	// Test Add error
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	mf.nextCallReturn = errors.New("test error")
 	mf.nextCallReturn = errors.New("test error")
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
@@ -857,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) {
 }
 }
 
 
 func TestFirewall_convertRule(t *testing.T) {
 func TestFirewall_convertRule(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ob := &bytes.Buffer{}
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
@@ -929,6 +921,6 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
 
 
 func resetConntrack(fw *Firewall) {
 func resetConntrack(fw *Firewall) {
 	fw.Conntrack.Lock()
 	fw.Conntrack.Lock()
-	fw.Conntrack.Conns = map[FirewallPacket]*conn{}
+	fw.Conntrack.Conns = map[firewall.Packet]*conn{}
 	fw.Conntrack.Unlock()
 	fw.Conntrack.Unlock()
 }
 }

+ 5 - 5
handshake.go

@@ -1,11 +1,11 @@
 package nebula
 package nebula
 
 
-const (
-	handshakeIXPSK0 = 0
-	handshakeXXPSK0 = 1
+import (
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/udp"
 )
 )
 
 
-func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
+func HandleIncomingHandshake(f *Interface, addr *udp.Addr, packet []byte, h *header.H, hostinfo *HostInfo) {
 	// First remote allow list check before we know the vpnIp
 	// First remote allow list check before we know the vpnIp
 	if !f.lightHouse.remoteAllowList.AllowUnknownVpnIp(addr.IP) {
 	if !f.lightHouse.remoteAllowList.AllowUnknownVpnIp(addr.IP) {
 		f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 		f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
@@ -13,7 +13,7 @@ func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Head
 	}
 	}
 
 
 	switch h.Subtype {
 	switch h.Subtype {
-	case handshakeIXPSK0:
+	case header.HandshakeIXPSK0:
 		switch h.MessageCounter {
 		switch h.MessageCounter {
 		case 1:
 		case 1:
 			ixHandshakeStage1(f, addr, packet, h)
 			ixHandshakeStage1(f, addr, packet, h)

+ 59 - 56
handshake_ix.go

@@ -6,13 +6,16 @@ import (
 
 
 	"github.com/flynn/noise"
 	"github.com/flynn/noise"
 	"github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/proto"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 // NOISE IX Handshakes
 // NOISE IX Handshakes
 
 
 // This function constructs a handshake packet, but does not actually send it
 // This function constructs a handshake packet, but does not actually send it
 // Sending is done by the handshake manager
 // Sending is done by the handshake manager
-func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
+func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
 	// This queries the lighthouse if we don't know a remote for the host
 	// This queries the lighthouse if we don't know a remote for the host
 	// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
 	// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
 	// more quickly, effect is a quicker handshake.
 	// more quickly, effect is a quicker handshake.
@@ -22,7 +25,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 
 
 	err := f.handshakeManager.AddIndexHostInfo(hostinfo)
 	err := f.handshakeManager.AddIndexHostInfo(hostinfo)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
+		f.l.WithError(err).WithField("vpnIp", vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
 		return
 		return
 	}
 	}
@@ -43,17 +46,17 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 	hsBytes, err = proto.Marshal(hs)
 	hsBytes, err = proto.Marshal(hs)
 
 
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
+		f.l.WithError(err).WithField("vpnIp", vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 		return
 		return
 	}
 	}
 
 
-	header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1)
+	h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
 	atomic.AddUint64(&ci.atomicMessageCounter, 1)
 	atomic.AddUint64(&ci.atomicMessageCounter, 1)
 
 
-	msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
+	msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
+		f.l.WithError(err).WithField("vpnIp", vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return
 		return
 	}
 	}
@@ -67,12 +70,12 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 	hostinfo.handshakeStart = time.Now()
 	hostinfo.handshakeStart = time.Now()
 }
 }
 
 
-func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
+func ixHandshakeStage1(f *Interface, addr *udp.Addr, packet []byte, h *header.H) {
 	ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
 	ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
 	// Mark packet 1 as seen so it doesn't show up as missed
 	// Mark packet 1 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 1)
 	ci.window.Update(f.l, 1)
 
 
-	msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
+	msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
 	if err != nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
 		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
@@ -97,13 +100,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 			Info("Invalid certificate from host")
 			Info("Invalid certificate from host")
 		return
 		return
 	}
 	}
-	vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
+	vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
 	certName := remoteCert.Details.Name
 	certName := remoteCert.Details.Name
 	fingerprint, _ := remoteCert.Sha256Sum()
 	fingerprint, _ := remoteCert.Sha256Sum()
 	issuer := remoteCert.Details.Issuer
 	issuer := remoteCert.Details.Issuer
 
 
-	if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) {
-		f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+	if vpnIp == f.myVpnIp {
+		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -111,14 +114,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		return
 		return
 	}
 	}
 
 
-	if !f.lightHouse.remoteAllowList.Allow(vpnIP, addr.IP) {
-		f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+	if !f.lightHouse.remoteAllowList.Allow(vpnIp, addr.IP) {
+		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 		return
 		return
 	}
 	}
 
 
 	myIndex, err := generateIndex(f.l)
 	myIndex, err := generateIndex(f.l)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -130,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		ConnectionState:   ci,
 		ConnectionState:   ci,
 		localIndexId:      myIndex,
 		localIndexId:      myIndex,
 		remoteIndexId:     hs.Details.InitiatorIndex,
 		remoteIndexId:     hs.Details.InitiatorIndex,
-		hostId:            vpnIP,
+		vpnIp:             vpnIp,
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		lastHandshakeTime: hs.Details.Time,
 		lastHandshakeTime: hs.Details.Time,
 	}
 	}
@@ -138,7 +141,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	hostinfo.Lock()
 	hostinfo.Lock()
 	defer hostinfo.Unlock()
 	defer hostinfo.Unlock()
 
 
-	f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
 		WithField("issuer", issuer).
@@ -153,7 +156,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 
 
 	hsBytes, err := proto.Marshal(hs)
 	hsBytes, err := proto.Marshal(hs)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -161,17 +164,17 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		return
 		return
 	}
 	}
 
 
-	header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
-	msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
+	nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
+	msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return
 		return
 	} else if dKey == nil || eKey == nil {
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -179,8 +182,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		return
 		return
 	}
 	}
 
 
-	hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:]))
-	copy(hostinfo.HandshakePacket[0], packet[HeaderLen:])
+	hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:]))
+	copy(hostinfo.HandshakePacket[0], packet[header.Len:])
 
 
 	// Regardless of whether you are the sender or receiver, you should arrive here
 	// Regardless of whether you are the sender or receiver, you should arrive here
 	// and complete standing up the connection.
 	// and complete standing up the connection.
@@ -195,12 +198,12 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	ci.dKey = NewNebulaCipherState(dKey)
 	ci.dKey = NewNebulaCipherState(dKey)
 	ci.eKey = NewNebulaCipherState(eKey)
 	ci.eKey = NewNebulaCipherState(eKey)
 
 
-	hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
+	hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
 	hostinfo.SetRemote(addr)
 	hostinfo.SetRemote(addr)
 	hostinfo.CreateRemoteCIDR(remoteCert)
 	hostinfo.CreateRemoteCIDR(remoteCert)
 
 
 	// Only overwrite existing record if we should win the handshake race
 	// Only overwrite existing record if we should win the handshake race
-	overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP)
+	overwrite := vpnIp > f.myVpnIp
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
 	if err != nil {
 	if err != nil {
 		switch err {
 		switch err {
@@ -214,27 +217,27 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
 			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
 				// Send a test packet to ensure the other side has also switched to
 				// Send a test packet to ensure the other side has also switched to
 				// the preferred remote
 				// the preferred remote
-				f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+				f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			}
 			}
 			existing.Unlock()
 			existing.Unlock()
 			hostinfo.Lock()
 			hostinfo.Lock()
 
 
 			msg = existing.HandshakePacket[2]
 			msg = existing.HandshakePacket[2]
-			f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
+			f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
 			err := f.outside.WriteTo(msg, addr)
 			err := f.outside.WriteTo(msg, addr)
 			if err != nil {
 			if err != nil {
-				f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
+				f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					WithError(err).Error("Failed to send handshake message")
 					WithError(err).Error("Failed to send handshake message")
 			} else {
 			} else {
-				f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
+				f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					Info("Handshake message sent")
 					Info("Handshake message sent")
 			}
 			}
 			return
 			return
 		case ErrExistingHostInfo:
 		case ErrExistingHostInfo:
 			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
 			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
-			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("oldHandshakeTime", existing.lastHandshakeTime).
 				WithField("oldHandshakeTime", existing.lastHandshakeTime).
 				WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
 				WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
@@ -245,22 +248,22 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 				Info("Handshake too old")
 				Info("Handshake too old")
 
 
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-			f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			return
 			return
 		case ErrLocalIndexCollision:
 		case ErrLocalIndexCollision:
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
-			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-				WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)).
+				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
 				Error("Failed to add HostInfo due to localIndex collision")
 				Error("Failed to add HostInfo due to localIndex collision")
 			return
 			return
 		case ErrExistingHandshake:
 		case ErrExistingHandshake:
 			// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
 			// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
-			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("issuer", issuer).
@@ -271,7 +274,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		default:
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
 			// And we forget to update it here
-			f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+			f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("issuer", issuer).
@@ -283,10 +286,10 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	}
 	}
 
 
 	// Do the send
 	// Do the send
-	f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
+	f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
 	err = f.outside.WriteTo(msg, addr)
 	err = f.outside.WriteTo(msg, addr)
 	if err != nil {
 	if err != nil {
-		f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -294,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithError(err).Error("Failed to send handshake")
 			WithError(err).Error("Failed to send handshake")
 	} else {
 	} else {
-		f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -309,7 +312,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	return
 	return
 }
 }
 
 
-func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool {
+func ixHandshakeStage2(f *Interface, addr *udp.Addr, hostinfo *HostInfo, packet []byte, h *header.H) bool {
 	if hostinfo == nil {
 	if hostinfo == nil {
 		// Nothing here to tear down, got a bogus stage 2 packet
 		// Nothing here to tear down, got a bogus stage 2 packet
 		return true
 		return true
@@ -318,14 +321,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 	hostinfo.Lock()
 	hostinfo.Lock()
 	defer hostinfo.Unlock()
 	defer hostinfo.Unlock()
 
 
-	if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) {
-		f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+	if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) {
+		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 		return false
 		return false
 	}
 	}
 
 
 	ci := hostinfo.ConnectionState
 	ci := hostinfo.ConnectionState
 	if ci.ready {
 	if ci.ready {
-		f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Info("Handshake is already complete")
 			Info("Handshake is already complete")
 
 
@@ -333,16 +336,16 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) {
 		if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) {
 			// Send a test packet to ensure the other side has also switched to
 			// Send a test packet to ensure the other side has also switched to
 			// the preferred remote
 			// the preferred remote
-			f.SendMessageToVpnIp(test, testRequest, hostinfo.hostId, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			f.SendMessageToVpnIp(header.Test, header.TestRequest, hostinfo.vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 		}
 		}
 
 
 		// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
 		// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
 		return false
 		return false
 	}
 	}
 
 
-	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
+	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Error("Failed to call noise.ReadMessage")
 			Error("Failed to call noise.ReadMessage")
 
 
@@ -351,7 +354,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		// near future
 		// near future
 		return false
 		return false
 	} else if dKey == nil || eKey == nil {
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Error("Noise did not arrive at a key")
 			Error("Noise did not arrive at a key")
 
 
@@ -363,7 +366,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 	hs := &NebulaHandshake{}
 	hs := &NebulaHandshake{}
 	err = proto.Unmarshal(msg, hs)
 	err = proto.Unmarshal(msg, hs)
 	if err != nil || hs.Details == nil {
 	if err != nil || hs.Details == nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 
 
 		// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
 		// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
@@ -372,7 +375,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 
 
 	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
 	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Error("Invalid certificate from host")
 			Error("Invalid certificate from host")
 
 
@@ -380,14 +383,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		return true
 		return true
 	}
 	}
 
 
-	vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
+	vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
 	certName := remoteCert.Details.Name
 	certName := remoteCert.Details.Name
 	fingerprint, _ := remoteCert.Sha256Sum()
 	fingerprint, _ := remoteCert.Sha256Sum()
 	issuer := remoteCert.Details.Issuer
 	issuer := remoteCert.Details.Issuer
 
 
 	// Ensure the right host responded
 	// Ensure the right host responded
-	if vpnIP != hostinfo.hostId {
-		f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)).
+	if vpnIp != hostinfo.vpnIp {
+		f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
 			WithField("udpAddr", addr).WithField("certName", certName).
 			WithField("udpAddr", addr).WithField("certName", certName).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Info("Incorrect host responded to handshake")
 			Info("Incorrect host responded to handshake")
@@ -397,7 +400,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 
 
 		// Create a new hostinfo/handshake for the intended vpn ip
 		// Create a new hostinfo/handshake for the intended vpn ip
 		//TODO: this adds it to the timer wheel in a way that aggressively retries
 		//TODO: this adds it to the timer wheel in a way that aggressively retries
-		newHostInfo := f.getOrHandshake(hostinfo.hostId)
+		newHostInfo := f.getOrHandshake(hostinfo.vpnIp)
 		newHostInfo.Lock()
 		newHostInfo.Lock()
 
 
 		// Block the current used address
 		// Block the current used address
@@ -405,9 +408,9 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		newHostInfo.remotes.BlockRemote(addr)
 		newHostInfo.remotes.BlockRemote(addr)
 
 
 		// Get the correct remote list for the host we did handshake with
 		// Get the correct remote list for the host we did handshake with
-		hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
+		hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
 
 
-		f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", IntIp(vpnIP)).
+		f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
 			WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
 			WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
 			Info("Blocked addresses for handshakes")
 			Info("Blocked addresses for handshakes")
 
 
@@ -418,7 +421,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		hostinfo.ConnectionState.queueLock.Unlock()
 		hostinfo.ConnectionState.queueLock.Unlock()
 
 
 		// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
 		// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
-		hostinfo.hostId = vpnIP
+		hostinfo.vpnIp = vpnIp
 		f.sendCloseTunnel(hostinfo)
 		f.sendCloseTunnel(hostinfo)
 		newHostInfo.Unlock()
 		newHostInfo.Unlock()
 
 
@@ -429,7 +432,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 	ci.window.Update(f.l, 2)
 	ci.window.Update(f.l, 2)
 
 
 	duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
 	duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
-	f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
 		WithField("issuer", issuer).

+ 36 - 33
handshake_manager.go

@@ -11,6 +11,9 @@ import (
 
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 const (
 const (
@@ -39,7 +42,7 @@ type HandshakeManager struct {
 	pendingHostMap         *HostMap
 	pendingHostMap         *HostMap
 	mainHostMap            *HostMap
 	mainHostMap            *HostMap
 	lightHouse             *LightHouse
 	lightHouse             *LightHouse
-	outside                *udpConn
+	outside                *udp.Conn
 	config                 HandshakeConfig
 	config                 HandshakeConfig
 	OutboundHandshakeTimer *SystemTimerWheel
 	OutboundHandshakeTimer *SystemTimerWheel
 	messageMetrics         *MessageMetrics
 	messageMetrics         *MessageMetrics
@@ -47,18 +50,18 @@ type HandshakeManager struct {
 	metricTimedOut         metrics.Counter
 	metricTimedOut         metrics.Counter
 	l                      *logrus.Logger
 	l                      *logrus.Logger
 
 
-	// can be used to trigger outbound handshake for the given vpnIP
-	trigger chan uint32
+	// can be used to trigger outbound handshake for the given vpnIp
+	trigger chan iputil.VpnIp
 }
 }
 
 
-func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
+func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udp.Conn, config HandshakeConfig) *HandshakeManager {
 	return &HandshakeManager{
 	return &HandshakeManager{
 		pendingHostMap:         NewHostMap(l, "pending", tunCidr, preferredRanges),
 		pendingHostMap:         NewHostMap(l, "pending", tunCidr, preferredRanges),
 		mainHostMap:            mainHostMap,
 		mainHostMap:            mainHostMap,
 		lightHouse:             lightHouse,
 		lightHouse:             lightHouse,
 		outside:                outside,
 		outside:                outside,
 		config:                 config,
 		config:                 config,
-		trigger:                make(chan uint32, config.triggerBuffer),
+		trigger:                make(chan iputil.VpnIp, config.triggerBuffer),
 		OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
 		OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
 		messageMetrics:         config.messageMetrics,
 		messageMetrics:         config.messageMetrics,
 		metricInitiated:        metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
 		metricInitiated:        metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
@@ -67,7 +70,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
 	}
 	}
 }
 }
 
 
-func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
+func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
 	clockSource := time.NewTicker(c.config.tryInterval)
 	clockSource := time.NewTicker(c.config.tryInterval)
 	defer clockSource.Stop()
 	defer clockSource.Stop()
 
 
@@ -76,7 +79,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
 		case <-ctx.Done():
 		case <-ctx.Done():
 			return
 			return
 		case vpnIP := <-c.trigger:
 		case vpnIP := <-c.trigger:
-			c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
+			c.l.WithField("vpnIp", vpnIP).Debug("HandshakeManager: triggered")
 			c.handleOutbound(vpnIP, f, true)
 			c.handleOutbound(vpnIP, f, true)
 		case now := <-clockSource.C:
 		case now := <-clockSource.C:
 			c.NextOutboundHandshakeTimerTick(now, f)
 			c.NextOutboundHandshakeTimerTick(now, f)
@@ -84,20 +87,20 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
 	}
 	}
 }
 }
 
 
-func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
+func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) {
 	c.OutboundHandshakeTimer.advance(now)
 	c.OutboundHandshakeTimer.advance(now)
 	for {
 	for {
 		ep := c.OutboundHandshakeTimer.Purge()
 		ep := c.OutboundHandshakeTimer.Purge()
 		if ep == nil {
 		if ep == nil {
 			break
 			break
 		}
 		}
-		vpnIP := ep.(uint32)
-		c.handleOutbound(vpnIP, f, false)
+		vpnIp := ep.(iputil.VpnIp)
+		c.handleOutbound(vpnIp, f, false)
 	}
 	}
 }
 }
 
 
-func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) {
-	hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
+func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
+	hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 	if err != nil {
 		return
 		return
 	}
 	}
@@ -115,7 +118,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
 	if !hostinfo.HandshakeReady {
 	if !hostinfo.HandshakeReady {
 		// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
 		// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
 		// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
 		// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
-		c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
 		return
 		return
 	}
 	}
 
 
@@ -143,21 +146,21 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
 	// Get a remotes object if we don't already have one.
 	// Get a remotes object if we don't already have one.
 	// This is mainly to protect us as this should never be the case
 	// This is mainly to protect us as this should never be the case
 	if hostinfo.remotes == nil {
 	if hostinfo.remotes == nil {
-		hostinfo.remotes = c.lightHouse.QueryCache(vpnIP)
+		hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
 	}
 	}
 
 
 	//TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped)
 	//TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped)
 	if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 {
 	if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 {
 		// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
 		// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
-		// Our vpnIP here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
+		// Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
 		// the learned public ip for them. Query again to short circuit the promotion counter
 		// the learned public ip for them. Query again to short circuit the promotion counter
-		c.lightHouse.QueryServer(vpnIP, f)
+		c.lightHouse.QueryServer(vpnIp, f)
 	}
 	}
 
 
 	// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
 	// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
-	var sentTo []*udpAddr
-	hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udpAddr, _ bool) {
-		c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
+	var sentTo []*udp.Addr
+	hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
+		c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
 		err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
 		err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
 		if err != nil {
 		if err != nil {
 			hostinfo.logger(c.l).WithField("udpAddr", addr).
 			hostinfo.logger(c.l).WithField("udpAddr", addr).
@@ -184,16 +187,16 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
 	// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
 	// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
 	if !lighthouseTriggered {
 	if !lighthouseTriggered {
 		//TODO: feel like we dupe handshake real fast in a tight loop, why?
 		//TODO: feel like we dupe handshake real fast in a tight loop, why?
-		c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
 	}
 	}
 }
 }
 
 
-func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
-	hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
+func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+	hostinfo := c.pendingHostMap.AddVpnIp(vpnIp)
 	// We lock here and use an array to insert items to prevent locking the
 	// We lock here and use an array to insert items to prevent locking the
 	// main receive thread for very long by waiting to add items to the pending map
 	// main receive thread for very long by waiting to add items to the pending map
 	//TODO: what lock?
 	//TODO: what lock?
-	c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
+	c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
 	c.metricInitiated.Inc(1)
 	c.metricInitiated.Inc(1)
 
 
 	return hostinfo
 	return hostinfo
@@ -208,12 +211,12 @@ var (
 
 
 // CheckAndComplete checks for any conflicts in the main and pending hostmap
 // CheckAndComplete checks for any conflicts in the main and pending hostmap
 // before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
 // before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
-
+//
 // ErrAlreadySeen if we already have an entry in the hostmap that has seen the
 // ErrAlreadySeen if we already have an entry in the hostmap that has seen the
 // exact same handshake packet
 // exact same handshake packet
 //
 //
 // ErrExistingHostInfo if we already have an entry in the hostmap for this
 // ErrExistingHostInfo if we already have an entry in the hostmap for this
-// VpnIP and the new handshake was older than the one we currently have
+// VpnIp and the new handshake was older than the one we currently have
 //
 //
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
 // hostmap for the hostinfo.localIndexId.
@@ -224,7 +227,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 	defer c.mainHostMap.Unlock()
 	defer c.mainHostMap.Unlock()
 
 
 	// Check if we already have a tunnel with this vpn ip
 	// Check if we already have a tunnel with this vpn ip
-	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
+	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
 	if found && existingHostInfo != nil {
 	if found && existingHostInfo != nil {
 		// Is it just a delayed handshake packet?
 		// Is it just a delayed handshake packet?
 		if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
 		if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
@@ -252,16 +255,16 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 	}
 	}
 
 
 	existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
 	existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
-	if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
+	if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
 		// We have a collision, but this can happen since we can't control
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
 		// the remote ID. Just log about the situation as a note.
 		hostinfo.logger(c.l).
 		hostinfo.logger(c.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
 			Info("New host shadows existing host remoteIndex")
 			Info("New host shadows existing host remoteIndex")
 	}
 	}
 
 
 	// Check if we are also handshaking with this vpn ip
 	// Check if we are also handshaking with this vpn ip
-	pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.hostId]
+	pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.vpnIp]
 	if found && pendingHostInfo != nil {
 	if found && pendingHostInfo != nil {
 		if !overwrite {
 		if !overwrite {
 			// We won, let our pending handshake win
 			// We won, let our pending handshake win
@@ -278,7 +281,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 
 
 	if existingHostInfo != nil {
 	if existingHostInfo != nil {
 		// We are going to overwrite this entry, so remove the old references
 		// We are going to overwrite this entry, so remove the old references
-		delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
+		delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
 		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
 		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
 		delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
 		delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
 	}
 	}
@@ -296,10 +299,10 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 	c.mainHostMap.Lock()
 	c.mainHostMap.Lock()
 	defer c.mainHostMap.Unlock()
 	defer c.mainHostMap.Unlock()
 
 
-	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
+	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
 	if found && existingHostInfo != nil {
 	if found && existingHostInfo != nil {
 		// We are going to overwrite this entry, so remove the old references
 		// We are going to overwrite this entry, so remove the old references
-		delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
+		delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
 		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
 		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
 		delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
 		delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
 	}
 	}
@@ -309,7 +312,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 		// We have a collision, but this can happen since we can't control
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
 		// the remote ID. Just log about the situation as a note.
 		hostinfo.logger(c.l).
 		hostinfo.logger(c.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
 			Info("New host shadows existing host remoteIndex")
 			Info("New host shadows existing host remoteIndex")
 	}
 	}
 
 

+ 16 - 12
handshake_manager_test.go

@@ -5,25 +5,29 @@ import (
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
-func Test_NewHandshakeManagerVpnIP(t *testing.T) {
-	l := NewTestLogger()
+func Test_NewHandshakeManagerVpnIp(t *testing.T) {
+	l := util.NewTestLogger()
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	ip := ip2int(net.ParseIP("172.1.1.2"))
+	ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
 	mw := &mockEncWriter{}
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
 
 
-	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
+	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udp.Conn{}, defaultHandshakeConfig)
 
 
 	now := time.Now()
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 
 
-	i := blah.AddVpnIP(ip)
+	i := blah.AddVpnIp(ip)
 	i.remotes = NewRemoteList()
 	i.remotes = NewRemoteList()
 	i.HandshakeReady = true
 	i.HandshakeReady = true
 
 
@@ -50,24 +54,24 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
 }
 }
 
 
 func Test_NewHandshakeManagerTrigger(t *testing.T) {
 func Test_NewHandshakeManagerTrigger(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	ip := ip2int(net.ParseIP("172.1.1.2"))
+	ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
 	mw := &mockEncWriter{}
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
-	lh := &LightHouse{addrMap: make(map[uint32]*RemoteList), l: l}
+	lh := &LightHouse{addrMap: make(map[iputil.VpnIp]*RemoteList), l: l}
 
 
-	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
+	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig)
 
 
 	now := time.Now()
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 
 
 	assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 	assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 
 
-	hi := blah.AddVpnIP(ip)
+	hi := blah.AddVpnIp(ip)
 	hi.HandshakeReady = true
 	hi.HandshakeReady = true
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 	assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
 	assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
@@ -80,7 +84,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
 	// Make sure the trigger doesn't double schedule the timer entry
 	// Make sure the trigger doesn't double schedule the timer entry
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 
 
-	uaddr := NewUDPAddrFromString("10.1.1.1:4242")
+	uaddr := udp.NewAddrFromString("10.1.1.1:4242")
 	hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port)))
 	hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port)))
 
 
 	// We now have remotes but only the first trigger should have pushed things forward
 	// We now have remotes but only the first trigger should have pushed things forward
@@ -103,6 +107,6 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
 type mockEncWriter struct {
 type mockEncWriter struct {
 }
 }
 
 
-func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
 	return
 	return
 }
 }

+ 62 - 66
header.go → header/header.go

@@ -1,4 +1,4 @@
-package nebula
+package header
 
 
 import (
 import (
 	"encoding/binary"
 	"encoding/binary"
@@ -19,82 +19,78 @@ import (
 // |-----------------------------------------------------------------------|
 // |-----------------------------------------------------------------------|
 // |                               payload...                              |
 // |                               payload...                              |
 
 
+type m map[string]interface{}
+
 const (
 const (
-	Version   uint8 = 1
-	HeaderLen       = 16
+	Version uint8 = 1
+	Len           = 16
 )
 )
 
 
-type NebulaMessageType uint8
-type NebulaMessageSubType uint8
+type MessageType uint8
+type MessageSubType uint8
 
 
 const (
 const (
-	handshake   NebulaMessageType = 0
-	message     NebulaMessageType = 1
-	recvError   NebulaMessageType = 2
-	lightHouse  NebulaMessageType = 3
-	test        NebulaMessageType = 4
-	closeTunnel NebulaMessageType = 5
-
-	//TODO These are deprecated as of 06/12/2018 - NB
-	testRemote      NebulaMessageType = 6
-	testRemoteReply NebulaMessageType = 7
+	Handshake   MessageType = 0
+	Message     MessageType = 1
+	RecvError   MessageType = 2
+	LightHouse  MessageType = 3
+	Test        MessageType = 4
+	CloseTunnel MessageType = 5
 )
 )
 
 
-var typeMap = map[NebulaMessageType]string{
-	handshake:   "handshake",
-	message:     "message",
-	recvError:   "recvError",
-	lightHouse:  "lightHouse",
-	test:        "test",
-	closeTunnel: "closeTunnel",
-
-	//TODO These are deprecated as of 06/12/2018 - NB
-	testRemote:      "testRemote",
-	testRemoteReply: "testRemoteReply",
+var typeMap = map[MessageType]string{
+	Handshake:   "handshake",
+	Message:     "message",
+	RecvError:   "recvError",
+	LightHouse:  "lightHouse",
+	Test:        "test",
+	CloseTunnel: "closeTunnel",
 }
 }
 
 
 const (
 const (
-	testRequest NebulaMessageSubType = 0
-	testReply   NebulaMessageSubType = 1
+	TestRequest MessageSubType = 0
+	TestReply   MessageSubType = 1
+)
+
+const (
+	HandshakeIXPSK0 MessageSubType = 0
+	HandshakeXXPSK0 MessageSubType = 1
 )
 )
 
 
-var eHeaderTooShort = errors.New("header is too short")
+var ErrHeaderTooShort = errors.New("header is too short")
 
 
-var subTypeTestMap = map[NebulaMessageSubType]string{
-	testRequest: "testRequest",
-	testReply:   "testReply",
+var subTypeTestMap = map[MessageSubType]string{
+	TestRequest: "testRequest",
+	TestReply:   "testReply",
 }
 }
 
 
-var subTypeNoneMap = map[NebulaMessageSubType]string{0: "none"}
+var subTypeNoneMap = map[MessageSubType]string{0: "none"}
 
 
-var subTypeMap = map[NebulaMessageType]*map[NebulaMessageSubType]string{
-	message:     &subTypeNoneMap,
-	recvError:   &subTypeNoneMap,
-	lightHouse:  &subTypeNoneMap,
-	test:        &subTypeTestMap,
-	closeTunnel: &subTypeNoneMap,
-	handshake: {
-		handshakeIXPSK0: "ix_psk0",
+var subTypeMap = map[MessageType]*map[MessageSubType]string{
+	Message:     &subTypeNoneMap,
+	RecvError:   &subTypeNoneMap,
+	LightHouse:  &subTypeNoneMap,
+	Test:        &subTypeTestMap,
+	CloseTunnel: &subTypeNoneMap,
+	Handshake: {
+		HandshakeIXPSK0: "ix_psk0",
 	},
 	},
-	//TODO: these are deprecated
-	testRemote:      &subTypeNoneMap,
-	testRemoteReply: &subTypeNoneMap,
 }
 }
 
 
-type Header struct {
+type H struct {
 	Version        uint8
 	Version        uint8
-	Type           NebulaMessageType
-	Subtype        NebulaMessageSubType
+	Type           MessageType
+	Subtype        MessageSubType
 	Reserved       uint16
 	Reserved       uint16
 	RemoteIndex    uint32
 	RemoteIndex    uint32
 	MessageCounter uint64
 	MessageCounter uint64
 }
 }
 
 
-// HeaderEncode uses the provided byte array to encode the provided header values into.
+// Encode uses the provided byte array to encode the provided header values into.
 // Byte array must be capped higher than HeaderLen or this will panic
 // Byte array must be capped higher than HeaderLen or this will panic
-func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []byte {
-	b = b[:HeaderLen]
-	b[0] = byte(v<<4 | (t & 0x0f))
+func Encode(b []byte, v uint8, t MessageType, st MessageSubType, ri uint32, c uint64) []byte {
+	b = b[:Len]
+	b[0] = v<<4 | byte(t&0x0f)
 	b[1] = byte(st)
 	b[1] = byte(st)
 	binary.BigEndian.PutUint16(b[2:4], 0)
 	binary.BigEndian.PutUint16(b[2:4], 0)
 	binary.BigEndian.PutUint32(b[4:8], ri)
 	binary.BigEndian.PutUint32(b[4:8], ri)
@@ -103,7 +99,7 @@ func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []b
 }
 }
 
 
 // String creates a readable string representation of a header
 // String creates a readable string representation of a header
-func (h *Header) String() string {
+func (h *H) String() string {
 	if h == nil {
 	if h == nil {
 		return "<nil>"
 		return "<nil>"
 	}
 	}
@@ -112,7 +108,7 @@ func (h *Header) String() string {
 }
 }
 
 
 // MarshalJSON creates a json string representation of a header
 // MarshalJSON creates a json string representation of a header
-func (h *Header) MarshalJSON() ([]byte, error) {
+func (h *H) MarshalJSON() ([]byte, error) {
 	return json.Marshal(m{
 	return json.Marshal(m{
 		"version":        h.Version,
 		"version":        h.Version,
 		"type":           h.TypeName(),
 		"type":           h.TypeName(),
@@ -124,24 +120,24 @@ func (h *Header) MarshalJSON() ([]byte, error) {
 }
 }
 
 
 // Encode turns header into bytes
 // Encode turns header into bytes
-func (h *Header) Encode(b []byte) ([]byte, error) {
+func (h *H) Encode(b []byte) ([]byte, error) {
 	if h == nil {
 	if h == nil {
 		return nil, errors.New("nil header")
 		return nil, errors.New("nil header")
 	}
 	}
 
 
-	return HeaderEncode(b, h.Version, uint8(h.Type), uint8(h.Subtype), h.RemoteIndex, h.MessageCounter), nil
+	return Encode(b, h.Version, h.Type, h.Subtype, h.RemoteIndex, h.MessageCounter), nil
 }
 }
 
 
 // Parse is a helper function to parses given bytes into new Header struct
 // Parse is a helper function to parses given bytes into new Header struct
-func (h *Header) Parse(b []byte) error {
-	if len(b) < HeaderLen {
-		return eHeaderTooShort
+func (h *H) Parse(b []byte) error {
+	if len(b) < Len {
+		return ErrHeaderTooShort
 	}
 	}
 	// get upper 4 bytes
 	// get upper 4 bytes
 	h.Version = uint8((b[0] >> 4) & 0x0f)
 	h.Version = uint8((b[0] >> 4) & 0x0f)
 	// get lower 4 bytes
 	// get lower 4 bytes
-	h.Type = NebulaMessageType(b[0] & 0x0f)
-	h.Subtype = NebulaMessageSubType(b[1])
+	h.Type = MessageType(b[0] & 0x0f)
+	h.Subtype = MessageSubType(b[1])
 	h.Reserved = binary.BigEndian.Uint16(b[2:4])
 	h.Reserved = binary.BigEndian.Uint16(b[2:4])
 	h.RemoteIndex = binary.BigEndian.Uint32(b[4:8])
 	h.RemoteIndex = binary.BigEndian.Uint32(b[4:8])
 	h.MessageCounter = binary.BigEndian.Uint64(b[8:16])
 	h.MessageCounter = binary.BigEndian.Uint64(b[8:16])
@@ -149,12 +145,12 @@ func (h *Header) Parse(b []byte) error {
 }
 }
 
 
 // TypeName will transform the headers message type into a human string
 // TypeName will transform the headers message type into a human string
-func (h *Header) TypeName() string {
+func (h *H) TypeName() string {
 	return TypeName(h.Type)
 	return TypeName(h.Type)
 }
 }
 
 
 // TypeName will transform a nebula message type into a human string
 // TypeName will transform a nebula message type into a human string
-func TypeName(t NebulaMessageType) string {
+func TypeName(t MessageType) string {
 	if n, ok := typeMap[t]; ok {
 	if n, ok := typeMap[t]; ok {
 		return n
 		return n
 	}
 	}
@@ -163,12 +159,12 @@ func TypeName(t NebulaMessageType) string {
 }
 }
 
 
 // SubTypeName will transform the headers message sub type into a human string
 // SubTypeName will transform the headers message sub type into a human string
-func (h *Header) SubTypeName() string {
+func (h *H) SubTypeName() string {
 	return SubTypeName(h.Type, h.Subtype)
 	return SubTypeName(h.Type, h.Subtype)
 }
 }
 
 
 // SubTypeName will transform a nebula message sub type into a human string
 // SubTypeName will transform a nebula message sub type into a human string
-func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string {
+func SubTypeName(t MessageType, s MessageSubType) string {
 	if n, ok := subTypeMap[t]; ok {
 	if n, ok := subTypeMap[t]; ok {
 		if x, ok := (*n)[s]; ok {
 		if x, ok := (*n)[s]; ok {
 			return x
 			return x
@@ -179,8 +175,8 @@ func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string {
 }
 }
 
 
 // NewHeader turns bytes into a header
 // NewHeader turns bytes into a header
-func NewHeader(b []byte) (*Header, error) {
-	h := new(Header)
+func NewHeader(b []byte) (*H, error) {
+	h := new(H)
 	if err := h.Parse(b); err != nil {
 	if err := h.Parse(b); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 115 - 0
header/header_test.go

@@ -0,0 +1,115 @@
+package header
+
+import (
+	"reflect"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+type headerTest struct {
+	expectedBytes []byte
+	*H
+}
+
+// 0001 0010 00010010
+var headerBigEndianTests = []headerTest{{
+	expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9},
+	// 1010 0000
+	H: &H{
+		// 1111 1+2+4+8 = 15
+		Version:        5,
+		Type:           4,
+		Subtype:        0,
+		Reserved:       0,
+		RemoteIndex:    10,
+		MessageCounter: 9,
+	},
+},
+}
+
+func TestEncode(t *testing.T) {
+	for _, tt := range headerBigEndianTests {
+		b, err := tt.Encode(make([]byte, Len))
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		assert.Equal(t, tt.expectedBytes, b)
+	}
+}
+
+func TestParse(t *testing.T) {
+	for _, tt := range headerBigEndianTests {
+		b := tt.expectedBytes
+		parsedHeader := &H{}
+		parsedHeader.Parse(b)
+
+		if !reflect.DeepEqual(tt.H, parsedHeader) {
+			t.Fatalf("got %#v; want %#v", parsedHeader, tt.H)
+		}
+	}
+}
+
+func TestTypeName(t *testing.T) {
+	assert.Equal(t, "test", TypeName(Test))
+	assert.Equal(t, "test", (&H{Type: Test}).TypeName())
+
+	assert.Equal(t, "unknown", TypeName(99))
+	assert.Equal(t, "unknown", (&H{Type: 99}).TypeName())
+}
+
+func TestSubTypeName(t *testing.T) {
+	assert.Equal(t, "testRequest", SubTypeName(Test, TestRequest))
+	assert.Equal(t, "testRequest", (&H{Type: Test, Subtype: TestRequest}).SubTypeName())
+
+	assert.Equal(t, "unknown", SubTypeName(99, TestRequest))
+	assert.Equal(t, "unknown", (&H{Type: 99, Subtype: TestRequest}).SubTypeName())
+
+	assert.Equal(t, "unknown", SubTypeName(Test, 99))
+	assert.Equal(t, "unknown", (&H{Type: Test, Subtype: 99}).SubTypeName())
+
+	assert.Equal(t, "none", SubTypeName(Message, 0))
+	assert.Equal(t, "none", (&H{Type: Message, Subtype: 0}).SubTypeName())
+}
+
+func TestTypeMap(t *testing.T) {
+	// Force people to document this stuff
+	assert.Equal(t, map[MessageType]string{
+		Handshake:   "handshake",
+		Message:     "message",
+		RecvError:   "recvError",
+		LightHouse:  "lightHouse",
+		Test:        "test",
+		CloseTunnel: "closeTunnel",
+	}, typeMap)
+
+	assert.Equal(t, map[MessageType]*map[MessageSubType]string{
+		Message:     &subTypeNoneMap,
+		RecvError:   &subTypeNoneMap,
+		LightHouse:  &subTypeNoneMap,
+		Test:        &subTypeTestMap,
+		CloseTunnel: &subTypeNoneMap,
+		Handshake: {
+			HandshakeIXPSK0: "ix_psk0",
+		},
+	}, subTypeMap)
+}
+
+func TestHeader_String(t *testing.T) {
+	assert.Equal(
+		t,
+		"ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97",
+		(&H{100, Test, TestRequest, 99, 98, 97}).String(),
+	)
+}
+
+func TestHeader_MarshalJSON(t *testing.T) {
+	b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
+	assert.Nil(t, err)
+	assert.Equal(
+		t,
+		"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
+		string(b),
+	)
+}

+ 0 - 119
header_test.go

@@ -1,119 +0,0 @@
-package nebula
-
-import (
-	"reflect"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-type headerTest struct {
-	expectedBytes []byte
-	*Header
-}
-
-// 0001 0010 00010010
-var headerBigEndianTests = []headerTest{{
-	expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9},
-	// 1010 0000
-	Header: &Header{
-		// 1111 1+2+4+8 = 15
-		Version:        5,
-		Type:           4,
-		Subtype:        0,
-		Reserved:       0,
-		RemoteIndex:    10,
-		MessageCounter: 9,
-	},
-},
-}
-
-func TestEncode(t *testing.T) {
-	for _, tt := range headerBigEndianTests {
-		b, err := tt.Encode(make([]byte, HeaderLen))
-		if err != nil {
-			t.Fatal(err)
-		}
-
-		assert.Equal(t, tt.expectedBytes, b)
-	}
-}
-
-func TestParse(t *testing.T) {
-	for _, tt := range headerBigEndianTests {
-		b := tt.expectedBytes
-		parsedHeader := &Header{}
-		parsedHeader.Parse(b)
-
-		if !reflect.DeepEqual(tt.Header, parsedHeader) {
-			t.Fatalf("got %#v; want %#v", parsedHeader, tt.Header)
-		}
-	}
-}
-
-func TestTypeName(t *testing.T) {
-	assert.Equal(t, "test", TypeName(test))
-	assert.Equal(t, "test", (&Header{Type: test}).TypeName())
-
-	assert.Equal(t, "unknown", TypeName(99))
-	assert.Equal(t, "unknown", (&Header{Type: 99}).TypeName())
-}
-
-func TestSubTypeName(t *testing.T) {
-	assert.Equal(t, "testRequest", SubTypeName(test, testRequest))
-	assert.Equal(t, "testRequest", (&Header{Type: test, Subtype: testRequest}).SubTypeName())
-
-	assert.Equal(t, "unknown", SubTypeName(99, testRequest))
-	assert.Equal(t, "unknown", (&Header{Type: 99, Subtype: testRequest}).SubTypeName())
-
-	assert.Equal(t, "unknown", SubTypeName(test, 99))
-	assert.Equal(t, "unknown", (&Header{Type: test, Subtype: 99}).SubTypeName())
-
-	assert.Equal(t, "none", SubTypeName(message, 0))
-	assert.Equal(t, "none", (&Header{Type: message, Subtype: 0}).SubTypeName())
-}
-
-func TestTypeMap(t *testing.T) {
-	// Force people to document this stuff
-	assert.Equal(t, map[NebulaMessageType]string{
-		handshake:       "handshake",
-		message:         "message",
-		recvError:       "recvError",
-		lightHouse:      "lightHouse",
-		test:            "test",
-		closeTunnel:     "closeTunnel",
-		testRemote:      "testRemote",
-		testRemoteReply: "testRemoteReply",
-	}, typeMap)
-
-	assert.Equal(t, map[NebulaMessageType]*map[NebulaMessageSubType]string{
-		message:     &subTypeNoneMap,
-		recvError:   &subTypeNoneMap,
-		lightHouse:  &subTypeNoneMap,
-		test:        &subTypeTestMap,
-		closeTunnel: &subTypeNoneMap,
-		handshake: {
-			handshakeIXPSK0: "ix_psk0",
-		},
-		testRemote:      &subTypeNoneMap,
-		testRemoteReply: &subTypeNoneMap,
-	}, subTypeMap)
-}
-
-func TestHeader_String(t *testing.T) {
-	assert.Equal(
-		t,
-		"ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97",
-		(&Header{100, test, testRequest, 99, 98, 97}).String(),
-	)
-}
-
-func TestHeader_MarshalJSON(t *testing.T) {
-	b, err := (&Header{100, test, testRequest, 99, 98, 97}).MarshalJSON()
-	assert.Nil(t, err)
-	assert.Equal(
-		t,
-		"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
-		string(b),
-	)
-}

+ 64 - 93
hostmap.go

@@ -12,6 +12,10 @@ import (
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 //const ProbeLen = 100
 //const ProbeLen = 100
@@ -28,10 +32,10 @@ type HostMap struct {
 	name            string
 	name            string
 	Indexes         map[uint32]*HostInfo
 	Indexes         map[uint32]*HostInfo
 	RemoteIndexes   map[uint32]*HostInfo
 	RemoteIndexes   map[uint32]*HostInfo
-	Hosts           map[uint32]*HostInfo
+	Hosts           map[iputil.VpnIp]*HostInfo
 	preferredRanges []*net.IPNet
 	preferredRanges []*net.IPNet
 	vpnCIDR         *net.IPNet
 	vpnCIDR         *net.IPNet
-	unsafeRoutes    *CIDRTree
+	unsafeRoutes    *cidr.Tree4
 	metricsEnabled  bool
 	metricsEnabled  bool
 	l               *logrus.Logger
 	l               *logrus.Logger
 }
 }
@@ -39,7 +43,7 @@ type HostMap struct {
 type HostInfo struct {
 type HostInfo struct {
 	sync.RWMutex
 	sync.RWMutex
 
 
-	remote            *udpAddr
+	remote            *udp.Addr
 	remotes           *RemoteList
 	remotes           *RemoteList
 	promoteCounter    uint32
 	promoteCounter    uint32
 	ConnectionState   *ConnectionState
 	ConnectionState   *ConnectionState
@@ -51,9 +55,9 @@ type HostInfo struct {
 	packetStore       []*cachedPacket  //todo: this is other handshake manager entry
 	packetStore       []*cachedPacket  //todo: this is other handshake manager entry
 	remoteIndexId     uint32
 	remoteIndexId     uint32
 	localIndexId      uint32
 	localIndexId      uint32
-	hostId            uint32
+	vpnIp             iputil.VpnIp
 	recvError         int
 	recvError         int
-	remoteCidr        *CIDRTree
+	remoteCidr        *cidr.Tree4
 
 
 	// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
 	// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
 	// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
 	// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
@@ -66,17 +70,17 @@ type HostInfo struct {
 	lastHandshakeTime uint64
 	lastHandshakeTime uint64
 
 
 	lastRoam       time.Time
 	lastRoam       time.Time
-	lastRoamRemote *udpAddr
+	lastRoamRemote *udp.Addr
 }
 }
 
 
 type cachedPacket struct {
 type cachedPacket struct {
-	messageType    NebulaMessageType
-	messageSubType NebulaMessageSubType
+	messageType    header.MessageType
+	messageSubType header.MessageSubType
 	callback       packetCallback
 	callback       packetCallback
 	packet         []byte
 	packet         []byte
 }
 }
 
 
-type packetCallback func(t NebulaMessageType, st NebulaMessageSubType, h *HostInfo, p, nb, out []byte)
+type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p, nb, out []byte)
 
 
 type cachedPacketMetrics struct {
 type cachedPacketMetrics struct {
 	sent    metrics.Counter
 	sent    metrics.Counter
@@ -84,7 +88,7 @@ type cachedPacketMetrics struct {
 }
 }
 
 
 func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
 func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
-	h := map[uint32]*HostInfo{}
+	h := map[iputil.VpnIp]*HostInfo{}
 	i := map[uint32]*HostInfo{}
 	i := map[uint32]*HostInfo{}
 	r := map[uint32]*HostInfo{}
 	r := map[uint32]*HostInfo{}
 	m := HostMap{
 	m := HostMap{
@@ -94,7 +98,7 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
 		Hosts:           h,
 		Hosts:           h,
 		preferredRanges: preferredRanges,
 		preferredRanges: preferredRanges,
 		vpnCIDR:         vpnCIDR,
 		vpnCIDR:         vpnCIDR,
-		unsafeRoutes:    NewCIDRTree(),
+		unsafeRoutes:    cidr.NewTree4(),
 		l:               l,
 		l:               l,
 	}
 	}
 	return &m
 	return &m
@@ -113,9 +117,9 @@ func (hm *HostMap) EmitStats(name string) {
 	metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
 	metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
 }
 }
 
 
-func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
+func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) {
 	hm.RLock()
 	hm.RLock()
-	if i, ok := hm.Hosts[vpnIP]; ok {
+	if i, ok := hm.Hosts[vpnIp]; ok {
 		index := i.localIndexId
 		index := i.localIndexId
 		hm.RUnlock()
 		hm.RUnlock()
 		return index, nil
 		return index, nil
@@ -124,43 +128,43 @@ func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
 	return 0, errors.New("vpn IP not found")
 	return 0, errors.New("vpn IP not found")
 }
 }
 
 
-func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) {
+func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
 	hm.Lock()
 	hm.Lock()
 	hm.Hosts[ip] = hostinfo
 	hm.Hosts[ip] = hostinfo
 	hm.Unlock()
 	hm.Unlock()
 }
 }
 
 
-func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo {
+func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
 	h := &HostInfo{}
 	h := &HostInfo{}
 	hm.RLock()
 	hm.RLock()
-	if _, ok := hm.Hosts[vpnIP]; !ok {
+	if _, ok := hm.Hosts[vpnIp]; !ok {
 		hm.RUnlock()
 		hm.RUnlock()
 		h = &HostInfo{
 		h = &HostInfo{
 			promoteCounter:  0,
 			promoteCounter:  0,
-			hostId:          vpnIP,
+			vpnIp:           vpnIp,
 			HandshakePacket: make(map[uint8][]byte, 0),
 			HandshakePacket: make(map[uint8][]byte, 0),
 		}
 		}
 		hm.Lock()
 		hm.Lock()
-		hm.Hosts[vpnIP] = h
+		hm.Hosts[vpnIp] = h
 		hm.Unlock()
 		hm.Unlock()
 		return h
 		return h
 	} else {
 	} else {
-		h = hm.Hosts[vpnIP]
+		h = hm.Hosts[vpnIp]
 		hm.RUnlock()
 		hm.RUnlock()
 		return h
 		return h
 	}
 	}
 }
 }
 
 
-func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
+func (hm *HostMap) DeleteVpnIp(vpnIp iputil.VpnIp) {
 	hm.Lock()
 	hm.Lock()
-	delete(hm.Hosts, vpnIP)
+	delete(hm.Hosts, vpnIp)
 	if len(hm.Hosts) == 0 {
 	if len(hm.Hosts) == 0 {
-		hm.Hosts = map[uint32]*HostInfo{}
+		hm.Hosts = map[iputil.VpnIp]*HostInfo{}
 	}
 	}
 	hm.Unlock()
 	hm.Unlock()
 
 
 	if hm.l.Level >= logrus.DebugLevel {
 	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
+		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts)}).
 			Debug("Hostmap vpnIp deleted")
 			Debug("Hostmap vpnIp deleted")
 	}
 	}
 }
 }
@@ -174,22 +178,22 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
 
 
 	if hm.l.Level > logrus.DebugLevel {
 	if hm.l.Level > logrus.DebugLevel {
 		hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
 		hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
-			"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
+			"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}).
 			Debug("Hostmap remoteIndex added")
 			Debug("Hostmap remoteIndex added")
 	}
 	}
 }
 }
 
 
-func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
+func (hm *HostMap) AddVpnIpHostInfo(vpnIp iputil.VpnIp, h *HostInfo) {
 	hm.Lock()
 	hm.Lock()
-	h.hostId = vpnIP
-	hm.Hosts[vpnIP] = h
+	h.vpnIp = vpnIp
+	hm.Hosts[vpnIp] = h
 	hm.Indexes[h.localIndexId] = h
 	hm.Indexes[h.localIndexId] = h
 	hm.RemoteIndexes[h.remoteIndexId] = h
 	hm.RemoteIndexes[h.remoteIndexId] = h
 	hm.Unlock()
 	hm.Unlock()
 
 
 	if hm.l.Level > logrus.DebugLevel {
 	if hm.l.Level > logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
-			"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
+		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts),
+			"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "vpnIp": h.vpnIp}}).
 			Debug("Hostmap vpnIp added")
 			Debug("Hostmap vpnIp added")
 	}
 	}
 }
 }
@@ -204,9 +208,9 @@ func (hm *HostMap) DeleteIndex(index uint32) {
 
 
 		// Check if we have an entry under hostId that matches the same hostinfo
 		// Check if we have an entry under hostId that matches the same hostinfo
 		// instance. Clean it up as well if we do.
 		// instance. Clean it up as well if we do.
-		hostinfo2, ok := hm.Hosts[hostinfo.hostId]
+		hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
 		if ok && hostinfo2 == hostinfo {
 		if ok && hostinfo2 == hostinfo {
-			delete(hm.Hosts, hostinfo.hostId)
+			delete(hm.Hosts, hostinfo.vpnIp)
 		}
 		}
 	}
 	}
 	hm.Unlock()
 	hm.Unlock()
@@ -228,9 +232,9 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
 		// Check if we have an entry under hostId that matches the same hostinfo
 		// Check if we have an entry under hostId that matches the same hostinfo
 		// instance. Clean it up as well if we do (they might not match in pendingHostmap)
 		// instance. Clean it up as well if we do (they might not match in pendingHostmap)
 		var hostinfo2 *HostInfo
 		var hostinfo2 *HostInfo
-		hostinfo2, ok = hm.Hosts[hostinfo.hostId]
+		hostinfo2, ok = hm.Hosts[hostinfo.vpnIp]
 		if ok && hostinfo2 == hostinfo {
 		if ok && hostinfo2 == hostinfo {
-			delete(hm.Hosts, hostinfo.hostId)
+			delete(hm.Hosts, hostinfo.vpnIp)
 		}
 		}
 	}
 	}
 	hm.Unlock()
 	hm.Unlock()
@@ -251,16 +255,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	// Check if this same hostId is in the hostmap with a different instance.
 	// Check if this same hostId is in the hostmap with a different instance.
 	// This could happen if we have an entry in the pending hostmap with different
 	// This could happen if we have an entry in the pending hostmap with different
 	// index values than the one in the main hostmap.
 	// index values than the one in the main hostmap.
-	hostinfo2, ok := hm.Hosts[hostinfo.hostId]
+	hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
 	if ok && hostinfo2 != hostinfo {
 	if ok && hostinfo2 != hostinfo {
-		delete(hm.Hosts, hostinfo2.hostId)
+		delete(hm.Hosts, hostinfo2.vpnIp)
 		delete(hm.Indexes, hostinfo2.localIndexId)
 		delete(hm.Indexes, hostinfo2.localIndexId)
 		delete(hm.RemoteIndexes, hostinfo2.remoteIndexId)
 		delete(hm.RemoteIndexes, hostinfo2.remoteIndexId)
 	}
 	}
 
 
-	delete(hm.Hosts, hostinfo.hostId)
+	delete(hm.Hosts, hostinfo.vpnIp)
 	if len(hm.Hosts) == 0 {
 	if len(hm.Hosts) == 0 {
-		hm.Hosts = map[uint32]*HostInfo{}
+		hm.Hosts = map[iputil.VpnIp]*HostInfo{}
 	}
 	}
 	delete(hm.Indexes, hostinfo.localIndexId)
 	delete(hm.Indexes, hostinfo.localIndexId)
 	if len(hm.Indexes) == 0 {
 	if len(hm.Indexes) == 0 {
@@ -273,7 +277,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 
 
 	if hm.l.Level >= logrus.DebugLevel {
 	if hm.l.Level >= logrus.DebugLevel {
 		hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
 		hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
-			"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Hostmap hostInfo deleted")
 			Debug("Hostmap hostInfo deleted")
 	}
 	}
 }
 }
@@ -301,17 +305,17 @@ func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
 	}
 	}
 }
 }
 
 
-func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
-	return hm.queryVpnIP(vpnIp, nil)
+func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) {
+	return hm.queryVpnIp(vpnIp, nil)
 }
 }
 
 
-// PromoteBestQueryVpnIP will attempt to lazily switch to the best remote every
+// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
 // `PromoteEvery` calls to this function for a given host.
 // `PromoteEvery` calls to this function for a given host.
-func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostInfo, error) {
-	return hm.queryVpnIP(vpnIp, ifce)
+func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) {
+	return hm.queryVpnIp(vpnIp, ifce)
 }
 }
 
 
-func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
+func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) {
 	hm.RLock()
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
 	if h, ok := hm.Hosts[vpnIp]; ok {
 		hm.RUnlock()
 		hm.RUnlock()
@@ -327,10 +331,10 @@ func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo,
 	return nil, errors.New("unable to find host")
 	return nil, errors.New("unable to find host")
 }
 }
 
 
-func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
+func (hm *HostMap) queryUnsafeRoute(ip iputil.VpnIp) iputil.VpnIp {
 	r := hm.unsafeRoutes.MostSpecificContains(ip)
 	r := hm.unsafeRoutes.MostSpecificContains(ip)
 	if r != nil {
 	if r != nil {
-		return r.(uint32)
+		return r.(iputil.VpnIp)
 	} else {
 	} else {
 		return 0
 		return 0
 	}
 	}
@@ -344,13 +348,13 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
 		dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
 		dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
 	}
 	}
 
 
-	hm.Hosts[hostinfo.hostId] = hostinfo
+	hm.Hosts[hostinfo.vpnIp] = hostinfo
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 
 
 	if hm.l.Level >= logrus.DebugLevel {
 	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
-			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
+		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
+			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
 			Debug("Hostmap vpnIp added")
 			Debug("Hostmap vpnIp added")
 	}
 	}
 }
 }
@@ -370,7 +374,7 @@ func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
 }
 }
 
 
 // Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
 // Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
-func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
+func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) {
 	var metricsTxPunchy metrics.Counter
 	var metricsTxPunchy metrics.Counter
 	if hm.metricsEnabled {
 	if hm.metricsEnabled {
 		metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
 		metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
@@ -406,7 +410,7 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
 func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
 func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
 	for _, r := range *routes {
 	for _, r := range *routes {
 		hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
 		hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
-		hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
+		hm.unsafeRoutes.AddCIDR(r.route, iputil.Ip2VpnIp(*r.via))
 	}
 	}
 }
 }
 
 
@@ -431,24 +435,24 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 			}
 			}
 		}
 		}
 
 
-		i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) {
+		i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) {
 			if addr == nil || !preferred {
 			if addr == nil || !preferred {
 				return
 				return
 			}
 			}
 
 
 			// Try to send a test packet to that host, this should
 			// Try to send a test packet to that host, this should
 			// cause it to detect a roaming event and switch remotes
 			// cause it to detect a roaming event and switch remotes
-			ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			ifce.send(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 		})
 		})
 	}
 	}
 
 
 	// Re query our lighthouses for new remotes occasionally
 	// Re query our lighthouses for new remotes occasionally
 	if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
 	if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
-		ifce.lightHouse.QueryServer(i.hostId, ifce)
+		ifce.lightHouse.QueryServer(i.vpnIp, ifce)
 	}
 	}
 }
 }
 
 
-func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
+func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
 	//TODO: return the error so we can log with more context
 	//TODO: return the error so we can log with more context
 	if len(i.packetStore) < 100 {
 	if len(i.packetStore) < 100 {
 		tempPacket := make([]byte, len(packet))
 		tempPacket := make([]byte, len(packet))
@@ -510,17 +514,17 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
 	return nil
 	return nil
 }
 }
 
 
-func (i *HostInfo) SetRemote(remote *udpAddr) {
+func (i *HostInfo) SetRemote(remote *udp.Addr) {
 	// We copy here because we likely got this remote from a source that reuses the object
 	// We copy here because we likely got this remote from a source that reuses the object
 	if !i.remote.Equals(remote) {
 	if !i.remote.Equals(remote) {
 		i.remote = remote.Copy()
 		i.remote = remote.Copy()
-		i.remotes.LearnRemote(i.hostId, remote.Copy())
+		i.remotes.LearnRemote(i.vpnIp, remote.Copy())
 	}
 	}
 }
 }
 
 
 // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
 // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
 // time on the HostInfo will also be updated.
 // time on the HostInfo will also be updated.
-func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udpAddr) bool {
+func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
 	currentRemote := i.remote
 	currentRemote := i.remote
 	if currentRemote == nil {
 	if currentRemote == nil {
 		i.SetRemote(newRemote)
 		i.SetRemote(newRemote)
@@ -572,7 +576,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
 		return
 		return
 	}
 	}
 
 
-	remoteCidr := NewCIDRTree()
+	remoteCidr := cidr.NewTree4()
 	for _, ip := range c.Details.Ips {
 	for _, ip := range c.Details.Ips {
 		remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
 		remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
 	}
 	}
@@ -588,8 +592,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 		return logrus.NewEntry(l)
 		return logrus.NewEntry(l)
 	}
 	}
 
 
-	li := l.WithField("vpnIp", IntIp(i.hostId))
-
+	li := l.WithField("vpnIp", i.vpnIp)
 	if connState := i.ConnectionState; connState != nil {
 	if connState := i.ConnectionState; connState != nil {
 		if peerCert := connState.peerCert; peerCert != nil {
 		if peerCert := connState.peerCert; peerCert != nil {
 			li = li.WithField("certName", peerCert.Details.Name)
 			li = li.WithField("certName", peerCert.Details.Name)
@@ -599,38 +602,6 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 	return li
 	return li
 }
 }
 
 
-//########################
-
-/*
-
-func (hm *HostMap) DebugRemotes(vpnIp uint32) string {
-	s := "\n"
-	for _, h := range hm.Hosts {
-		for _, r := range h.Remotes {
-			s += fmt.Sprintf("%s : %d ## %v\n", r.addr.IP.String(), r.addr.Port, r.probes)
-		}
-	}
-	return s
-}
-
-func (i *HostInfo) HandleReply(addr *net.UDPAddr, counter int) {
-	for _, r := range i.Remotes {
-		if r.addr.IP.Equal(addr.IP) && r.addr.Port == addr.Port {
-			r.ProbeReceived(counter)
-		}
-	}
-}
-
-func (i *HostInfo) Probes() []*Probe {
-	p := []*Probe{}
-	for _, d := range i.Remotes {
-		p = append(p, &Probe{Addr: d.addr, Counter: d.Probe()})
-	}
-	return p
-}
-
-*/
-
 // Utility functions
 // Utility functions
 
 
 func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
 func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {

+ 28 - 23
inside.go

@@ -5,9 +5,13 @@ import (
 
 
 	"github.com/flynn/noise"
 	"github.com/flynn/noise"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 )
 
 
-func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
+func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
 	err := newPacket(packet, false, fwPacket)
 	err := newPacket(packet, false, fwPacket)
 	if err != nil {
 	if err != nil {
 		f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
 		f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
@@ -32,7 +36,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 	hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
 	hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
 	if hostinfo == nil {
 	if hostinfo == nil {
 		if f.l.Level >= logrus.DebugLevel {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
+			f.l.WithField("vpnIp", fwPacket.RemoteIP).
 				WithField("fwPacket", fwPacket).
 				WithField("fwPacket", fwPacket).
 				Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
 				Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
 		}
 		}
@@ -45,7 +49,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 		// the packet queue.
 		// the packet queue.
 		ci.queueLock.Lock()
 		ci.queueLock.Lock()
 		if !ci.ready {
 		if !ci.ready {
-			hostinfo.cachePacket(f.l, message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
+			hostinfo.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
 			ci.queueLock.Unlock()
 			ci.queueLock.Unlock()
 			return
 			return
 		}
 		}
@@ -54,7 +58,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 
 
 	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
 	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
 	if dropReason == nil {
 	if dropReason == nil {
-		f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
+		f.sendNoMetrics(header.Message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
 
 
 	} else if f.l.Level >= logrus.DebugLevel {
 	} else if f.l.Level >= logrus.DebugLevel {
 		hostinfo.logger(f.l).
 		hostinfo.logger(f.l).
@@ -65,20 +69,21 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 }
 }
 
 
 // getOrHandshake returns nil if the vpnIp is not routable
 // getOrHandshake returns nil if the vpnIp is not routable
-func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
-	if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false {
+func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
+	//TODO: we can find contains without converting back to bytes
+	if f.hostMap.vpnCIDR.Contains(vpnIp.ToIP()) == false {
 		vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
 		vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
 		if vpnIp == 0 {
 		if vpnIp == 0 {
 			return nil
 			return nil
 		}
 		}
 	}
 	}
-	hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
+	hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
 
 
 	//if err != nil || hostinfo.ConnectionState == nil {
 	//if err != nil || hostinfo.ConnectionState == nil {
 	if err != nil {
 	if err != nil {
-		hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIp)
+		hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
 		if err != nil {
 		if err != nil {
-			hostinfo = f.handshakeManager.AddVpnIP(vpnIp)
+			hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
 		}
 		}
 	}
 	}
 	ci := hostinfo.ConnectionState
 	ci := hostinfo.ConnectionState
@@ -126,8 +131,8 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
 	return hostinfo
 	return hostinfo
 }
 }
 
 
-func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
-	fp := &FirewallPacket{}
+func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
+	fp := &firewall.Packet{}
 	err := newPacket(p, false, fp)
 	err := newPacket(p, false, fp)
 	if err != nil {
 	if err != nil {
 		f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
 		f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
@@ -145,15 +150,15 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
 		return
 		return
 	}
 	}
 
 
-	f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
+	f.sendNoMetrics(header.Message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
 }
 }
 
 
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
-func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
+func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
 	hostInfo := f.getOrHandshake(vpnIp)
 	hostInfo := f.getOrHandshake(vpnIp)
 	if hostInfo == nil {
 	if hostInfo == nil {
 		if f.l.Level >= logrus.DebugLevel {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", IntIp(vpnIp)).
+			f.l.WithField("vpnIp", vpnIp).
 				Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
 				Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
 		}
 		}
 		return
 		return
@@ -175,16 +180,16 @@ func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
 	return
 	return
 }
 }
 
 
-func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
+func (f *Interface) sendMessageToVpnIp(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
 	f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
 	f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
 }
 }
 
 
-func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
+func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
 	f.messageMetrics.Tx(t, st, 1)
 	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
 	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
 }
 }
 
 
-func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) {
+func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) {
 	if ci.eKey == nil {
 	if ci.eKey == nil {
 		//TODO: log warning
 		//TODO: log warning
 		return
 		return
@@ -196,18 +201,18 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
 	c := atomic.AddUint64(&ci.atomicMessageCounter, 1)
 	c := atomic.AddUint64(&ci.atomicMessageCounter, 1)
 
 
 	//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
 	//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
-	out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c)
-	f.connectionManager.Out(hostinfo.hostId)
+	out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
+	f.connectionManager.Out(hostinfo.vpnIp)
 
 
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
 	// all our IPs and enable a faster roaming.
 	// all our IPs and enable a faster roaming.
-	if t != closeTunnel && hostinfo.lastRebindCount != f.rebindCount {
+	if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
 		//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
 		//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
 		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
 		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
-		f.lightHouse.QueryServer(hostinfo.hostId, f)
+		f.lightHouse.QueryServer(hostinfo.vpnIp, f)
 		hostinfo.lastRebindCount = f.rebindCount
 		hostinfo.lastRebindCount = f.rebindCount
 		if f.l.Level >= logrus.DebugLevel {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
+			f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
 		}
 		}
 	}
 	}
 
 
@@ -230,7 +235,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
 	return
 	return
 }
 }
 
 
-func isMulticast(ip uint32) bool {
+func isMulticast(ip iputil.VpnIp) bool {
 	// Class D multicast
 	// Class D multicast
 	if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 {
 	if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 {
 		return true
 		return true

+ 26 - 21
interface.go

@@ -12,6 +12,10 @@ import (
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 const mtu = 9001
 const mtu = 9001
@@ -27,7 +31,7 @@ type Inside interface {
 
 
 type InterfaceConfig struct {
 type InterfaceConfig struct {
 	HostMap                 *HostMap
 	HostMap                 *HostMap
-	Outside                 *udpConn
+	Outside                 *udp.Conn
 	Inside                  Inside
 	Inside                  Inside
 	certState               *CertState
 	certState               *CertState
 	Cipher                  string
 	Cipher                  string
@@ -39,7 +43,6 @@ type InterfaceConfig struct {
 	pendingDeletionInterval int
 	pendingDeletionInterval int
 	DropLocalBroadcast      bool
 	DropLocalBroadcast      bool
 	DropMulticast           bool
 	DropMulticast           bool
-	UDPBatchSize            int
 	routines                int
 	routines                int
 	MessageMetrics          *MessageMetrics
 	MessageMetrics          *MessageMetrics
 	version                 string
 	version                 string
@@ -52,7 +55,7 @@ type InterfaceConfig struct {
 
 
 type Interface struct {
 type Interface struct {
 	hostMap            *HostMap
 	hostMap            *HostMap
-	outside            *udpConn
+	outside            *udp.Conn
 	inside             Inside
 	inside             Inside
 	certState          *CertState
 	certState          *CertState
 	cipher             string
 	cipher             string
@@ -62,11 +65,10 @@ type Interface struct {
 	serveDns           bool
 	serveDns           bool
 	createTime         time.Time
 	createTime         time.Time
 	lightHouse         *LightHouse
 	lightHouse         *LightHouse
-	localBroadcast     uint32
-	myVpnIp            uint32
+	localBroadcast     iputil.VpnIp
+	myVpnIp            iputil.VpnIp
 	dropLocalBroadcast bool
 	dropLocalBroadcast bool
 	dropMulticast      bool
 	dropMulticast      bool
-	udpBatchSize       int
 	routines           int
 	routines           int
 	caPool             *cert.NebulaCAPool
 	caPool             *cert.NebulaCAPool
 	disconnectInvalid  bool
 	disconnectInvalid  bool
@@ -77,7 +79,7 @@ type Interface struct {
 
 
 	conntrackCacheTimeout time.Duration
 	conntrackCacheTimeout time.Duration
 
 
-	writers []*udpConn
+	writers []*udp.Conn
 	readers []io.ReadWriteCloser
 	readers []io.ReadWriteCloser
 
 
 	metricHandshakes    metrics.Histogram
 	metricHandshakes    metrics.Histogram
@@ -101,6 +103,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		return nil, errors.New("no firewall rules")
 		return nil, errors.New("no firewall rules")
 	}
 	}
 
 
+	myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP)
 	ifce := &Interface{
 	ifce := &Interface{
 		hostMap:            c.HostMap,
 		hostMap:            c.HostMap,
 		outside:            c.Outside,
 		outside:            c.Outside,
@@ -112,17 +115,16 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		handshakeManager:   c.HandshakeManager,
 		handshakeManager:   c.HandshakeManager,
 		createTime:         time.Now(),
 		createTime:         time.Now(),
 		lightHouse:         c.lightHouse,
 		lightHouse:         c.lightHouse,
-		localBroadcast:     ip2int(c.certState.certificate.Details.Ips[0].IP) | ^ip2int(c.certState.certificate.Details.Ips[0].Mask),
+		localBroadcast:     myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask),
 		dropLocalBroadcast: c.DropLocalBroadcast,
 		dropLocalBroadcast: c.DropLocalBroadcast,
 		dropMulticast:      c.DropMulticast,
 		dropMulticast:      c.DropMulticast,
-		udpBatchSize:       c.UDPBatchSize,
 		routines:           c.routines,
 		routines:           c.routines,
 		version:            c.version,
 		version:            c.version,
-		writers:            make([]*udpConn, c.routines),
+		writers:            make([]*udp.Conn, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
 		caPool:             c.caPool,
 		caPool:             c.caPool,
 		disconnectInvalid:  c.disconnectInvalid,
 		disconnectInvalid:  c.disconnectInvalid,
-		myVpnIp:            ip2int(c.certState.certificate.Details.Ips[0].IP),
+		myVpnIp:            myVpnIp,
 
 
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 
 
@@ -190,14 +192,17 @@ func (f *Interface) run() {
 func (f *Interface) listenOut(i int) {
 func (f *Interface) listenOut(i int) {
 	runtime.LockOSThread()
 	runtime.LockOSThread()
 
 
-	var li *udpConn
+	var li *udp.Conn
 	// TODO clean this up with a coherent interface for each outside connection
 	// TODO clean this up with a coherent interface for each outside connection
 	if i > 0 {
 	if i > 0 {
 		li = f.writers[i]
 		li = f.writers[i]
 	} else {
 	} else {
 		li = f.outside
 		li = f.outside
 	}
 	}
-	li.ListenOut(f, i)
+
+	lhh := f.lightHouse.NewRequestHandler()
+	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
+	li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i)
 }
 }
 
 
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@@ -205,10 +210,10 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 
 
 	packet := make([]byte, mtu)
 	packet := make([]byte, mtu)
 	out := make([]byte, mtu)
 	out := make([]byte, mtu)
-	fwPacket := &FirewallPacket{}
+	fwPacket := &firewall.Packet{}
 	nb := make([]byte, 12, 12)
 	nb := make([]byte, 12, 12)
 
 
-	conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
+	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 
 
 	for {
 	for {
 		n, err := reader.Read(packet)
 		n, err := reader.Read(packet)
@@ -222,16 +227,16 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 	}
 	}
 }
 }
 
 
-func (f *Interface) RegisterConfigChangeCallbacks(c *Config) {
+func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
 	c.RegisterReloadCallback(f.reloadCA)
 	c.RegisterReloadCallback(f.reloadCA)
 	c.RegisterReloadCallback(f.reloadCertKey)
 	c.RegisterReloadCallback(f.reloadCertKey)
 	c.RegisterReloadCallback(f.reloadFirewall)
 	c.RegisterReloadCallback(f.reloadFirewall)
 	for _, udpConn := range f.writers {
 	for _, udpConn := range f.writers {
-		c.RegisterReloadCallback(udpConn.reloadConfig)
+		c.RegisterReloadCallback(udpConn.ReloadConfig)
 	}
 	}
 }
 }
 
 
-func (f *Interface) reloadCA(c *Config) {
+func (f *Interface) reloadCA(c *config.C) {
 	// reload and check regardless
 	// reload and check regardless
 	// todo: need mutex?
 	// todo: need mutex?
 	newCAs, err := loadCAFromConfig(f.l, c)
 	newCAs, err := loadCAFromConfig(f.l, c)
@@ -244,7 +249,7 @@ func (f *Interface) reloadCA(c *Config) {
 	f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
 	f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
 }
 }
 
 
-func (f *Interface) reloadCertKey(c *Config) {
+func (f *Interface) reloadCertKey(c *config.C) {
 	// reload and check in all cases
 	// reload and check in all cases
 	cs, err := NewCertStateFromConfig(c)
 	cs, err := NewCertStateFromConfig(c)
 	if err != nil {
 	if err != nil {
@@ -264,7 +269,7 @@ func (f *Interface) reloadCertKey(c *Config) {
 	f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
 	f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
 }
 }
 
 
-func (f *Interface) reloadFirewall(c *Config) {
+func (f *Interface) reloadFirewall(c *config.C) {
 	//TODO: need to trigger/detect if the certificate changed too
 	//TODO: need to trigger/detect if the certificate changed too
 	if c.HasChanged("firewall") == false {
 	if c.HasChanged("firewall") == false {
 		f.l.Debug("No firewall config change detected")
 		f.l.Debug("No firewall config change detected")
@@ -307,7 +312,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 	ticker := time.NewTicker(i)
 	ticker := time.NewTicker(i)
 	defer ticker.Stop()
 	defer ticker.Stop()
 
 
-	udpStats := NewUDPStatsEmitter(f.writers)
+	udpStats := udp.NewUDPStatsEmitter(f.writers)
 
 
 	for {
 	for {
 		select {
 		select {

+ 66 - 0
iputil/util.go

@@ -0,0 +1,66 @@
+package iputil
+
+import (
+	"encoding/binary"
+	"fmt"
+	"net"
+)
+
+type VpnIp uint32
+
+const maxIPv4StringLen = len("255.255.255.255")
+
+func (ip VpnIp) String() string {
+	b := make([]byte, maxIPv4StringLen)
+
+	n := ubtoa(b, 0, byte(ip>>24))
+	b[n] = '.'
+	n++
+
+	n += ubtoa(b, n, byte(ip>>16&255))
+	b[n] = '.'
+	n++
+
+	n += ubtoa(b, n, byte(ip>>8&255))
+	b[n] = '.'
+	n++
+
+	n += ubtoa(b, n, byte(ip&255))
+	return string(b[:n])
+}
+
+func (ip VpnIp) MarshalJSON() ([]byte, error) {
+	return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil
+}
+
+func (ip VpnIp) ToIP() net.IP {
+	nip := make(net.IP, 4)
+	binary.BigEndian.PutUint32(nip, uint32(ip))
+	return nip
+}
+
+func Ip2VpnIp(ip []byte) VpnIp {
+	if len(ip) == 16 {
+		return VpnIp(binary.BigEndian.Uint32(ip[12:16]))
+	}
+	return VpnIp(binary.BigEndian.Uint32(ip))
+}
+
+// ubtoa encodes the string form of the integer v to dst[start:] and
+// returns the number of bytes written to dst. The caller must ensure
+// that dst has sufficient length.
+func ubtoa(dst []byte, start int, v byte) int {
+	if v < 10 {
+		dst[start] = v + '0'
+		return 1
+	} else if v < 100 {
+		dst[start+1] = v%10 + '0'
+		dst[start] = v/10 + '0'
+		return 2
+	}
+
+	dst[start+2] = v%10 + '0'
+	dst[start+1] = (v/10)%10 + '0'
+	dst[start] = v/100 + '0'
+	return 3
+}

+ 17 - 0
iputil/util_test.go

@@ -0,0 +1,17 @@
+package iputil
+
+import (
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestVpnIp_String(t *testing.T) {
+	assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String())
+	assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String())
+	assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String())
+	assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String())
+	assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String())
+	assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String())
+}

+ 82 - 81
lighthouse.go

@@ -12,6 +12,9 @@ import (
 	"github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/proto"
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 //TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake?
 //TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake?
@@ -23,13 +26,13 @@ type LightHouse struct {
 	//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
 	//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
 	sync.RWMutex //Because we concurrently read and write to our maps
 	sync.RWMutex //Because we concurrently read and write to our maps
 	amLighthouse bool
 	amLighthouse bool
-	myVpnIp      uint32
-	myVpnZeros   uint32
-	punchConn    *udpConn
+	myVpnIp      iputil.VpnIp
+	myVpnZeros   iputil.VpnIp
+	punchConn    *udp.Conn
 
 
 	// Local cache of answers from light houses
 	// Local cache of answers from light houses
 	// map of vpn Ip to answers
 	// map of vpn Ip to answers
-	addrMap map[uint32]*RemoteList
+	addrMap map[iputil.VpnIp]*RemoteList
 
 
 	// filters remote addresses allowed for each host
 	// filters remote addresses allowed for each host
 	// - When we are a lighthouse, this filters what addresses we store and
 	// - When we are a lighthouse, this filters what addresses we store and
@@ -42,12 +45,12 @@ type LightHouse struct {
 	localAllowList *LocalAllowList
 	localAllowList *LocalAllowList
 
 
 	// used to trigger the HandshakeManager when we receive HostQueryReply
 	// used to trigger the HandshakeManager when we receive HostQueryReply
-	handshakeTrigger chan<- uint32
+	handshakeTrigger chan<- iputil.VpnIp
 
 
 	// staticList exists to avoid having a bool in each addrMap entry
 	// staticList exists to avoid having a bool in each addrMap entry
 	// since static should be rare
 	// since static should be rare
-	staticList  map[uint32]struct{}
-	lighthouses map[uint32]struct{}
+	staticList  map[iputil.VpnIp]struct{}
+	lighthouses map[iputil.VpnIp]struct{}
 	interval    int
 	interval    int
 	nebulaPort  uint32 // 32 bits because protobuf does not have a uint16
 	nebulaPort  uint32 // 32 bits because protobuf does not have a uint16
 	punchBack   bool
 	punchBack   bool
@@ -58,20 +61,16 @@ type LightHouse struct {
 	l                 *logrus.Logger
 	l                 *logrus.Logger
 }
 }
 
 
-type EncWriter interface {
-	SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
-}
-
-func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
+func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []iputil.VpnIp, interval int, nebulaPort uint32, pc *udp.Conn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
 	ones, _ := myVpnIpNet.Mask.Size()
 	ones, _ := myVpnIpNet.Mask.Size()
 	h := LightHouse{
 	h := LightHouse{
 		amLighthouse: amLighthouse,
 		amLighthouse: amLighthouse,
-		myVpnIp:      ip2int(myVpnIpNet.IP),
-		myVpnZeros:   uint32(32 - ones),
-		addrMap:      make(map[uint32]*RemoteList),
+		myVpnIp:      iputil.Ip2VpnIp(myVpnIpNet.IP),
+		myVpnZeros:   iputil.VpnIp(32 - ones),
+		addrMap:      make(map[iputil.VpnIp]*RemoteList),
 		nebulaPort:   nebulaPort,
 		nebulaPort:   nebulaPort,
-		lighthouses:  make(map[uint32]struct{}),
-		staticList:   make(map[uint32]struct{}),
+		lighthouses:  make(map[iputil.VpnIp]struct{}),
+		staticList:   make(map[iputil.VpnIp]struct{}),
 		interval:     interval,
 		interval:     interval,
 		punchConn:    pc,
 		punchConn:    pc,
 		punchBack:    punchBack,
 		punchBack:    punchBack,
@@ -111,13 +110,13 @@ func (lh *LightHouse) SetLocalAllowList(allowList *LocalAllowList) {
 func (lh *LightHouse) ValidateLHStaticEntries() error {
 func (lh *LightHouse) ValidateLHStaticEntries() error {
 	for lhIP, _ := range lh.lighthouses {
 	for lhIP, _ := range lh.lighthouses {
 		if _, ok := lh.staticList[lhIP]; !ok {
 		if _, ok := lh.staticList[lhIP]; !ok {
-			return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", IntIp(lhIP))
+			return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", lhIP)
 		}
 		}
 	}
 	}
 	return nil
 	return nil
 }
 }
 
 
-func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
+func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
 	if !lh.IsLighthouseIP(ip) {
 	if !lh.IsLighthouseIP(ip) {
 		lh.QueryServer(ip, f)
 		lh.QueryServer(ip, f)
 	}
 	}
@@ -131,7 +130,7 @@ func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
 }
 }
 
 
 // This is asynchronous so no reply should be expected
 // This is asynchronous so no reply should be expected
-func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
+func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) {
 	if lh.amLighthouse {
 	if lh.amLighthouse {
 		return
 		return
 	}
 	}
@@ -143,7 +142,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
 	// Send a query to the lighthouses and hope for the best next time
 	// Send a query to the lighthouses and hope for the best next time
 	query, err := proto.Marshal(NewLhQueryByInt(ip))
 	query, err := proto.Marshal(NewLhQueryByInt(ip))
 	if err != nil {
 	if err != nil {
-		lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
+		lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload")
 		return
 		return
 	}
 	}
 
 
@@ -151,11 +150,11 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
 	nb := make([]byte, 12, 12)
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
 	out := make([]byte, mtu)
 	for n := range lh.lighthouses {
 	for n := range lh.lighthouses {
-		f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
+		f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
 	}
 	}
 }
 }
 
 
-func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
+func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
 	lh.RLock()
 	lh.RLock()
 	if v, ok := lh.addrMap[ip]; ok {
 	if v, ok := lh.addrMap[ip]; ok {
 		lh.RUnlock()
 		lh.RUnlock()
@@ -172,7 +171,7 @@ func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
 // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
 // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
 // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
 // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
 // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
 // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
-func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, error)) (bool, int, error) {
+func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) {
 	lh.RLock()
 	lh.RLock()
 	// Do we have an entry in the main cache?
 	// Do we have an entry in the main cache?
 	if v, ok := lh.addrMap[vpnIp]; ok {
 	if v, ok := lh.addrMap[vpnIp]; ok {
@@ -195,18 +194,18 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, err
 	return false, 0, nil
 	return false, 0, nil
 }
 }
 
 
-func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
+func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
 	// First we check the static mapping
 	// First we check the static mapping
 	// and do nothing if it is there
 	// and do nothing if it is there
-	if _, ok := lh.staticList[vpnIP]; ok {
+	if _, ok := lh.staticList[vpnIp]; ok {
 		return
 		return
 	}
 	}
 	lh.Lock()
 	lh.Lock()
 	//l.Debugln(lh.addrMap)
 	//l.Debugln(lh.addrMap)
-	delete(lh.addrMap, vpnIP)
+	delete(lh.addrMap, vpnIp)
 
 
 	if lh.l.Level >= logrus.DebugLevel {
 	if lh.l.Level >= logrus.DebugLevel {
-		lh.l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
+		lh.l.Debugf("deleting %s from lighthouse.", vpnIp)
 	}
 	}
 
 
 	lh.Unlock()
 	lh.Unlock()
@@ -215,7 +214,7 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
 // AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
 // AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
 // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
 // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
 // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
 // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
-func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
+func (lh *LightHouse) AddStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr) {
 	lh.Lock()
 	lh.Lock()
 	am := lh.unlockedGetRemoteList(vpnIp)
 	am := lh.unlockedGetRemoteList(vpnIp)
 	am.Lock()
 	am.Lock()
@@ -242,23 +241,23 @@ func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
 }
 }
 
 
 // unlockedGetRemoteList assumes you have the lh lock
 // unlockedGetRemoteList assumes you have the lh lock
-func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList {
-	am, ok := lh.addrMap[vpnIP]
+func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
+	am, ok := lh.addrMap[vpnIp]
 	if !ok {
 	if !ok {
 		am = NewRemoteList()
 		am = NewRemoteList()
-		lh.addrMap[vpnIP] = am
+		lh.addrMap[vpnIp] = am
 	}
 	}
 	return am
 	return am
 }
 }
 
 
 // unlockedShouldAddV4 checks if to is allowed by our allow list
 // unlockedShouldAddV4 checks if to is allowed by our allow list
-func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool {
-	allow := lh.remoteAllowList.AllowIpV4(vpnIp, to.Ip)
+func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
+	allow := lh.remoteAllowList.AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
 	if lh.l.Level >= logrus.TraceLevel {
 	if lh.l.Level >= logrus.TraceLevel {
-		lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow")
+		lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
 	}
 	}
 
 
-	if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, to.Ip) {
+	if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) {
 		return false
 		return false
 	}
 	}
 
 
@@ -266,7 +265,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool {
 }
 }
 
 
 // unlockedShouldAddV6 checks if to is allowed by our allow list
 // unlockedShouldAddV6 checks if to is allowed by our allow list
-func (lh *LightHouse) unlockedShouldAddV6(vpnIp uint32, to *Ip6AndPort) bool {
+func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool {
 	allow := lh.remoteAllowList.AllowIpV6(vpnIp, to.Hi, to.Lo)
 	allow := lh.remoteAllowList.AllowIpV6(vpnIp, to.Hi, to.Lo)
 	if lh.l.Level >= logrus.TraceLevel {
 	if lh.l.Level >= logrus.TraceLevel {
 		lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
 		lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
@@ -287,25 +286,25 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
 	return ip
 	return ip
 }
 }
 
 
-func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
-	if _, ok := lh.lighthouses[vpnIP]; ok {
+func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool {
+	if _, ok := lh.lighthouses[vpnIp]; ok {
 		return true
 		return true
 	}
 	}
 	return false
 	return false
 }
 }
 
 
-func NewLhQueryByInt(VpnIp uint32) *NebulaMeta {
+func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta {
 	return &NebulaMeta{
 	return &NebulaMeta{
 		Type: NebulaMeta_HostQuery,
 		Type: NebulaMeta_HostQuery,
 		Details: &NebulaMetaDetails{
 		Details: &NebulaMetaDetails{
-			VpnIp: VpnIp,
+			VpnIp: uint32(VpnIp),
 		},
 		},
 	}
 	}
 }
 }
 
 
 func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
 func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
 	ipp := Ip4AndPort{Port: port}
 	ipp := Ip4AndPort{Port: port}
-	ipp.Ip = ip2int(ip)
+	ipp.Ip = uint32(iputil.Ip2VpnIp(ip))
 	return &ipp
 	return &ipp
 }
 }
 
 
@@ -317,19 +316,19 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
 	}
 	}
 }
 }
 
 
-func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udpAddr {
+func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
 	ip := ipp.Ip
 	ip := ipp.Ip
-	return NewUDPAddr(
+	return udp.NewAddr(
 		net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
 		net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
 		uint16(ipp.Port),
 		uint16(ipp.Port),
 	)
 	)
 }
 }
 
 
-func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udpAddr {
-	return NewUDPAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
+func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
+	return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
 }
 }
 
 
-func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
+func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
 	if lh.amLighthouse || lh.interval == 0 {
 	if lh.amLighthouse || lh.interval == 0 {
 		return
 		return
 	}
 	}
@@ -349,12 +348,12 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
 	}
 	}
 }
 }
 
 
-func (lh *LightHouse) SendUpdate(f EncWriter) {
+func (lh *LightHouse) SendUpdate(f udp.EncWriter) {
 	var v4 []*Ip4AndPort
 	var v4 []*Ip4AndPort
 	var v6 []*Ip6AndPort
 	var v6 []*Ip6AndPort
 
 
 	for _, e := range *localIps(lh.l, lh.localAllowList) {
 	for _, e := range *localIps(lh.l, lh.localAllowList) {
-		if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip2int(ip4)) {
+		if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) {
 			continue
 			continue
 		}
 		}
 
 
@@ -368,7 +367,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
 	m := &NebulaMeta{
 	m := &NebulaMeta{
 		Type: NebulaMeta_HostUpdateNotification,
 		Type: NebulaMeta_HostUpdateNotification,
 		Details: &NebulaMetaDetails{
 		Details: &NebulaMetaDetails{
-			VpnIp:       lh.myVpnIp,
+			VpnIp:       uint32(lh.myVpnIp),
 			Ip4AndPorts: v4,
 			Ip4AndPorts: v4,
 			Ip6AndPorts: v6,
 			Ip6AndPorts: v6,
 		},
 		},
@@ -385,7 +384,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
 	}
 	}
 
 
 	for vpnIp := range lh.lighthouses {
 	for vpnIp := range lh.lighthouses {
-		f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
+		f.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out)
 	}
 	}
 }
 }
 
 
@@ -415,11 +414,11 @@ func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
 }
 }
 
 
 func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
 func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
-	lh.metrics.Rx(NebulaMessageType(t), 0, i)
+	lh.metrics.Rx(header.MessageType(t), 0, i)
 }
 }
 
 
 func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
 func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
-	lh.metrics.Tx(NebulaMessageType(t), 0, i)
+	lh.metrics.Tx(header.MessageType(t), 0, i)
 }
 }
 
 
 // This method is similar to Reset(), but it re-uses the pointer structs
 // This method is similar to Reset(), but it re-uses the pointer structs
@@ -436,18 +435,18 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
 	return lhh.meta
 	return lhh.meta
 }
 }
 
 
-func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) {
+func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) {
 	n := lhh.resetMeta()
 	n := lhh.resetMeta()
 	err := n.Unmarshal(p)
 	err := n.Unmarshal(p)
 	if err != nil {
 	if err != nil {
-		lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
+		lhh.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr).
 			Error("Failed to unmarshal lighthouse packet")
 			Error("Failed to unmarshal lighthouse packet")
 		//TODO: send recv_error?
 		//TODO: send recv_error?
 		return
 		return
 	}
 	}
 
 
 	if n.Details == nil {
 	if n.Details == nil {
-		lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
+		lhh.l.WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr).
 			Error("Invalid lighthouse update")
 			Error("Invalid lighthouse update")
 		//TODO: send recv_error?
 		//TODO: send recv_error?
 		return
 		return
@@ -471,7 +470,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
 	}
 	}
 }
 }
 
 
-func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr *udpAddr, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) {
 	// Exit if we don't answer queries
 	// Exit if we don't answer queries
 	if !lhh.lh.amLighthouse {
 	if !lhh.lh.amLighthouse {
 		if lhh.l.Level >= logrus.DebugLevel {
 		if lhh.l.Level >= logrus.DebugLevel {
@@ -481,12 +480,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
 	}
 	}
 
 
 	//TODO: we can DRY this further
 	//TODO: we can DRY this further
-	reqVpnIP := n.Details.VpnIp
+	reqVpnIp := n.Details.VpnIp
 	//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
 	//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
-	found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(c *cache) (int, error) {
+	found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) {
 		n = lhh.resetMeta()
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostQueryReply
 		n.Type = NebulaMeta_HostQueryReply
-		n.Details.VpnIp = reqVpnIP
+		n.Details.VpnIp = reqVpnIp
 
 
 		lhh.coalesceAnswers(c, n)
 		lhh.coalesceAnswers(c, n)
 
 
@@ -498,18 +497,18 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
 	}
 	}
 
 
 	if err != nil {
 	if err != nil {
-		lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
+		lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host query reply")
 		return
 		return
 	}
 	}
 
 
 	lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
 	lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
-	w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
+	w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 
 
 	// This signals the other side to punch some zero byte udp packets
 	// This signals the other side to punch some zero byte udp packets
 	found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
 	found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
 		n = lhh.resetMeta()
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostPunchNotification
 		n.Type = NebulaMeta_HostPunchNotification
-		n.Details.VpnIp = vpnIp
+		n.Details.VpnIp = uint32(vpnIp)
 
 
 		lhh.coalesceAnswers(c, n)
 		lhh.coalesceAnswers(c, n)
 
 
@@ -521,12 +520,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
 	}
 	}
 
 
 	if err != nil {
 	if err != nil {
-		lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host was queried for")
+		lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host was queried for")
 		return
 		return
 	}
 	}
 
 
 	lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
 	lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
-	w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0])
+	w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0])
 }
 }
 
 
 func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
 func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
@@ -549,28 +548,29 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
 	}
 	}
 }
 }
 
 
-func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) {
+func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 		return
 		return
 	}
 	}
 
 
 	lhh.lh.Lock()
 	lhh.lh.Lock()
-	am := lhh.lh.unlockedGetRemoteList(n.Details.VpnIp)
+	am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp))
 	am.Lock()
 	am.Lock()
 	lhh.lh.Unlock()
 	lhh.lh.Unlock()
 
 
-	am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
-	am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
+	certVpnIp := iputil.VpnIp(n.Details.VpnIp)
+	am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
+	am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
 	am.Unlock()
 	am.Unlock()
 
 
 	// Non-blocking attempt to trigger, skip if it would block
 	// Non-blocking attempt to trigger, skip if it would block
 	select {
 	select {
-	case lhh.lh.handshakeTrigger <- n.Details.VpnIp:
+	case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp):
 	default:
 	default:
 	}
 	}
 }
 }
 
 
-func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp uint32) {
+func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp) {
 	if !lhh.lh.amLighthouse {
 	if !lhh.lh.amLighthouse {
 		if lhh.l.Level >= logrus.DebugLevel {
 		if lhh.l.Level >= logrus.DebugLevel {
 			lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
 			lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
@@ -579,9 +579,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	}
 	}
 
 
 	//Simple check that the host sent this not someone else
 	//Simple check that the host sent this not someone else
-	if n.Details.VpnIp != vpnIp {
+	if n.Details.VpnIp != uint32(vpnIp) {
 		if lhh.l.Level >= logrus.DebugLevel {
 		if lhh.l.Level >= logrus.DebugLevel {
-			lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
+			lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
 		}
 		}
 		return
 		return
 	}
 	}
@@ -591,18 +591,19 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	am.Lock()
 	am.Lock()
 	lhh.lh.Unlock()
 	lhh.lh.Unlock()
 
 
-	am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
-	am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
+	certVpnIp := iputil.VpnIp(n.Details.VpnIp)
+	am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
+	am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
 	am.Unlock()
 	am.Unlock()
 }
 }
 
 
-func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 		return
 		return
 	}
 	}
 
 
 	empty := []byte{0}
 	empty := []byte{0}
-	punch := func(vpnPeer *udpAddr) {
+	punch := func(vpnPeer *udp.Addr) {
 		if vpnPeer == nil {
 		if vpnPeer == nil {
 			return
 			return
 		}
 		}
@@ -615,7 +616,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
 
 
 		if lhh.l.Level >= logrus.DebugLevel {
 		if lhh.l.Level >= logrus.DebugLevel {
 			//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
 			//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
-			lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, IntIp(n.Details.VpnIp))
+			lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp))
 		}
 		}
 	}
 	}
 
 
@@ -634,18 +635,18 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
 		go func() {
 		go func() {
 			time.Sleep(time.Second * 5)
 			time.Sleep(time.Second * 5)
 			if lhh.l.Level >= logrus.DebugLevel {
 			if lhh.l.Level >= logrus.DebugLevel {
-				lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
+				lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", iputil.VpnIp(n.Details.VpnIp))
 			}
 			}
 			//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
 			//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
 			// for each punchBack packet. We should move this into a timerwheel or a single goroutine
 			// for each punchBack packet. We should move this into a timerwheel or a single goroutine
 			// managed by a channel.
 			// managed by a channel.
-			w.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			w.SendMessageToVpnIp(header.Test, header.TestRequest, iputil.VpnIp(n.Details.VpnIp), []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 		}()
 		}()
 	}
 	}
 }
 }
 
 
 // ipMaskContains checks if testIp is contained by ip after applying a cidr
 // ipMaskContains checks if testIp is contained by ip after applying a cidr
 // zeros is 32 - bits from net.IPMask.Size()
 // zeros is 32 - bits from net.IPMask.Size()
-func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool {
+func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool {
 	return (testIp^ip)>>zeros == 0
 	return (testIp^ip)>>zeros == 0
 }
 }

+ 72 - 68
lighthouse_test.go

@@ -6,6 +6,10 @@ import (
 	"testing"
 	"testing"
 
 
 	"github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/proto"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
@@ -17,12 +21,12 @@ func TestOldIPv4Only(t *testing.T) {
 	var m Ip4AndPort
 	var m Ip4AndPort
 	err := proto.Unmarshal(b, &m)
 	err := proto.Unmarshal(b, &m)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
-	assert.Equal(t, "10.1.1.1", int2ip(m.GetIp()).String())
+	assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String())
 }
 }
 
 
 func TestNewLhQuery(t *testing.T) {
 func TestNewLhQuery(t *testing.T) {
 	myIp := net.ParseIP("192.1.1.1")
 	myIp := net.ParseIP("192.1.1.1")
-	myIpint := ip2int(myIp)
+	myIpint := iputil.Ip2VpnIp(myIp)
 
 
 	// Generating a new lh query should work
 	// Generating a new lh query should work
 	a := NewLhQueryByInt(myIpint)
 	a := NewLhQueryByInt(myIpint)
@@ -42,37 +46,37 @@ func TestNewLhQuery(t *testing.T) {
 }
 }
 
 
 func Test_lhStaticMapping(t *testing.T) {
 func Test_lhStaticMapping(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	lh1 := "10.128.0.2"
 	lh1 := "10.128.0.2"
 	lh1IP := net.ParseIP(lh1)
 	lh1IP := net.ParseIP(lh1)
 
 
-	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
+	udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
 
 
-	meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
-	meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
+	meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false)
+	meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242)))
 	err := meh.ValidateLHStaticEntries()
 	err := meh.ValidateLHStaticEntries()
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
 	lh2 := "10.128.0.3"
 	lh2 := "10.128.0.3"
 	lh2IP := net.ParseIP(lh2)
 	lh2IP := net.ParseIP(lh2)
 
 
-	meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
-	meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
+	meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP), iputil.Ip2VpnIp(lh2IP)}, 10, 10003, udpServer, false, 1, false)
+	meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242)))
 	err = meh.ValidateLHStaticEntries()
 	err = meh.ValidateLHStaticEntries()
 	assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
 	assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 }
 
 
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	lh1 := "10.128.0.2"
 	lh1 := "10.128.0.2"
 	lh1IP := net.ParseIP(lh1)
 	lh1IP := net.ParseIP(lh1)
 
 
-	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
+	udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
 
 
-	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
+	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false)
 
 
-	hAddr := NewUDPAddrFromString("4.5.6.7:12345")
-	hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
+	hAddr := udp.NewAddrFromString("4.5.6.7:12345")
+	hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
 	lh.addrMap[3] = NewRemoteList()
 	lh.addrMap[3] = NewRemoteList()
 	lh.addrMap[3].unlockedSetV4(
 	lh.addrMap[3].unlockedSetV4(
 		3,
 		3,
@@ -81,11 +85,11 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 			NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
 			NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
 			NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
 			NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
 		},
 		},
-		func(uint32, *Ip4AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
 	)
 	)
 
 
-	rAddr := NewUDPAddrFromString("1.2.2.3:12345")
-	rAddr2 := NewUDPAddrFromString("1.2.2.3:12346")
+	rAddr := udp.NewAddrFromString("1.2.2.3:12345")
+	rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
 	lh.addrMap[2] = NewRemoteList()
 	lh.addrMap[2] = NewRemoteList()
 	lh.addrMap[2].unlockedSetV4(
 	lh.addrMap[2].unlockedSetV4(
 		3,
 		3,
@@ -94,7 +98,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 			NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
 			NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
 			NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
 			NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
 		},
 		},
-		func(uint32, *Ip4AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
 	)
 	)
 
 
 	mw := &mockEncWriter{}
 	mw := &mockEncWriter{}
@@ -133,50 +137,50 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 }
 }
 
 
 func TestLighthouse_Memory(t *testing.T) {
 func TestLighthouse_Memory(t *testing.T) {
-	l := NewTestLogger()
-
-	myUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
-	myUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
-	myUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
-	myUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
-	myUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
-	myUdpAddr5 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
-	myUdpAddr6 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
-	myUdpAddr7 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
-	myUdpAddr8 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
-	myUdpAddr9 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
-	myUdpAddr10 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
-	myUdpAddr11 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
-	myVpnIp := ip2int(net.ParseIP("10.128.0.2"))
-
-	theirUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
-	theirUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
-	theirUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
-	theirUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
-	theirUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
-	theirVpnIp := ip2int(net.ParseIP("10.128.0.3"))
-
-	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
-	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{}, 10, 10003, udpServer, false, 1, false)
+	l := util.NewTestLogger()
+
+	myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
+	myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
+	myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
+	myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
+	myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
+	myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
+	myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
+	myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
+	myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
+	myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
+	myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
+	myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
+	myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2"))
+
+	theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
+	theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
+	theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
+	theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
+	theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
+	theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3"))
+
+	udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
+	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []iputil.VpnIp{}, 10, 10003, udpServer, false, 1, false)
 	lhh := lh.NewRequestHandler()
 	lhh := lh.NewRequestHandler()
 
 
 	// Test that my first update responds with just that
 	// Test that my first update responds with just that
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr2}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh)
 	r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
 
 
 	// Ensure we don't accumulate addresses
 	// Ensure we don't accumulate addresses
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr3}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
 
 
 	// Grow it back to 2
 	// Grow it back to 2
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr4}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
 
 
 	// Update a different host
 	// Update a different host
-	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udpAddr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
+	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
 
@@ -189,7 +193,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	newLHHostUpdate(
 	newLHHostUpdate(
 		myUdpAddr0,
 		myUdpAddr0,
 		myVpnIp,
 		myVpnIp,
-		[]*udpAddr{
+		[]*udp.Addr{
 			myUdpAddr1,
 			myUdpAddr1,
 			myUdpAddr2,
 			myUdpAddr2,
 			myUdpAddr3,
 			myUdpAddr3,
@@ -212,19 +216,19 @@ func TestLighthouse_Memory(t *testing.T) {
 	)
 	)
 
 
 	// Make sure we won't add ips in our vpn network
 	// Make sure we won't add ips in our vpn network
-	bad1 := &udpAddr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
-	bad2 := &udpAddr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
-	good := &udpAddr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{bad1, bad2, good}, lhh)
+	bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
+	bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
+	good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
 }
 }
 
 
-func newLHHostRequest(fromAddr *udpAddr, myVpnIp, queryVpnIp uint32, lhh *LightHouseHandler) testLhReply {
+func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {
 	req := &NebulaMeta{
 	req := &NebulaMeta{
 		Type: NebulaMeta_HostQuery,
 		Type: NebulaMeta_HostQuery,
 		Details: &NebulaMetaDetails{
 		Details: &NebulaMetaDetails{
-			VpnIp: queryVpnIp,
+			VpnIp: uint32(queryVpnIp),
 		},
 		},
 	}
 	}
 
 
@@ -238,17 +242,17 @@ func newLHHostRequest(fromAddr *udpAddr, myVpnIp, queryVpnIp uint32, lhh *LightH
 	return w.lastReply
 	return w.lastReply
 }
 }
 
 
-func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *LightHouseHandler) {
+func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) {
 	req := &NebulaMeta{
 	req := &NebulaMeta{
 		Type: NebulaMeta_HostUpdateNotification,
 		Type: NebulaMeta_HostUpdateNotification,
 		Details: &NebulaMetaDetails{
 		Details: &NebulaMetaDetails{
-			VpnIp:       vpnIp,
+			VpnIp:       uint32(vpnIp),
 			Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
 			Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
 		},
 		},
 	}
 	}
 
 
 	for k, v := range addrs {
 	for k, v := range addrs {
-		req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: ip2int(v.IP), Port: uint32(v.Port)}
+		req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)}
 	}
 	}
 
 
 	b, err := req.Marshal()
 	b, err := req.Marshal()
@@ -327,15 +331,15 @@ func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *Lig
 //}
 //}
 
 
 func Test_ipMaskContains(t *testing.T) {
 func Test_ipMaskContains(t *testing.T) {
-	assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255"))))
-	assert.False(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.1.1"))))
-	assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32, ip2int(net.ParseIP("10.0.1.1"))))
+	assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255"))))
+	assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
+	assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
 }
 }
 
 
 type testLhReply struct {
 type testLhReply struct {
-	nebType    NebulaMessageType
-	nebSubType NebulaMessageSubType
-	vpnIp      uint32
+	nebType    header.MessageType
+	nebSubType header.MessageSubType
+	vpnIp      iputil.VpnIp
 	msg        *NebulaMeta
 	msg        *NebulaMeta
 }
 }
 
 
@@ -343,7 +347,7 @@ type testEncWriter struct {
 	lastReply testLhReply
 	lastReply testLhReply
 }
 }
 
 
-func (tw *testEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, _, _ []byte) {
+func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
 	tw.lastReply = testLhReply{
 	tw.lastReply = testLhReply{
 		nebType:    t,
 		nebType:    t,
 		nebSubType: st,
 		nebSubType: st,
@@ -358,17 +362,17 @@ func (tw *testEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessag
 }
 }
 
 
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
-func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udpAddr) {
+func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) {
 	assert.Len(t, have, len(want))
 	assert.Len(t, have, len(want))
 	for k, w := range want {
 	for k, w := range want {
-		if !(have[k].Ip == ip2int(w.IP) && have[k].Port == uint32(w.Port)) {
+		if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) {
 			assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
 			assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
 		}
 		}
 	}
 	}
 }
 }
 
 
 // assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
 // assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
-func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
+func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) {
 	assert.Len(t, have, len(want))
 	assert.Len(t, have, len(want))
 	for k, w := range want {
 	for k, w := range want {
 		if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
 		if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
@@ -377,8 +381,8 @@ func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
 	}
 	}
 }
 }
 
 
-func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr {
-	addrs := make([]*udpAddr, len(ips))
+func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr {
+	addrs := make([]*udp.Addr, len(ips))
 	for k, v := range ips {
 	for k, v := range ips {
 		addrs[k] = NewUDPAddrFromLH4(v)
 		addrs[k] = NewUDPAddrFromLH4(v)
 	}
 	}

+ 39 - 0
logger.go

@@ -2,8 +2,12 @@ package nebula
 
 
 import (
 import (
 	"errors"
 	"errors"
+	"fmt"
+	"strings"
+	"time"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
 )
 )
 
 
 type ContextualError struct {
 type ContextualError struct {
@@ -37,3 +41,38 @@ func (ce *ContextualError) Log(lr *logrus.Logger) {
 		lr.WithFields(ce.Fields).Error(ce.Context)
 		lr.WithFields(ce.Fields).Error(ce.Context)
 	}
 	}
 }
 }
+
+func configLogger(l *logrus.Logger, c *config.C) error {
+	// set up our logging level
+	logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
+	if err != nil {
+		return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
+	}
+	l.SetLevel(logLevel)
+
+	disableTimestamp := c.GetBool("logging.disable_timestamp", false)
+	timestampFormat := c.GetString("logging.timestamp_format", "")
+	fullTimestamp := (timestampFormat != "")
+	if timestampFormat == "" {
+		timestampFormat = time.RFC3339
+	}
+
+	logFormat := strings.ToLower(c.GetString("logging.format", "text"))
+	switch logFormat {
+	case "text":
+		l.Formatter = &logrus.TextFormatter{
+			TimestampFormat:  timestampFormat,
+			FullTimestamp:    fullTimestamp,
+			DisableTimestamp: disableTimestamp,
+		}
+	case "json":
+		l.Formatter = &logrus.JSONFormatter{
+			TimestampFormat:  timestampFormat,
+			DisableTimestamp: disableTimestamp,
+		}
+	default:
+		return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
+	}
+
+	return nil
+}

+ 71 - 68
main.go

@@ -8,14 +8,16 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/sshd"
 	"github.com/slackhq/nebula/sshd"
+	"github.com/slackhq/nebula/udp"
 	"gopkg.in/yaml.v2"
 	"gopkg.in/yaml.v2"
 )
 )
 
 
 type m map[string]interface{}
 type m map[string]interface{}
 
 
-func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
-
+func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
 	ctx, cancel := context.WithCancel(context.Background())
 	ctx, cancel := context.WithCancel(context.Background())
 	// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
 	// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
 	defer func() {
 	defer func() {
@@ -31,7 +33,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 
 
 	// Print the config if in test, the exit comes later
 	// Print the config if in test, the exit comes later
 	if configTest {
 	if configTest {
-		b, err := yaml.Marshal(config.Settings)
+		b, err := yaml.Marshal(c.Settings)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -40,33 +42,33 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		l.Println(string(b))
 		l.Println(string(b))
 	}
 	}
 
 
-	err := configLogger(config)
+	err := configLogger(l, c)
 	if err != nil {
 	if err != nil {
 		return nil, NewContextualError("Failed to configure the logger", nil, err)
 		return nil, NewContextualError("Failed to configure the logger", nil, err)
 	}
 	}
 
 
-	config.RegisterReloadCallback(func(c *Config) {
-		err := configLogger(c)
+	c.RegisterReloadCallback(func(c *config.C) {
+		err := configLogger(l, c)
 		if err != nil {
 		if err != nil {
 			l.WithError(err).Error("Failed to configure the logger")
 			l.WithError(err).Error("Failed to configure the logger")
 		}
 		}
 	})
 	})
 
 
-	caPool, err := loadCAFromConfig(l, config)
+	caPool, err := loadCAFromConfig(l, c)
 	if err != nil {
 	if err != nil {
 		//The errors coming out of loadCA are already nicely formatted
 		//The errors coming out of loadCA are already nicely formatted
 		return nil, NewContextualError("Failed to load ca from config", nil, err)
 		return nil, NewContextualError("Failed to load ca from config", nil, err)
 	}
 	}
 	l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
 	l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
 
 
-	cs, err := NewCertStateFromConfig(config)
+	cs, err := NewCertStateFromConfig(c)
 	if err != nil {
 	if err != nil {
 		//The errors coming out of NewCertStateFromConfig are already nicely formatted
 		//The errors coming out of NewCertStateFromConfig are already nicely formatted
 		return nil, NewContextualError("Failed to load certificate from config", nil, err)
 		return nil, NewContextualError("Failed to load certificate from config", nil, err)
 	}
 	}
 	l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
 	l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
 
 
-	fw, err := NewFirewallFromConfig(l, cs.certificate, config)
+	fw, err := NewFirewallFromConfig(l, cs.certificate, c)
 	if err != nil {
 	if err != nil {
 		return nil, NewContextualError("Error while loading firewall rules", nil, err)
 		return nil, NewContextualError("Error while loading firewall rules", nil, err)
 	}
 	}
@@ -74,20 +76,20 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 
 
 	// TODO: make sure mask is 4 bytes
 	// TODO: make sure mask is 4 bytes
 	tunCidr := cs.certificate.Details.Ips[0]
 	tunCidr := cs.certificate.Details.Ips[0]
-	routes, err := parseRoutes(config, tunCidr)
+	routes, err := parseRoutes(c, tunCidr)
 	if err != nil {
 	if err != nil {
 		return nil, NewContextualError("Could not parse tun.routes", nil, err)
 		return nil, NewContextualError("Could not parse tun.routes", nil, err)
 	}
 	}
-	unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
+	unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
 	if err != nil {
 	if err != nil {
 		return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
 		return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
 	}
 	}
 
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
-	wireSSHReload(l, ssh, config)
+	wireSSHReload(l, ssh, c)
 	var sshStart func()
 	var sshStart func()
-	if config.GetBool("sshd.enabled", false) {
-		sshStart, err = configSSH(l, ssh, config)
+	if c.GetBool("sshd.enabled", false) {
+		sshStart, err = configSSH(l, ssh, c)
 		if err != nil {
 		if err != nil {
 			return nil, NewContextualError("Error while configuring the sshd", nil, err)
 			return nil, NewContextualError("Error while configuring the sshd", nil, err)
 		}
 		}
@@ -101,7 +103,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	var routines int
 	var routines int
 
 
 	// If `routines` is set, use that and ignore the specific values
 	// If `routines` is set, use that and ignore the specific values
-	if routines = config.GetInt("routines", 0); routines != 0 {
+	if routines = c.GetInt("routines", 0); routines != 0 {
 		if routines < 1 {
 		if routines < 1 {
 			routines = 1
 			routines = 1
 		}
 		}
@@ -110,8 +112,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		}
 		}
 	} else {
 	} else {
 		// deprecated and undocumented
 		// deprecated and undocumented
-		tunQueues := config.GetInt("tun.routines", 1)
-		udpQueues := config.GetInt("listen.routines", 1)
+		tunQueues := c.GetInt("tun.routines", 1)
+		udpQueues := c.GetInt("listen.routines", 1)
 		if tunQueues > udpQueues {
 		if tunQueues > udpQueues {
 			routines = tunQueues
 			routines = tunQueues
 		} else {
 		} else {
@@ -125,8 +127,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	// EXPERIMENTAL
 	// EXPERIMENTAL
 	// Intentionally not documented yet while we do more testing and determine
 	// Intentionally not documented yet while we do more testing and determine
 	// a good default value.
 	// a good default value.
-	conntrackCacheTimeout := config.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
-	if routines > 1 && !config.IsSet("firewall.conntrack.routine_cache_timeout") {
+	conntrackCacheTimeout := c.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
+	if routines > 1 && !c.IsSet("firewall.conntrack.routine_cache_timeout") {
 		// Use a different default if we are running with multiple routines
 		// Use a different default if we are running with multiple routines
 		conntrackCacheTimeout = 1 * time.Second
 		conntrackCacheTimeout = 1 * time.Second
 	}
 	}
@@ -136,30 +138,30 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 
 
 	var tun Inside
 	var tun Inside
 	if !configTest {
 	if !configTest {
-		config.CatchHUP(ctx)
+		c.CatchHUP(ctx)
 
 
 		switch {
 		switch {
-		case config.GetBool("tun.disabled", false):
-			tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l)
+		case c.GetBool("tun.disabled", false):
+			tun = newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
 		case tunFd != nil:
 		case tunFd != nil:
 			tun, err = newTunFromFd(
 			tun, err = newTunFromFd(
 				l,
 				l,
 				*tunFd,
 				*tunFd,
 				tunCidr,
 				tunCidr,
-				config.GetInt("tun.mtu", DEFAULT_MTU),
+				c.GetInt("tun.mtu", DEFAULT_MTU),
 				routes,
 				routes,
 				unsafeRoutes,
 				unsafeRoutes,
-				config.GetInt("tun.tx_queue", 500),
+				c.GetInt("tun.tx_queue", 500),
 			)
 			)
 		default:
 		default:
 			tun, err = newTun(
 			tun, err = newTun(
 				l,
 				l,
-				config.GetString("tun.dev", ""),
+				c.GetString("tun.dev", ""),
 				tunCidr,
 				tunCidr,
-				config.GetInt("tun.mtu", DEFAULT_MTU),
+				c.GetInt("tun.mtu", DEFAULT_MTU),
 				routes,
 				routes,
 				unsafeRoutes,
 				unsafeRoutes,
-				config.GetInt("tun.tx_queue", 500),
+				c.GetInt("tun.tx_queue", 500),
 				routines > 1,
 				routines > 1,
 			)
 			)
 		}
 		}
@@ -176,16 +178,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	}()
 	}()
 
 
 	// set up our UDP listener
 	// set up our UDP listener
-	udpConns := make([]*udpConn, routines)
-	port := config.GetInt("listen.port", 0)
+	udpConns := make([]*udp.Conn, routines)
+	port := c.GetInt("listen.port", 0)
 
 
 	if !configTest {
 	if !configTest {
 		for i := 0; i < routines; i++ {
 		for i := 0; i < routines; i++ {
-			udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
+			udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64))
 			if err != nil {
 			if err != nil {
 				return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 				return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 			}
 			}
-			udpServer.reloadConfig(config)
+			udpServer.ReloadConfig(c)
 			udpConns[i] = udpServer
 			udpConns[i] = udpServer
 
 
 			// If port is dynamic, discover it
 			// If port is dynamic, discover it
@@ -201,7 +203,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 
 
 	// Set up my internal host map
 	// Set up my internal host map
 	var preferredRanges []*net.IPNet
 	var preferredRanges []*net.IPNet
-	rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
+	rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
 	// First, check if 'preferred_ranges' is set and fallback to 'local_range'
 	// First, check if 'preferred_ranges' is set and fallback to 'local_range'
 	if len(rawPreferredRanges) > 0 {
 	if len(rawPreferredRanges) > 0 {
 		for _, rawPreferredRange := range rawPreferredRanges {
 		for _, rawPreferredRange := range rawPreferredRanges {
@@ -216,7 +218,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	// local_range was superseded by preferred_ranges. If it is still present,
 	// local_range was superseded by preferred_ranges. If it is still present,
 	// merge the local_range setting into preferred_ranges. We will probably
 	// merge the local_range setting into preferred_ranges. We will probably
 	// deprecate local_range and remove in the future.
 	// deprecate local_range and remove in the future.
-	rawLocalRange := config.GetString("local_range", "")
+	rawLocalRange := c.GetString("local_range", "")
 	if rawLocalRange != "" {
 	if rawLocalRange != "" {
 		_, localRange, err := net.ParseCIDR(rawLocalRange)
 		_, localRange, err := net.ParseCIDR(rawLocalRange)
 		if err != nil {
 		if err != nil {
@@ -240,7 +242,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
 	hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
 
 
 	hostMap.addUnsafeRoutes(&unsafeRoutes)
 	hostMap.addUnsafeRoutes(&unsafeRoutes)
-	hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
+	hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
 
 
 	l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
 	l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
 
 
@@ -249,26 +251,26 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		go hostMap.Promoter(config.GetInt("promoter.interval"))
 		go hostMap.Promoter(config.GetInt("promoter.interval"))
 	*/
 	*/
 
 
-	punchy := NewPunchyFromConfig(config)
+	punchy := NewPunchyFromConfig(c)
 	if punchy.Punch && !configTest {
 	if punchy.Punch && !configTest {
 		l.Info("UDP hole punching enabled")
 		l.Info("UDP hole punching enabled")
 		go hostMap.Punchy(ctx, udpConns[0])
 		go hostMap.Punchy(ctx, udpConns[0])
 	}
 	}
 
 
-	amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
+	amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
 
 
 	// fatal if am_lighthouse is enabled but we are using an ephemeral port
 	// fatal if am_lighthouse is enabled but we are using an ephemeral port
-	if amLighthouse && (config.GetInt("listen.port", 0) == 0) {
+	if amLighthouse && (c.GetInt("listen.port", 0) == 0) {
 		return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
 		return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
 	}
 	}
 
 
 	// warn if am_lighthouse is enabled but upstream lighthouses exists
 	// warn if am_lighthouse is enabled but upstream lighthouses exists
-	rawLighthouseHosts := config.GetStringSlice("lighthouse.hosts", []string{})
+	rawLighthouseHosts := c.GetStringSlice("lighthouse.hosts", []string{})
 	if amLighthouse && len(rawLighthouseHosts) != 0 {
 	if amLighthouse && len(rawLighthouseHosts) != 0 {
 		l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
 		l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
 	}
 	}
 
 
-	lighthouseHosts := make([]uint32, len(rawLighthouseHosts))
+	lighthouseHosts := make([]iputil.VpnIp, len(rawLighthouseHosts))
 	for i, host := range rawLighthouseHosts {
 	for i, host := range rawLighthouseHosts {
 		ip := net.ParseIP(host)
 		ip := net.ParseIP(host)
 		if ip == nil {
 		if ip == nil {
@@ -277,7 +279,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		if !tunCidr.Contains(ip) {
 		if !tunCidr.Contains(ip) {
 			return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
 			return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
 		}
 		}
-		lighthouseHosts[i] = ip2int(ip)
+		lighthouseHosts[i] = iputil.Ip2VpnIp(ip)
 	}
 	}
 
 
 	lightHouse := NewLightHouse(
 	lightHouse := NewLightHouse(
@@ -286,47 +288,48 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		tunCidr,
 		tunCidr,
 		lighthouseHosts,
 		lighthouseHosts,
 		//TODO: change to a duration
 		//TODO: change to a duration
-		config.GetInt("lighthouse.interval", 10),
+		c.GetInt("lighthouse.interval", 10),
 		uint32(port),
 		uint32(port),
 		udpConns[0],
 		udpConns[0],
 		punchy.Respond,
 		punchy.Respond,
 		punchy.Delay,
 		punchy.Delay,
-		config.GetBool("stats.lighthouse_metrics", false),
+		c.GetBool("stats.lighthouse_metrics", false),
 	)
 	)
 
 
-	remoteAllowList, err := config.GetRemoteAllowList("lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
+	remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
 	if err != nil {
 	if err != nil {
 		return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
 		return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
 	}
 	}
 	lightHouse.SetRemoteAllowList(remoteAllowList)
 	lightHouse.SetRemoteAllowList(remoteAllowList)
 
 
-	localAllowList, err := config.GetLocalAllowList("lighthouse.local_allow_list")
+	localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
 	if err != nil {
 	if err != nil {
 		return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
 		return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
 	}
 	}
 	lightHouse.SetLocalAllowList(localAllowList)
 	lightHouse.SetLocalAllowList(localAllowList)
 
 
 	//TODO: Move all of this inside functions in lighthouse.go
 	//TODO: Move all of this inside functions in lighthouse.go
-	for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
-		vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
-		if !tunCidr.Contains(vpnIp) {
+	for k, v := range c.GetMap("static_host_map", map[interface{}]interface{}{}) {
+		ip := net.ParseIP(fmt.Sprintf("%v", k))
+		vpnIp := iputil.Ip2VpnIp(ip)
+		if !tunCidr.Contains(ip) {
 			return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
 			return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
 		}
 		}
 		vals, ok := v.([]interface{})
 		vals, ok := v.([]interface{})
 		if ok {
 		if ok {
 			for _, v := range vals {
 			for _, v := range vals {
-				ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
+				ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
 				if err != nil {
 				if err != nil {
 					return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 					return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 				}
 				}
-				lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
+				lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
 			}
 			}
 		} else {
 		} else {
-			ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
+			ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
 			if err != nil {
 			if err != nil {
 				return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 				return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 			}
 			}
-			lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
+			lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
 		}
 		}
 	}
 	}
 
 
@@ -336,16 +339,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	}
 	}
 
 
 	var messageMetrics *MessageMetrics
 	var messageMetrics *MessageMetrics
-	if config.GetBool("stats.message_metrics", false) {
+	if c.GetBool("stats.message_metrics", false) {
 		messageMetrics = newMessageMetrics()
 		messageMetrics = newMessageMetrics()
 	} else {
 	} else {
 		messageMetrics = newMessageMetricsOnlyRecvError()
 		messageMetrics = newMessageMetricsOnlyRecvError()
 	}
 	}
 
 
 	handshakeConfig := HandshakeConfig{
 	handshakeConfig := HandshakeConfig{
-		tryInterval:   config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
-		retries:       config.GetInt("handshakes.retries", DefaultHandshakeRetries),
-		triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
+		tryInterval:   c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
+		retries:       c.GetInt("handshakes.retries", DefaultHandshakeRetries),
+		triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
 
 
 		messageMetrics: messageMetrics,
 		messageMetrics: messageMetrics,
 	}
 	}
@@ -358,36 +361,35 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
 	//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
 
 
 	serveDns := false
 	serveDns := false
-	if config.GetBool("lighthouse.serve_dns", false) {
-		if config.GetBool("lighthouse.am_lighthouse", false) {
+	if c.GetBool("lighthouse.serve_dns", false) {
+		if c.GetBool("lighthouse.am_lighthouse", false) {
 			serveDns = true
 			serveDns = true
 		} else {
 		} else {
 			l.Warn("DNS server refusing to run because this host is not a lighthouse.")
 			l.Warn("DNS server refusing to run because this host is not a lighthouse.")
 		}
 		}
 	}
 	}
 
 
-	checkInterval := config.GetInt("timers.connection_alive_interval", 5)
-	pendingDeletionInterval := config.GetInt("timers.pending_deletion_interval", 10)
+	checkInterval := c.GetInt("timers.connection_alive_interval", 5)
+	pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
 	ifConfig := &InterfaceConfig{
 	ifConfig := &InterfaceConfig{
 		HostMap:                 hostMap,
 		HostMap:                 hostMap,
 		Inside:                  tun,
 		Inside:                  tun,
 		Outside:                 udpConns[0],
 		Outside:                 udpConns[0],
 		certState:               cs,
 		certState:               cs,
-		Cipher:                  config.GetString("cipher", "aes"),
+		Cipher:                  c.GetString("cipher", "aes"),
 		Firewall:                fw,
 		Firewall:                fw,
 		ServeDns:                serveDns,
 		ServeDns:                serveDns,
 		HandshakeManager:        handshakeManager,
 		HandshakeManager:        handshakeManager,
 		lightHouse:              lightHouse,
 		lightHouse:              lightHouse,
 		checkInterval:           checkInterval,
 		checkInterval:           checkInterval,
 		pendingDeletionInterval: pendingDeletionInterval,
 		pendingDeletionInterval: pendingDeletionInterval,
-		DropLocalBroadcast:      config.GetBool("tun.drop_local_broadcast", false),
-		DropMulticast:           config.GetBool("tun.drop_multicast", false),
-		UDPBatchSize:            config.GetInt("listen.batch", 64),
+		DropLocalBroadcast:      c.GetBool("tun.drop_local_broadcast", false),
+		DropMulticast:           c.GetBool("tun.drop_multicast", false),
 		routines:                routines,
 		routines:                routines,
 		MessageMetrics:          messageMetrics,
 		MessageMetrics:          messageMetrics,
 		version:                 buildVersion,
 		version:                 buildVersion,
 		caPool:                  caPool,
 		caPool:                  caPool,
-		disconnectInvalid:       config.GetBool("pki.disconnect_invalid", false),
+		disconnectInvalid:       c.GetBool("pki.disconnect_invalid", false),
 
 
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		l:                     l,
 		l:                     l,
@@ -413,7 +415,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		// I don't want to make this initial commit too far-reaching though
 		// I don't want to make this initial commit too far-reaching though
 		ifce.writers = udpConns
 		ifce.writers = udpConns
 
 
-		ifce.RegisterConfigChangeCallbacks(config)
+		ifce.RegisterConfigChangeCallbacks(c)
 
 
 		go handshakeManager.Run(ctx, ifce)
 		go handshakeManager.Run(ctx, ifce)
 		go lightHouse.LhUpdateWorker(ctx, ifce)
 		go lightHouse.LhUpdateWorker(ctx, ifce)
@@ -421,7 +423,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 
 
 	// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
 	// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
 	// a context so that they can exit when the context is Done.
 	// a context so that they can exit when the context is Done.
-	statsStart, err := startStats(l, config, buildVersion, configTest)
+	statsStart, err := startStats(l, c, buildVersion, configTest)
+
 	if err != nil {
 	if err != nil {
 		return nil, NewContextualError("Failed to start stats emitter", nil, err)
 		return nil, NewContextualError("Failed to start stats emitter", nil, err)
 	}
 	}
@@ -431,7 +434,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	}
 	}
 
 
 	//TODO: check if we _should_ be emitting stats
 	//TODO: check if we _should_ be emitting stats
-	go ifce.emitStats(ctx, config.GetDuration("stats.interval", time.Second*10))
+	go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
 
 
 	attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
 	attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
 
 
@@ -439,7 +442,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	var dnsStart func()
 	var dnsStart func()
 	if amLighthouse && serveDns {
 	if amLighthouse && serveDns {
 		l.Debugln("Starting dns server")
 		l.Debugln("Starting dns server")
-		dnsStart = dnsMain(l, hostMap, config)
+		dnsStart = dnsMain(l, hostMap, c)
 	}
 	}
 
 
 	return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil
 	return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil

+ 5 - 2
message_metrics.go

@@ -4,8 +4,11 @@ import (
 	"fmt"
 	"fmt"
 
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
+	"github.com/slackhq/nebula/header"
 )
 )
 
 
+//TODO: this can probably move into the header package
+
 type MessageMetrics struct {
 type MessageMetrics struct {
 	rx [][]metrics.Counter
 	rx [][]metrics.Counter
 	tx [][]metrics.Counter
 	tx [][]metrics.Counter
@@ -14,7 +17,7 @@ type MessageMetrics struct {
 	txUnknown metrics.Counter
 	txUnknown metrics.Counter
 }
 }
 
 
-func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
+func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
 	if m != nil {
 	if m != nil {
 		if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
 		if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
 			m.rx[t][s].Inc(i)
 			m.rx[t][s].Inc(i)
@@ -23,7 +26,7 @@ func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64
 		}
 		}
 	}
 	}
 }
 }
-func (m *MessageMetrics) Tx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
+func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int64) {
 	if m != nil {
 	if m != nil {
 		if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
 		if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
 			m.tx[t][s].Inc(i)
 			m.tx[t][s].Inc(i)

+ 61 - 57
outside.go

@@ -10,6 +10,10 @@ import (
 	"github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/proto"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv4"
 )
 )
 
 
@@ -17,8 +21,8 @@ const (
 	minFwPacketLen = 4
 	minFwPacketLen = 4
 )
 )
 
 
-func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int, localCache ConntrackCache) {
-	err := header.Parse(packet)
+func (f *Interface) readOutsidePackets(addr *udp.Addr, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
+	err := h.Parse(packet)
 	if err != nil {
 	if err != nil {
 		// TODO: best if we return this and let caller log
 		// TODO: best if we return this and let caller log
 		// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
 		// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
@@ -32,30 +36,30 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 	//l.Error("in packet ", header, packet[HeaderLen:])
 	//l.Error("in packet ", header, packet[HeaderLen:])
 
 
 	// verify if we've seen this index before, otherwise respond to the handshake initiation
 	// verify if we've seen this index before, otherwise respond to the handshake initiation
-	hostinfo, err := f.hostMap.QueryIndex(header.RemoteIndex)
+	hostinfo, err := f.hostMap.QueryIndex(h.RemoteIndex)
 
 
 	var ci *ConnectionState
 	var ci *ConnectionState
 	if err == nil {
 	if err == nil {
 		ci = hostinfo.ConnectionState
 		ci = hostinfo.ConnectionState
 	}
 	}
 
 
-	switch header.Type {
-	case message:
-		if !f.handleEncrypted(ci, addr, header) {
+	switch h.Type {
+	case header.Message:
+		if !f.handleEncrypted(ci, addr, h) {
 			return
 			return
 		}
 		}
 
 
-		f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q, localCache)
+		f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache)
 
 
 		// Fallthrough to the bottom to record incoming traffic
 		// Fallthrough to the bottom to record incoming traffic
 
 
-	case lightHouse:
-		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-		if !f.handleEncrypted(ci, addr, header) {
+	case header.LightHouse:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		if !f.handleEncrypted(ci, addr, h) {
 			return
 			return
 		}
 		}
 
 
-		d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
+		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
 				WithField("packet", packet).
 				WithField("packet", packet).
@@ -66,17 +70,17 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 			return
 			return
 		}
 		}
 
 
-		lhh.HandleRequest(addr, hostinfo.hostId, d, f)
+		lhf(addr, hostinfo.vpnIp, d, f)
 
 
 		// Fallthrough to the bottom to record incoming traffic
 		// Fallthrough to the bottom to record incoming traffic
 
 
-	case test:
-		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-		if !f.handleEncrypted(ci, addr, header) {
+	case header.Test:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		if !f.handleEncrypted(ci, addr, h) {
 			return
 			return
 		}
 		}
 
 
-		d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
+		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
 				WithField("packet", packet).
 				WithField("packet", packet).
@@ -87,11 +91,11 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 			return
 			return
 		}
 		}
 
 
-		if header.Subtype == testRequest {
+		if h.Subtype == header.TestRequest {
 			// This testRequest might be from TryPromoteBest, so we should roam
 			// This testRequest might be from TryPromoteBest, so we should roam
 			// to the new IP address before responding
 			// to the new IP address before responding
 			f.handleHostRoaming(hostinfo, addr)
 			f.handleHostRoaming(hostinfo, addr)
-			f.send(test, testReply, ci, hostinfo, hostinfo.remote, d, nb, out)
+			f.send(header.Test, header.TestReply, ci, hostinfo, hostinfo.remote, d, nb, out)
 		}
 		}
 
 
 		// Fallthrough to the bottom to record incoming traffic
 		// Fallthrough to the bottom to record incoming traffic
@@ -99,19 +103,19 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 		// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
 		// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
 		// are unauthenticated
 		// are unauthenticated
 
 
-	case handshake:
-		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-		HandleIncomingHandshake(f, addr, packet, header, hostinfo)
+	case header.Handshake:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		HandleIncomingHandshake(f, addr, packet, h, hostinfo)
 		return
 		return
 
 
-	case recvError:
-		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-		f.handleRecvError(addr, header)
+	case header.RecvError:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		f.handleRecvError(addr, h)
 		return
 		return
 
 
-	case closeTunnel:
-		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-		if !f.handleEncrypted(ci, addr, header) {
+	case header.CloseTunnel:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		if !f.handleEncrypted(ci, addr, h) {
 			return
 			return
 		}
 		}
 
 
@@ -122,22 +126,22 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 		return
 		return
 
 
 	default:
 	default:
-		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
 		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
 		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
 		return
 		return
 	}
 	}
 
 
 	f.handleHostRoaming(hostinfo, addr)
 	f.handleHostRoaming(hostinfo, addr)
 
 
-	f.connectionManager.In(hostinfo.hostId)
+	f.connectionManager.In(hostinfo.vpnIp)
 }
 }
 
 
 // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
 // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
 func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) {
 func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) {
 	//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
 	//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
-	f.connectionManager.ClearIP(hostInfo.hostId)
-	f.connectionManager.ClearPendingDeletion(hostInfo.hostId)
-	f.lightHouse.DeleteVpnIP(hostInfo.hostId)
+	f.connectionManager.ClearIP(hostInfo.vpnIp)
+	f.connectionManager.ClearPendingDeletion(hostInfo.vpnIp)
+	f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
 
 
 	if hasHostMapLock {
 	if hasHostMapLock {
 		f.hostMap.unlockedDeleteHostInfo(hostInfo)
 		f.hostMap.unlockedDeleteHostInfo(hostInfo)
@@ -148,12 +152,12 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) {
 
 
 // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
 // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
 func (f *Interface) sendCloseTunnel(h *HostInfo) {
 func (f *Interface) sendCloseTunnel(h *HostInfo) {
-	f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
+	f.send(header.CloseTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 }
 }
 
 
-func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
-	if hostDidRoam(hostinfo.remote, addr) {
-		if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) {
+func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) {
+	if !hostinfo.remote.Equals(addr) {
+		if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) {
 			hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
 			hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
 			return
 			return
 		}
 		}
@@ -175,11 +179,11 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
 
 
 }
 }
 
 
-func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool {
+func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool {
 	// If connectionstate exists and the replay protector allows, process packet
 	// If connectionstate exists and the replay protector allows, process packet
 	// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
 	// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
-	if ci == nil || !ci.window.Check(f.l, header.MessageCounter) {
-		f.sendRecvError(addr, header.RemoteIndex)
+	if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
+		f.sendRecvError(addr, h.RemoteIndex)
 		return false
 		return false
 	}
 	}
 
 
@@ -187,7 +191,7 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *
 }
 }
 
 
 // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
 // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
-func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
+func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 	// Do we at least have an ipv4 header worth of data?
 	// Do we at least have an ipv4 header worth of data?
 	if len(data) < ipv4.HeaderLen {
 	if len(data) < ipv4.HeaderLen {
 		return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
 		return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
@@ -215,7 +219,7 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
 
 
 	// Accounting for a variable header length, do we have enough data for our src/dst tuples?
 	// Accounting for a variable header length, do we have enough data for our src/dst tuples?
 	minLen := ihl
 	minLen := ihl
-	if !fp.Fragment && fp.Protocol != fwProtoICMP {
+	if !fp.Fragment && fp.Protocol != firewall.ProtoICMP {
 		minLen += minFwPacketLen
 		minLen += minFwPacketLen
 	}
 	}
 	if len(data) < minLen {
 	if len(data) < minLen {
@@ -224,9 +228,9 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
 
 
 	// Firewall packets are locally oriented
 	// Firewall packets are locally oriented
 	if incoming {
 	if incoming {
-		fp.RemoteIP = binary.BigEndian.Uint32(data[12:16])
-		fp.LocalIP = binary.BigEndian.Uint32(data[16:20])
-		if fp.Fragment || fp.Protocol == fwProtoICMP {
+		fp.RemoteIP = iputil.Ip2VpnIp(data[12:16])
+		fp.LocalIP = iputil.Ip2VpnIp(data[16:20])
+		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.RemotePort = 0
 			fp.LocalPort = 0
 			fp.LocalPort = 0
 		} else {
 		} else {
@@ -234,9 +238,9 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
 			fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
 			fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
 		}
 		}
 	} else {
 	} else {
-		fp.LocalIP = binary.BigEndian.Uint32(data[12:16])
-		fp.RemoteIP = binary.BigEndian.Uint32(data[16:20])
-		if fp.Fragment || fp.Protocol == fwProtoICMP {
+		fp.LocalIP = iputil.Ip2VpnIp(data[12:16])
+		fp.RemoteIP = iputil.Ip2VpnIp(data[16:20])
+		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.RemotePort = 0
 			fp.LocalPort = 0
 			fp.LocalPort = 0
 		} else {
 		} else {
@@ -248,15 +252,15 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
 	return nil
 	return nil
 }
 }
 
 
-func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, header *Header, nb []byte) ([]byte, error) {
+func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, h *header.H, nb []byte) ([]byte, error) {
 	var err error
 	var err error
-	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], mc, nb)
+	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], mc, nb)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
 	if !hostinfo.ConnectionState.window.Update(f.l, mc) {
 	if !hostinfo.ConnectionState.window.Update(f.l, mc) {
-		hostinfo.logger(f.l).WithField("header", header).
+		hostinfo.logger(f.l).WithField("header", h).
 			Debugln("dropping out of window packet")
 			Debugln("dropping out of window packet")
 		return nil, errors.New("out of window packet")
 		return nil, errors.New("out of window packet")
 	}
 	}
@@ -264,10 +268,10 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
 	return out, nil
 	return out, nil
 }
 }
 
 
-func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int, localCache ConntrackCache) {
+func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) {
 	var err error
 	var err error
 
 
-	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
+	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
 	if err != nil {
 	if err != nil {
 		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
 		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
 		//TODO: maybe after build 64 is out? 06/14/2018 - NB
 		//TODO: maybe after build 64 is out? 06/14/2018 - NB
@@ -298,18 +302,18 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 		return
 		return
 	}
 	}
 
 
-	f.connectionManager.In(hostinfo.hostId)
+	f.connectionManager.In(hostinfo.vpnIp)
 	_, err = f.readers[q].Write(out)
 	_, err = f.readers[q].Write(out)
 	if err != nil {
 	if err != nil {
 		f.l.WithError(err).Error("Failed to write to tun")
 		f.l.WithError(err).Error("Failed to write to tun")
 	}
 	}
 }
 }
 
 
-func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
-	f.messageMetrics.Tx(recvError, 0, 1)
+func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
+	f.messageMetrics.Tx(header.RecvError, 0, 1)
 
 
 	//TODO: this should be a signed message so we can trust that we should drop the index
 	//TODO: this should be a signed message so we can trust that we should drop the index
-	b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
+	b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
 	f.outside.WriteTo(b, endpoint)
 	f.outside.WriteTo(b, endpoint)
 	if f.l.Level >= logrus.DebugLevel {
 	if f.l.Level >= logrus.DebugLevel {
 		f.l.WithField("index", index).
 		f.l.WithField("index", index).
@@ -318,7 +322,7 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
 	}
 	}
 }
 }
 
 
-func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
+func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
 	if f.l.Level >= logrus.DebugLevel {
 	if f.l.Level >= logrus.DebugLevel {
 		f.l.WithField("index", h.RemoteIndex).
 		f.l.WithField("index", h.RemoteIndex).
 			WithField("udpAddr", addr).
 			WithField("udpAddr", addr).

+ 9 - 7
outside_test.go

@@ -4,12 +4,14 @@ import (
 	"net"
 	"net"
 	"testing"
 	"testing"
 
 
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv4"
 )
 )
 
 
 func Test_newPacket(t *testing.T) {
 func Test_newPacket(t *testing.T) {
-	p := &FirewallPacket{}
+	p := &firewall.Packet{}
 
 
 	// length fail
 	// length fail
 	err := newPacket([]byte{0, 1}, true, p)
 	err := newPacket([]byte{0, 1}, true, p)
@@ -44,7 +46,7 @@ func Test_newPacket(t *testing.T) {
 		Src:      net.IPv4(10, 0, 0, 1),
 		Src:      net.IPv4(10, 0, 0, 1),
 		Dst:      net.IPv4(10, 0, 0, 2),
 		Dst:      net.IPv4(10, 0, 0, 2),
 		Options:  []byte{0, 1, 0, 2},
 		Options:  []byte{0, 1, 0, 2},
-		Protocol: fwProtoTCP,
+		Protocol: firewall.ProtoTCP,
 	}
 	}
 
 
 	b, _ = h.Marshal()
 	b, _ = h.Marshal()
@@ -52,9 +54,9 @@ func Test_newPacket(t *testing.T) {
 	err = newPacket(b, true, p)
 	err = newPacket(b, true, p)
 
 
 	assert.Nil(t, err)
 	assert.Nil(t, err)
-	assert.Equal(t, p.Protocol, uint8(fwProtoTCP))
-	assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 2)))
-	assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 1)))
+	assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
+	assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
+	assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
 	assert.Equal(t, p.RemotePort, uint16(3))
 	assert.Equal(t, p.RemotePort, uint16(3))
 	assert.Equal(t, p.LocalPort, uint16(4))
 	assert.Equal(t, p.LocalPort, uint16(4))
 
 
@@ -74,8 +76,8 @@ func Test_newPacket(t *testing.T) {
 
 
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Equal(t, p.Protocol, uint8(2))
 	assert.Equal(t, p.Protocol, uint8(2))
-	assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 1)))
-	assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 2)))
+	assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
+	assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
 	assert.Equal(t, p.RemotePort, uint16(6))
 	assert.Equal(t, p.RemotePort, uint16(6))
 	assert.Equal(t, p.LocalPort, uint16(5))
 	assert.Equal(t, p.LocalPort, uint16(5))
 }
 }

+ 6 - 2
punchy.go

@@ -1,6 +1,10 @@
 package nebula
 package nebula
 
 
-import "time"
+import (
+	"time"
+
+	"github.com/slackhq/nebula/config"
+)
 
 
 type Punchy struct {
 type Punchy struct {
 	Punch   bool
 	Punch   bool
@@ -8,7 +12,7 @@ type Punchy struct {
 	Delay   time.Duration
 	Delay   time.Duration
 }
 }
 
 
-func NewPunchyFromConfig(c *Config) *Punchy {
+func NewPunchyFromConfig(c *config.C) *Punchy {
 	p := &Punchy{}
 	p := &Punchy{}
 
 
 	if c.IsSet("punchy.punch") {
 	if c.IsSet("punchy.punch") {

+ 4 - 2
punchy_test.go

@@ -4,12 +4,14 @@ import (
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
 func TestNewPunchyFromConfig(t *testing.T) {
 func TestNewPunchyFromConfig(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
+	l := util.NewTestLogger()
+	c := config.NewC(l)
 
 
 	// Test defaults
 	// Test defaults
 	p := NewPunchyFromConfig(c)
 	p := NewPunchyFromConfig(c)

+ 31 - 28
remote_list.go

@@ -5,14 +5,17 @@ import (
 	"net"
 	"net"
 	"sort"
 	"sort"
 	"sync"
 	"sync"
+
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 // forEachFunc is used to benefit folks that want to do work inside the lock
 // forEachFunc is used to benefit folks that want to do work inside the lock
-type forEachFunc func(addr *udpAddr, preferred bool)
+type forEachFunc func(addr *udp.Addr, preferred bool)
 
 
 // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
 // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
-type checkFuncV4 func(vpnIp uint32, to *Ip4AndPort) bool
-type checkFuncV6 func(vpnIp uint32, to *Ip6AndPort) bool
+type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool
+type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool
 
 
 // CacheMap is a struct that better represents the lighthouse cache for humans
 // CacheMap is a struct that better represents the lighthouse cache for humans
 // The string key is the owners vpnIp
 // The string key is the owners vpnIp
@@ -21,8 +24,8 @@ type CacheMap map[string]*Cache
 // Cache is the other part of CacheMap to better represent the lighthouse cache for humans
 // Cache is the other part of CacheMap to better represent the lighthouse cache for humans
 // We don't reason about ipv4 vs ipv6 here
 // We don't reason about ipv4 vs ipv6 here
 type Cache struct {
 type Cache struct {
-	Learned  []*udpAddr `json:"learned,omitempty"`
-	Reported []*udpAddr `json:"reported,omitempty"`
+	Learned  []*udp.Addr `json:"learned,omitempty"`
+	Reported []*udp.Addr `json:"reported,omitempty"`
 }
 }
 
 
 //TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
 //TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
@@ -53,16 +56,16 @@ type RemoteList struct {
 	sync.RWMutex
 	sync.RWMutex
 
 
 	// A deduplicated set of addresses. Any accessor should lock beforehand.
 	// A deduplicated set of addresses. Any accessor should lock beforehand.
-	addrs []*udpAddr
+	addrs []*udp.Addr
 
 
 	// These are maps to store v4 and v6 addresses per lighthouse
 	// These are maps to store v4 and v6 addresses per lighthouse
 	// Map key is the vpnIp of the person that told us about this the cached entries underneath.
 	// Map key is the vpnIp of the person that told us about this the cached entries underneath.
 	// For learned addresses, this is the vpnIp that sent the packet
 	// For learned addresses, this is the vpnIp that sent the packet
-	cache map[uint32]*cache
+	cache map[iputil.VpnIp]*cache
 
 
 	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
 	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
 	// They should not be tried again during a handshake
 	// They should not be tried again during a handshake
-	badRemotes []*udpAddr
+	badRemotes []*udp.Addr
 
 
 	// A flag that the cache may have changed and addrs needs to be rebuilt
 	// A flag that the cache may have changed and addrs needs to be rebuilt
 	shouldRebuild bool
 	shouldRebuild bool
@@ -71,8 +74,8 @@ type RemoteList struct {
 // NewRemoteList creates a new empty RemoteList
 // NewRemoteList creates a new empty RemoteList
 func NewRemoteList() *RemoteList {
 func NewRemoteList() *RemoteList {
 	return &RemoteList{
 	return &RemoteList{
-		addrs: make([]*udpAddr, 0),
-		cache: make(map[uint32]*cache),
+		addrs: make([]*udp.Addr, 0),
+		cache: make(map[iputil.VpnIp]*cache),
 	}
 	}
 }
 }
 
 
@@ -98,7 +101,7 @@ func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc)
 
 
 // CopyAddrs locks and makes a deep copy of the deduplicated address list
 // CopyAddrs locks and makes a deep copy of the deduplicated address list
 // The deduplication work may need to occur here, so you must pass preferredRanges
 // The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
+func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
 	if r == nil {
 	if r == nil {
 		return nil
 		return nil
 	}
 	}
@@ -107,7 +110,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
 
 
 	r.RLock()
 	r.RLock()
 	defer r.RUnlock()
 	defer r.RUnlock()
-	c := make([]*udpAddr, len(r.addrs))
+	c := make([]*udp.Addr, len(r.addrs))
 	for i, v := range r.addrs {
 	for i, v := range r.addrs {
 		c[i] = v.Copy()
 		c[i] = v.Copy()
 	}
 	}
@@ -118,7 +121,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
 // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
 // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
 // It will mark the deduplicated address list as dirty, so do not call it unless new information is available
 // It will mark the deduplicated address list as dirty, so do not call it unless new information is available
 //TODO: this needs to support the allow list list
 //TODO: this needs to support the allow list list
-func (r *RemoteList) LearnRemote(ownerVpnIp uint32, addr *udpAddr) {
+func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) {
 	r.Lock()
 	r.Lock()
 	defer r.Unlock()
 	defer r.Unlock()
 	if v4 := addr.IP.To4(); v4 != nil {
 	if v4 := addr.IP.To4(); v4 != nil {
@@ -139,8 +142,8 @@ func (r *RemoteList) CopyCache() *CacheMap {
 		c := cm[vpnIp]
 		c := cm[vpnIp]
 		if c == nil {
 		if c == nil {
 			c = &Cache{
 			c = &Cache{
-				Learned:  make([]*udpAddr, 0),
-				Reported: make([]*udpAddr, 0),
+				Learned:  make([]*udp.Addr, 0),
+				Reported: make([]*udp.Addr, 0),
 			}
 			}
 			cm[vpnIp] = c
 			cm[vpnIp] = c
 		}
 		}
@@ -148,7 +151,7 @@ func (r *RemoteList) CopyCache() *CacheMap {
 	}
 	}
 
 
 	for owner, mc := range r.cache {
 	for owner, mc := range r.cache {
-		c := getOrMake(IntIp(owner).String())
+		c := getOrMake(owner.String())
 
 
 		if mc.v4 != nil {
 		if mc.v4 != nil {
 			if mc.v4.learned != nil {
 			if mc.v4.learned != nil {
@@ -175,7 +178,7 @@ func (r *RemoteList) CopyCache() *CacheMap {
 }
 }
 
 
 // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
 // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
-func (r *RemoteList) BlockRemote(bad *udpAddr) {
+func (r *RemoteList) BlockRemote(bad *udp.Addr) {
 	r.Lock()
 	r.Lock()
 	defer r.Unlock()
 	defer r.Unlock()
 
 
@@ -192,11 +195,11 @@ func (r *RemoteList) BlockRemote(bad *udpAddr) {
 }
 }
 
 
 // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
 // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
-func (r *RemoteList) CopyBlockedRemotes() []*udpAddr {
+func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr {
 	r.RLock()
 	r.RLock()
 	defer r.RUnlock()
 	defer r.RUnlock()
 
 
-	c := make([]*udpAddr, len(r.badRemotes))
+	c := make([]*udp.Addr, len(r.badRemotes))
 	for i, v := range r.badRemotes {
 	for i, v := range r.badRemotes {
 		c[i] = v.Copy()
 		c[i] = v.Copy()
 	}
 	}
@@ -228,7 +231,7 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
 }
 }
 
 
 // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
 // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
-func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
+func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
 	for _, v := range r.badRemotes {
 	for _, v := range r.badRemotes {
 		if v.Equals(remote) {
 		if v.Equals(remote) {
 			return true
 			return true
@@ -239,14 +242,14 @@ func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
 
 
 // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
 // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
 // deduplicated address list as dirty
 // deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) {
+func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
 	r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
 }
 }
 
 
 // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // and marks the deduplicated address list as dirty
 // and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4AndPort, check checkFuncV4) {
+func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 
 
@@ -263,7 +266,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4And
 
 
 // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
 // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
 // This is only useful for establishing static hosts
 // This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
+func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 
 
@@ -276,14 +279,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
 
 
 // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
 // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
 // deduplicated address list as dirty
 // deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) {
+func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
 	r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
 }
 }
 
 
 // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // and marks the deduplicated address list as dirty
 // and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6AndPort, check checkFuncV6) {
+func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 
 
@@ -300,7 +303,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6And
 
 
 // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
 // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
 // This is only useful for establishing static hosts
 // This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
+func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 
 
@@ -313,7 +316,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
 
 
 // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
 // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
 // The caller must dirty the learned address cache if required
 // The caller must dirty the learned address cache if required
-func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
+func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
 	am := r.cache[ownerVpnIp]
 	am := r.cache[ownerVpnIp]
 	if am == nil {
 	if am == nil {
 		am = &cache{}
 		am = &cache{}
@@ -328,7 +331,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
 
 
 // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
 // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
 // The caller must dirty the learned address cache if required
 // The caller must dirty the learned address cache if required
-func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp uint32) *cacheV6 {
+func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 {
 	am := r.cache[ownerVpnIp]
 	am := r.cache[ownerVpnIp]
 	if am == nil {
 	if am == nil {
 		am = &cache{}
 		am = &cache{}

+ 33 - 32
remote_list_test.go

@@ -4,6 +4,7 @@ import (
 	"net"
 	"net"
 	"testing"
 	"testing"
 
 
+	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
@@ -13,18 +14,18 @@ func TestRemoteList_Rebuild(t *testing.T) {
 		0,
 		0,
 		0,
 		0,
 		[]*Ip4AndPort{
 		[]*Ip4AndPort{
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped
-			{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is duped
-			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is duped
-			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is a dupe
-			{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},   // this is a dupe
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},   // this is a dupe
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe
 		},
 		},
-		func(uint32, *Ip4AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
 	)
 	)
 
 
 	rl.unlockedSetV6(
 	rl.unlockedSetV6(
@@ -37,7 +38,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
 			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
 			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
 			NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
 			NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
 		},
 		},
-		func(uint32, *Ip6AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip6AndPort) bool { return true },
 	)
 	)
 
 
 	rl.Rebuild([]*net.IPNet{})
 	rl.Rebuild([]*net.IPNet{})
@@ -106,16 +107,16 @@ func BenchmarkFullRebuild(b *testing.B) {
 		0,
 		0,
 		0,
 		0,
 		[]*Ip4AndPort{
 		[]*Ip4AndPort{
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
-			{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},   // this is a dupe
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},   // this is a dupe
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
 		},
 		},
-		func(uint32, *Ip4AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
 	)
 	)
 
 
 	rl.unlockedSetV6(
 	rl.unlockedSetV6(
@@ -127,7 +128,7 @@ func BenchmarkFullRebuild(b *testing.B) {
 			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
 			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
 			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
 			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
 		},
 		},
-		func(uint32, *Ip6AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip6AndPort) bool { return true },
 	)
 	)
 
 
 	b.Run("no preferred", func(b *testing.B) {
 	b.Run("no preferred", func(b *testing.B) {
@@ -171,16 +172,16 @@ func BenchmarkSortRebuild(b *testing.B) {
 		0,
 		0,
 		0,
 		0,
 		[]*Ip4AndPort{
 		[]*Ip4AndPort{
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
-			{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},   // this is a dupe
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},   // this is a dupe
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
 		},
 		},
-		func(uint32, *Ip4AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
 	)
 	)
 
 
 	rl.unlockedSetV6(
 	rl.unlockedSetV6(
@@ -192,7 +193,7 @@ func BenchmarkSortRebuild(b *testing.B) {
 			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
 			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
 			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
 			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
 		},
 		},
-		func(uint32, *Ip6AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip6AndPort) bool { return true },
 	)
 	)
 
 
 	b.Run("no preferred", func(b *testing.B) {
 	b.Run("no preferred", func(b *testing.B) {

+ 30 - 26
ssh.go

@@ -15,7 +15,11 @@ import (
 	"syscall"
 	"syscall"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/sshd"
 	"github.com/slackhq/nebula/sshd"
+	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 type sshListHostMapFlags struct {
 type sshListHostMapFlags struct {
@@ -45,8 +49,8 @@ type sshCreateTunnelFlags struct {
 	Address string
 	Address string
 }
 }
 
 
-func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
-	c.RegisterReloadCallback(func(c *Config) {
+func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
+	c.RegisterReloadCallback(func(c *config.C) {
 		if c.GetBool("sshd.enabled", false) {
 		if c.GetBool("sshd.enabled", false) {
 			sshRun, err := configSSH(l, ssh, c)
 			sshRun, err := configSSH(l, ssh, c)
 			if err != nil {
 			if err != nil {
@@ -66,7 +70,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
 // updates the passed-in SSHServer. On success, it returns a function
 // updates the passed-in SSHServer. On success, it returns a function
 // that callers may invoke to run the configured ssh server. On
 // that callers may invoke to run the configured ssh server. On
 // failure, it returns nil, error.
 // failure, it returns nil, error.
-func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) (func(), error) {
+func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
 	//TODO conntrack list
 	//TODO conntrack list
 	//TODO print firewall rules or hash?
 	//TODO print firewall rules or hash?
 
 
@@ -351,7 +355,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
 
 
 	hm := listHostMap(hostMap)
 	hm := listHostMap(hostMap)
 	sort.Slice(hm, func(i, j int) bool {
 	sort.Slice(hm, func(i, j int) bool {
-		return bytes.Compare(hm[i].VpnIP, hm[j].VpnIP) < 0
+		return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
 	})
 	})
 
 
 	if fs.Json || fs.Pretty {
 	if fs.Json || fs.Pretty {
@@ -368,7 +372,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
 
 
 	} else {
 	} else {
 		for _, v := range hm {
 		for _, v := range hm {
-			err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, v.RemoteAddrs))
+			err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs))
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
@@ -386,7 +390,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 	}
 	}
 
 
 	type lighthouseInfo struct {
 	type lighthouseInfo struct {
-		VpnIP net.IP    `json:"vpnIp"`
+		VpnIp string    `json:"vpnIp"`
 		Addrs *CacheMap `json:"addrs"`
 		Addrs *CacheMap `json:"addrs"`
 	}
 	}
 
 
@@ -395,7 +399,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 	x := 0
 	x := 0
 	for k, v := range lightHouse.addrMap {
 	for k, v := range lightHouse.addrMap {
 		addrMap[x] = lighthouseInfo{
 		addrMap[x] = lighthouseInfo{
-			VpnIP: int2ip(k),
+			VpnIp: k.String(),
 			Addrs: v.CopyCache(),
 			Addrs: v.CopyCache(),
 		}
 		}
 		x++
 		x++
@@ -403,7 +407,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 	lightHouse.RUnlock()
 	lightHouse.RUnlock()
 
 
 	sort.Slice(addrMap, func(i, j int) bool {
 	sort.Slice(addrMap, func(i, j int) bool {
-		return bytes.Compare(addrMap[i].VpnIP, addrMap[j].VpnIP) < 0
+		return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0
 	})
 	})
 
 
 	if fs.Json || fs.Pretty {
 	if fs.Json || fs.Pretty {
@@ -424,7 +428,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
-			err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, string(b)))
+			err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b)))
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
@@ -470,7 +474,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	vpnIp := ip2int(parsedIp)
+	vpnIp := iputil.Ip2VpnIp(parsedIp)
 	if vpnIp == 0 {
 	if vpnIp == 0 {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
@@ -499,19 +503,19 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	vpnIp := ip2int(parsedIp)
+	vpnIp := iputil.Ip2VpnIp(parsedIp)
 	if vpnIp == 0 {
 	if vpnIp == 0 {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
+	hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 	if err != nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
 	}
 
 
 	if !flags.LocalOnly {
 	if !flags.LocalOnly {
 		ifce.send(
 		ifce.send(
-			closeTunnel,
+			header.CloseTunnel,
 			0,
 			0,
 			hostInfo.ConnectionState,
 			hostInfo.ConnectionState,
 			hostInfo,
 			hostInfo,
@@ -542,30 +546,30 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	vpnIp := ip2int(parsedIp)
+	vpnIp := iputil.Ip2VpnIp(parsedIp)
 	if vpnIp == 0 {
 	if vpnIp == 0 {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	hostInfo, _ := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
+	hostInfo, _ := ifce.hostMap.QueryVpnIp(vpnIp)
 	if hostInfo != nil {
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
 		return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
 	}
 	}
 
 
-	hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIP(uint32(vpnIp))
+	hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
 	if hostInfo != nil {
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
 		return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
 	}
 	}
 
 
-	var addr *udpAddr
+	var addr *udp.Addr
 	if flags.Address != "" {
 	if flags.Address != "" {
-		addr = NewUDPAddrFromString(flags.Address)
+		addr = udp.NewAddrFromString(flags.Address)
 		if addr == nil {
 		if addr == nil {
 			return w.WriteLine("Address could not be parsed")
 			return w.WriteLine("Address could not be parsed")
 		}
 		}
 	}
 	}
 
 
-	hostInfo = ifce.handshakeManager.AddVpnIP(vpnIp)
+	hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp)
 	if addr != nil {
 	if addr != nil {
 		hostInfo.SetRemote(addr)
 		hostInfo.SetRemote(addr)
 	}
 	}
@@ -589,7 +593,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine("No address was provided")
 		return w.WriteLine("No address was provided")
 	}
 	}
 
 
-	addr := NewUDPAddrFromString(flags.Address)
+	addr := udp.NewAddrFromString(flags.Address)
 	if addr == nil {
 	if addr == nil {
 		return w.WriteLine("Address could not be parsed")
 		return w.WriteLine("Address could not be parsed")
 	}
 	}
@@ -599,12 +603,12 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	vpnIp := ip2int(parsedIp)
+	vpnIp := iputil.Ip2VpnIp(parsedIp)
 	if vpnIp == 0 {
 	if vpnIp == 0 {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
+	hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 	if err != nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
 	}
@@ -680,12 +684,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		}
 		}
 
 
-		vpnIp := ip2int(parsedIp)
+		vpnIp := iputil.Ip2VpnIp(parsedIp)
 		if vpnIp == 0 {
 		if vpnIp == 0 {
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		}
 		}
 
 
-		hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
+		hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
 		if err != nil {
 		if err != nil {
 			return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 			return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 		}
 		}
@@ -742,12 +746,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	vpnIp := ip2int(parsedIp)
+	vpnIp := iputil.Ip2VpnIp(parsedIp)
 	if vpnIp == 0 {
 	if vpnIp == 0 {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	hostInfo, err := ifce.hostMap.QueryVpnIP(vpnIp)
+	hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 	if err != nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
 	}

+ 4 - 3
stats.go

@@ -15,12 +15,13 @@ import (
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
 )
 )
 
 
 // startStats initializes stats from config. On success, if any futher work
 // startStats initializes stats from config. On success, if any futher work
 // is needed to serve stats, it returns a func to handle that work. If no
 // is needed to serve stats, it returns a func to handle that work. If no
 // work is needed, it'll return nil. On failure, it returns nil, error.
 // work is needed, it'll return nil. On failure, it returns nil, error.
-func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) (func(), error) {
+func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) {
 	mType := c.GetString("stats.type", "")
 	mType := c.GetString("stats.type", "")
 	if mType == "" || mType == "none" {
 	if mType == "" || mType == "none" {
 		return nil, nil
 		return nil, nil
@@ -57,7 +58,7 @@ func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest boo
 	return startFn, nil
 	return startFn, nil
 }
 }
 
 
-func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
+func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error {
 	proto := c.GetString("stats.protocol", "tcp")
 	proto := c.GetString("stats.protocol", "tcp")
 	host := c.GetString("stats.host", "")
 	host := c.GetString("stats.host", "")
 	if host == "" {
 	if host == "" {
@@ -77,7 +78,7 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest
 	return nil
 	return nil
 }
 }
 
 
-func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) (func(), error) {
+func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) {
 	namespace := c.GetString("stats.namespace", "")
 	namespace := c.GetString("stats.namespace", "")
 	subsystem := c.GetString("stats.subsystem", "")
 	subsystem := c.GetString("stats.subsystem", "")
 
 

+ 7 - 5
timeout.go

@@ -2,12 +2,14 @@ package nebula
 
 
 import (
 import (
 	"time"
 	"time"
+
+	"github.com/slackhq/nebula/firewall"
 )
 )
 
 
 // How many timer objects should be cached
 // How many timer objects should be cached
 const timerCacheMax = 50000
 const timerCacheMax = 50000
 
 
-var emptyFWPacket = FirewallPacket{}
+var emptyFWPacket = firewall.Packet{}
 
 
 type TimerWheel struct {
 type TimerWheel struct {
 	// Current tick
 	// Current tick
@@ -42,7 +44,7 @@ type TimeoutList struct {
 
 
 // Represents an item within a tick
 // Represents an item within a tick
 type TimeoutItem struct {
 type TimeoutItem struct {
-	Packet FirewallPacket
+	Packet firewall.Packet
 	Next   *TimeoutItem
 	Next   *TimeoutItem
 }
 }
 
 
@@ -73,8 +75,8 @@ func NewTimerWheel(min, max time.Duration) *TimerWheel {
 	return &tw
 	return &tw
 }
 }
 
 
-// Add will add a FirewallPacket to the wheel in it's proper timeout
-func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem {
+// Add will add a firewall.Packet to the wheel in it's proper timeout
+func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem {
 	// Check and see if we should progress the tick
 	// Check and see if we should progress the tick
 	tw.advance(time.Now())
 	tw.advance(time.Now())
 
 
@@ -103,7 +105,7 @@ func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem
 	return ti
 	return ti
 }
 }
 
 
-func (tw *TimerWheel) Purge() (FirewallPacket, bool) {
+func (tw *TimerWheel) Purge() (firewall.Packet, bool) {
 	if tw.expired.Head == nil {
 	if tw.expired.Head == nil {
 		return emptyFWPacket, false
 		return emptyFWPacket, false
 	}
 	}

+ 4 - 2
timeout_system.go

@@ -3,6 +3,8 @@ package nebula
 import (
 import (
 	"sync"
 	"sync"
 	"time"
 	"time"
+
+	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 // How many timer objects should be cached
 // How many timer objects should be cached
@@ -43,7 +45,7 @@ type SystemTimeoutList struct {
 
 
 // Represents an item within a tick
 // Represents an item within a tick
 type SystemTimeoutItem struct {
 type SystemTimeoutItem struct {
-	Item uint32
+	Item iputil.VpnIp
 	Next *SystemTimeoutItem
 	Next *SystemTimeoutItem
 }
 }
 
 
@@ -74,7 +76,7 @@ func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel {
 	return &tw
 	return &tw
 }
 }
 
 
-func (tw *SystemTimerWheel) Add(v uint32, timeout time.Duration) *SystemTimeoutItem {
+func (tw *SystemTimerWheel) Add(v iputil.VpnIp, timeout time.Duration) *SystemTimeoutItem {
 	tw.lock.Lock()
 	tw.lock.Lock()
 	defer tw.lock.Unlock()
 	defer tw.lock.Unlock()
 
 

+ 4 - 3
timeout_system_test.go

@@ -5,6 +5,7 @@ import (
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
+	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
@@ -51,7 +52,7 @@ func TestSystemTimerWheel_findWheel(t *testing.T) {
 func TestSystemTimerWheel_Add(t *testing.T) {
 func TestSystemTimerWheel_Add(t *testing.T) {
 	tw := NewSystemTimerWheel(time.Second, time.Second*10)
 	tw := NewSystemTimerWheel(time.Second, time.Second*10)
 
 
-	fp1 := ip2int(net.ParseIP("1.2.3.4"))
+	fp1 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
 	tw.Add(fp1, time.Second*1)
 	tw.Add(fp1, time.Second*1)
 
 
 	// Make sure we set head and tail properly
 	// Make sure we set head and tail properly
@@ -62,7 +63,7 @@ func TestSystemTimerWheel_Add(t *testing.T) {
 	assert.Nil(t, tw.wheel[2].Tail.Next)
 	assert.Nil(t, tw.wheel[2].Tail.Next)
 
 
 	// Make sure we only modify head
 	// Make sure we only modify head
-	fp2 := ip2int(net.ParseIP("1.2.3.4"))
+	fp2 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
 	tw.Add(fp2, time.Second*1)
 	tw.Add(fp2, time.Second*1)
 	assert.Equal(t, fp2, tw.wheel[2].Head.Item)
 	assert.Equal(t, fp2, tw.wheel[2].Head.Item)
 	assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
 	assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
@@ -85,7 +86,7 @@ func TestSystemTimerWheel_Purge(t *testing.T) {
 	assert.NotNil(t, tw.lastTick)
 	assert.NotNil(t, tw.lastTick)
 	assert.Equal(t, 0, tw.current)
 	assert.Equal(t, 0, tw.current)
 
 
-	fps := []uint32{9, 10, 11, 12}
+	fps := []iputil.VpnIp{9, 10, 11, 12}
 
 
 	//fp1 := ip2int(net.ParseIP("1.2.3.4"))
 	//fp1 := ip2int(net.ParseIP("1.2.3.4"))
 
 

+ 4 - 3
timeout_test.go

@@ -4,6 +4,7 @@ import (
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
+	"github.com/slackhq/nebula/firewall"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
@@ -50,7 +51,7 @@ func TestTimerWheel_findWheel(t *testing.T) {
 func TestTimerWheel_Add(t *testing.T) {
 func TestTimerWheel_Add(t *testing.T) {
 	tw := NewTimerWheel(time.Second, time.Second*10)
 	tw := NewTimerWheel(time.Second, time.Second*10)
 
 
-	fp1 := FirewallPacket{}
+	fp1 := firewall.Packet{}
 	tw.Add(fp1, time.Second*1)
 	tw.Add(fp1, time.Second*1)
 
 
 	// Make sure we set head and tail properly
 	// Make sure we set head and tail properly
@@ -61,7 +62,7 @@ func TestTimerWheel_Add(t *testing.T) {
 	assert.Nil(t, tw.wheel[2].Tail.Next)
 	assert.Nil(t, tw.wheel[2].Tail.Next)
 
 
 	// Make sure we only modify head
 	// Make sure we only modify head
-	fp2 := FirewallPacket{}
+	fp2 := firewall.Packet{}
 	tw.Add(fp2, time.Second*1)
 	tw.Add(fp2, time.Second*1)
 	assert.Equal(t, fp2, tw.wheel[2].Head.Packet)
 	assert.Equal(t, fp2, tw.wheel[2].Head.Packet)
 	assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet)
 	assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet)
@@ -84,7 +85,7 @@ func TestTimerWheel_Purge(t *testing.T) {
 	assert.NotNil(t, tw.lastTick)
 	assert.NotNil(t, tw.lastTick)
 	assert.Equal(t, 0, tw.current)
 	assert.Equal(t, 0, tw.current)
 
 
-	fps := []FirewallPacket{
+	fps := []firewall.Packet{
 		{LocalIP: 1},
 		{LocalIP: 1},
 		{LocalIP: 2},
 		{LocalIP: 2},
 		{LocalIP: 3},
 		{LocalIP: 3},

+ 7 - 5
tun_common.go

@@ -4,6 +4,8 @@ import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"strconv"
 	"strconv"
+
+	"github.com/slackhq/nebula/config"
 )
 )
 
 
 const DEFAULT_MTU = 1300
 const DEFAULT_MTU = 1300
@@ -14,10 +16,10 @@ type route struct {
 	via   *net.IP
 	via   *net.IP
 }
 }
 
 
-func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
+func parseRoutes(c *config.C, network *net.IPNet) ([]route, error) {
 	var err error
 	var err error
 
 
-	r := config.Get("tun.routes")
+	r := c.Get("tun.routes")
 	if r == nil {
 	if r == nil {
 		return []route{}, nil
 		return []route{}, nil
 	}
 	}
@@ -84,10 +86,10 @@ func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
 	return routes, nil
 	return routes, nil
 }
 }
 
 
-func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
+func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]route, error) {
 	var err error
 	var err error
 
 
-	r := config.Get("tun.unsafe_routes")
+	r := c.Get("tun.unsafe_routes")
 	if r == nil {
 	if r == nil {
 		return []route{}, nil
 		return []route{}, nil
 	}
 	}
@@ -110,7 +112,7 @@ func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
 
 
 		rMtu, ok := m["mtu"]
 		rMtu, ok := m["mtu"]
 		if !ok {
 		if !ok {
-			rMtu = config.GetInt("tun.mtu", DEFAULT_MTU)
+			rMtu = c.GetInt("tun.mtu", DEFAULT_MTU)
 		}
 		}
 
 
 		mtu, ok := rMtu.(int)
 		mtu, ok := rMtu.(int)

+ 6 - 4
tun_test.go

@@ -5,12 +5,14 @@ import (
 	"net"
 	"net"
 	"testing"
 	"testing"
 
 
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
 func Test_parseRoutes(t *testing.T) {
 func Test_parseRoutes(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
+	l := util.NewTestLogger()
+	c := config.NewC(l)
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 
 
 	// test no routes config
 	// test no routes config
@@ -105,8 +107,8 @@ func Test_parseRoutes(t *testing.T) {
 }
 }
 
 
 func Test_parseUnsafeRoutes(t *testing.T) {
 func Test_parseUnsafeRoutes(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
+	l := util.NewTestLogger()
+	c := config.NewC(l)
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 
 
 	// test no routes config
 	// test no routes config

+ 20 - 0
udp/conn.go

@@ -0,0 +1,20 @@
+package udp
+
+import (
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
+)
+
+const MTU = 9001
+
+type EncReader func(
+	addr *Addr,
+	out []byte,
+	packet []byte,
+	header *header.H,
+	fwPacket *firewall.Packet,
+	lhh LightHouseHandlerFunc,
+	nb []byte,
+	q int,
+	localCache firewall.ConntrackCache,
+)

+ 14 - 0
udp/temp.go

@@ -0,0 +1,14 @@
+package udp
+
+import (
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+)
+
+//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
+
+type EncWriter interface {
+	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
+}
+
+type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter)

+ 15 - 13
udp_all.go → udp/udp_all.go

@@ -1,4 +1,4 @@
-package nebula
+package udp
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
@@ -7,32 +7,34 @@ import (
 	"strconv"
 	"strconv"
 )
 )
 
 
-type udpAddr struct {
+type m map[string]interface{}
+
+type Addr struct {
 	IP   net.IP
 	IP   net.IP
 	Port uint16
 	Port uint16
 }
 }
 
 
-func NewUDPAddr(ip net.IP, port uint16) *udpAddr {
-	addr := udpAddr{IP: make([]byte, net.IPv6len), Port: port}
+func NewAddr(ip net.IP, port uint16) *Addr {
+	addr := Addr{IP: make([]byte, net.IPv6len), Port: port}
 	copy(addr.IP, ip.To16())
 	copy(addr.IP, ip.To16())
 	return &addr
 	return &addr
 }
 }
 
 
-func NewUDPAddrFromString(s string) *udpAddr {
-	ip, port, err := parseIPAndPort(s)
+func NewAddrFromString(s string) *Addr {
+	ip, port, err := ParseIPAndPort(s)
 	//TODO: handle err
 	//TODO: handle err
 	_ = err
 	_ = err
-	return &udpAddr{IP: ip.To16(), Port: port}
+	return &Addr{IP: ip.To16(), Port: port}
 }
 }
 
 
-func (ua *udpAddr) Equals(t *udpAddr) bool {
+func (ua *Addr) Equals(t *Addr) bool {
 	if t == nil || ua == nil {
 	if t == nil || ua == nil {
 		return t == nil && ua == nil
 		return t == nil && ua == nil
 	}
 	}
 	return ua.IP.Equal(t.IP) && ua.Port == t.Port
 	return ua.IP.Equal(t.IP) && ua.Port == t.Port
 }
 }
 
 
-func (ua *udpAddr) String() string {
+func (ua *Addr) String() string {
 	if ua == nil {
 	if ua == nil {
 		return "<nil>"
 		return "<nil>"
 	}
 	}
@@ -40,7 +42,7 @@ func (ua *udpAddr) String() string {
 	return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
 	return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
 }
 }
 
 
-func (ua *udpAddr) MarshalJSON() ([]byte, error) {
+func (ua *Addr) MarshalJSON() ([]byte, error) {
 	if ua == nil {
 	if ua == nil {
 		return nil, nil
 		return nil, nil
 	}
 	}
@@ -48,12 +50,12 @@ func (ua *udpAddr) MarshalJSON() ([]byte, error) {
 	return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
 	return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
 }
 }
 
 
-func (ua *udpAddr) Copy() *udpAddr {
+func (ua *Addr) Copy() *Addr {
 	if ua == nil {
 	if ua == nil {
 		return nil
 		return nil
 	}
 	}
 
 
-	nu := udpAddr{
+	nu := Addr{
 		Port: ua.Port,
 		Port: ua.Port,
 		IP:   make(net.IP, len(ua.IP)),
 		IP:   make(net.IP, len(ua.IP)),
 	}
 	}
@@ -62,7 +64,7 @@ func (ua *udpAddr) Copy() *udpAddr {
 	return &nu
 	return &nu
 }
 }
 
 
-func parseIPAndPort(s string) (net.IP, uint16, error) {
+func ParseIPAndPort(s string) (net.IP, uint16, error) {
 	rIp, sPort, err := net.SplitHostPort(s)
 	rIp, sPort, err := net.SplitHostPort(s)
 	if err != nil {
 	if err != nil {
 		return nil, 0, err
 		return nil, 0, err

+ 2 - 2
udp_android.go → udp/udp_android.go

@@ -1,7 +1,7 @@
 //go:build !e2e_testing
 //go:build !e2e_testing
 // +build !e2e_testing
 // +build !e2e_testing
 
 
-package nebula
+package udp
 
 
 import (
 import (
 	"fmt"
 	"fmt"
@@ -34,6 +34,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 	}
 }
 }
 
 
-func (u *udpConn) Rebind() error {
+func (u *Conn) Rebind() error {
 	return nil
 	return nil
 }
 }

+ 2 - 2
udp_darwin.go → udp/udp_darwin.go

@@ -1,7 +1,7 @@
 //go:build !e2e_testing
 //go:build !e2e_testing
 // +build !e2e_testing
 // +build !e2e_testing
 
 
-package nebula
+package udp
 
 
 // Darwin support is primarily implemented in udp_generic, besides NewListenConfig
 // Darwin support is primarily implemented in udp_generic, besides NewListenConfig
 
 
@@ -37,7 +37,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 	}
 }
 }
 
 
-func (u *udpConn) Rebind() error {
+func (u *Conn) Rebind() error {
 	file, err := u.File()
 	file, err := u.File()
 	if err != nil {
 	if err != nil {
 		return err
 		return err

+ 2 - 2
udp_freebsd.go → udp/udp_freebsd.go

@@ -1,7 +1,7 @@
 //go:build !e2e_testing
 //go:build !e2e_testing
 // +build !e2e_testing
 // +build !e2e_testing
 
 
-package nebula
+package udp
 
 
 // FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
 // FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
 
 
@@ -36,6 +36,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 	}
 }
 }
 
 
-func (u *udpConn) Rebind() error {
+func (u *Conn) Rebind() error {
 	return nil
 	return nil
 }
 }

+ 20 - 25
udp_generic.go → udp/udp_generic.go

@@ -5,7 +5,7 @@
 // udp_generic implements the nebula UDP interface in pure Go stdlib. This
 // udp_generic implements the nebula UDP interface in pure Go stdlib. This
 // means it can be used on platforms like Darwin and Windows.
 // means it can be used on platforms like Darwin and Windows.
 
 
-package nebula
+package udp
 
 
 import (
 import (
 	"context"
 	"context"
@@ -13,36 +13,39 @@ import (
 	"net"
 	"net"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
 )
 )
 
 
-type udpConn struct {
+type Conn struct {
 	*net.UDPConn
 	*net.UDPConn
 	l *logrus.Logger
 	l *logrus.Logger
 }
 }
 
 
-func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
+func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) {
 	lc := NewListenConfig(multi)
 	lc := NewListenConfig(multi)
 	pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
 	pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 	if uc, ok := pc.(*net.UDPConn); ok {
 	if uc, ok := pc.(*net.UDPConn); ok {
-		return &udpConn{UDPConn: uc, l: l}, nil
+		return &Conn{UDPConn: uc, l: l}, nil
 	}
 	}
 	return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
 	return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
 }
 }
 
 
-func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
+func (uc *Conn) WriteTo(b []byte, addr *Addr) error {
 	_, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
 	_, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
 	return err
 	return err
 }
 }
 
 
-func (uc *udpConn) LocalAddr() (*udpAddr, error) {
+func (uc *Conn) LocalAddr() (*Addr, error) {
 	a := uc.UDPConn.LocalAddr()
 	a := uc.UDPConn.LocalAddr()
 
 
 	switch v := a.(type) {
 	switch v := a.(type) {
 	case *net.UDPAddr:
 	case *net.UDPAddr:
-		addr := &udpAddr{IP: make([]byte, len(v.IP))}
+		addr := &Addr{IP: make([]byte, len(v.IP))}
 		copy(addr.IP, v.IP)
 		copy(addr.IP, v.IP)
 		addr.Port = uint16(v.Port)
 		addr.Port = uint16(v.Port)
 		return addr, nil
 		return addr, nil
@@ -52,11 +55,11 @@ func (uc *udpConn) LocalAddr() (*udpAddr, error) {
 	}
 	}
 }
 }
 
 
-func (u *udpConn) reloadConfig(c *Config) {
+func (u *Conn) ReloadConfig(c *config.C) {
 	// TODO
 	// TODO
 }
 }
 
 
-func NewUDPStatsEmitter(udpConns []*udpConn) func() {
+func NewUDPStatsEmitter(udpConns []*Conn) func() {
 	// No UDP stats for non-linux
 	// No UDP stats for non-linux
 	return func() {}
 	return func() {}
 }
 }
@@ -65,32 +68,24 @@ type rawMessage struct {
 	Len uint32
 	Len uint32
 }
 }
 
 
-func (u *udpConn) ListenOut(f *Interface, q int) {
-	plaintext := make([]byte, mtu)
-	buffer := make([]byte, mtu)
-	header := &Header{}
-	fwPacket := &FirewallPacket{}
-	udpAddr := &udpAddr{IP: make([]byte, 16)}
+func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
+	plaintext := make([]byte, MTU)
+	buffer := make([]byte, MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	udpAddr := &Addr{IP: make([]byte, 16)}
 	nb := make([]byte, 12, 12)
 	nb := make([]byte, 12, 12)
 
 
-	lhh := f.lightHouse.NewRequestHandler()
-
-	conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
-
 	for {
 	for {
 		// Just read one packet at a time
 		// Just read one packet at a time
 		n, rua, err := u.ReadFromUDP(buffer)
 		n, rua, err := u.ReadFromUDP(buffer)
 		if err != nil {
 		if err != nil {
-			f.l.WithError(err).Error("Failed to read packets")
+			u.l.WithError(err).Error("Failed to read packets")
 			continue
 			continue
 		}
 		}
 
 
 		udpAddr.IP = rua.IP
 		udpAddr.IP = rua.IP
 		udpAddr.Port = uint16(rua.Port)
 		udpAddr.Port = uint16(rua.Port)
-		f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get(f.l))
+		r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
 	}
 	}
 }
 }
-
-func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
-	return !addr.Equals(newaddr)
-}

+ 29 - 33
udp_linux.go → udp/udp_linux.go

@@ -1,7 +1,7 @@
 //go:build !android && !e2e_testing
 //go:build !android && !e2e_testing
 // +build !android,!e2e_testing
 // +build !android,!e2e_testing
 
 
-package nebula
+package udp
 
 
 import (
 import (
 	"encoding/binary"
 	"encoding/binary"
@@ -12,14 +12,18 @@ import (
 
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
 )
 )
 
 
 //TODO: make it support reload as best you can!
 //TODO: make it support reload as best you can!
 
 
-type udpConn struct {
+type Conn struct {
 	sysFd int
 	sysFd int
 	l     *logrus.Logger
 	l     *logrus.Logger
+	batch int
 }
 }
 
 
 var x int
 var x int
@@ -41,7 +45,7 @@ const (
 
 
 type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
 type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
 
 
-func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
+func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) {
 	syscall.ForkLock.RLock()
 	syscall.ForkLock.RLock()
 	fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
 	fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
 	if err == nil {
 	if err == nil {
@@ -73,36 +77,36 @@ func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, e
 	//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
 	//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
 	//l.Println(v, err)
 	//l.Println(v, err)
 
 
-	return &udpConn{sysFd: fd, l: l}, err
+	return &Conn{sysFd: fd, l: l, batch: batch}, err
 }
 }
 
 
-func (u *udpConn) Rebind() error {
+func (u *Conn) Rebind() error {
 	return nil
 	return nil
 }
 }
 
 
-func (u *udpConn) SetRecvBuffer(n int) error {
+func (u *Conn) SetRecvBuffer(n int) error {
 	return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
 	return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
 }
 }
 
 
-func (u *udpConn) SetSendBuffer(n int) error {
+func (u *Conn) SetSendBuffer(n int) error {
 	return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
 	return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
 }
 }
 
 
-func (u *udpConn) GetRecvBuffer() (int, error) {
+func (u *Conn) GetRecvBuffer() (int, error) {
 	return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
 	return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
 }
 }
 
 
-func (u *udpConn) GetSendBuffer() (int, error) {
+func (u *Conn) GetSendBuffer() (int, error) {
 	return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
 	return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
 }
 }
 
 
-func (u *udpConn) LocalAddr() (*udpAddr, error) {
+func (u *Conn) LocalAddr() (*Addr, error) {
 	sa, err := unix.Getsockname(u.sysFd)
 	sa, err := unix.Getsockname(u.sysFd)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	addr := &udpAddr{}
+	addr := &Addr{}
 	switch sa := sa.(type) {
 	switch sa := sa.(type) {
 	case *unix.SockaddrInet4:
 	case *unix.SockaddrInet4:
 		addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
 		addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
@@ -115,25 +119,21 @@ func (u *udpConn) LocalAddr() (*udpAddr, error) {
 	return addr, nil
 	return addr, nil
 }
 }
 
 
-func (u *udpConn) ListenOut(f *Interface, q int) {
-	plaintext := make([]byte, mtu)
-	header := &Header{}
-	fwPacket := &FirewallPacket{}
-	udpAddr := &udpAddr{}
+func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
+	plaintext := make([]byte, MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	udpAddr := &Addr{}
 	nb := make([]byte, 12, 12)
 	nb := make([]byte, 12, 12)
 
 
-	lhh := f.lightHouse.NewRequestHandler()
-
 	//TODO: should we track this?
 	//TODO: should we track this?
 	//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
 	//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
-	msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize)
+	msgs, buffers, names := u.PrepareRawMessages(u.batch)
 	read := u.ReadMulti
 	read := u.ReadMulti
-	if f.udpBatchSize == 1 {
+	if u.batch == 1 {
 		read = u.ReadSingle
 		read = u.ReadSingle
 	}
 	}
 
 
-	conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
-
 	for {
 	for {
 		n, err := read(msgs)
 		n, err := read(msgs)
 		if err != nil {
 		if err != nil {
@@ -145,12 +145,12 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
 		for i := 0; i < n; i++ {
 		for i := 0; i < n; i++ {
 			udpAddr.IP = names[i][8:24]
 			udpAddr.IP = names[i][8:24]
 			udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
 			udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
-			f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
+			r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
 		}
 		}
 	}
 	}
 }
 }
 
 
-func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
+func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) {
 	for {
 	for {
 		n, _, err := unix.Syscall6(
 		n, _, err := unix.Syscall6(
 			unix.SYS_RECVMSG,
 			unix.SYS_RECVMSG,
@@ -171,7 +171,7 @@ func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
 	}
 	}
 }
 }
 
 
-func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) {
+func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) {
 	for {
 	for {
 		n, _, err := unix.Syscall6(
 		n, _, err := unix.Syscall6(
 			unix.SYS_RECVMMSG,
 			unix.SYS_RECVMMSG,
@@ -191,7 +191,7 @@ func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) {
 	}
 	}
 }
 }
 
 
-func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
+func (u *Conn) WriteTo(b []byte, addr *Addr) error {
 
 
 	var rsa unix.RawSockaddrInet6
 	var rsa unix.RawSockaddrInet6
 	rsa.Family = unix.AF_INET6
 	rsa.Family = unix.AF_INET6
@@ -221,7 +221,7 @@ func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
 	}
 	}
 }
 }
 
 
-func (u *udpConn) reloadConfig(c *Config) {
+func (u *Conn) ReloadConfig(c *config.C) {
 	b := c.GetInt("listen.read_buffer", 0)
 	b := c.GetInt("listen.read_buffer", 0)
 	if b > 0 {
 	if b > 0 {
 		err := u.SetRecvBuffer(b)
 		err := u.SetRecvBuffer(b)
@@ -253,7 +253,7 @@ func (u *udpConn) reloadConfig(c *Config) {
 	}
 	}
 }
 }
 
 
-func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error {
+func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error {
 	var vallen uint32 = 4 * _SK_MEMINFO_VARS
 	var vallen uint32 = 4 * _SK_MEMINFO_VARS
 	_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
 	_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
 	if err != 0 {
 	if err != 0 {
@@ -262,7 +262,7 @@ func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error {
 	return nil
 	return nil
 }
 }
 
 
-func NewUDPStatsEmitter(udpConns []*udpConn) func() {
+func NewUDPStatsEmitter(udpConns []*Conn) func() {
 	// Check if our kernel supports SO_MEMINFO before registering the gauges
 	// Check if our kernel supports SO_MEMINFO before registering the gauges
 	var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
 	var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
 	var meminfo _SK_MEMINFO
 	var meminfo _SK_MEMINFO
@@ -293,7 +293,3 @@ func NewUDPStatsEmitter(udpConns []*udpConn) func() {
 		}
 		}
 	}
 	}
 }
 }
-
-func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
-	return !addr.Equals(newaddr)
-}

+ 3 - 3
udp_linux_32.go → udp/udp_linux_32.go

@@ -4,7 +4,7 @@
 // +build !android
 // +build !android
 // +build !e2e_testing
 // +build !e2e_testing
 
 
-package nebula
+package udp
 
 
 import (
 import (
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
@@ -30,13 +30,13 @@ type rawMessage struct {
 	Len uint32
 	Len uint32
 }
 }
 
 
-func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
+func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
 	msgs := make([]rawMessage, n)
 	msgs := make([]rawMessage, n)
 	buffers := make([][]byte, n)
 	buffers := make([][]byte, n)
 	names := make([][]byte, n)
 	names := make([][]byte, n)
 
 
 	for i := range msgs {
 	for i := range msgs {
-		buffers[i] = make([]byte, mtu)
+		buffers[i] = make([]byte, MTU)
 		names[i] = make([]byte, unix.SizeofSockaddrInet6)
 		names[i] = make([]byte, unix.SizeofSockaddrInet6)
 
 
 		//TODO: this is still silly, no need for an array
 		//TODO: this is still silly, no need for an array

+ 3 - 3
udp_linux_64.go → udp/udp_linux_64.go

@@ -4,7 +4,7 @@
 // +build !android
 // +build !android
 // +build !e2e_testing
 // +build !e2e_testing
 
 
-package nebula
+package udp
 
 
 import (
 import (
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
@@ -33,13 +33,13 @@ type rawMessage struct {
 	Pad0 [4]byte
 	Pad0 [4]byte
 }
 }
 
 
-func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
+func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
 	msgs := make([]rawMessage, n)
 	msgs := make([]rawMessage, n)
 	buffers := make([][]byte, n)
 	buffers := make([][]byte, n)
 	names := make([][]byte, n)
 	names := make([][]byte, n)
 
 
 	for i := range msgs {
 	for i := range msgs {
-		buffers[i] = make([]byte, mtu)
+		buffers[i] = make([]byte, MTU)
 		names[i] = make([]byte, unix.SizeofSockaddrInet6)
 		names[i] = make([]byte, unix.SizeofSockaddrInet6)
 
 
 		//TODO: this is still silly, no need for an array
 		//TODO: this is still silly, no need for an array

+ 39 - 43
udp_tester.go → udp/udp_tester.go

@@ -1,16 +1,19 @@
 //go:build e2e_testing
 //go:build e2e_testing
 // +build e2e_testing
 // +build e2e_testing
 
 
-package nebula
+package udp
 
 
 import (
 import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
 )
 )
 
 
-type UdpPacket struct {
+type Packet struct {
 	ToIp     net.IP
 	ToIp     net.IP
 	ToPort   uint16
 	ToPort   uint16
 	FromIp   net.IP
 	FromIp   net.IP
@@ -18,8 +21,8 @@ type UdpPacket struct {
 	Data     []byte
 	Data     []byte
 }
 }
 
 
-func (u *UdpPacket) Copy() *UdpPacket {
-	n := &UdpPacket{
+func (u *Packet) Copy() *Packet {
+	n := &Packet{
 		ToIp:     make(net.IP, len(u.ToIp)),
 		ToIp:     make(net.IP, len(u.ToIp)),
 		ToPort:   u.ToPort,
 		ToPort:   u.ToPort,
 		FromIp:   make(net.IP, len(u.FromIp)),
 		FromIp:   make(net.IP, len(u.FromIp)),
@@ -33,20 +36,20 @@ func (u *UdpPacket) Copy() *UdpPacket {
 	return n
 	return n
 }
 }
 
 
-type udpConn struct {
-	addr *udpAddr
+type Conn struct {
+	Addr *Addr
 
 
-	rxPackets chan *UdpPacket // Packets to receive into nebula
-	txPackets chan *UdpPacket // Packets transmitted outside by nebula
+	RxPackets chan *Packet // Packets to receive into nebula
+	TxPackets chan *Packet // Packets transmitted outside by nebula
 
 
 	l *logrus.Logger
 	l *logrus.Logger
 }
 }
 
 
-func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error) {
-	return &udpConn{
-		addr:      &udpAddr{net.ParseIP(ip), uint16(port)},
-		rxPackets: make(chan *UdpPacket, 1),
-		txPackets: make(chan *UdpPacket, 1),
+func NewListener(l *logrus.Logger, ip string, port int, _ bool, _ int) (*Conn, error) {
+	return &Conn{
+		Addr:      &Addr{net.ParseIP(ip), uint16(port)},
+		RxPackets: make(chan *Packet, 1),
+		TxPackets: make(chan *Packet, 1),
 		l:         l,
 		l:         l,
 	}, nil
 	}, nil
 }
 }
@@ -54,8 +57,8 @@ func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error
 // Send will place a UdpPacket onto the receive queue for nebula to consume
 // Send will place a UdpPacket onto the receive queue for nebula to consume
 // this is an encrypted packet or a handshake message in most cases
 // this is an encrypted packet or a handshake message in most cases
 // packets were transmitted from another nebula node, you can send them with Tun.Send
 // packets were transmitted from another nebula node, you can send them with Tun.Send
-func (u *udpConn) Send(packet *UdpPacket) {
-	h := &Header{}
+func (u *Conn) Send(packet *Packet) {
+	h := &header.H{}
 	if err := h.Parse(packet.Data); err != nil {
 	if err := h.Parse(packet.Data); err != nil {
 		panic(err)
 		panic(err)
 	}
 	}
@@ -63,19 +66,19 @@ func (u *udpConn) Send(packet *UdpPacket) {
 		WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
 		WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
 		WithField("dataLen", len(packet.Data)).
 		WithField("dataLen", len(packet.Data)).
 		Info("UDP receiving injected packet")
 		Info("UDP receiving injected packet")
-	u.rxPackets <- packet
+	u.RxPackets <- packet
 }
 }
 
 
 // Get will pull a UdpPacket from the transmit queue
 // Get will pull a UdpPacket from the transmit queue
 // nebula meant to send this message on the network, it will be encrypted
 // nebula meant to send this message on the network, it will be encrypted
 // packets were ingested from the tun side (in most cases), you can send them with Tun.Send
 // packets were ingested from the tun side (in most cases), you can send them with Tun.Send
-func (u *udpConn) Get(block bool) *UdpPacket {
+func (u *Conn) Get(block bool) *Packet {
 	if block {
 	if block {
-		return <-u.txPackets
+		return <-u.TxPackets
 	}
 	}
 
 
 	select {
 	select {
-	case p := <-u.txPackets:
+	case p := <-u.TxPackets:
 		return p
 		return p
 	default:
 	default:
 		return nil
 		return nil
@@ -86,56 +89,49 @@ func (u *udpConn) Get(block bool) *UdpPacket {
 // Below this is boilerplate implementation to make nebula actually work
 // Below this is boilerplate implementation to make nebula actually work
 //********************************************************************************************************************//
 //********************************************************************************************************************//
 
 
-func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
-	p := &UdpPacket{
+func (u *Conn) WriteTo(b []byte, addr *Addr) error {
+	p := &Packet{
 		Data:     make([]byte, len(b), len(b)),
 		Data:     make([]byte, len(b), len(b)),
 		FromIp:   make([]byte, 16),
 		FromIp:   make([]byte, 16),
-		FromPort: u.addr.Port,
+		FromPort: u.Addr.Port,
 		ToIp:     make([]byte, 16),
 		ToIp:     make([]byte, 16),
 		ToPort:   addr.Port,
 		ToPort:   addr.Port,
 	}
 	}
 
 
 	copy(p.Data, b)
 	copy(p.Data, b)
 	copy(p.ToIp, addr.IP.To16())
 	copy(p.ToIp, addr.IP.To16())
-	copy(p.FromIp, u.addr.IP.To16())
+	copy(p.FromIp, u.Addr.IP.To16())
 
 
-	u.txPackets <- p
+	u.TxPackets <- p
 	return nil
 	return nil
 }
 }
 
 
-func (u *udpConn) ListenOut(f *Interface, q int) {
-	plaintext := make([]byte, mtu)
-	header := &Header{}
-	fwPacket := &FirewallPacket{}
-	ua := &udpAddr{IP: make([]byte, 16)}
+func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
+	plaintext := make([]byte, MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	ua := &Addr{IP: make([]byte, 16)}
 	nb := make([]byte, 12, 12)
 	nb := make([]byte, 12, 12)
 
 
-	lhh := f.lightHouse.NewRequestHandler()
-	conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
-
 	for {
 	for {
-		p := <-u.rxPackets
+		p := <-u.RxPackets
 		ua.Port = p.FromPort
 		ua.Port = p.FromPort
 		copy(ua.IP, p.FromIp.To16())
 		copy(ua.IP, p.FromIp.To16())
-		f.readOutsidePackets(ua, plaintext[:0], p.Data, header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
+		r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
 	}
 	}
 }
 }
 
 
-func (u *udpConn) reloadConfig(*Config) {}
+func (u *Conn) ReloadConfig(*config.C) {}
 
 
-func NewUDPStatsEmitter(_ []*udpConn) func() {
+func NewUDPStatsEmitter(_ []*Conn) func() {
 	// No UDP stats for non-linux
 	// No UDP stats for non-linux
 	return func() {}
 	return func() {}
 }
 }
 
 
-func (u *udpConn) LocalAddr() (*udpAddr, error) {
-	return u.addr, nil
+func (u *Conn) LocalAddr() (*Addr, error) {
+	return u.Addr, nil
 }
 }
 
 
-func (u *udpConn) Rebind() error {
+func (u *Conn) Rebind() error {
 	return nil
 	return nil
 }
 }
-
-func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
-	return !addr.Equals(newaddr)
-}

+ 2 - 2
udp_windows.go → udp/udp_windows.go

@@ -1,7 +1,7 @@
 //go:build !e2e_testing
 //go:build !e2e_testing
 // +build !e2e_testing
 // +build !e2e_testing
 
 
-package nebula
+package udp
 
 
 // Windows support is primarily implemented in udp_generic, besides NewListenConfig
 // Windows support is primarily implemented in udp_generic, besides NewListenConfig
 
 
@@ -24,6 +24,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 	}
 }
 }
 
 
-func (u *udpConn) Rebind() error {
+func (u *Conn) Rebind() error {
 	return nil
 	return nil
 }
 }

+ 3 - 4
main_test.go → util/main.go

@@ -1,4 +1,4 @@
-package nebula
+package util
 
 
 import (
 import (
 	"io/ioutil"
 	"io/ioutil"
@@ -17,13 +17,12 @@ func NewTestLogger() *logrus.Logger {
 	}
 	}
 
 
 	switch v {
 	switch v {
-	case "1":
-		// This is the default level but we are being explicit
-		l.SetLevel(logrus.InfoLevel)
 	case "2":
 	case "2":
 		l.SetLevel(logrus.DebugLevel)
 		l.SetLevel(logrus.DebugLevel)
 	case "3":
 	case "3":
 		l.SetLevel(logrus.TraceLevel)
 		l.SetLevel(logrus.TraceLevel)
+	default:
+		l.SetLevel(logrus.InfoLevel)
 	}
 	}
 
 
 	return l
 	return l