Przeglądaj źródła

Rework some things into packages (#489)

Nate Brown 4 lat temu
rodzic
commit
bcabcfdaca
73 zmienionych plików z 2527 dodań i 2375 usunięć
  1. 228 7
      allow_list.go
  2. 98 9
      allow_list_test.go
  3. 5 4
      bits_test.go
  4. 3 2
      cert.go
  5. 10 0
      cidr/parse.go
  6. 20 44
      cidr/tree4.go
  7. 153 0
      cidr/tree4_test.go
  8. 20 19
      cidr/tree6.go
  9. 25 21
      cidr/tree6_test.go
  10. 0 157
      cidr_radix_test.go
  11. 6 5
      cmd/nebula-service/main.go
  12. 4 3
      cmd/nebula-service/service.go
  13. 6 5
      cmd/nebula/main.go
  14. 0 611
      config.go
  15. 358 0
      config/config.go
  16. 18 103
      config/config_test.go
  17. 51 49
      connection_manager.go
  18. 39 36
      connection_manager_test.go
  19. 18 15
      control.go
  20. 16 14
      control_test.go
  21. 15 12
      control_tester.go
  22. 9 7
      dns_server.go
  23. 8 5
      e2e/handshakes_test.go
  24. 11 23
      e2e/helpers_test.go
  25. 12 10
      e2e/router/router.go
  26. 45 147
      firewall.go
  27. 59 0
      firewall/cache.go
  28. 62 0
      firewall/packet.go
  29. 101 109
      firewall_test.go
  30. 5 5
      handshake.go
  31. 59 56
      handshake_ix.go
  32. 36 33
      handshake_manager.go
  33. 16 12
      handshake_manager_test.go
  34. 62 66
      header/header.go
  35. 115 0
      header/header_test.go
  36. 0 119
      header_test.go
  37. 64 93
      hostmap.go
  38. 28 23
      inside.go
  39. 26 21
      interface.go
  40. 66 0
      iputil/util.go
  41. 17 0
      iputil/util_test.go
  42. 82 81
      lighthouse.go
  43. 72 68
      lighthouse_test.go
  44. 39 0
      logger.go
  45. 71 68
      main.go
  46. 5 2
      message_metrics.go
  47. 61 57
      outside.go
  48. 9 7
      outside_test.go
  49. 6 2
      punchy.go
  50. 4 2
      punchy_test.go
  51. 31 28
      remote_list.go
  52. 33 32
      remote_list_test.go
  53. 30 26
      ssh.go
  54. 4 3
      stats.go
  55. 7 5
      timeout.go
  56. 4 2
      timeout_system.go
  57. 4 3
      timeout_system_test.go
  58. 4 3
      timeout_test.go
  59. 7 5
      tun_common.go
  60. 6 4
      tun_test.go
  61. 20 0
      udp/conn.go
  62. 14 0
      udp/temp.go
  63. 15 13
      udp/udp_all.go
  64. 2 2
      udp/udp_android.go
  65. 2 2
      udp/udp_darwin.go
  66. 2 2
      udp/udp_freebsd.go
  67. 20 25
      udp/udp_generic.go
  68. 29 33
      udp/udp_linux.go
  69. 3 3
      udp/udp_linux_32.go
  70. 3 3
      udp/udp_linux_64.go
  71. 39 43
      udp/udp_tester.go
  72. 2 2
      udp/udp_windows.go
  73. 3 4
      util/main.go

+ 228 - 7
allow_list.go

@@ -4,11 +4,15 @@ import (
 	"fmt"
 	"net"
 	"regexp"
+
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/iputil"
 )
 
 type AllowList struct {
 	// The values of this cidrTree are `bool`, signifying allow/deny
-	cidrTree *CIDR6Tree
+	cidrTree *cidr.Tree6
 }
 
 type RemoteAllowList struct {
@@ -16,7 +20,7 @@ type RemoteAllowList struct {
 
 	// Inside Range Specific, keys of this tree are inside CIDRs and values
 	// are *AllowList
-	insideAllowLists *CIDR6Tree
+	insideAllowLists *cidr.Tree6
 }
 
 type LocalAllowList struct {
@@ -31,6 +35,223 @@ type AllowListNameRule struct {
 	Allow bool
 }
 
+func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
+	var nameRules []AllowListNameRule
+	handleKey := func(key string, value interface{}) (bool, error) {
+		if key == "interfaces" {
+			var err error
+			nameRules, err = getAllowListInterfaces(k, value)
+			if err != nil {
+				return false, err
+			}
+
+			return true, nil
+		}
+		return false, nil
+	}
+
+	al, err := newAllowListFromConfig(c, k, handleKey)
+	if err != nil {
+		return nil, err
+	}
+	return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil
+}
+
+func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllowList, error) {
+	al, err := newAllowListFromConfig(c, k, nil)
+	if err != nil {
+		return nil, err
+	}
+	remoteAllowRanges, err := getRemoteAllowRanges(c, rangesKey)
+	if err != nil {
+		return nil, err
+	}
+	return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil
+}
+
+// If the handleKey func returns true, the rest of the parsing is skipped
+// for this key. This allows parsing of special values like `interfaces`.
+func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
+	r := c.Get(k)
+	if r == nil {
+		return nil, nil
+	}
+
+	return newAllowList(k, r, handleKey)
+}
+
+// If the handleKey func returns true, the rest of the parsing is skipped
+// for this key. This allows parsing of special values like `interfaces`.
+func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
+	rawMap, ok := raw.(map[interface{}]interface{})
+	if !ok {
+		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
+	}
+
+	tree := cidr.NewTree6()
+
+	// Keep track of the rules we have added for both ipv4 and ipv6
+	type allowListRules struct {
+		firstValue     bool
+		allValuesMatch bool
+		defaultSet     bool
+		allValues      bool
+	}
+
+	rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
+	rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
+
+	for rawKey, rawValue := range rawMap {
+		rawCIDR, ok := rawKey.(string)
+		if !ok {
+			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
+		}
+
+		if handleKey != nil {
+			handled, err := handleKey(rawCIDR, rawValue)
+			if err != nil {
+				return nil, err
+			}
+			if handled {
+				continue
+			}
+		}
+
+		value, ok := rawValue.(bool)
+		if !ok {
+			return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
+		}
+
+		_, ipNet, err := net.ParseCIDR(rawCIDR)
+		if err != nil {
+			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+		}
+
+		// TODO: should we error on duplicate CIDRs in the config?
+		tree.AddCIDR(ipNet, value)
+
+		maskBits, maskSize := ipNet.Mask.Size()
+
+		var rules *allowListRules
+		if maskSize == 32 {
+			rules = &rules4
+		} else {
+			rules = &rules6
+		}
+
+		if rules.firstValue {
+			rules.allValues = value
+			rules.firstValue = false
+		} else {
+			if value != rules.allValues {
+				rules.allValuesMatch = false
+			}
+		}
+
+		// Check if this is 0.0.0.0/0 or ::/0
+		if maskBits == 0 {
+			rules.defaultSet = true
+		}
+	}
+
+	if !rules4.defaultSet {
+		if rules4.allValuesMatch {
+			_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
+			tree.AddCIDR(zeroCIDR, !rules4.allValues)
+		} else {
+			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
+		}
+	}
+
+	if !rules6.defaultSet {
+		if rules6.allValuesMatch {
+			_, zeroCIDR, _ := net.ParseCIDR("::/0")
+			tree.AddCIDR(zeroCIDR, !rules6.allValues)
+		} else {
+			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
+		}
+	}
+
+	return &AllowList{cidrTree: tree}, nil
+}
+
+func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
+	var nameRules []AllowListNameRule
+
+	rawRules, ok := v.(map[interface{}]interface{})
+	if !ok {
+		return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
+	}
+
+	firstEntry := true
+	var allValues bool
+	for rawName, rawAllow := range rawRules {
+		name, ok := rawName.(string)
+		if !ok {
+			return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
+		}
+		allow, ok := rawAllow.(bool)
+		if !ok {
+			return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
+		}
+
+		nameRE, err := regexp.Compile("^" + name + "$")
+		if err != nil {
+			return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
+		}
+
+		nameRules = append(nameRules, AllowListNameRule{
+			Name:  nameRE,
+			Allow: allow,
+		})
+
+		if firstEntry {
+			allValues = allow
+			firstEntry = false
+		} else {
+			if allow != allValues {
+				return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
+			}
+		}
+	}
+
+	return nameRules, nil
+}
+
+func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) {
+	value := c.Get(k)
+	if value == nil {
+		return nil, nil
+	}
+
+	remoteAllowRanges := cidr.NewTree6()
+
+	rawMap, ok := value.(map[interface{}]interface{})
+	if !ok {
+		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
+	}
+	for rawKey, rawValue := range rawMap {
+		rawCIDR, ok := rawKey.(string)
+		if !ok {
+			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
+		}
+
+		allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
+		if err != nil {
+			return nil, err
+		}
+
+		_, ipNet, err := net.ParseCIDR(rawCIDR)
+		if err != nil {
+			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+		}
+
+		remoteAllowRanges.AddCIDR(ipNet, allowList)
+	}
+
+	return remoteAllowRanges, nil
+}
+
 func (al *AllowList) Allow(ip net.IP) bool {
 	if al == nil {
 		return true
@@ -45,7 +266,7 @@ func (al *AllowList) Allow(ip net.IP) bool {
 	}
 }
 
-func (al *AllowList) AllowIpV4(ip uint32) bool {
+func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
 	if al == nil {
 		return true
 	}
@@ -102,14 +323,14 @@ func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool {
 	return al.AllowList.Allow(ip)
 }
 
-func (al *RemoteAllowList) Allow(vpnIp uint32, ip net.IP) bool {
+func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool {
 	if !al.getInsideAllowList(vpnIp).Allow(ip) {
 		return false
 	}
 	return al.AllowList.Allow(ip)
 }
 
-func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool {
+func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool {
 	if al == nil {
 		return true
 	}
@@ -119,7 +340,7 @@ func (al *RemoteAllowList) AllowIpV4(vpnIp uint32, ip uint32) bool {
 	return al.AllowList.AllowIpV4(ip)
 }
 
-func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool {
+func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
 	if al == nil {
 		return true
 	}
@@ -129,7 +350,7 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp uint32, hi, lo uint64) bool {
 	return al.AllowList.AllowIpV6(hi, lo)
 }
 
-func (al *RemoteAllowList) getInsideAllowList(vpnIp uint32) *AllowList {
+func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
 	if al.insideAllowLists != nil {
 		inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
 		if inside != nil {

+ 98 - 9
allow_list_test.go

@@ -5,21 +5,110 @@ import (
 	"regexp"
 	"testing"
 
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
+func TestNewAllowListFromConfig(t *testing.T) {
+	l := util.NewTestLogger()
+	c := config.NewC(l)
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"192.168.0.0": true,
+	}
+	r, err := newAllowListFromConfig(c, "allowlist", nil)
+	assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
+	assert.Nil(t, r)
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"192.168.0.0/16": "abc",
+	}
+	r, err = newAllowListFromConfig(c, "allowlist", nil)
+	assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"192.168.0.0/16": true,
+		"10.0.0.0/8":     false,
+	}
+	r, err = newAllowListFromConfig(c, "allowlist", nil)
+	assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"0.0.0.0/0":      true,
+		"10.0.0.0/8":     false,
+		"10.42.42.0/24":  true,
+		"fd00::/8":       true,
+		"fd00:fd00::/16": false,
+	}
+	r, err = newAllowListFromConfig(c, "allowlist", nil)
+	assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"0.0.0.0/0":     true,
+		"10.0.0.0/8":    false,
+		"10.42.42.0/24": true,
+	}
+	r, err = newAllowListFromConfig(c, "allowlist", nil)
+	if assert.NoError(t, err) {
+		assert.NotNil(t, r)
+	}
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"0.0.0.0/0":      true,
+		"10.0.0.0/8":     false,
+		"10.42.42.0/24":  true,
+		"::/0":           false,
+		"fd00::/8":       true,
+		"fd00:fd00::/16": false,
+	}
+	r, err = newAllowListFromConfig(c, "allowlist", nil)
+	if assert.NoError(t, err) {
+		assert.NotNil(t, r)
+	}
+
+	// Test interface names
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"interfaces": map[interface{}]interface{}{
+			`docker.*`: "foo",
+		},
+	}
+	lr, err := NewLocalAllowListFromConfig(c, "allowlist")
+	assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"interfaces": map[interface{}]interface{}{
+			`docker.*`: false,
+			`eth.*`:    true,
+		},
+	}
+	lr, err = NewLocalAllowListFromConfig(c, "allowlist")
+	assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
+
+	c.Settings["allowlist"] = map[interface{}]interface{}{
+		"interfaces": map[interface{}]interface{}{
+			`docker.*`: false,
+		},
+	}
+	lr, err = NewLocalAllowListFromConfig(c, "allowlist")
+	if assert.NoError(t, err) {
+		assert.NotNil(t, lr)
+	}
+}
+
 func TestAllowList_Allow(t *testing.T) {
 	assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
 
-	tree := NewCIDR6Tree()
-	tree.AddCIDR(getCIDR("0.0.0.0/0"), true)
-	tree.AddCIDR(getCIDR("10.0.0.0/8"), false)
-	tree.AddCIDR(getCIDR("10.42.42.42/32"), true)
-	tree.AddCIDR(getCIDR("10.42.0.0/16"), true)
-	tree.AddCIDR(getCIDR("10.42.42.0/24"), true)
-	tree.AddCIDR(getCIDR("10.42.42.0/24"), false)
-	tree.AddCIDR(getCIDR("::1/128"), true)
-	tree.AddCIDR(getCIDR("::2/128"), false)
+	tree := cidr.NewTree6()
+	tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
+	tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
+	tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
+	tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true)
+	tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true)
+	tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false)
+	tree.AddCIDR(cidr.Parse("::1/128"), true)
+	tree.AddCIDR(cidr.Parse("::2/128"), false)
 	al := &AllowList{cidrTree: tree}
 
 	assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))

+ 5 - 4
bits_test.go

@@ -3,11 +3,12 @@ package nebula
 import (
 	"testing"
 
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestBits(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	b := NewBits(10)
 
 	// make sure it is the right size
@@ -75,7 +76,7 @@ func TestBits(t *testing.T) {
 }
 
 func TestBitsDupeCounter(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	b := NewBits(10)
 	b.lostCounter.Clear()
 	b.dupeCounter.Clear()
@@ -100,7 +101,7 @@ func TestBitsDupeCounter(t *testing.T) {
 }
 
 func TestBitsOutOfWindowCounter(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	b := NewBits(10)
 	b.lostCounter.Clear()
 	b.dupeCounter.Clear()
@@ -130,7 +131,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
 }
 
 func TestBitsLostCounter(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	b := NewBits(10)
 	b.lostCounter.Clear()
 	b.dupeCounter.Clear()

+ 3 - 2
cert.go

@@ -9,6 +9,7 @@ import (
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
 )
 
 type CertState struct {
@@ -45,7 +46,7 @@ func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*Cert
 	return cs, nil
 }
 
-func NewCertStateFromConfig(c *Config) (*CertState, error) {
+func NewCertStateFromConfig(c *config.C) (*CertState, error) {
 	var pemPrivateKey []byte
 	var err error
 
@@ -118,7 +119,7 @@ func NewCertStateFromConfig(c *Config) (*CertState, error) {
 	return NewCertState(nebulaCert, rawKey)
 }
 
-func loadCAFromConfig(l *logrus.Logger, c *Config) (*cert.NebulaCAPool, error) {
+func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
 	var rawCA []byte
 	var err error
 

+ 10 - 0
cidr/parse.go

@@ -0,0 +1,10 @@
+package cidr
+
+import "net"
+
+// Parse is a convenience function that returns only the IPNet
+// This function ignores errors since it is primarily a test helper, the result could be nil
+func Parse(s string) *net.IPNet {
+	_, c, _ := net.ParseCIDR(s)
+	return c
+}

+ 20 - 44
cidr_radix.go → cidr/tree4.go

@@ -1,39 +1,39 @@
-package nebula
+package cidr
 
 import (
-	"encoding/binary"
-	"fmt"
 	"net"
+
+	"github.com/slackhq/nebula/iputil"
 )
 
-type CIDRNode struct {
-	left   *CIDRNode
-	right  *CIDRNode
-	parent *CIDRNode
+type Node struct {
+	left   *Node
+	right  *Node
+	parent *Node
 	value  interface{}
 }
 
-type CIDRTree struct {
-	root *CIDRNode
+type Tree4 struct {
+	root *Node
 }
 
 const (
-	startbit = uint32(0x80000000)
+	startbit = iputil.VpnIp(0x80000000)
 )
 
-func NewCIDRTree() *CIDRTree {
-	tree := new(CIDRTree)
-	tree.root = &CIDRNode{}
+func NewTree4() *Tree4 {
+	tree := new(Tree4)
+	tree.root = &Node{}
 	return tree
 }
 
-func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
+func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
 	bit := startbit
 	node := tree.root
 	next := tree.root
 
-	ip := ip2int(cidr.IP)
-	mask := ip2int(cidr.Mask)
+	ip := iputil.Ip2VpnIp(cidr.IP)
+	mask := iputil.Ip2VpnIp(cidr.Mask)
 
 	// Find our last ancestor in the tree
 	for bit&mask != 0 {
@@ -59,7 +59,7 @@ func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 	// Build up the rest of the tree we don't already have
 	for bit&mask != 0 {
-		next = &CIDRNode{}
+		next = &Node{}
 		next.parent = node
 
 		if ip&bit != 0 {
@@ -77,7 +77,7 @@ func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
 }
 
 // Finds the first match, which may be the least specific
-func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
+func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	node := tree.root
 
@@ -100,7 +100,7 @@ func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
 }
 
 // Finds the most specific match
-func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) {
+func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	node := tree.root
 
@@ -122,7 +122,7 @@ func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) {
 }
 
 // Finds the most specific match
-func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
+func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	node := tree.root
 	lastNode := node
@@ -143,27 +143,3 @@ func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
 	}
 	return value
 }
-
-// A helper type to avoid converting to IP when logging
-type IntIp uint32
-
-func (ip IntIp) String() string {
-	return fmt.Sprintf("%v", int2ip(uint32(ip)))
-}
-
-func (ip IntIp) MarshalJSON() ([]byte, error) {
-	return []byte(fmt.Sprintf("\"%s\"", int2ip(uint32(ip)).String())), nil
-}
-
-func ip2int(ip []byte) uint32 {
-	if len(ip) == 16 {
-		return binary.BigEndian.Uint32(ip[12:16])
-	}
-	return binary.BigEndian.Uint32(ip)
-}
-
-func int2ip(nn uint32) net.IP {
-	ip := make(net.IP, 4)
-	binary.BigEndian.PutUint32(ip, nn)
-	return ip
-}

+ 153 - 0
cidr/tree4_test.go

@@ -0,0 +1,153 @@
+package cidr
+
+import (
+	"net"
+	"testing"
+
+	"github.com/slackhq/nebula/iputil"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestCIDRTree_Contains(t *testing.T) {
+	tree := NewTree4()
+	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
+	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
+	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
+	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
+	tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
+	tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
+	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
+
+	tests := []struct {
+		Result interface{}
+		IP     string
+	}{
+		{"1", "1.0.0.0"},
+		{"1", "1.255.255.255"},
+		{"2", "2.1.0.0"},
+		{"2", "2.1.255.255"},
+		{"3", "3.1.1.0"},
+		{"3", "3.1.1.255"},
+		{"4a", "4.1.1.255"},
+		{"4a", "4.1.1.1"},
+		{"5", "240.0.0.0"},
+		{"5", "255.255.255.255"},
+		{nil, "239.0.0.0"},
+		{nil, "4.1.2.2"},
+	}
+
+	for _, tt := range tests {
+		assert.Equal(t, tt.Result, tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+	}
+
+	tree = NewTree4()
+	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
+	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
+	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+}
+
+func TestCIDRTree_MostSpecificContains(t *testing.T) {
+	tree := NewTree4()
+	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
+	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
+	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
+	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
+	tree.AddCIDR(Parse("4.1.1.0/30"), "4b")
+	tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
+	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
+
+	tests := []struct {
+		Result interface{}
+		IP     string
+	}{
+		{"1", "1.0.0.0"},
+		{"1", "1.255.255.255"},
+		{"2", "2.1.0.0"},
+		{"2", "2.1.255.255"},
+		{"3", "3.1.1.0"},
+		{"3", "3.1.1.255"},
+		{"4a", "4.1.1.255"},
+		{"4b", "4.1.1.2"},
+		{"4c", "4.1.1.1"},
+		{"5", "240.0.0.0"},
+		{"5", "255.255.255.255"},
+		{nil, "239.0.0.0"},
+		{nil, "4.1.2.2"},
+	}
+
+	for _, tt := range tests {
+		assert.Equal(t, tt.Result, tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+	}
+
+	tree = NewTree4()
+	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
+	assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
+	assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+}
+
+func TestCIDRTree_Match(t *testing.T) {
+	tree := NewTree4()
+	tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
+	tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
+
+	tests := []struct {
+		Result interface{}
+		IP     string
+	}{
+		{"1a", "4.1.1.0"},
+		{"1b", "4.1.1.1"},
+	}
+
+	for _, tt := range tests {
+		assert.Equal(t, tt.Result, tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+	}
+
+	tree = NewTree4()
+	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
+	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
+	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+}
+
+func BenchmarkCIDRTree_Contains(b *testing.B) {
+	tree := NewTree4()
+	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
+	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
+	tree.AddCIDR(Parse("192.2.1.1/32"), "1")
+	tree.AddCIDR(Parse("172.2.1.1/32"), "1")
+
+	ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
+	b.Run("found", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			tree.Contains(ip)
+		}
+	})
+
+	ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
+	b.Run("not found", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			tree.Contains(ip)
+		}
+	})
+}
+
+func BenchmarkCIDRTree_Match(b *testing.B) {
+	tree := NewTree4()
+	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
+	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
+	tree.AddCIDR(Parse("192.2.1.1/32"), "1")
+	tree.AddCIDR(Parse("172.2.1.1/32"), "1")
+
+	ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
+	b.Run("found", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			tree.Match(ip)
+		}
+	})
+
+	ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
+	b.Run("not found", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			tree.Match(ip)
+		}
+	})
+}

+ 20 - 19
cidr6_radix.go → cidr/tree6.go

@@ -1,26 +1,27 @@
-package nebula
+package cidr
 
 import (
-	"encoding/binary"
 	"net"
+
+	"github.com/slackhq/nebula/iputil"
 )
 
 const startbit6 = uint64(1 << 63)
 
-type CIDR6Tree struct {
-	root4 *CIDRNode
-	root6 *CIDRNode
+type Tree6 struct {
+	root4 *Node
+	root6 *Node
 }
 
-func NewCIDR6Tree() *CIDR6Tree {
-	tree := new(CIDR6Tree)
-	tree.root4 = &CIDRNode{}
-	tree.root6 = &CIDRNode{}
+func NewTree6() *Tree6 {
+	tree := new(Tree6)
+	tree.root4 = &Node{}
+	tree.root6 = &Node{}
 	return tree
 }
 
-func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
-	var node, next *CIDRNode
+func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
+	var node, next *Node
 
 	cidrIP, ipv4 := isIPV4(cidr.IP)
 	if ipv4 {
@@ -33,8 +34,8 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
 	}
 
 	for i := 0; i < len(cidrIP); i += 4 {
-		ip := binary.BigEndian.Uint32(cidrIP[i : i+4])
-		mask := binary.BigEndian.Uint32(cidr.Mask[i : i+4])
+		ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
+		mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
 		bit := startbit
 
 		// Find our last ancestor in the tree
@@ -55,7 +56,7 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 		// Build up the rest of the tree we don't already have
 		for bit&mask != 0 {
-			next = &CIDRNode{}
+			next = &Node{}
 			next.parent = node
 
 			if ip&bit != 0 {
@@ -74,8 +75,8 @@ func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
 }
 
 // Finds the most specific match
-func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
-	var node *CIDRNode
+func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
+	var node *Node
 
 	wholeIP, ipv4 := isIPV4(ip)
 	if ipv4 {
@@ -85,7 +86,7 @@ func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
 	}
 
 	for i := 0; i < len(wholeIP); i += 4 {
-		ip := ip2int(wholeIP[i : i+4])
+		ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
 		bit := startbit
 
 		for node != nil {
@@ -110,7 +111,7 @@ func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
 	return value
 }
 
-func (tree *CIDR6Tree) MostSpecificContainsIpV4(ip uint32) (value interface{}) {
+func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) {
 	bit := startbit
 	node := tree.root4
 
@@ -131,7 +132,7 @@ func (tree *CIDR6Tree) MostSpecificContainsIpV4(ip uint32) (value interface{}) {
 	return value
 }
 
-func (tree *CIDR6Tree) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
+func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
 	ip := hi
 	node := tree.root6
 

+ 25 - 21
cidr6_radix_test.go → cidr/tree6_test.go

@@ -1,6 +1,7 @@
-package nebula
+package cidr
 
 import (
+	"encoding/binary"
 	"net"
 	"testing"
 
@@ -8,17 +9,17 @@ import (
 )
 
 func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
-	tree := NewCIDR6Tree()
-	tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
-	tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
-	tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
-	tree.AddCIDR(getCIDR("4.1.1.1/24"), "4a")
-	tree.AddCIDR(getCIDR("4.1.1.1/30"), "4b")
-	tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c")
-	tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
-	tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a")
-	tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b")
-	tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c")
+	tree := NewTree6()
+	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
+	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
+	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
+	tree.AddCIDR(Parse("4.1.1.1/24"), "4a")
+	tree.AddCIDR(Parse("4.1.1.1/30"), "4b")
+	tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
+	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
+	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
+	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
+	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
 
 	tests := []struct {
 		Result interface{}
@@ -46,9 +47,9 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
 		assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP)))
 	}
 
-	tree = NewCIDR6Tree()
-	tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
-	tree.AddCIDR(getCIDR("::/0"), "cool6")
+	tree = NewTree6()
+	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
+	tree.AddCIDR(Parse("::/0"), "cool6")
 	assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0")))
 	assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255")))
 	assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::")))
@@ -56,10 +57,10 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
 }
 
 func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
-	tree := NewCIDR6Tree()
-	tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/64"), "6a")
-	tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/80"), "6b")
-	tree.AddCIDR(getCIDR("1:2:0:4:5:0:0:0/96"), "6c")
+	tree := NewTree6()
+	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
+	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
+	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
 
 	tests := []struct {
 		Result interface{}
@@ -71,7 +72,10 @@ func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
 	}
 
 	for _, tt := range tests {
-		ip := NewIp6AndPort(net.ParseIP(tt.IP), 0)
-		assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(ip.Hi, ip.Lo))
+		ip := net.ParseIP(tt.IP)
+		hi := binary.BigEndian.Uint64(ip[:8])
+		lo := binary.BigEndian.Uint64(ip[8:])
+
+		assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(hi, lo))
 	}
 }

+ 0 - 157
cidr_radix_test.go

@@ -1,157 +0,0 @@
-package nebula
-
-import (
-	"net"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-func TestCIDRTree_Contains(t *testing.T) {
-	tree := NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
-	tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
-	tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
-	tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a")
-	tree.AddCIDR(getCIDR("4.1.1.1/32"), "4b")
-	tree.AddCIDR(getCIDR("4.1.2.1/32"), "4c")
-	tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
-
-	tests := []struct {
-		Result interface{}
-		IP     string
-	}{
-		{"1", "1.0.0.0"},
-		{"1", "1.255.255.255"},
-		{"2", "2.1.0.0"},
-		{"2", "2.1.255.255"},
-		{"3", "3.1.1.0"},
-		{"3", "3.1.1.255"},
-		{"4a", "4.1.1.255"},
-		{"4a", "4.1.1.1"},
-		{"5", "240.0.0.0"},
-		{"5", "255.255.255.255"},
-		{nil, "239.0.0.0"},
-		{nil, "4.1.2.2"},
-	}
-
-	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.Contains(ip2int(net.ParseIP(tt.IP))))
-	}
-
-	tree = NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
-}
-
-func TestCIDRTree_MostSpecificContains(t *testing.T) {
-	tree := NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
-	tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
-	tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
-	tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a")
-	tree.AddCIDR(getCIDR("4.1.1.0/30"), "4b")
-	tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c")
-	tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
-
-	tests := []struct {
-		Result interface{}
-		IP     string
-	}{
-		{"1", "1.0.0.0"},
-		{"1", "1.255.255.255"},
-		{"2", "2.1.0.0"},
-		{"2", "2.1.255.255"},
-		{"3", "3.1.1.0"},
-		{"3", "3.1.1.255"},
-		{"4a", "4.1.1.255"},
-		{"4b", "4.1.1.2"},
-		{"4c", "4.1.1.1"},
-		{"5", "240.0.0.0"},
-		{"5", "255.255.255.255"},
-		{nil, "239.0.0.0"},
-		{nil, "4.1.2.2"},
-	}
-
-	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.MostSpecificContains(ip2int(net.ParseIP(tt.IP))))
-	}
-
-	tree = NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("255.255.255.255"))))
-}
-
-func TestCIDRTree_Match(t *testing.T) {
-	tree := NewCIDRTree()
-	tree.AddCIDR(getCIDR("4.1.1.0/32"), "1a")
-	tree.AddCIDR(getCIDR("4.1.1.1/32"), "1b")
-
-	tests := []struct {
-		Result interface{}
-		IP     string
-	}{
-		{"1a", "4.1.1.0"},
-		{"1b", "4.1.1.1"},
-	}
-
-	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.Match(ip2int(net.ParseIP(tt.IP))))
-	}
-
-	tree = NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
-}
-
-func BenchmarkCIDRTree_Contains(b *testing.B) {
-	tree := NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.1.0.0/16"), "1")
-	tree.AddCIDR(getCIDR("1.2.1.1/32"), "1")
-	tree.AddCIDR(getCIDR("192.2.1.1/32"), "1")
-	tree.AddCIDR(getCIDR("172.2.1.1/32"), "1")
-
-	ip := ip2int(net.ParseIP("1.2.1.1"))
-	b.Run("found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Contains(ip)
-		}
-	})
-
-	ip = ip2int(net.ParseIP("1.2.1.255"))
-	b.Run("not found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Contains(ip)
-		}
-	})
-}
-
-func BenchmarkCIDRTree_Match(b *testing.B) {
-	tree := NewCIDRTree()
-	tree.AddCIDR(getCIDR("1.1.0.0/16"), "1")
-	tree.AddCIDR(getCIDR("1.2.1.1/32"), "1")
-	tree.AddCIDR(getCIDR("192.2.1.1/32"), "1")
-	tree.AddCIDR(getCIDR("172.2.1.1/32"), "1")
-
-	ip := ip2int(net.ParseIP("1.2.1.1"))
-	b.Run("found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Match(ip)
-		}
-	})
-
-	ip = ip2int(net.ParseIP("1.2.1.255"))
-	b.Run("not found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Match(ip)
-		}
-	})
-}
-
-func getCIDR(s string) *net.IPNet {
-	_, c, _ := net.ParseCIDR(s)
-	return c
-}

+ 6 - 5
cmd/nebula-service/main.go

@@ -7,6 +7,7 @@ import (
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/config"
 )
 
 // A version string that can be set with
@@ -49,14 +50,14 @@ func main() {
 	l := logrus.New()
 	l.Out = os.Stdout
 
-	config := nebula.NewConfig(l)
-	err := config.Load(*configPath)
+	c := config.NewC(l)
+	err := c.Load(*configPath)
 	if err != nil {
 		fmt.Printf("failed to load config: %s", err)
 		os.Exit(1)
 	}
 
-	c, err := nebula.Main(config, *configTest, Build, l, nil)
+	ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
 
 	switch v := err.(type) {
 	case nebula.ContextualError:
@@ -68,8 +69,8 @@ func main() {
 	}
 
 	if !*configTest {
-		c.Start()
-		c.ShutdownBlock()
+		ctrl.Start()
+		ctrl.ShutdownBlock()
 	}
 
 	os.Exit(0)

+ 4 - 3
cmd/nebula-service/service.go

@@ -9,6 +9,7 @@ import (
 	"github.com/kardianos/service"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/config"
 )
 
 var logger service.Logger
@@ -27,13 +28,13 @@ func (p *program) Start(s service.Service) error {
 	l := logrus.New()
 	HookLogger(l)
 
-	config := nebula.NewConfig(l)
-	err := config.Load(*p.configPath)
+	c := config.NewC(l)
+	err := c.Load(*p.configPath)
 	if err != nil {
 		return fmt.Errorf("failed to load config: %s", err)
 	}
 
-	p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
+	p.control, err = nebula.Main(c, *p.configTest, Build, l, nil)
 	if err != nil {
 		return err
 	}

+ 6 - 5
cmd/nebula/main.go

@@ -7,6 +7,7 @@ import (
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/config"
 )
 
 // A version string that can be set with
@@ -43,14 +44,14 @@ func main() {
 	l := logrus.New()
 	l.Out = os.Stdout
 
-	config := nebula.NewConfig(l)
-	err := config.Load(*configPath)
+	c := config.NewC(l)
+	err := c.Load(*configPath)
 	if err != nil {
 		fmt.Printf("failed to load config: %s", err)
 		os.Exit(1)
 	}
 
-	c, err := nebula.Main(config, *configTest, Build, l, nil)
+	ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
 
 	switch v := err.(type) {
 	case nebula.ContextualError:
@@ -62,8 +63,8 @@ func main() {
 	}
 
 	if !*configTest {
-		c.Start()
-		c.ShutdownBlock()
+		ctrl.Start()
+		ctrl.ShutdownBlock()
 	}
 
 	os.Exit(0)

+ 0 - 611
config.go

@@ -1,611 +0,0 @@
-package nebula
-
-import (
-	"context"
-	"errors"
-	"fmt"
-	"io/ioutil"
-	"net"
-	"os"
-	"os/signal"
-	"path/filepath"
-	"regexp"
-	"sort"
-	"strconv"
-	"strings"
-	"syscall"
-	"time"
-
-	"github.com/imdario/mergo"
-	"github.com/sirupsen/logrus"
-	"gopkg.in/yaml.v2"
-)
-
-type Config struct {
-	path        string
-	files       []string
-	Settings    map[interface{}]interface{}
-	oldSettings map[interface{}]interface{}
-	callbacks   []func(*Config)
-	l           *logrus.Logger
-}
-
-func NewConfig(l *logrus.Logger) *Config {
-	return &Config{
-		Settings: make(map[interface{}]interface{}),
-		l:        l,
-	}
-}
-
-// Load will find all yaml files within path and load them in lexical order
-func (c *Config) Load(path string) error {
-	c.path = path
-	c.files = make([]string, 0)
-
-	err := c.resolve(path, true)
-	if err != nil {
-		return err
-	}
-
-	if len(c.files) == 0 {
-		return fmt.Errorf("no config files found at %s", path)
-	}
-
-	sort.Strings(c.files)
-
-	err = c.parse()
-	if err != nil {
-		return err
-	}
-
-	return nil
-}
-
-func (c *Config) LoadString(raw string) error {
-	if raw == "" {
-		return errors.New("Empty configuration")
-	}
-	return c.parseRaw([]byte(raw))
-}
-
-// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
-// here should decide if they need to make a change to the current process before making the change. HasChanged can be
-// used to help decide if a change is necessary.
-// These functions should return quickly or spawn their own go routine if they will take a while
-func (c *Config) RegisterReloadCallback(f func(*Config)) {
-	c.callbacks = append(c.callbacks, f)
-}
-
-// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
-// k in both the old and new settings will be serialized, the result of the string comparison is returned.
-// If k is an empty string the entire config is tested.
-// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
-// there is change when there actually wasn't any.
-func (c *Config) HasChanged(k string) bool {
-	if c.oldSettings == nil {
-		return false
-	}
-
-	var (
-		nv interface{}
-		ov interface{}
-	)
-
-	if k == "" {
-		nv = c.Settings
-		ov = c.oldSettings
-		k = "all settings"
-	} else {
-		nv = c.get(k, c.Settings)
-		ov = c.get(k, c.oldSettings)
-	}
-
-	newVals, err := yaml.Marshal(nv)
-	if err != nil {
-		c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
-	}
-
-	oldVals, err := yaml.Marshal(ov)
-	if err != nil {
-		c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
-	}
-
-	return string(newVals) != string(oldVals)
-}
-
-// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
-// original path provided to Load. The old settings are shallow copied for change detection after the reload.
-func (c *Config) CatchHUP(ctx context.Context) {
-	ch := make(chan os.Signal, 1)
-	signal.Notify(ch, syscall.SIGHUP)
-
-	go func() {
-		for {
-			select {
-			case <-ctx.Done():
-				signal.Stop(ch)
-				close(ch)
-				return
-			case <-ch:
-				c.l.Info("Caught HUP, reloading config")
-				c.ReloadConfig()
-			}
-		}
-	}()
-}
-
-func (c *Config) ReloadConfig() {
-	c.oldSettings = make(map[interface{}]interface{})
-	for k, v := range c.Settings {
-		c.oldSettings[k] = v
-	}
-
-	err := c.Load(c.path)
-	if err != nil {
-		c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
-		return
-	}
-
-	for _, v := range c.callbacks {
-		v(c)
-	}
-}
-
-// GetString will get the string for k or return the default d if not found or invalid
-func (c *Config) GetString(k, d string) string {
-	r := c.Get(k)
-	if r == nil {
-		return d
-	}
-
-	return fmt.Sprintf("%v", r)
-}
-
-// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
-func (c *Config) GetStringSlice(k string, d []string) []string {
-	r := c.Get(k)
-	if r == nil {
-		return d
-	}
-
-	rv, ok := r.([]interface{})
-	if !ok {
-		return d
-	}
-
-	v := make([]string, len(rv))
-	for i := 0; i < len(v); i++ {
-		v[i] = fmt.Sprintf("%v", rv[i])
-	}
-
-	return v
-}
-
-// GetMap will get the map for k or return the default d if not found or invalid
-func (c *Config) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
-	r := c.Get(k)
-	if r == nil {
-		return d
-	}
-
-	v, ok := r.(map[interface{}]interface{})
-	if !ok {
-		return d
-	}
-
-	return v
-}
-
-// GetInt will get the int for k or return the default d if not found or invalid
-func (c *Config) GetInt(k string, d int) int {
-	r := c.GetString(k, strconv.Itoa(d))
-	v, err := strconv.Atoi(r)
-	if err != nil {
-		return d
-	}
-
-	return v
-}
-
-// GetBool will get the bool for k or return the default d if not found or invalid
-func (c *Config) GetBool(k string, d bool) bool {
-	r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
-	v, err := strconv.ParseBool(r)
-	if err != nil {
-		switch r {
-		case "y", "yes":
-			return true
-		case "n", "no":
-			return false
-		}
-		return d
-	}
-
-	return v
-}
-
-// GetDuration will get the duration for k or return the default d if not found or invalid
-func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
-	r := c.GetString(k, "")
-	v, err := time.ParseDuration(r)
-	if err != nil {
-		return d
-	}
-	return v
-}
-
-func (c *Config) GetLocalAllowList(k string) (*LocalAllowList, error) {
-	var nameRules []AllowListNameRule
-	handleKey := func(key string, value interface{}) (bool, error) {
-		if key == "interfaces" {
-			var err error
-			nameRules, err = c.getAllowListInterfaces(k, value)
-			if err != nil {
-				return false, err
-			}
-
-			return true, nil
-		}
-		return false, nil
-	}
-
-	al, err := c.GetAllowList(k, handleKey)
-	if err != nil {
-		return nil, err
-	}
-	return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil
-}
-
-func (c *Config) GetRemoteAllowList(k, rangesKey string) (*RemoteAllowList, error) {
-	al, err := c.GetAllowList(k, nil)
-	if err != nil {
-		return nil, err
-	}
-	remoteAllowRanges, err := c.getRemoteAllowRanges(rangesKey)
-	if err != nil {
-		return nil, err
-	}
-	return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil
-}
-
-func (c *Config) getRemoteAllowRanges(k string) (*CIDR6Tree, error) {
-	value := c.Get(k)
-	if value == nil {
-		return nil, nil
-	}
-
-	remoteAllowRanges := NewCIDR6Tree()
-
-	rawMap, ok := value.(map[interface{}]interface{})
-	if !ok {
-		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
-	}
-	for rawKey, rawValue := range rawMap {
-		rawCIDR, ok := rawKey.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
-		}
-
-		allowList, err := c.getAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
-		if err != nil {
-			return nil, err
-		}
-
-		_, cidr, err := net.ParseCIDR(rawCIDR)
-		if err != nil {
-			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
-		}
-
-		remoteAllowRanges.AddCIDR(cidr, allowList)
-	}
-
-	return remoteAllowRanges, nil
-}
-
-// If the handleKey func returns true, the rest of the parsing is skipped
-// for this key. This allows parsing of special values like `interfaces`.
-func (c *Config) GetAllowList(k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
-	r := c.Get(k)
-	if r == nil {
-		return nil, nil
-	}
-
-	return c.getAllowList(k, r, handleKey)
-}
-
-// If the handleKey func returns true, the rest of the parsing is skipped
-// for this key. This allows parsing of special values like `interfaces`.
-func (c *Config) getAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
-	rawMap, ok := raw.(map[interface{}]interface{})
-	if !ok {
-		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
-	}
-
-	tree := NewCIDR6Tree()
-
-	// Keep track of the rules we have added for both ipv4 and ipv6
-	type allowListRules struct {
-		firstValue     bool
-		allValuesMatch bool
-		defaultSet     bool
-		allValues      bool
-	}
-	rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
-	rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
-
-	for rawKey, rawValue := range rawMap {
-		rawCIDR, ok := rawKey.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
-		}
-
-		if handleKey != nil {
-			handled, err := handleKey(rawCIDR, rawValue)
-			if err != nil {
-				return nil, err
-			}
-			if handled {
-				continue
-			}
-		}
-
-		value, ok := rawValue.(bool)
-		if !ok {
-			return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
-		}
-
-		_, cidr, err := net.ParseCIDR(rawCIDR)
-		if err != nil {
-			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
-		}
-
-		// TODO: should we error on duplicate CIDRs in the config?
-		tree.AddCIDR(cidr, value)
-
-		maskBits, maskSize := cidr.Mask.Size()
-
-		var rules *allowListRules
-		if maskSize == 32 {
-			rules = &rules4
-		} else {
-			rules = &rules6
-		}
-
-		if rules.firstValue {
-			rules.allValues = value
-			rules.firstValue = false
-		} else {
-			if value != rules.allValues {
-				rules.allValuesMatch = false
-			}
-		}
-
-		// Check if this is 0.0.0.0/0 or ::/0
-		if maskBits == 0 {
-			rules.defaultSet = true
-		}
-	}
-
-	if !rules4.defaultSet {
-		if rules4.allValuesMatch {
-			_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
-			tree.AddCIDR(zeroCIDR, !rules4.allValues)
-		} else {
-			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
-		}
-	}
-
-	if !rules6.defaultSet {
-		if rules6.allValuesMatch {
-			_, zeroCIDR, _ := net.ParseCIDR("::/0")
-			tree.AddCIDR(zeroCIDR, !rules6.allValues)
-		} else {
-			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
-		}
-	}
-
-	return &AllowList{cidrTree: tree}, nil
-}
-
-func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
-	var nameRules []AllowListNameRule
-
-	rawRules, ok := v.(map[interface{}]interface{})
-	if !ok {
-		return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
-	}
-
-	firstEntry := true
-	var allValues bool
-	for rawName, rawAllow := range rawRules {
-		name, ok := rawName.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
-		}
-		allow, ok := rawAllow.(bool)
-		if !ok {
-			return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
-		}
-
-		nameRE, err := regexp.Compile("^" + name + "$")
-		if err != nil {
-			return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
-		}
-
-		nameRules = append(nameRules, AllowListNameRule{
-			Name:  nameRE,
-			Allow: allow,
-		})
-
-		if firstEntry {
-			allValues = allow
-			firstEntry = false
-		} else {
-			if allow != allValues {
-				return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
-			}
-		}
-	}
-
-	return nameRules, nil
-}
-
-func (c *Config) Get(k string) interface{} {
-	return c.get(k, c.Settings)
-}
-
-func (c *Config) IsSet(k string) bool {
-	return c.get(k, c.Settings) != nil
-}
-
-func (c *Config) get(k string, v interface{}) interface{} {
-	parts := strings.Split(k, ".")
-	for _, p := range parts {
-		m, ok := v.(map[interface{}]interface{})
-		if !ok {
-			return nil
-		}
-
-		v, ok = m[p]
-		if !ok {
-			return nil
-		}
-	}
-
-	return v
-}
-
-// direct signifies if this is the config path directly specified by the user,
-// versus a file/dir found by recursing into that path
-func (c *Config) resolve(path string, direct bool) error {
-	i, err := os.Stat(path)
-	if err != nil {
-		return nil
-	}
-
-	if !i.IsDir() {
-		c.addFile(path, direct)
-		return nil
-	}
-
-	paths, err := readDirNames(path)
-	if err != nil {
-		return fmt.Errorf("problem while reading directory %s: %s", path, err)
-	}
-
-	for _, p := range paths {
-		err := c.resolve(filepath.Join(path, p), false)
-		if err != nil {
-			return err
-		}
-	}
-
-	return nil
-}
-
-func (c *Config) addFile(path string, direct bool) error {
-	ext := filepath.Ext(path)
-
-	if !direct && ext != ".yaml" && ext != ".yml" {
-		return nil
-	}
-
-	ap, err := filepath.Abs(path)
-	if err != nil {
-		return err
-	}
-
-	c.files = append(c.files, ap)
-	return nil
-}
-
-func (c *Config) parseRaw(b []byte) error {
-	var m map[interface{}]interface{}
-
-	err := yaml.Unmarshal(b, &m)
-	if err != nil {
-		return err
-	}
-
-	c.Settings = m
-	return nil
-}
-
-func (c *Config) parse() error {
-	var m map[interface{}]interface{}
-
-	for _, path := range c.files {
-		b, err := ioutil.ReadFile(path)
-		if err != nil {
-			return err
-		}
-
-		var nm map[interface{}]interface{}
-		err = yaml.Unmarshal(b, &nm)
-		if err != nil {
-			return err
-		}
-
-		// We need to use WithAppendSlice so that firewall rules in separate
-		// files are appended together
-		err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
-		m = nm
-		if err != nil {
-			return err
-		}
-	}
-
-	c.Settings = m
-	return nil
-}
-
-func readDirNames(path string) ([]string, error) {
-	f, err := os.Open(path)
-	if err != nil {
-		return nil, err
-	}
-
-	paths, err := f.Readdirnames(-1)
-	f.Close()
-	if err != nil {
-		return nil, err
-	}
-
-	sort.Strings(paths)
-	return paths, nil
-}
-
-func configLogger(c *Config) error {
-	// set up our logging level
-	logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
-	if err != nil {
-		return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
-	}
-	c.l.SetLevel(logLevel)
-
-	disableTimestamp := c.GetBool("logging.disable_timestamp", false)
-	timestampFormat := c.GetString("logging.timestamp_format", "")
-	fullTimestamp := (timestampFormat != "")
-	if timestampFormat == "" {
-		timestampFormat = time.RFC3339
-	}
-
-	logFormat := strings.ToLower(c.GetString("logging.format", "text"))
-	switch logFormat {
-	case "text":
-		c.l.Formatter = &logrus.TextFormatter{
-			TimestampFormat:  timestampFormat,
-			FullTimestamp:    fullTimestamp,
-			DisableTimestamp: disableTimestamp,
-		}
-	case "json":
-		c.l.Formatter = &logrus.JSONFormatter{
-			TimestampFormat:  timestampFormat,
-			DisableTimestamp: disableTimestamp,
-		}
-	default:
-		return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
-	}
-
-	return nil
-}

+ 358 - 0
config/config.go

@@ -0,0 +1,358 @@
+package config
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"io/ioutil"
+	"os"
+	"os/signal"
+	"path/filepath"
+	"sort"
+	"strconv"
+	"strings"
+	"syscall"
+	"time"
+
+	"github.com/imdario/mergo"
+	"github.com/sirupsen/logrus"
+	"gopkg.in/yaml.v2"
+)
+
+type C struct {
+	path        string
+	files       []string
+	Settings    map[interface{}]interface{}
+	oldSettings map[interface{}]interface{}
+	callbacks   []func(*C)
+	l           *logrus.Logger
+}
+
+func NewC(l *logrus.Logger) *C {
+	return &C{
+		Settings: make(map[interface{}]interface{}),
+		l:        l,
+	}
+}
+
+// Load will find all yaml files within path and load them in lexical order
+func (c *C) Load(path string) error {
+	c.path = path
+	c.files = make([]string, 0)
+
+	err := c.resolve(path, true)
+	if err != nil {
+		return err
+	}
+
+	if len(c.files) == 0 {
+		return fmt.Errorf("no config files found at %s", path)
+	}
+
+	sort.Strings(c.files)
+
+	err = c.parse()
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (c *C) LoadString(raw string) error {
+	if raw == "" {
+		return errors.New("Empty configuration")
+	}
+	return c.parseRaw([]byte(raw))
+}
+
+// RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
+// here should decide if they need to make a change to the current process before making the change. HasChanged can be
+// used to help decide if a change is necessary.
+// These functions should return quickly or spawn their own go routine if they will take a while
+func (c *C) RegisterReloadCallback(f func(*C)) {
+	c.callbacks = append(c.callbacks, f)
+}
+
+// HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
+// k in both the old and new settings will be serialized, the result of the string comparison is returned.
+// If k is an empty string the entire config is tested.
+// It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
+// there is change when there actually wasn't any.
+func (c *C) HasChanged(k string) bool {
+	if c.oldSettings == nil {
+		return false
+	}
+
+	var (
+		nv interface{}
+		ov interface{}
+	)
+
+	if k == "" {
+		nv = c.Settings
+		ov = c.oldSettings
+		k = "all settings"
+	} else {
+		nv = c.get(k, c.Settings)
+		ov = c.get(k, c.oldSettings)
+	}
+
+	newVals, err := yaml.Marshal(nv)
+	if err != nil {
+		c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
+	}
+
+	oldVals, err := yaml.Marshal(ov)
+	if err != nil {
+		c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
+	}
+
+	return string(newVals) != string(oldVals)
+}
+
+// CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
+// original path provided to Load. The old settings are shallow copied for change detection after the reload.
+func (c *C) CatchHUP(ctx context.Context) {
+	ch := make(chan os.Signal, 1)
+	signal.Notify(ch, syscall.SIGHUP)
+
+	go func() {
+		for {
+			select {
+			case <-ctx.Done():
+				signal.Stop(ch)
+				close(ch)
+				return
+			case <-ch:
+				c.l.Info("Caught HUP, reloading config")
+				c.ReloadConfig()
+			}
+		}
+	}()
+}
+
+func (c *C) ReloadConfig() {
+	c.oldSettings = make(map[interface{}]interface{})
+	for k, v := range c.Settings {
+		c.oldSettings[k] = v
+	}
+
+	err := c.Load(c.path)
+	if err != nil {
+		c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
+		return
+	}
+
+	for _, v := range c.callbacks {
+		v(c)
+	}
+}
+
+// GetString will get the string for k or return the default d if not found or invalid
+func (c *C) GetString(k, d string) string {
+	r := c.Get(k)
+	if r == nil {
+		return d
+	}
+
+	return fmt.Sprintf("%v", r)
+}
+
+// GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
+func (c *C) GetStringSlice(k string, d []string) []string {
+	r := c.Get(k)
+	if r == nil {
+		return d
+	}
+
+	rv, ok := r.([]interface{})
+	if !ok {
+		return d
+	}
+
+	v := make([]string, len(rv))
+	for i := 0; i < len(v); i++ {
+		v[i] = fmt.Sprintf("%v", rv[i])
+	}
+
+	return v
+}
+
+// GetMap will get the map for k or return the default d if not found or invalid
+func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
+	r := c.Get(k)
+	if r == nil {
+		return d
+	}
+
+	v, ok := r.(map[interface{}]interface{})
+	if !ok {
+		return d
+	}
+
+	return v
+}
+
+// GetInt will get the int for k or return the default d if not found or invalid
+func (c *C) GetInt(k string, d int) int {
+	r := c.GetString(k, strconv.Itoa(d))
+	v, err := strconv.Atoi(r)
+	if err != nil {
+		return d
+	}
+
+	return v
+}
+
+// GetBool will get the bool for k or return the default d if not found or invalid
+func (c *C) GetBool(k string, d bool) bool {
+	r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
+	v, err := strconv.ParseBool(r)
+	if err != nil {
+		switch r {
+		case "y", "yes":
+			return true
+		case "n", "no":
+			return false
+		}
+		return d
+	}
+
+	return v
+}
+
+// GetDuration will get the duration for k or return the default d if not found or invalid
+func (c *C) GetDuration(k string, d time.Duration) time.Duration {
+	r := c.GetString(k, "")
+	v, err := time.ParseDuration(r)
+	if err != nil {
+		return d
+	}
+	return v
+}
+
+func (c *C) Get(k string) interface{} {
+	return c.get(k, c.Settings)
+}
+
+func (c *C) IsSet(k string) bool {
+	return c.get(k, c.Settings) != nil
+}
+
+func (c *C) get(k string, v interface{}) interface{} {
+	parts := strings.Split(k, ".")
+	for _, p := range parts {
+		m, ok := v.(map[interface{}]interface{})
+		if !ok {
+			return nil
+		}
+
+		v, ok = m[p]
+		if !ok {
+			return nil
+		}
+	}
+
+	return v
+}
+
+// direct signifies if this is the config path directly specified by the user,
+// versus a file/dir found by recursing into that path
+func (c *C) resolve(path string, direct bool) error {
+	i, err := os.Stat(path)
+	if err != nil {
+		return nil
+	}
+
+	if !i.IsDir() {
+		c.addFile(path, direct)
+		return nil
+	}
+
+	paths, err := readDirNames(path)
+	if err != nil {
+		return fmt.Errorf("problem while reading directory %s: %s", path, err)
+	}
+
+	for _, p := range paths {
+		err := c.resolve(filepath.Join(path, p), false)
+		if err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (c *C) addFile(path string, direct bool) error {
+	ext := filepath.Ext(path)
+
+	if !direct && ext != ".yaml" && ext != ".yml" {
+		return nil
+	}
+
+	ap, err := filepath.Abs(path)
+	if err != nil {
+		return err
+	}
+
+	c.files = append(c.files, ap)
+	return nil
+}
+
+func (c *C) parseRaw(b []byte) error {
+	var m map[interface{}]interface{}
+
+	err := yaml.Unmarshal(b, &m)
+	if err != nil {
+		return err
+	}
+
+	c.Settings = m
+	return nil
+}
+
+func (c *C) parse() error {
+	var m map[interface{}]interface{}
+
+	for _, path := range c.files {
+		b, err := ioutil.ReadFile(path)
+		if err != nil {
+			return err
+		}
+
+		var nm map[interface{}]interface{}
+		err = yaml.Unmarshal(b, &nm)
+		if err != nil {
+			return err
+		}
+
+		// We need to use WithAppendSlice so that firewall rules in separate
+		// files are appended together
+		err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
+		m = nm
+		if err != nil {
+			return err
+		}
+	}
+
+	c.Settings = m
+	return nil
+}
+
+func readDirNames(path string) ([]string, error) {
+	f, err := os.Open(path)
+	if err != nil {
+		return nil, err
+	}
+
+	paths, err := f.Readdirnames(-1)
+	f.Close()
+	if err != nil {
+		return nil, err
+	}
+
+	sort.Strings(paths)
+	return paths, nil
+}

+ 18 - 103
config_test.go → config/config_test.go

@@ -1,4 +1,4 @@
-package nebula
+package config
 
 import (
 	"io/ioutil"
@@ -7,19 +7,20 @@ import (
 	"testing"
 	"time"
 
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestConfig_Load(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	dir, err := ioutil.TempDir("", "config-test")
 	// invalid yaml
-	c := NewConfig(l)
+	c := NewC(l)
 	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
 	assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n  line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
 
 	// simple multi config merge
-	c = NewConfig(l)
+	c = NewC(l)
 	os.RemoveAll(dir)
 	os.Mkdir(dir, 0755)
 
@@ -41,9 +42,9 @@ func TestConfig_Load(t *testing.T) {
 }
 
 func TestConfig_Get(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	// test simple type
-	c := NewConfig(l)
+	c := NewC(l)
 	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
 	assert.Equal(t, "hi", c.Get("firewall.outbound"))
 
@@ -57,15 +58,15 @@ func TestConfig_Get(t *testing.T) {
 }
 
 func TestConfig_GetStringSlice(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
+	l := util.NewTestLogger()
+	c := NewC(l)
 	c.Settings["slice"] = []interface{}{"one", "two"}
 	assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
 }
 
 func TestConfig_GetBool(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
+	l := util.NewTestLogger()
+	c := NewC(l)
 	c.Settings["bool"] = true
 	assert.Equal(t, true, c.GetBool("bool", false))
 
@@ -91,108 +92,22 @@ func TestConfig_GetBool(t *testing.T) {
 	assert.Equal(t, false, c.GetBool("bool", true))
 }
 
-func TestConfig_GetAllowList(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"192.168.0.0": true,
-	}
-	r, err := c.GetAllowList("allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
-	assert.Nil(t, r)
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"192.168.0.0/16": "abc",
-	}
-	r, err = c.GetAllowList("allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"192.168.0.0/16": true,
-		"10.0.0.0/8":     false,
-	}
-	r, err = c.GetAllowList("allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"0.0.0.0/0":      true,
-		"10.0.0.0/8":     false,
-		"10.42.42.0/24":  true,
-		"fd00::/8":       true,
-		"fd00:fd00::/16": false,
-	}
-	r, err = c.GetAllowList("allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"0.0.0.0/0":     true,
-		"10.0.0.0/8":    false,
-		"10.42.42.0/24": true,
-	}
-	r, err = c.GetAllowList("allowlist", nil)
-	if assert.NoError(t, err) {
-		assert.NotNil(t, r)
-	}
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"0.0.0.0/0":      true,
-		"10.0.0.0/8":     false,
-		"10.42.42.0/24":  true,
-		"::/0":           false,
-		"fd00::/8":       true,
-		"fd00:fd00::/16": false,
-	}
-	r, err = c.GetAllowList("allowlist", nil)
-	if assert.NoError(t, err) {
-		assert.NotNil(t, r)
-	}
-
-	// Test interface names
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"interfaces": map[interface{}]interface{}{
-			`docker.*`: "foo",
-		},
-	}
-	lr, err := c.GetLocalAllowList("allowlist")
-	assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"interfaces": map[interface{}]interface{}{
-			`docker.*`: false,
-			`eth.*`:    true,
-		},
-	}
-	lr, err = c.GetLocalAllowList("allowlist")
-	assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
-
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"interfaces": map[interface{}]interface{}{
-			`docker.*`: false,
-		},
-	}
-	lr, err = c.GetLocalAllowList("allowlist")
-	if assert.NoError(t, err) {
-		assert.NotNil(t, lr)
-	}
-}
-
 func TestConfig_HasChanged(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	// No reload has occurred, return false
-	c := NewConfig(l)
+	c := NewC(l)
 	c.Settings["test"] = "hi"
 	assert.False(t, c.HasChanged(""))
 
 	// Test key change
-	c = NewConfig(l)
+	c = NewC(l)
 	c.Settings["test"] = "hi"
 	c.oldSettings = map[interface{}]interface{}{"test": "no"}
 	assert.True(t, c.HasChanged("test"))
 	assert.True(t, c.HasChanged(""))
 
 	// No key change
-	c = NewConfig(l)
+	c = NewC(l)
 	c.Settings["test"] = "hi"
 	c.oldSettings = map[interface{}]interface{}{"test": "hi"}
 	assert.False(t, c.HasChanged("test"))
@@ -200,13 +115,13 @@ func TestConfig_HasChanged(t *testing.T) {
 }
 
 func TestConfig_ReloadConfig(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	done := make(chan bool, 1)
 	dir, err := ioutil.TempDir("", "config-test")
 	assert.Nil(t, err)
 	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
 
-	c := NewConfig(l)
+	c := NewC(l)
 	assert.Nil(t, c.Load(dir))
 
 	assert.False(t, c.HasChanged("outer.inner"))
@@ -215,7 +130,7 @@ func TestConfig_ReloadConfig(t *testing.T) {
 
 	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: ho"), 0644)
 
-	c.RegisterReloadCallback(func(c *Config) {
+	c.RegisterReloadCallback(func(c *C) {
 		done <- true
 	})
 

+ 51 - 49
connection_manager.go

@@ -6,6 +6,8 @@ import (
 	"time"
 
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
 )
 
 // TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
@@ -13,16 +15,16 @@ import (
 
 type connectionManager struct {
 	hostMap      *HostMap
-	in           map[uint32]struct{}
+	in           map[iputil.VpnIp]struct{}
 	inLock       *sync.RWMutex
 	inCount      int
-	out          map[uint32]struct{}
+	out          map[iputil.VpnIp]struct{}
 	outLock      *sync.RWMutex
 	outCount     int
 	TrafficTimer *SystemTimerWheel
 	intf         *Interface
 
-	pendingDeletion      map[uint32]int
+	pendingDeletion      map[iputil.VpnIp]int
 	pendingDeletionLock  *sync.RWMutex
 	pendingDeletionTimer *SystemTimerWheel
 
@@ -36,15 +38,15 @@ type connectionManager struct {
 func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
 	nc := &connectionManager{
 		hostMap:                 intf.hostMap,
-		in:                      make(map[uint32]struct{}),
+		in:                      make(map[iputil.VpnIp]struct{}),
 		inLock:                  &sync.RWMutex{},
 		inCount:                 0,
-		out:                     make(map[uint32]struct{}),
+		out:                     make(map[iputil.VpnIp]struct{}),
 		outLock:                 &sync.RWMutex{},
 		outCount:                0,
 		TrafficTimer:            NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
 		intf:                    intf,
-		pendingDeletion:         make(map[uint32]int),
+		pendingDeletion:         make(map[iputil.VpnIp]int),
 		pendingDeletionLock:     &sync.RWMutex{},
 		pendingDeletionTimer:    NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
 		checkInterval:           checkInterval,
@@ -55,7 +57,7 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
 	return nc
 }
 
-func (n *connectionManager) In(ip uint32) {
+func (n *connectionManager) In(ip iputil.VpnIp) {
 	n.inLock.RLock()
 	// If this already exists, return
 	if _, ok := n.in[ip]; ok {
@@ -68,7 +70,7 @@ func (n *connectionManager) In(ip uint32) {
 	n.inLock.Unlock()
 }
 
-func (n *connectionManager) Out(ip uint32) {
+func (n *connectionManager) Out(ip iputil.VpnIp) {
 	n.outLock.RLock()
 	// If this already exists, return
 	if _, ok := n.out[ip]; ok {
@@ -87,9 +89,9 @@ func (n *connectionManager) Out(ip uint32) {
 	n.outLock.Unlock()
 }
 
-func (n *connectionManager) CheckIn(vpnIP uint32) bool {
+func (n *connectionManager) CheckIn(vpnIp iputil.VpnIp) bool {
 	n.inLock.RLock()
-	if _, ok := n.in[vpnIP]; ok {
+	if _, ok := n.in[vpnIp]; ok {
 		n.inLock.RUnlock()
 		return true
 	}
@@ -97,7 +99,7 @@ func (n *connectionManager) CheckIn(vpnIP uint32) bool {
 	return false
 }
 
-func (n *connectionManager) ClearIP(ip uint32) {
+func (n *connectionManager) ClearIP(ip iputil.VpnIp) {
 	n.inLock.Lock()
 	n.outLock.Lock()
 	delete(n.in, ip)
@@ -106,13 +108,13 @@ func (n *connectionManager) ClearIP(ip uint32) {
 	n.outLock.Unlock()
 }
 
-func (n *connectionManager) ClearPendingDeletion(ip uint32) {
+func (n *connectionManager) ClearPendingDeletion(ip iputil.VpnIp) {
 	n.pendingDeletionLock.Lock()
 	delete(n.pendingDeletion, ip)
 	n.pendingDeletionLock.Unlock()
 }
 
-func (n *connectionManager) AddPendingDeletion(ip uint32) {
+func (n *connectionManager) AddPendingDeletion(ip iputil.VpnIp) {
 	n.pendingDeletionLock.Lock()
 	if _, ok := n.pendingDeletion[ip]; ok {
 		n.pendingDeletion[ip] += 1
@@ -123,7 +125,7 @@ func (n *connectionManager) AddPendingDeletion(ip uint32) {
 	n.pendingDeletionLock.Unlock()
 }
 
-func (n *connectionManager) checkPendingDeletion(ip uint32) bool {
+func (n *connectionManager) checkPendingDeletion(ip iputil.VpnIp) bool {
 	n.pendingDeletionLock.RLock()
 	if _, ok := n.pendingDeletion[ip]; ok {
 
@@ -134,8 +136,8 @@ func (n *connectionManager) checkPendingDeletion(ip uint32) bool {
 	return false
 }
 
-func (n *connectionManager) AddTrafficWatch(vpnIP uint32, seconds int) {
-	n.TrafficTimer.Add(vpnIP, time.Second*time.Duration(seconds))
+func (n *connectionManager) AddTrafficWatch(vpnIp iputil.VpnIp, seconds int) {
+	n.TrafficTimer.Add(vpnIp, time.Second*time.Duration(seconds))
 }
 
 func (n *connectionManager) Start(ctx context.Context) {
@@ -169,23 +171,23 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 			break
 		}
 
-		vpnIP := ep.(uint32)
+		vpnIp := ep.(iputil.VpnIp)
 
 		// Check for traffic coming back in from this host.
-		traf := n.CheckIn(vpnIP)
+		traf := n.CheckIn(vpnIp)
 
-		hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
+		hostinfo, err := n.hostMap.QueryVpnIp(vpnIp)
 		if err != nil {
-			n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
+			n.l.Debugf("Not found in hostmap: %s", vpnIp)
 
 			if !n.intf.disconnectInvalid {
-				n.ClearIP(vpnIP)
-				n.ClearPendingDeletion(vpnIP)
+				n.ClearIP(vpnIp)
+				n.ClearPendingDeletion(vpnIp)
 				continue
 			}
 		}
 
-		if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
+		if n.handleInvalidCertificate(now, vpnIp, hostinfo) {
 			continue
 		}
 
@@ -193,12 +195,12 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 		// expired, just ignore.
 		if traf {
 			if n.l.Level >= logrus.DebugLevel {
-				n.l.WithField("vpnIp", IntIp(vpnIP)).
+				n.l.WithField("vpnIp", vpnIp).
 					WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
 					Debug("Tunnel status")
 			}
-			n.ClearIP(vpnIP)
-			n.ClearPendingDeletion(vpnIP)
+			n.ClearIP(vpnIp)
+			n.ClearPendingDeletion(vpnIp)
 			continue
 		}
 
@@ -208,12 +210,12 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 
 		if hostinfo != nil && hostinfo.ConnectionState != nil {
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-			n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out)
+			n.intf.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, p, nb, out)
 
 		} else {
-			hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
+			hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", vpnIp)
 		}
-		n.AddPendingDeletion(vpnIP)
+		n.AddPendingDeletion(vpnIp)
 	}
 
 }
@@ -226,38 +228,38 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
 			break
 		}
 
-		vpnIP := ep.(uint32)
+		vpnIp := ep.(iputil.VpnIp)
 
-		hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
+		hostinfo, err := n.hostMap.QueryVpnIp(vpnIp)
 		if err != nil {
-			n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
+			n.l.Debugf("Not found in hostmap: %s", vpnIp)
 
 			if !n.intf.disconnectInvalid {
-				n.ClearIP(vpnIP)
-				n.ClearPendingDeletion(vpnIP)
+				n.ClearIP(vpnIp)
+				n.ClearPendingDeletion(vpnIp)
 				continue
 			}
 		}
 
-		if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
+		if n.handleInvalidCertificate(now, vpnIp, hostinfo) {
 			continue
 		}
 
 		// If we saw an incoming packets from this ip and peer's certificate is not
 		// expired, just ignore.
-		traf := n.CheckIn(vpnIP)
+		traf := n.CheckIn(vpnIp)
 		if traf {
-			n.l.WithField("vpnIp", IntIp(vpnIP)).
+			n.l.WithField("vpnIp", vpnIp).
 				WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
 				Debug("Tunnel status")
 
-			n.ClearIP(vpnIP)
-			n.ClearPendingDeletion(vpnIP)
+			n.ClearIP(vpnIp)
+			n.ClearPendingDeletion(vpnIp)
 			continue
 		}
 
 		// If it comes around on deletion wheel and hasn't resolved itself, delete
-		if n.checkPendingDeletion(vpnIP) {
+		if n.checkPendingDeletion(vpnIp) {
 			cn := ""
 			if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
 				cn = hostinfo.ConnectionState.peerCert.Details.Name
@@ -267,22 +269,22 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
 				WithField("certName", cn).
 				Info("Tunnel status")
 
-			n.ClearIP(vpnIP)
-			n.ClearPendingDeletion(vpnIP)
+			n.ClearIP(vpnIp)
+			n.ClearPendingDeletion(vpnIp)
 			// TODO: This is only here to let tests work. Should do proper mocking
 			if n.intf.lightHouse != nil {
-				n.intf.lightHouse.DeleteVpnIP(vpnIP)
+				n.intf.lightHouse.DeleteVpnIp(vpnIp)
 			}
 			n.hostMap.DeleteHostInfo(hostinfo)
 		} else {
-			n.ClearIP(vpnIP)
-			n.ClearPendingDeletion(vpnIP)
+			n.ClearIP(vpnIp)
+			n.ClearPendingDeletion(vpnIp)
 		}
 	}
 }
 
 // handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
-func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32, hostinfo *HostInfo) bool {
+func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIp iputil.VpnIp, hostinfo *HostInfo) bool {
 	if !n.intf.disconnectInvalid {
 		return false
 	}
@@ -298,7 +300,7 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32
 	}
 
 	fingerprint, _ := remoteCert.Sha256Sum()
-	n.l.WithField("vpnIp", IntIp(vpnIP)).WithError(err).
+	n.l.WithField("vpnIp", vpnIp).WithError(err).
 		WithField("certName", remoteCert.Details.Name).
 		WithField("fingerprint", fingerprint).
 		Info("Remote certificate is no longer valid, tearing down the tunnel")
@@ -307,7 +309,7 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32
 	n.intf.sendCloseTunnel(hostinfo)
 	n.intf.closeTunnel(hostinfo, false)
 
-	n.ClearIP(vpnIP)
-	n.ClearPendingDeletion(vpnIP)
+	n.ClearIP(vpnIp)
+	n.ClearPendingDeletion(vpnIp)
 	return true
 }

+ 39 - 36
connection_manager_test.go

@@ -10,17 +10,20 @@ import (
 
 	"github.com/flynn/noise"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
-var vpnIP uint32
+var vpnIp iputil.VpnIp
 
 func Test_NewConnectionManagerTest(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	vpnIP = ip2int(net.ParseIP("172.1.1.2"))
+	vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
 
 	// Very incomplete mock objects
@@ -32,15 +35,15 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		rawCertificateNoKey: []byte{},
 	}
 
-	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
+	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
 	ifce := &Interface{
 		hostMap:          hostMap,
 		inside:           &Tun{},
-		outside:          &udpConn{},
+		outside:          &udp.Conn{},
 		certState:        cs,
 		firewall:         &Firewall{},
 		lightHouse:       lh,
-		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
+		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
 		l:                l,
 	}
 	now := time.Now()
@@ -54,16 +57,16 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	out := make([]byte, mtu)
 	nc.HandleMonitorTick(now, p, nb, out)
 	// Add an ip we have established a connection w/ to hostmap
-	hostinfo := nc.hostMap.AddVpnIP(vpnIP)
+	hostinfo := nc.hostMap.AddVpnIp(vpnIp)
 	hostinfo.ConnectionState = &ConnectionState{
 		certState: cs,
 		H:         &noise.HandshakeState{},
 	}
 
-	// We saw traffic out to vpnIP
-	nc.Out(vpnIP)
-	assert.NotContains(t, nc.pendingDeletion, vpnIP)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIP)
+	// We saw traffic out to vpnIp
+	nc.Out(vpnIp)
+	assert.NotContains(t, nc.pendingDeletion, vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
 	// Move ahead 5s. Nothing should happen
 	next_tick := now.Add(5 * time.Second)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
@@ -73,20 +76,20 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleDeletionTick(next_tick)
 	// This host should now be up for deletion
-	assert.Contains(t, nc.pendingDeletion, vpnIP)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIP)
+	assert.Contains(t, nc.pendingDeletion, vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
 	// Move ahead some more
 	next_tick = now.Add(45 * time.Second)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleDeletionTick(next_tick)
 	// The host should be evicted
-	assert.NotContains(t, nc.pendingDeletion, vpnIP)
-	assert.NotContains(t, nc.hostMap.Hosts, vpnIP)
+	assert.NotContains(t, nc.pendingDeletion, vpnIp)
+	assert.NotContains(t, nc.hostMap.Hosts, vpnIp)
 
 }
 
 func Test_NewConnectionManagerTest2(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@@ -101,15 +104,15 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		rawCertificateNoKey: []byte{},
 	}
 
-	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
+	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
 	ifce := &Interface{
 		hostMap:          hostMap,
 		inside:           &Tun{},
-		outside:          &udpConn{},
+		outside:          &udp.Conn{},
 		certState:        cs,
 		firewall:         &Firewall{},
 		lightHouse:       lh,
-		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
+		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
 		l:                l,
 	}
 	now := time.Now()
@@ -123,16 +126,16 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	out := make([]byte, mtu)
 	nc.HandleMonitorTick(now, p, nb, out)
 	// Add an ip we have established a connection w/ to hostmap
-	hostinfo := nc.hostMap.AddVpnIP(vpnIP)
+	hostinfo := nc.hostMap.AddVpnIp(vpnIp)
 	hostinfo.ConnectionState = &ConnectionState{
 		certState: cs,
 		H:         &noise.HandshakeState{},
 	}
 
-	// We saw traffic out to vpnIP
-	nc.Out(vpnIP)
-	assert.NotContains(t, nc.pendingDeletion, vpnIP)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIP)
+	// We saw traffic out to vpnIp
+	nc.Out(vpnIp)
+	assert.NotContains(t, nc.pendingDeletion, vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
 	// Move ahead 5s. Nothing should happen
 	next_tick := now.Add(5 * time.Second)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
@@ -142,17 +145,17 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleDeletionTick(next_tick)
 	// This host should now be up for deletion
-	assert.Contains(t, nc.pendingDeletion, vpnIP)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIP)
+	assert.Contains(t, nc.pendingDeletion, vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
 	// We heard back this time
-	nc.In(vpnIP)
+	nc.In(vpnIp)
 	// Move ahead some more
 	next_tick = now.Add(45 * time.Second)
 	nc.HandleMonitorTick(next_tick, p, nb, out)
 	nc.HandleDeletionTick(next_tick)
 	// The host should be evicted
-	assert.NotContains(t, nc.pendingDeletion, vpnIP)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIP)
+	assert.NotContains(t, nc.pendingDeletion, vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
 
 }
 
@@ -161,7 +164,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 // Disconnect only if disconnectInvalid: true is set.
 func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	now := time.Now()
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ipNet := net.IPNet{
 		IP:   net.IPv4(172, 1, 1, 2),
 		Mask: net.IPMask{255, 255, 255, 0},
@@ -210,15 +213,15 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 		rawCertificateNoKey: []byte{},
 	}
 
-	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
+	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{}, 1000, 0, &udp.Conn{}, false, 1, false)
 	ifce := &Interface{
 		hostMap:           hostMap,
 		inside:            &Tun{},
-		outside:           &udpConn{},
+		outside:           &udp.Conn{},
 		certState:         cs,
 		firewall:          &Firewall{},
 		lightHouse:        lh,
-		handshakeManager:  NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
+		handshakeManager:  NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
 		l:                 l,
 		disconnectInvalid: true,
 		caPool:            ncp,
@@ -229,7 +232,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	defer cancel()
 	nc := newConnectionManager(ctx, l, ifce, 5, 10)
 	ifce.connectionManager = nc
-	hostinfo := nc.hostMap.AddVpnIP(vpnIP)
+	hostinfo := nc.hostMap.AddVpnIp(vpnIp)
 	hostinfo.ConnectionState = &ConnectionState{
 		certState: cs,
 		peerCert:  &peerCert,
@@ -240,13 +243,13 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	// Check if to disconnect with invalid certificate.
 	// Should be alive.
 	nextTick := now.Add(45 * time.Second)
-	destroyed := nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
+	destroyed := nc.handleInvalidCertificate(nextTick, vpnIp, hostinfo)
 	assert.False(t, destroyed)
 
 	// Move ahead 61s.
 	// Check if to disconnect with invalid certificate.
 	// Should be disconnected.
 	nextTick = now.Add(61 * time.Second)
-	destroyed = nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
+	destroyed = nc.handleInvalidCertificate(nextTick, vpnIp, hostinfo)
 	assert.True(t, destroyed)
 }

+ 18 - 15
control.go

@@ -10,6 +10,9 @@ import (
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 
 // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
@@ -25,14 +28,14 @@ type Control struct {
 }
 
 type ControlHostInfo struct {
-	VpnIP          net.IP                  `json:"vpnIp"`
+	VpnIp          net.IP                  `json:"vpnIp"`
 	LocalIndex     uint32                  `json:"localIndex"`
 	RemoteIndex    uint32                  `json:"remoteIndex"`
-	RemoteAddrs    []*udpAddr              `json:"remoteAddrs"`
+	RemoteAddrs    []*udp.Addr             `json:"remoteAddrs"`
 	CachedPackets  int                     `json:"cachedPackets"`
 	Cert           *cert.NebulaCertificate `json:"cert"`
 	MessageCounter uint64                  `json:"messageCounter"`
-	CurrentRemote  *udpAddr                `json:"currentRemote"`
+	CurrentRemote  *udp.Addr               `json:"currentRemote"`
 }
 
 // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
@@ -95,8 +98,8 @@ func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
 	}
 }
 
-// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
-func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInfo {
+// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
+func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
 	var hm *HostMap
 	if pending {
 		hm = c.f.handshakeManager.pendingHostMap
@@ -104,7 +107,7 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
 		hm = c.f.hostMap
 	}
 
-	h, err := hm.QueryVpnIP(vpnIP)
+	h, err := hm.QueryVpnIp(vpnIp)
 	if err != nil {
 		return nil
 	}
@@ -114,8 +117,8 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
 }
 
 // SetRemoteForTunnel forces a tunnel to use a specific remote
-func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInfo {
-	hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
+func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
+	hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 		return nil
 	}
@@ -126,15 +129,15 @@ func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInf
 }
 
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
-func (c *Control) CloseTunnel(vpnIP uint32, localOnly bool) bool {
-	hostInfo, err := c.f.hostMap.QueryVpnIP(vpnIP)
+func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
+	hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 		return false
 	}
 
 	if !localOnly {
 		c.f.send(
-			closeTunnel,
+			header.CloseTunnel,
 			0,
 			hostInfo.ConnectionState,
 			hostInfo,
@@ -156,16 +159,16 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	c.f.hostMap.Lock()
 	for _, h := range c.f.hostMap.Hosts {
 		if excludeLighthouses {
-			if _, ok := c.f.lightHouse.lighthouses[h.hostId]; ok {
+			if _, ok := c.f.lightHouse.lighthouses[h.vpnIp]; ok {
 				continue
 			}
 		}
 
 		if h.ConnectionState.ready {
-			c.f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
+			c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 			c.f.closeTunnel(h, true)
 
-			c.l.WithField("vpnIp", IntIp(h.hostId)).WithField("udpAddr", h.remote).
+			c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
 				Debug("Sending close tunnel message")
 			closed++
 		}
@@ -176,7 +179,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 
 func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 	chi := ControlHostInfo{
-		VpnIP:         int2ip(h.hostId),
+		VpnIp:         h.vpnIp.ToIP(),
 		LocalIndex:    h.localIndexId,
 		RemoteIndex:   h.remoteIndexId,
 		RemoteAddrs:   h.remotes.CopyAddrs(preferredRanges),

+ 16 - 14
control_test.go

@@ -8,17 +8,19 @@ import (
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
-func TestControl_GetHostInfoByVpnIP(t *testing.T) {
-	l := NewTestLogger()
+func TestControl_GetHostInfoByVpnIp(t *testing.T) {
+	l := util.NewTestLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
 	hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
-	remote1 := NewUDPAddr(int2ip(100), 4444)
-	remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
+	remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
+	remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
 	ipNet := net.IPNet{
 		IP:   net.IPv4(1, 2, 3, 4),
 		Mask: net.IPMask{255, 255, 255, 0},
@@ -48,7 +50,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 	remotes := NewRemoteList()
 	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
 	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
-	hm.Add(ip2int(ipNet.IP), &HostInfo{
+	hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{
 		remote:  remote1,
 		remotes: remotes,
 		ConnectionState: &ConnectionState{
@@ -56,10 +58,10 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 		},
 		remoteIndexId: 200,
 		localIndexId:  201,
-		hostId:        ip2int(ipNet.IP),
+		vpnIp:         iputil.Ip2VpnIp(ipNet.IP),
 	})
 
-	hm.Add(ip2int(ipNet2.IP), &HostInfo{
+	hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{
 		remote:  remote1,
 		remotes: remotes,
 		ConnectionState: &ConnectionState{
@@ -67,7 +69,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 		},
 		remoteIndexId: 200,
 		localIndexId:  201,
-		hostId:        ip2int(ipNet2.IP),
+		vpnIp:         iputil.Ip2VpnIp(ipNet2.IP),
 	})
 
 	c := Control{
@@ -77,26 +79,26 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 		l: logrus.New(),
 	}
 
-	thi := c.GetHostInfoByVpnIP(ip2int(ipNet.IP), false)
+	thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false)
 
 	expectedInfo := ControlHostInfo{
-		VpnIP:          net.IPv4(1, 2, 3, 4).To4(),
+		VpnIp:          net.IPv4(1, 2, 3, 4).To4(),
 		LocalIndex:     201,
 		RemoteIndex:    200,
-		RemoteAddrs:    []*udpAddr{remote2, remote1},
+		RemoteAddrs:    []*udp.Addr{remote2, remote1},
 		CachedPackets:  0,
 		Cert:           crt.Copy(),
 		MessageCounter: 0,
-		CurrentRemote:  NewUDPAddr(int2ip(100), 4444),
+		CurrentRemote:  udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
 	}
 
 	// Make sure we don't have any unexpected fields
-	assertFields(t, []string{"VpnIP", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
+	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote"}, thi)
 	util.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	assert.NotPanics(t, func() {
-		thi = c.GetHostInfoByVpnIP(ip2int(ipNet2.IP), false)
+		thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false)
 	})
 }
 

+ 15 - 12
control_tester.go

@@ -8,12 +8,15 @@ import (
 
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 
 // WaitForTypeByIndex will pipe all messages from this control device into the pipeTo control device
 // returning after a message matching the criteria has been piped
-func (c *Control) WaitForType(msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) {
-	h := &Header{}
+func (c *Control) WaitForType(msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
+	h := &header.H{}
 	for {
 		p := c.f.outside.Get(true)
 		if err := h.Parse(p.Data); err != nil {
@@ -28,8 +31,8 @@ func (c *Control) WaitForType(msgType NebulaMessageType, subType NebulaMessageSu
 
 // WaitForTypeByIndex is similar to WaitForType except it adds an index check
 // Useful if you have many nodes communicating and want to wait to find a specific nodes packet
-func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType, subType NebulaMessageSubType, pipeTo *Control) {
-	h := &Header{}
+func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) {
+	h := &header.H{}
 	for {
 		p := c.f.outside.Get(true)
 		if err := h.Parse(p.Data); err != nil {
@@ -46,12 +49,12 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType,
 // This is necessary if you did not configure static hosts or are not running a lighthouse
 func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(ip2int(vpnIp))
+	remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 
-	iVpnIp := ip2int(vpnIp)
+	iVpnIp := iputil.Ip2VpnIp(vpnIp)
 	if v4 := toAddr.IP.To4(); v4 != nil {
 		remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
 	} else {
@@ -65,12 +68,12 @@ func (c *Control) GetFromTun(block bool) []byte {
 }
 
 // GetFromUDP will pull a udp packet off the udp side of nebula
-func (c *Control) GetFromUDP(block bool) *UdpPacket {
+func (c *Control) GetFromUDP(block bool) *udp.Packet {
 	return c.f.outside.Get(block)
 }
 
-func (c *Control) GetUDPTxChan() <-chan *UdpPacket {
-	return c.f.outside.txPackets
+func (c *Control) GetUDPTxChan() <-chan *udp.Packet {
+	return c.f.outside.TxPackets
 }
 
 func (c *Control) GetTunTxChan() <-chan []byte {
@@ -78,7 +81,7 @@ func (c *Control) GetTunTxChan() <-chan []byte {
 }
 
 // InjectUDPPacket will inject a packet into the udp side of nebula
-func (c *Control) InjectUDPPacket(p *UdpPacket) {
+func (c *Control) InjectUDPPacket(p *udp.Packet) {
 	c.f.outside.Send(p)
 }
 
@@ -115,11 +118,11 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
 }
 
 func (c *Control) GetUDPAddr() string {
-	return c.f.outside.addr.String()
+	return c.f.outside.Addr.String()
 }
 
 func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
-	hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[ip2int(vpnIp)]
+	hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)]
 	if !ok {
 		return false
 	}

+ 9 - 7
dns_server.go

@@ -8,6 +8,8 @@ import (
 
 	"github.com/miekg/dns"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/iputil"
 )
 
 // This whole thing should be rewritten to use context
@@ -44,8 +46,8 @@ func (d *dnsRecords) QueryCert(data string) string {
 	if ip == nil {
 		return ""
 	}
-	iip := ip2int(ip)
-	hostinfo, err := d.hostMap.QueryVpnIP(iip)
+	iip := iputil.Ip2VpnIp(ip)
+	hostinfo, err := d.hostMap.QueryVpnIp(iip)
 	if err != nil {
 		return ""
 	}
@@ -109,7 +111,7 @@ func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
 	w.WriteMsg(m)
 }
 
-func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
+func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
 	dnsR = newDnsRecords(hostMap)
 
 	// attach request handler func
@@ -117,7 +119,7 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
 		handleDnsRequest(l, w, r)
 	})
 
-	c.RegisterReloadCallback(func(c *Config) {
+	c.RegisterReloadCallback(func(c *config.C) {
 		reloadDns(l, c)
 	})
 
@@ -126,11 +128,11 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
 	}
 }
 
-func getDnsServerAddr(c *Config) string {
+func getDnsServerAddr(c *config.C) string {
 	return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
 }
 
-func startDns(l *logrus.Logger, c *Config) {
+func startDns(l *logrus.Logger, c *config.C) {
 	dnsAddr = getDnsServerAddr(c)
 	dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
 	l.WithField("dnsListener", dnsAddr).Infof("Starting DNS responder")
@@ -141,7 +143,7 @@ func startDns(l *logrus.Logger, c *Config) {
 	}
 }
 
-func reloadDns(l *logrus.Logger, c *Config) {
+func reloadDns(l *logrus.Logger, c *config.C) {
 	if dnsAddr == getDnsServerAddr(c) {
 		l.Debug("No DNS server config change detected")
 		return

+ 8 - 5
e2e/handshakes_test.go

@@ -10,6 +10,9 @@ import (
 
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/e2e/router"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 )
 
@@ -37,7 +40,7 @@ func TestGoodHandshake(t *testing.T) {
 	t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
 	// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
 	badPacket := stage1Packet.Copy()
-	badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen]
+	badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
 	myControl.InjectUDPPacket(badPacket)
 
 	t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
@@ -87,8 +90,8 @@ func TestWrongResponderHandshake(t *testing.T) {
 
 	t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
 	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
-	r.RouteForAllExitFunc(func(p *nebula.UdpPacket, c *nebula.Control) router.ExitType {
-		h := &nebula.Header{}
+	r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
+		h := &header.H{}
 		err := h.Parse(p.Data)
 		if err != nil {
 			panic(err)
@@ -115,8 +118,8 @@ func TestWrongResponderHandshake(t *testing.T) {
 	r.FlushAll()
 
 	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), true), "My pending hostmap should not contain evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false), "My main hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), true), "My pending hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), false), "My main hostmap should not contain evil")
 	//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
 
 	//TODO: assert hostmaps for everyone

+ 11 - 23
e2e/helpers_test.go

@@ -5,7 +5,6 @@ package e2e
 
 import (
 	"crypto/rand"
-	"encoding/binary"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -19,7 +18,9 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/e2e/router"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/crypto/curve25519"
 	"golang.org/x/crypto/ed25519"
@@ -82,10 +83,10 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		panic(err)
 	}
 
-	config := nebula.NewConfig(l)
-	config.LoadString(string(cb))
+	c := config.NewC(l)
+	c.LoadString(string(cb))
 
-	control, err := nebula.Main(config, false, "e2e-test", l, nil)
+	control, err := nebula.Main(c, false, "e2e-test", l, nil)
 
 	if err != nil {
 		panic(err)
@@ -200,19 +201,6 @@ func x25519Keypair() ([]byte, []byte) {
 	return pubkey, privkey
 }
 
-func ip2int(ip []byte) uint32 {
-	if len(ip) == 16 {
-		return binary.BigEndian.Uint32(ip[12:16])
-	}
-	return binary.BigEndian.Uint32(ip)
-}
-
-func int2ip(nn uint32) net.IP {
-	ip := make(net.IP, 4)
-	binary.BigEndian.PutUint32(ip, nn)
-	return ip
-}
-
 type doneCb func()
 
 func deadline(t *testing.T, seconds time.Duration) doneCb {
@@ -245,15 +233,15 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul
 
 func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
 	// Get both host infos
-	hBinA := controlA.GetHostInfoByVpnIP(ip2int(vpnIpB), false)
-	assert.NotNil(t, hBinA, "Host B was not found by vpnIP in controlA")
+	hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false)
+	assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
 
-	hAinB := controlB.GetHostInfoByVpnIP(ip2int(vpnIpA), false)
-	assert.NotNil(t, hAinB, "Host A was not found by vpnIP in controlB")
+	hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false)
+	assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
 
 	// Check that both vpn and real addr are correct
-	assert.Equal(t, vpnIpB, hBinA.VpnIP, "Host B VpnIp is wrong in control A")
-	assert.Equal(t, vpnIpA, hAinB.VpnIP, "Host A VpnIp is wrong in control B")
+	assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
+	assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
 
 	assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
 	assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")

+ 12 - 10
e2e/router/router.go

@@ -11,6 +11,8 @@ import (
 	"sync"
 
 	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/udp"
 )
 
 type R struct {
@@ -41,7 +43,7 @@ const (
 	RouteAndExit ExitType = 2
 )
 
-type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType
+type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType
 
 func NewR(controls ...*nebula.Control) *R {
 	r := &R{
@@ -79,7 +81,7 @@ func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
 // OnceFrom will route a single packet from sender then return
 // If the router doesn't have the nebula controller for that address, we panic
 func (r *R) OnceFrom(sender *nebula.Control) {
-	r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType {
+	r.RouteExitFunc(sender, func(*udp.Packet, *nebula.Control) ExitType {
 		return RouteAndExit
 	})
 }
@@ -119,7 +121,7 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
 //   - routeAndExit: this call will return immediately after routing the last packet from sender
 //   - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
 func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
-	h := &nebula.Header{}
+	h := &header.H{}
 	for {
 		p := sender.GetFromUDP(true)
 		r.Lock()
@@ -159,9 +161,9 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
 
 // RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender
 // If the router doesn't have the nebula controller for that address, we panic
-func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) {
-	h := &nebula.Header{}
-	r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
+func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType header.MessageType, subType header.MessageSubType) {
+	h := &header.H{}
+	r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
 		if err := h.Parse(p.Data); err != nil {
 			panic(err)
 		}
@@ -181,7 +183,7 @@ func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr
 		finish = RouteAndExit
 	}
 
-	r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
+	r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
 		if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
 			return finish
 		}
@@ -215,7 +217,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
 		x, rx, _ := reflect.Select(sc)
 		r.Lock()
 
-		p := rx.Interface().(*nebula.UdpPacket)
+		p := rx.Interface().(*udp.Packet)
 
 		outAddr := cm[x].GetUDPAddr()
 		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
@@ -277,7 +279,7 @@ func (r *R) FlushAll() {
 		}
 		r.Lock()
 
-		p := rx.Interface().(*nebula.UdpPacket)
+		p := rx.Interface().(*udp.Packet)
 
 		outAddr := cm[x].GetUDPAddr()
 		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
@@ -292,7 +294,7 @@ func (r *R) FlushAll() {
 
 // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
 // This is an internal router function, the caller must hold the lock
-func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control {
+func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control {
 	if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
 		p.FromIp = newAddr.IP
 		p.FromPort = uint16(newAddr.Port)

+ 45 - 147
firewall.go

@@ -4,7 +4,6 @@ import (
 	"crypto/sha256"
 	"encoding/binary"
 	"encoding/hex"
-	"encoding/json"
 	"errors"
 	"fmt"
 	"net"
@@ -12,22 +11,14 @@ import (
 	"strconv"
 	"strings"
 	"sync"
-	"sync/atomic"
 	"time"
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
-)
-
-const (
-	fwProtoAny  = 0 // When we want to handle HOPOPT (0) we can change this, if ever
-	fwProtoTCP  = 6
-	fwProtoUDP  = 17
-	fwProtoICMP = 1
-
-	fwPortAny      = 0  // Special value for matching `port: any`
-	fwPortFragment = -1 // Special value for matching `port: fragment`
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
 )
 
 const tcpACK = 0x10
@@ -63,7 +54,7 @@ type Firewall struct {
 	DefaultTimeout time.Duration //linux: 600s
 
 	// Used to ensure we don't emit local packets for ips we don't own
-	localIps *CIDRTree
+	localIps *cidr.Tree4
 
 	rules        string
 	rulesVersion uint16
@@ -85,7 +76,7 @@ type firewallMetrics struct {
 type FirewallConntrack struct {
 	sync.Mutex
 
-	Conns      map[FirewallPacket]*conn
+	Conns      map[firewall.Packet]*conn
 	TimerWheel *TimerWheel
 }
 
@@ -116,55 +107,13 @@ type FirewallRule struct {
 	Any    bool
 	Hosts  map[string]struct{}
 	Groups [][]string
-	CIDR   *CIDRTree
+	CIDR   *cidr.Tree4
 }
 
 // Even though ports are uint16, int32 maps are faster for lookup
 // Plus we can use `-1` for fragment rules
 type firewallPort map[int32]*FirewallCA
 
-type FirewallPacket struct {
-	LocalIP    uint32
-	RemoteIP   uint32
-	LocalPort  uint16
-	RemotePort uint16
-	Protocol   uint8
-	Fragment   bool
-}
-
-func (fp *FirewallPacket) Copy() *FirewallPacket {
-	return &FirewallPacket{
-		LocalIP:    fp.LocalIP,
-		RemoteIP:   fp.RemoteIP,
-		LocalPort:  fp.LocalPort,
-		RemotePort: fp.RemotePort,
-		Protocol:   fp.Protocol,
-		Fragment:   fp.Fragment,
-	}
-}
-
-func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
-	var proto string
-	switch fp.Protocol {
-	case fwProtoTCP:
-		proto = "tcp"
-	case fwProtoICMP:
-		proto = "icmp"
-	case fwProtoUDP:
-		proto = "udp"
-	default:
-		proto = fmt.Sprintf("unknown %v", fp.Protocol)
-	}
-	return json.Marshal(m{
-		"LocalIP":    int2ip(fp.LocalIP).String(),
-		"RemoteIP":   int2ip(fp.RemoteIP).String(),
-		"LocalPort":  fp.LocalPort,
-		"RemotePort": fp.RemotePort,
-		"Protocol":   proto,
-		"Fragment":   fp.Fragment,
-	})
-}
-
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
 func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
 	//TODO: error on 0 duration
@@ -184,7 +133,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 		max = defaultTimeout
 	}
 
-	localIps := NewCIDRTree()
+	localIps := cidr.NewTree4()
 	for _, ip := range c.Details.Ips {
 		localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
 	}
@@ -195,7 +144,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 
 	return &Firewall{
 		Conntrack: &FirewallConntrack{
-			Conns:      make(map[FirewallPacket]*conn),
+			Conns:      make(map[firewall.Packet]*conn),
 			TimerWheel: NewTimerWheel(min, max),
 		},
 		InRules:        newFirewallTable(),
@@ -220,7 +169,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 	}
 }
 
-func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
+func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) {
 	fw := NewFirewall(
 		l,
 		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
@@ -278,13 +227,13 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 	}
 
 	switch proto {
-	case fwProtoTCP:
+	case firewall.ProtoTCP:
 		fp = ft.TCP
-	case fwProtoUDP:
+	case firewall.ProtoUDP:
 		fp = ft.UDP
-	case fwProtoICMP:
+	case firewall.ProtoICMP:
 		fp = ft.ICMP
-	case fwProtoAny:
+	case firewall.ProtoAny:
 		fp = ft.AnyProto
 	default:
 		return fmt.Errorf("unknown protocol %v", proto)
@@ -299,7 +248,7 @@ func (f *Firewall) GetRuleHash() string {
 	return hex.EncodeToString(sum[:])
 }
 
-func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
+func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
 	var table string
 	if inbound {
 		table = "firewall.inbound"
@@ -307,7 +256,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
 		table = "firewall.outbound"
 	}
 
-	r := config.Get(table)
+	r := c.Get(table)
 	if r == nil {
 		return nil
 	}
@@ -362,13 +311,13 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config,
 		var proto uint8
 		switch r.Proto {
 		case "any":
-			proto = fwProtoAny
+			proto = firewall.ProtoAny
 		case "tcp":
-			proto = fwProtoTCP
+			proto = firewall.ProtoTCP
 		case "udp":
-			proto = fwProtoUDP
+			proto = firewall.ProtoUDP
 		case "icmp":
-			proto = fwProtoICMP
+			proto = firewall.ProtoICMP
 		default:
 			return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
 		}
@@ -396,7 +345,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
 
 // Drop returns an error if the packet should be dropped, explaining why. It
 // returns nil if the packet should not be dropped.
-func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error {
+func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
 	// Check if we spoke to this tuple, if we did then allow this packet
 	if f.inConns(packet, fp, incoming, h, caPool, localCache) {
 		return nil
@@ -410,7 +359,7 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *Host
 		}
 	} else {
 		// Simple case: Certificate has one IP and no subnets
-		if fp.RemoteIP != h.hostId {
+		if fp.RemoteIP != h.vpnIp {
 			f.metrics(incoming).droppedRemoteIP.Inc(1)
 			return ErrInvalidRemoteIP
 		}
@@ -462,7 +411,7 @@ func (f *Firewall) EmitStats() {
 	metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
 }
 
-func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool {
+func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
 	if localCache != nil {
 		if _, ok := localCache[fp]; ok {
 			return true
@@ -520,14 +469,14 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
 	}
 
 	switch fp.Protocol {
-	case fwProtoTCP:
+	case firewall.ProtoTCP:
 		c.Expires = time.Now().Add(f.TCPTimeout)
 		if incoming {
 			f.checkTCPRTT(c, packet)
 		} else {
 			setTCPRTTTracking(c, packet)
 		}
-	case fwProtoUDP:
+	case firewall.ProtoUDP:
 		c.Expires = time.Now().Add(f.UDPTimeout)
 	default:
 		c.Expires = time.Now().Add(f.DefaultTimeout)
@@ -542,17 +491,17 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
 	return true
 }
 
-func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
+func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
 	var timeout time.Duration
 	c := &conn{}
 
 	switch fp.Protocol {
-	case fwProtoTCP:
+	case firewall.ProtoTCP:
 		timeout = f.TCPTimeout
 		if !incoming {
 			setTCPRTTTracking(c, packet)
 		}
-	case fwProtoUDP:
+	case firewall.ProtoUDP:
 		timeout = f.UDPTimeout
 	default:
 		timeout = f.DefaultTimeout
@@ -575,7 +524,7 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
 
 // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
 // Caller must own the connMutex lock!
-func (f *Firewall) evict(p FirewallPacket) {
+func (f *Firewall) evict(p firewall.Packet) {
 	//TODO: report a stat if the tcp rtt tracking was never resolved?
 	// Are we still tracking this conn?
 	conntrack := f.Conntrack
@@ -596,21 +545,21 @@ func (f *Firewall) evict(p FirewallPacket) {
 	delete(conntrack.Conns, p)
 }
 
-func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
 	if ft.AnyProto.match(p, incoming, c, caPool) {
 		return true
 	}
 
 	switch p.Protocol {
-	case fwProtoTCP:
+	case firewall.ProtoTCP:
 		if ft.TCP.match(p, incoming, c, caPool) {
 			return true
 		}
-	case fwProtoUDP:
+	case firewall.ProtoUDP:
 		if ft.UDP.match(p, incoming, c, caPool) {
 			return true
 		}
-	case fwProtoICMP:
+	case firewall.ProtoICMP:
 		if ft.ICMP.match(p, incoming, c, caPool) {
 			return true
 		}
@@ -640,7 +589,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
 	return nil
 }
 
-func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
 	// We don't have any allowed ports, bail
 	if fp == nil {
 		return false
@@ -649,7 +598,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
 	var port int32
 
 	if p.Fragment {
-		port = fwPortFragment
+		port = firewall.PortFragment
 	} else if incoming {
 		port = int32(p.LocalPort)
 	} else {
@@ -660,7 +609,7 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
 		return true
 	}
 
-	return fp[fwPortAny].match(p, c, caPool)
+	return fp[firewall.PortAny].match(p, c, caPool)
 }
 
 func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
@@ -668,7 +617,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
 		return &FirewallRule{
 			Hosts:  make(map[string]struct{}),
 			Groups: make([][]string, 0),
-			CIDR:   NewCIDRTree(),
+			CIDR:   cidr.NewTree4(),
 		}
 	}
 
@@ -703,7 +652,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
 	return nil
 }
 
-func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
 	if fc == nil {
 		return false
 	}
@@ -736,7 +685,7 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
 		// If it's any we need to wipe out any pre-existing rules to save on memory
 		fr.Groups = make([][]string, 0)
 		fr.Hosts = make(map[string]struct{})
-		fr.CIDR = NewCIDRTree()
+		fr.CIDR = cidr.NewTree4()
 	} else {
 		if len(groups) > 0 {
 			fr.Groups = append(fr.Groups, groups)
@@ -776,7 +725,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
 	return false
 }
 
-func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool {
+func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
 	if fr == nil {
 		return false
 	}
@@ -885,12 +834,12 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
 
 func parsePort(s string) (startPort, endPort int32, err error) {
 	if s == "any" {
-		startPort = fwPortAny
-		endPort = fwPortAny
+		startPort = firewall.PortAny
+		endPort = firewall.PortAny
 
 	} else if s == "fragment" {
-		startPort = fwPortFragment
-		endPort = fwPortFragment
+		startPort = firewall.PortFragment
+		endPort = firewall.PortFragment
 
 	} else if strings.Contains(s, `-`) {
 		sPorts := strings.SplitN(s, `-`, 2)
@@ -914,8 +863,8 @@ func parsePort(s string) (startPort, endPort int32, err error) {
 		startPort = int32(rStartPort)
 		endPort = int32(rEndPort)
 
-		if startPort == fwPortAny {
-			endPort = fwPortAny
+		if startPort == firewall.PortAny {
+			endPort = firewall.PortAny
 		}
 
 	} else {
@@ -968,54 +917,3 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
 	c.Seq = 0
 	return true
 }
-
-// ConntrackCache is used as a local routine cache to know if a given flow
-// has been seen in the conntrack table.
-type ConntrackCache map[FirewallPacket]struct{}
-
-type ConntrackCacheTicker struct {
-	cacheV    uint64
-	cacheTick uint64
-
-	cache ConntrackCache
-}
-
-func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
-	if d == 0 {
-		return nil
-	}
-
-	c := &ConntrackCacheTicker{
-		cache: ConntrackCache{},
-	}
-
-	go c.tick(d)
-
-	return c
-}
-
-func (c *ConntrackCacheTicker) tick(d time.Duration) {
-	for {
-		time.Sleep(d)
-		atomic.AddUint64(&c.cacheTick, 1)
-	}
-}
-
-// Get checks if the cache ticker has moved to the next version before returning
-// the map. If it has moved, we reset the map.
-func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
-	if c == nil {
-		return nil
-	}
-	if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
-		c.cacheV = tick
-		if ll := len(c.cache); ll > 0 {
-			if l.Level == logrus.DebugLevel {
-				l.WithField("len", ll).Debug("resetting conntrack cache")
-			}
-			c.cache = make(ConntrackCache, ll)
-		}
-	}
-
-	return c.cache
-}

+ 59 - 0
firewall/cache.go

@@ -0,0 +1,59 @@
+package firewall
+
+import (
+	"sync/atomic"
+	"time"
+
+	"github.com/sirupsen/logrus"
+)
+
+// ConntrackCache is used as a local routine cache to know if a given flow
+// has been seen in the conntrack table.
+type ConntrackCache map[Packet]struct{}
+
+type ConntrackCacheTicker struct {
+	cacheV    uint64
+	cacheTick uint64
+
+	cache ConntrackCache
+}
+
+func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
+	if d == 0 {
+		return nil
+	}
+
+	c := &ConntrackCacheTicker{
+		cache: ConntrackCache{},
+	}
+
+	go c.tick(d)
+
+	return c
+}
+
+func (c *ConntrackCacheTicker) tick(d time.Duration) {
+	for {
+		time.Sleep(d)
+		atomic.AddUint64(&c.cacheTick, 1)
+	}
+}
+
+// Get checks if the cache ticker has moved to the next version before returning
+// the map. If it has moved, we reset the map.
+func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
+	if c == nil {
+		return nil
+	}
+	if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
+		c.cacheV = tick
+		if ll := len(c.cache); ll > 0 {
+			if l.Level == logrus.DebugLevel {
+				l.WithField("len", ll).Debug("resetting conntrack cache")
+			}
+			c.cache = make(ConntrackCache, ll)
+		}
+	}
+
+	return c.cache
+}

+ 62 - 0
firewall/packet.go

@@ -0,0 +1,62 @@
+package firewall
+
+import (
+	"encoding/json"
+	"fmt"
+
+	"github.com/slackhq/nebula/iputil"
+)
+
+type m map[string]interface{}
+
+const (
+	ProtoAny  = 0 // When we want to handle HOPOPT (0) we can change this, if ever
+	ProtoTCP  = 6
+	ProtoUDP  = 17
+	ProtoICMP = 1
+
+	PortAny      = 0  // Special value for matching `port: any`
+	PortFragment = -1 // Special value for matching `port: fragment`
+)
+
+type Packet struct {
+	LocalIP    iputil.VpnIp
+	RemoteIP   iputil.VpnIp
+	LocalPort  uint16
+	RemotePort uint16
+	Protocol   uint8
+	Fragment   bool
+}
+
+func (fp *Packet) Copy() *Packet {
+	return &Packet{
+		LocalIP:    fp.LocalIP,
+		RemoteIP:   fp.RemoteIP,
+		LocalPort:  fp.LocalPort,
+		RemotePort: fp.RemotePort,
+		Protocol:   fp.Protocol,
+		Fragment:   fp.Fragment,
+	}
+}
+
+func (fp Packet) MarshalJSON() ([]byte, error) {
+	var proto string
+	switch fp.Protocol {
+	case ProtoTCP:
+		proto = "tcp"
+	case ProtoICMP:
+		proto = "icmp"
+	case ProtoUDP:
+		proto = "udp"
+	default:
+		proto = fmt.Sprintf("unknown %v", fp.Protocol)
+	}
+	return json.Marshal(m{
+		"LocalIP":    fp.LocalIP.String(),
+		"RemoteIP":   fp.RemoteIP.String(),
+		"LocalPort":  fp.LocalPort,
+		"RemotePort": fp.RemotePort,
+		"Protocol":   proto,
+		"Fragment":   fp.Fragment,
+	})
+}

+ 101 - 109
firewall_test.go

@@ -11,11 +11,15 @@ import (
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestNewFirewall(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	c := &cert.NebulaCertificate{}
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	conntrack := fw.Conntrack
@@ -54,7 +58,7 @@ func TestNewFirewall(t *testing.T) {
 }
 
 func TestFirewall_AddRule(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
@@ -65,92 +69,80 @@ func TestFirewall_AddRule(t *testing.T) {
 
 	_, ti, _ := net.ParseCIDR("1.2.3.4/32")
 
-	assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", ""))
 	// An empty rule is any
 	assert.True(t, fw.InRules.TCP[1].Any.Any)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
-	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
-	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
-	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
 	assert.False(t, fw.InRules.UDP[1].Any.Any)
 	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
 	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
-	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left)
-	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
-	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
 	assert.False(t, fw.InRules.ICMP[1].Any.Any)
 	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
 	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
-	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left)
-	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
-	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", ""))
 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
-	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
+	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
 	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
 
 	// Set any and clear fields
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
 	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
 	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
-	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP)))
+	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
 
 	// run twice just to make sure
 	//TODO: these ANY rules should clear the CA firewall portion
-	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
-	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
-	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left)
-	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
-	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
-	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
 	// Test error conditions
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
-	assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
+	assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", ""))
 }
 
 func TestFirewall_Drop(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
-	p := FirewallPacket{
-		ip2int(net.IPv4(1, 2, 3, 4)),
-		ip2int(net.IPv4(1, 2, 3, 4)),
+	p := firewall.Packet{
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
 		10,
 		90,
-		fwProtoUDP,
+		firewall.ProtoUDP,
 		false,
 	}
 
@@ -172,12 +164,12 @@ func TestFirewall_Drop(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 		},
-		hostId: ip2int(ipNet.IP),
+		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
 	}
 	h.CreateRemoteCIDR(&c)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
 	cp := cert.NewCAPool()
 
 	// Drop outbound
@@ -190,34 +182,34 @@ func TestFirewall_Drop(t *testing.T) {
 
 	// test remote mismatch
 	oldRemote := p.RemoteIP
-	p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
+	p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
 	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
 	p.RemoteIP = oldRemote
 
 	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caSha doesn't drop on match
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 
 	// ensure ca name doesn't get in the way of group checks
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caName doesn't drop on match
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
 	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 }
 
@@ -237,14 +229,14 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	b.Run("fail on proto", func(b *testing.B) {
 		c := &cert.NebulaCertificate{}
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoUDP}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)
 		}
 	})
 
 	b.Run("fail on port", func(b *testing.B) {
 		c := &cert.NebulaCertificate{}
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 1}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)
 		}
 	})
 
@@ -258,7 +250,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 		}
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
 		}
 	})
 
@@ -270,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 		}
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
 		}
 	})
 
@@ -282,12 +274,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 		}
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
 		}
 	})
 
 	b.Run("pass on ip", func(b *testing.B) {
-		ip := ip2int(net.IPv4(172, 1, 1, 1))
+		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
 		c := &cert.NebulaCertificate{
 			Details: cert.NebulaCertificateDetails{
 				InvertedGroups: map[string]struct{}{"nope": {}},
@@ -295,14 +287,14 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 		}
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
 		}
 	})
 
 	_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
 
 	b.Run("pass on ip with any port", func(b *testing.B) {
-		ip := ip2int(net.IPv4(172, 1, 1, 1))
+		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
 		c := &cert.NebulaCertificate{
 			Details: cert.NebulaCertificateDetails{
 				InvertedGroups: map[string]struct{}{"nope": {}},
@@ -310,22 +302,22 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 		}
 		for n := 0; n < b.N; n++ {
-			ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
 		}
 	})
 }
 
 func TestFirewall_Drop2(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
-	p := FirewallPacket{
-		ip2int(net.IPv4(1, 2, 3, 4)),
-		ip2int(net.IPv4(1, 2, 3, 4)),
+	p := firewall.Packet{
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
 		10,
 		90,
-		fwProtoUDP,
+		firewall.ProtoUDP,
 		false,
 	}
 
@@ -345,7 +337,7 @@ func TestFirewall_Drop2(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 		},
-		hostId: ip2int(ipNet.IP),
+		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
 	}
 	h.CreateRemoteCIDR(&c)
 
@@ -364,7 +356,7 @@ func TestFirewall_Drop2(t *testing.T) {
 	h1.CreateRemoteCIDR(&c1)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
 	cp := cert.NewCAPool()
 
 	// h1/c1 lacks the proper groups
@@ -375,16 +367,16 @@ func TestFirewall_Drop2(t *testing.T) {
 }
 
 func TestFirewall_Drop3(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
-	p := FirewallPacket{
-		ip2int(net.IPv4(1, 2, 3, 4)),
-		ip2int(net.IPv4(1, 2, 3, 4)),
+	p := firewall.Packet{
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
 		1,
 		1,
-		fwProtoUDP,
+		firewall.ProtoUDP,
 		false,
 	}
 
@@ -411,7 +403,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c1,
 		},
-		hostId: ip2int(ipNet.IP),
+		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
 	}
 	h1.CreateRemoteCIDR(&c1)
 
@@ -426,7 +418,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c2,
 		},
-		hostId: ip2int(ipNet.IP),
+		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
 	}
 	h2.CreateRemoteCIDR(&c2)
 
@@ -441,13 +433,13 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c3,
 		},
-		hostId: ip2int(ipNet.IP),
+		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
 	}
 	h3.CreateRemoteCIDR(&c3)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
 	cp := cert.NewCAPool()
 
 	// c1 should pass because host match
@@ -461,16 +453,16 @@ func TestFirewall_Drop3(t *testing.T) {
 }
 
 func TestFirewall_DropConntrackReload(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
-	p := FirewallPacket{
-		ip2int(net.IPv4(1, 2, 3, 4)),
-		ip2int(net.IPv4(1, 2, 3, 4)),
+	p := firewall.Packet{
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
 		10,
 		90,
-		fwProtoUDP,
+		firewall.ProtoUDP,
 		false,
 	}
 
@@ -492,12 +484,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 		},
-		hostId: ip2int(ipNet.IP),
+		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
 	}
 	h.CreateRemoteCIDR(&c)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
 	cp := cert.NewCAPool()
 
 	// Drop outbound
@@ -510,7 +502,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 	oldFw := fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
@@ -519,7 +511,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 	oldFw = fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
@@ -643,28 +635,28 @@ func Test_parsePort(t *testing.T) {
 }
 
 func TestNewFirewallFromConfig(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	// Test a bad rule definition
 	c := &cert.NebulaCertificate{}
-	conf := NewConfig(l)
+	conf := config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
 	_, err := NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
 
 	// Test both port and code
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
 
 	// Test missing host, group, cidr, ca_name and ca_sha
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
 
 	// Test code/port error
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
@@ -674,91 +666,91 @@ func TestNewFirewallFromConfig(t *testing.T) {
 	assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
 
 	// Test proto error
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
 
 	// Test cidr parse error
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
 
 	// Test both group and groups
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
 }
 
 func TestAddFirewallRulesFromConfig(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	// Test adding tcp rule
-	conf := NewConfig(l)
+	conf := config.NewC(l)
 	mf := &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
 
 	// Test adding udp rule
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
 
 	// Test adding icmp rule
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
 
 	// Test adding any rule
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
 
 	// Test adding rule with ca_sha
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
 
 	// Test adding rule with ca_name
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
 
 	// Test single group
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
 
 	// Test single groups
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
 
 	// Test multiple AND groups
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
 
 	// Test Add error
-	conf = NewConfig(l)
+	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf.nextCallReturn = errors.New("test error")
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
@@ -857,7 +849,7 @@ func TestTCPRTTTracking(t *testing.T) {
 }
 
 func TestFirewall_convertRule(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
@@ -929,6 +921,6 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
 
 func resetConntrack(fw *Firewall) {
 	fw.Conntrack.Lock()
-	fw.Conntrack.Conns = map[FirewallPacket]*conn{}
+	fw.Conntrack.Conns = map[firewall.Packet]*conn{}
 	fw.Conntrack.Unlock()
 }

+ 5 - 5
handshake.go

@@ -1,11 +1,11 @@
 package nebula
 
-const (
-	handshakeIXPSK0 = 0
-	handshakeXXPSK0 = 1
+import (
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/udp"
 )
 
-func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
+func HandleIncomingHandshake(f *Interface, addr *udp.Addr, packet []byte, h *header.H, hostinfo *HostInfo) {
 	// First remote allow list check before we know the vpnIp
 	if !f.lightHouse.remoteAllowList.AllowUnknownVpnIp(addr.IP) {
 		f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
@@ -13,7 +13,7 @@ func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Head
 	}
 
 	switch h.Subtype {
-	case handshakeIXPSK0:
+	case header.HandshakeIXPSK0:
 		switch h.MessageCounter {
 		case 1:
 			ixHandshakeStage1(f, addr, packet, h)

+ 59 - 56
handshake_ix.go

@@ -6,13 +6,16 @@ import (
 
 	"github.com/flynn/noise"
 	"github.com/golang/protobuf/proto"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 
 // NOISE IX Handshakes
 
 // This function constructs a handshake packet, but does not actually send it
 // Sending is done by the handshake manager
-func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
+func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
 	// This queries the lighthouse if we don't know a remote for the host
 	// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
 	// more quickly, effect is a quicker handshake.
@@ -22,7 +25,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 
 	err := f.handshakeManager.AddIndexHostInfo(hostinfo)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
+		f.l.WithError(err).WithField("vpnIp", vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
 		return
 	}
@@ -43,17 +46,17 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 	hsBytes, err = proto.Marshal(hs)
 
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
+		f.l.WithError(err).WithField("vpnIp", vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 		return
 	}
 
-	header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, 0, 1)
+	h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
 	atomic.AddUint64(&ci.atomicMessageCounter, 1)
 
-	msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
+	msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
+		f.l.WithError(err).WithField("vpnIp", vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return
 	}
@@ -67,12 +70,12 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 	hostinfo.handshakeStart = time.Now()
 }
 
-func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
+func ixHandshakeStage1(f *Interface, addr *udp.Addr, packet []byte, h *header.H) {
 	ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
 	// Mark packet 1 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 1)
 
-	msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
+	msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
@@ -97,13 +100,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 			Info("Invalid certificate from host")
 		return
 	}
-	vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
+	vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
 	certName := remoteCert.Details.Name
 	fingerprint, _ := remoteCert.Sha256Sum()
 	issuer := remoteCert.Details.Issuer
 
-	if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) {
-		f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+	if vpnIp == f.myVpnIp {
+		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -111,14 +114,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		return
 	}
 
-	if !f.lightHouse.remoteAllowList.Allow(vpnIP, addr.IP) {
-		f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+	if !f.lightHouse.remoteAllowList.Allow(vpnIp, addr.IP) {
+		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 		return
 	}
 
 	myIndex, err := generateIndex(f.l)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -130,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		ConnectionState:   ci,
 		localIndexId:      myIndex,
 		remoteIndexId:     hs.Details.InitiatorIndex,
-		hostId:            vpnIP,
+		vpnIp:             vpnIp,
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		lastHandshakeTime: hs.Details.Time,
 	}
@@ -138,7 +141,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	hostinfo.Lock()
 	defer hostinfo.Unlock()
 
-	f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
@@ -153,7 +156,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 
 	hsBytes, err := proto.Marshal(hs)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -161,17 +164,17 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		return
 	}
 
-	header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
-	msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
+	nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
+	msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -179,8 +182,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		return
 	}
 
-	hostinfo.HandshakePacket[0] = make([]byte, len(packet[HeaderLen:]))
-	copy(hostinfo.HandshakePacket[0], packet[HeaderLen:])
+	hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:]))
+	copy(hostinfo.HandshakePacket[0], packet[header.Len:])
 
 	// Regardless of whether you are the sender or receiver, you should arrive here
 	// and complete standing up the connection.
@@ -195,12 +198,12 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	ci.dKey = NewNebulaCipherState(dKey)
 	ci.eKey = NewNebulaCipherState(eKey)
 
-	hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
+	hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
 	hostinfo.SetRemote(addr)
 	hostinfo.CreateRemoteCIDR(remoteCert)
 
 	// Only overwrite existing record if we should win the handshake race
-	overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP)
+	overwrite := vpnIp > f.myVpnIp
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
 	if err != nil {
 		switch err {
@@ -214,27 +217,27 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
 				// Send a test packet to ensure the other side has also switched to
 				// the preferred remote
-				f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+				f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			}
 			existing.Unlock()
 			hostinfo.Lock()
 
 			msg = existing.HandshakePacket[2]
-			f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
+			f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
 			err := f.outside.WriteTo(msg, addr)
 			if err != nil {
-				f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
+				f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					WithError(err).Error("Failed to send handshake message")
 			} else {
-				f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
+				f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					Info("Handshake message sent")
 			}
 			return
 		case ErrExistingHostInfo:
 			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
-			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("oldHandshakeTime", existing.lastHandshakeTime).
 				WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
@@ -245,22 +248,22 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 				Info("Handshake too old")
 
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-			f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			return
 		case ErrLocalIndexCollision:
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
-			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-				WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)).
+				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
 				Error("Failed to add HostInfo due to localIndex collision")
 			return
 		case ErrExistingHandshake:
 			// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
-			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
@@ -271,7 +274,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
-			f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+			f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
@@ -283,10 +286,10 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	}
 
 	// Do the send
-	f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
+	f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
 	err = f.outside.WriteTo(msg, addr)
 	if err != nil {
-		f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -294,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithError(err).Error("Failed to send handshake")
 	} else {
-		f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -309,7 +312,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	return
 }
 
-func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet []byte, h *Header) bool {
+func ixHandshakeStage2(f *Interface, addr *udp.Addr, hostinfo *HostInfo, packet []byte, h *header.H) bool {
 	if hostinfo == nil {
 		// Nothing here to tear down, got a bogus stage 2 packet
 		return true
@@ -318,14 +321,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 	hostinfo.Lock()
 	defer hostinfo.Unlock()
 
-	if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) {
-		f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+	if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) {
+		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 		return false
 	}
 
 	ci := hostinfo.ConnectionState
 	if ci.ready {
-		f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Info("Handshake is already complete")
 
@@ -333,16 +336,16 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) {
 			// Send a test packet to ensure the other side has also switched to
 			// the preferred remote
-			f.SendMessageToVpnIp(test, testRequest, hostinfo.hostId, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			f.SendMessageToVpnIp(header.Test, header.TestRequest, hostinfo.vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 		}
 
 		// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
 		return false
 	}
 
-	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
+	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Error("Failed to call noise.ReadMessage")
 
@@ -351,7 +354,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		// near future
 		return false
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Error("Noise did not arrive at a key")
 
@@ -363,7 +366,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 	hs := &NebulaHandshake{}
 	err = proto.Unmarshal(msg, hs)
 	if err != nil || hs.Details == nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 
 		// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
@@ -372,7 +375,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 
 	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Error("Invalid certificate from host")
 
@@ -380,14 +383,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		return true
 	}
 
-	vpnIP := ip2int(remoteCert.Details.Ips[0].IP)
+	vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
 	certName := remoteCert.Details.Name
 	fingerprint, _ := remoteCert.Sha256Sum()
 	issuer := remoteCert.Details.Issuer
 
 	// Ensure the right host responded
-	if vpnIP != hostinfo.hostId {
-		f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)).
+	if vpnIp != hostinfo.vpnIp {
+		f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
 			WithField("udpAddr", addr).WithField("certName", certName).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Info("Incorrect host responded to handshake")
@@ -397,7 +400,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 
 		// Create a new hostinfo/handshake for the intended vpn ip
 		//TODO: this adds it to the timer wheel in a way that aggressively retries
-		newHostInfo := f.getOrHandshake(hostinfo.hostId)
+		newHostInfo := f.getOrHandshake(hostinfo.vpnIp)
 		newHostInfo.Lock()
 
 		// Block the current used address
@@ -405,9 +408,9 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		newHostInfo.remotes.BlockRemote(addr)
 
 		// Get the correct remote list for the host we did handshake with
-		hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
+		hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
 
-		f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", IntIp(vpnIP)).
+		f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
 			WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
 			Info("Blocked addresses for handshakes")
 
@@ -418,7 +421,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		hostinfo.ConnectionState.queueLock.Unlock()
 
 		// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
-		hostinfo.hostId = vpnIP
+		hostinfo.vpnIp = vpnIp
 		f.sendCloseTunnel(hostinfo)
 		newHostInfo.Unlock()
 
@@ -429,7 +432,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 	ci.window.Update(f.l, 2)
 
 	duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
-	f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).

+ 36 - 33
handshake_manager.go

@@ -11,6 +11,9 @@ import (
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 
 const (
@@ -39,7 +42,7 @@ type HandshakeManager struct {
 	pendingHostMap         *HostMap
 	mainHostMap            *HostMap
 	lightHouse             *LightHouse
-	outside                *udpConn
+	outside                *udp.Conn
 	config                 HandshakeConfig
 	OutboundHandshakeTimer *SystemTimerWheel
 	messageMetrics         *MessageMetrics
@@ -47,18 +50,18 @@ type HandshakeManager struct {
 	metricTimedOut         metrics.Counter
 	l                      *logrus.Logger
 
-	// can be used to trigger outbound handshake for the given vpnIP
-	trigger chan uint32
+	// can be used to trigger outbound handshake for the given vpnIp
+	trigger chan iputil.VpnIp
 }
 
-func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
+func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udp.Conn, config HandshakeConfig) *HandshakeManager {
 	return &HandshakeManager{
 		pendingHostMap:         NewHostMap(l, "pending", tunCidr, preferredRanges),
 		mainHostMap:            mainHostMap,
 		lightHouse:             lightHouse,
 		outside:                outside,
 		config:                 config,
-		trigger:                make(chan uint32, config.triggerBuffer),
+		trigger:                make(chan iputil.VpnIp, config.triggerBuffer),
 		OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
 		messageMetrics:         config.messageMetrics,
 		metricInitiated:        metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
@@ -67,7 +70,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
 	}
 }
 
-func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
+func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
 	clockSource := time.NewTicker(c.config.tryInterval)
 	defer clockSource.Stop()
 
@@ -76,7 +79,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
 		case <-ctx.Done():
 			return
 		case vpnIP := <-c.trigger:
-			c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
+			c.l.WithField("vpnIp", vpnIP).Debug("HandshakeManager: triggered")
 			c.handleOutbound(vpnIP, f, true)
 		case now := <-clockSource.C:
 			c.NextOutboundHandshakeTimerTick(now, f)
@@ -84,20 +87,20 @@ func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
 	}
 }
 
-func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
+func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) {
 	c.OutboundHandshakeTimer.advance(now)
 	for {
 		ep := c.OutboundHandshakeTimer.Purge()
 		if ep == nil {
 			break
 		}
-		vpnIP := ep.(uint32)
-		c.handleOutbound(vpnIP, f, false)
+		vpnIp := ep.(iputil.VpnIp)
+		c.handleOutbound(vpnIp, f, false)
 	}
 }
 
-func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseTriggered bool) {
-	hostinfo, err := c.pendingHostMap.QueryVpnIP(vpnIP)
+func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
+	hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 		return
 	}
@@ -115,7 +118,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
 	if !hostinfo.HandshakeReady {
 		// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
 		// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
-		c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
 		return
 	}
 
@@ -143,21 +146,21 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
 	// Get a remotes object if we don't already have one.
 	// This is mainly to protect us as this should never be the case
 	if hostinfo.remotes == nil {
-		hostinfo.remotes = c.lightHouse.QueryCache(vpnIP)
+		hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
 	}
 
 	//TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped)
 	if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 {
 		// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
-		// Our vpnIP here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
+		// Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
 		// the learned public ip for them. Query again to short circuit the promotion counter
-		c.lightHouse.QueryServer(vpnIP, f)
+		c.lightHouse.QueryServer(vpnIp, f)
 	}
 
 	// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
-	var sentTo []*udpAddr
-	hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udpAddr, _ bool) {
-		c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
+	var sentTo []*udp.Addr
+	hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
+		c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
 		err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
 		if err != nil {
 			hostinfo.logger(c.l).WithField("udpAddr", addr).
@@ -184,16 +187,16 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
 	// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
 	if !lighthouseTriggered {
 		//TODO: feel like we dupe handshake real fast in a tight loop, why?
-		c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
 	}
 }
 
-func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
-	hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
+func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+	hostinfo := c.pendingHostMap.AddVpnIp(vpnIp)
 	// We lock here and use an array to insert items to prevent locking the
 	// main receive thread for very long by waiting to add items to the pending map
 	//TODO: what lock?
-	c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
+	c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
 	c.metricInitiated.Inc(1)
 
 	return hostinfo
@@ -208,12 +211,12 @@ var (
 
 // CheckAndComplete checks for any conflicts in the main and pending hostmap
 // before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
-
+//
 // ErrAlreadySeen if we already have an entry in the hostmap that has seen the
 // exact same handshake packet
 //
 // ErrExistingHostInfo if we already have an entry in the hostmap for this
-// VpnIP and the new handshake was older than the one we currently have
+// VpnIp and the new handshake was older than the one we currently have
 //
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
@@ -224,7 +227,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 	defer c.mainHostMap.Unlock()
 
 	// Check if we already have a tunnel with this vpn ip
-	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
+	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
 	if found && existingHostInfo != nil {
 		// Is it just a delayed handshake packet?
 		if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
@@ -252,16 +255,16 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 	}
 
 	existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
-	if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
+	if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
 		hostinfo.logger(c.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
 			Info("New host shadows existing host remoteIndex")
 	}
 
 	// Check if we are also handshaking with this vpn ip
-	pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.hostId]
+	pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.vpnIp]
 	if found && pendingHostInfo != nil {
 		if !overwrite {
 			// We won, let our pending handshake win
@@ -278,7 +281,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 
 	if existingHostInfo != nil {
 		// We are going to overwrite this entry, so remove the old references
-		delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
+		delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
 		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
 		delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
 	}
@@ -296,10 +299,10 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 	c.mainHostMap.Lock()
 	defer c.mainHostMap.Unlock()
 
-	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
+	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
 	if found && existingHostInfo != nil {
 		// We are going to overwrite this entry, so remove the old references
-		delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
+		delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
 		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
 		delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
 	}
@@ -309,7 +312,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
 		hostinfo.logger(c.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
 			Info("New host shadows existing host remoteIndex")
 	}
 

+ 16 - 12
handshake_manager_test.go

@@ -5,25 +5,29 @@ import (
 	"testing"
 	"time"
 
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
-func Test_NewHandshakeManagerVpnIP(t *testing.T) {
-	l := NewTestLogger()
+func Test_NewHandshakeManagerVpnIp(t *testing.T) {
+	l := util.NewTestLogger()
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	ip := ip2int(net.ParseIP("172.1.1.2"))
+	ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
 
-	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
+	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udp.Conn{}, defaultHandshakeConfig)
 
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 
-	i := blah.AddVpnIP(ip)
+	i := blah.AddVpnIp(ip)
 	i.remotes = NewRemoteList()
 	i.HandshakeReady = true
 
@@ -50,24 +54,24 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
 }
 
 func Test_NewHandshakeManagerTrigger(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	ip := ip2int(net.ParseIP("172.1.1.2"))
+	ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
-	lh := &LightHouse{addrMap: make(map[uint32]*RemoteList), l: l}
+	lh := &LightHouse{addrMap: make(map[iputil.VpnIp]*RemoteList), l: l}
 
-	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
+	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.Conn{}, defaultHandshakeConfig)
 
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 
 	assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 
-	hi := blah.AddVpnIP(ip)
+	hi := blah.AddVpnIp(ip)
 	hi.HandshakeReady = true
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 	assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
@@ -80,7 +84,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
 	// Make sure the trigger doesn't double schedule the timer entry
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 
-	uaddr := NewUDPAddrFromString("10.1.1.1:4242")
+	uaddr := udp.NewAddrFromString("10.1.1.1:4242")
 	hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port)))
 
 	// We now have remotes but only the first trigger should have pushed things forward
@@ -103,6 +107,6 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
 type mockEncWriter struct {
 }
 
-func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
 	return
 }

+ 62 - 66
header.go → header/header.go

@@ -1,4 +1,4 @@
-package nebula
+package header
 
 import (
 	"encoding/binary"
@@ -19,82 +19,78 @@ import (
 // |-----------------------------------------------------------------------|
 // |                               payload...                              |
 
+type m map[string]interface{}
+
 const (
-	Version   uint8 = 1
-	HeaderLen       = 16
+	Version uint8 = 1
+	Len           = 16
 )
 
-type NebulaMessageType uint8
-type NebulaMessageSubType uint8
+type MessageType uint8
+type MessageSubType uint8
 
 const (
-	handshake   NebulaMessageType = 0
-	message     NebulaMessageType = 1
-	recvError   NebulaMessageType = 2
-	lightHouse  NebulaMessageType = 3
-	test        NebulaMessageType = 4
-	closeTunnel NebulaMessageType = 5
-
-	//TODO These are deprecated as of 06/12/2018 - NB
-	testRemote      NebulaMessageType = 6
-	testRemoteReply NebulaMessageType = 7
+	Handshake   MessageType = 0
+	Message     MessageType = 1
+	RecvError   MessageType = 2
+	LightHouse  MessageType = 3
+	Test        MessageType = 4
+	CloseTunnel MessageType = 5
 )
 
-var typeMap = map[NebulaMessageType]string{
-	handshake:   "handshake",
-	message:     "message",
-	recvError:   "recvError",
-	lightHouse:  "lightHouse",
-	test:        "test",
-	closeTunnel: "closeTunnel",
-
-	//TODO These are deprecated as of 06/12/2018 - NB
-	testRemote:      "testRemote",
-	testRemoteReply: "testRemoteReply",
+var typeMap = map[MessageType]string{
+	Handshake:   "handshake",
+	Message:     "message",
+	RecvError:   "recvError",
+	LightHouse:  "lightHouse",
+	Test:        "test",
+	CloseTunnel: "closeTunnel",
 }
 
 const (
-	testRequest NebulaMessageSubType = 0
-	testReply   NebulaMessageSubType = 1
+	TestRequest MessageSubType = 0
+	TestReply   MessageSubType = 1
+)
+
+const (
+	HandshakeIXPSK0 MessageSubType = 0
+	HandshakeXXPSK0 MessageSubType = 1
 )
 
-var eHeaderTooShort = errors.New("header is too short")
+var ErrHeaderTooShort = errors.New("header is too short")
 
-var subTypeTestMap = map[NebulaMessageSubType]string{
-	testRequest: "testRequest",
-	testReply:   "testReply",
+var subTypeTestMap = map[MessageSubType]string{
+	TestRequest: "testRequest",
+	TestReply:   "testReply",
 }
 
-var subTypeNoneMap = map[NebulaMessageSubType]string{0: "none"}
+var subTypeNoneMap = map[MessageSubType]string{0: "none"}
 
-var subTypeMap = map[NebulaMessageType]*map[NebulaMessageSubType]string{
-	message:     &subTypeNoneMap,
-	recvError:   &subTypeNoneMap,
-	lightHouse:  &subTypeNoneMap,
-	test:        &subTypeTestMap,
-	closeTunnel: &subTypeNoneMap,
-	handshake: {
-		handshakeIXPSK0: "ix_psk0",
+var subTypeMap = map[MessageType]*map[MessageSubType]string{
+	Message:     &subTypeNoneMap,
+	RecvError:   &subTypeNoneMap,
+	LightHouse:  &subTypeNoneMap,
+	Test:        &subTypeTestMap,
+	CloseTunnel: &subTypeNoneMap,
+	Handshake: {
+		HandshakeIXPSK0: "ix_psk0",
 	},
-	//TODO: these are deprecated
-	testRemote:      &subTypeNoneMap,
-	testRemoteReply: &subTypeNoneMap,
 }
 
-type Header struct {
+type H struct {
 	Version        uint8
-	Type           NebulaMessageType
-	Subtype        NebulaMessageSubType
+	Type           MessageType
+	Subtype        MessageSubType
 	Reserved       uint16
 	RemoteIndex    uint32
 	MessageCounter uint64
 }
 
-// HeaderEncode uses the provided byte array to encode the provided header values into.
+// Encode uses the provided byte array to encode the provided header values into.
 // Byte array must be capped higher than HeaderLen or this will panic
-func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []byte {
-	b = b[:HeaderLen]
-	b[0] = byte(v<<4 | (t & 0x0f))
+func Encode(b []byte, v uint8, t MessageType, st MessageSubType, ri uint32, c uint64) []byte {
+	b = b[:Len]
+	b[0] = v<<4 | byte(t&0x0f)
 	b[1] = byte(st)
 	binary.BigEndian.PutUint16(b[2:4], 0)
 	binary.BigEndian.PutUint32(b[4:8], ri)
@@ -103,7 +99,7 @@ func HeaderEncode(b []byte, v uint8, t uint8, st uint8, ri uint32, c uint64) []b
 }
 
 // String creates a readable string representation of a header
-func (h *Header) String() string {
+func (h *H) String() string {
 	if h == nil {
 		return "<nil>"
 	}
@@ -112,7 +108,7 @@ func (h *Header) String() string {
 }
 
 // MarshalJSON creates a json string representation of a header
-func (h *Header) MarshalJSON() ([]byte, error) {
+func (h *H) MarshalJSON() ([]byte, error) {
 	return json.Marshal(m{
 		"version":        h.Version,
 		"type":           h.TypeName(),
@@ -124,24 +120,24 @@ func (h *Header) MarshalJSON() ([]byte, error) {
 }
 
 // Encode turns header into bytes
-func (h *Header) Encode(b []byte) ([]byte, error) {
+func (h *H) Encode(b []byte) ([]byte, error) {
 	if h == nil {
 		return nil, errors.New("nil header")
 	}
 
-	return HeaderEncode(b, h.Version, uint8(h.Type), uint8(h.Subtype), h.RemoteIndex, h.MessageCounter), nil
+	return Encode(b, h.Version, h.Type, h.Subtype, h.RemoteIndex, h.MessageCounter), nil
 }
 
 // Parse is a helper function to parses given bytes into new Header struct
-func (h *Header) Parse(b []byte) error {
-	if len(b) < HeaderLen {
-		return eHeaderTooShort
+func (h *H) Parse(b []byte) error {
+	if len(b) < Len {
+		return ErrHeaderTooShort
 	}
 	// get upper 4 bytes
 	h.Version = uint8((b[0] >> 4) & 0x0f)
 	// get lower 4 bytes
-	h.Type = NebulaMessageType(b[0] & 0x0f)
-	h.Subtype = NebulaMessageSubType(b[1])
+	h.Type = MessageType(b[0] & 0x0f)
+	h.Subtype = MessageSubType(b[1])
 	h.Reserved = binary.BigEndian.Uint16(b[2:4])
 	h.RemoteIndex = binary.BigEndian.Uint32(b[4:8])
 	h.MessageCounter = binary.BigEndian.Uint64(b[8:16])
@@ -149,12 +145,12 @@ func (h *Header) Parse(b []byte) error {
 }
 
 // TypeName will transform the headers message type into a human string
-func (h *Header) TypeName() string {
+func (h *H) TypeName() string {
 	return TypeName(h.Type)
 }
 
 // TypeName will transform a nebula message type into a human string
-func TypeName(t NebulaMessageType) string {
+func TypeName(t MessageType) string {
 	if n, ok := typeMap[t]; ok {
 		return n
 	}
@@ -163,12 +159,12 @@ func TypeName(t NebulaMessageType) string {
 }
 
 // SubTypeName will transform the headers message sub type into a human string
-func (h *Header) SubTypeName() string {
+func (h *H) SubTypeName() string {
 	return SubTypeName(h.Type, h.Subtype)
 }
 
 // SubTypeName will transform a nebula message sub type into a human string
-func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string {
+func SubTypeName(t MessageType, s MessageSubType) string {
 	if n, ok := subTypeMap[t]; ok {
 		if x, ok := (*n)[s]; ok {
 			return x
@@ -179,8 +175,8 @@ func SubTypeName(t NebulaMessageType, s NebulaMessageSubType) string {
 }
 
 // NewHeader turns bytes into a header
-func NewHeader(b []byte) (*Header, error) {
-	h := new(Header)
+func NewHeader(b []byte) (*H, error) {
+	h := new(H)
 	if err := h.Parse(b); err != nil {
 		return nil, err
 	}

+ 115 - 0
header/header_test.go

@@ -0,0 +1,115 @@
+package header
+
+import (
+	"reflect"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+type headerTest struct {
+	expectedBytes []byte
+	*H
+}
+
+// 0001 0010 00010010
+var headerBigEndianTests = []headerTest{{
+	expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9},
+	// 1010 0000
+	H: &H{
+		// 1111 1+2+4+8 = 15
+		Version:        5,
+		Type:           4,
+		Subtype:        0,
+		Reserved:       0,
+		RemoteIndex:    10,
+		MessageCounter: 9,
+	},
+},
+}
+
+func TestEncode(t *testing.T) {
+	for _, tt := range headerBigEndianTests {
+		b, err := tt.Encode(make([]byte, Len))
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		assert.Equal(t, tt.expectedBytes, b)
+	}
+}
+
+func TestParse(t *testing.T) {
+	for _, tt := range headerBigEndianTests {
+		b := tt.expectedBytes
+		parsedHeader := &H{}
+		parsedHeader.Parse(b)
+
+		if !reflect.DeepEqual(tt.H, parsedHeader) {
+			t.Fatalf("got %#v; want %#v", parsedHeader, tt.H)
+		}
+	}
+}
+
+func TestTypeName(t *testing.T) {
+	assert.Equal(t, "test", TypeName(Test))
+	assert.Equal(t, "test", (&H{Type: Test}).TypeName())
+
+	assert.Equal(t, "unknown", TypeName(99))
+	assert.Equal(t, "unknown", (&H{Type: 99}).TypeName())
+}
+
+func TestSubTypeName(t *testing.T) {
+	assert.Equal(t, "testRequest", SubTypeName(Test, TestRequest))
+	assert.Equal(t, "testRequest", (&H{Type: Test, Subtype: TestRequest}).SubTypeName())
+
+	assert.Equal(t, "unknown", SubTypeName(99, TestRequest))
+	assert.Equal(t, "unknown", (&H{Type: 99, Subtype: TestRequest}).SubTypeName())
+
+	assert.Equal(t, "unknown", SubTypeName(Test, 99))
+	assert.Equal(t, "unknown", (&H{Type: Test, Subtype: 99}).SubTypeName())
+
+	assert.Equal(t, "none", SubTypeName(Message, 0))
+	assert.Equal(t, "none", (&H{Type: Message, Subtype: 0}).SubTypeName())
+}
+
+func TestTypeMap(t *testing.T) {
+	// Force people to document this stuff
+	assert.Equal(t, map[MessageType]string{
+		Handshake:   "handshake",
+		Message:     "message",
+		RecvError:   "recvError",
+		LightHouse:  "lightHouse",
+		Test:        "test",
+		CloseTunnel: "closeTunnel",
+	}, typeMap)
+
+	assert.Equal(t, map[MessageType]*map[MessageSubType]string{
+		Message:     &subTypeNoneMap,
+		RecvError:   &subTypeNoneMap,
+		LightHouse:  &subTypeNoneMap,
+		Test:        &subTypeTestMap,
+		CloseTunnel: &subTypeNoneMap,
+		Handshake: {
+			HandshakeIXPSK0: "ix_psk0",
+		},
+	}, subTypeMap)
+}
+
+func TestHeader_String(t *testing.T) {
+	assert.Equal(
+		t,
+		"ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97",
+		(&H{100, Test, TestRequest, 99, 98, 97}).String(),
+	)
+}
+
+func TestHeader_MarshalJSON(t *testing.T) {
+	b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
+	assert.Nil(t, err)
+	assert.Equal(
+		t,
+		"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
+		string(b),
+	)
+}

+ 0 - 119
header_test.go

@@ -1,119 +0,0 @@
-package nebula
-
-import (
-	"reflect"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-type headerTest struct {
-	expectedBytes []byte
-	*Header
-}
-
-// 0001 0010 00010010
-var headerBigEndianTests = []headerTest{{
-	expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9},
-	// 1010 0000
-	Header: &Header{
-		// 1111 1+2+4+8 = 15
-		Version:        5,
-		Type:           4,
-		Subtype:        0,
-		Reserved:       0,
-		RemoteIndex:    10,
-		MessageCounter: 9,
-	},
-},
-}
-
-func TestEncode(t *testing.T) {
-	for _, tt := range headerBigEndianTests {
-		b, err := tt.Encode(make([]byte, HeaderLen))
-		if err != nil {
-			t.Fatal(err)
-		}
-
-		assert.Equal(t, tt.expectedBytes, b)
-	}
-}
-
-func TestParse(t *testing.T) {
-	for _, tt := range headerBigEndianTests {
-		b := tt.expectedBytes
-		parsedHeader := &Header{}
-		parsedHeader.Parse(b)
-
-		if !reflect.DeepEqual(tt.Header, parsedHeader) {
-			t.Fatalf("got %#v; want %#v", parsedHeader, tt.Header)
-		}
-	}
-}
-
-func TestTypeName(t *testing.T) {
-	assert.Equal(t, "test", TypeName(test))
-	assert.Equal(t, "test", (&Header{Type: test}).TypeName())
-
-	assert.Equal(t, "unknown", TypeName(99))
-	assert.Equal(t, "unknown", (&Header{Type: 99}).TypeName())
-}
-
-func TestSubTypeName(t *testing.T) {
-	assert.Equal(t, "testRequest", SubTypeName(test, testRequest))
-	assert.Equal(t, "testRequest", (&Header{Type: test, Subtype: testRequest}).SubTypeName())
-
-	assert.Equal(t, "unknown", SubTypeName(99, testRequest))
-	assert.Equal(t, "unknown", (&Header{Type: 99, Subtype: testRequest}).SubTypeName())
-
-	assert.Equal(t, "unknown", SubTypeName(test, 99))
-	assert.Equal(t, "unknown", (&Header{Type: test, Subtype: 99}).SubTypeName())
-
-	assert.Equal(t, "none", SubTypeName(message, 0))
-	assert.Equal(t, "none", (&Header{Type: message, Subtype: 0}).SubTypeName())
-}
-
-func TestTypeMap(t *testing.T) {
-	// Force people to document this stuff
-	assert.Equal(t, map[NebulaMessageType]string{
-		handshake:       "handshake",
-		message:         "message",
-		recvError:       "recvError",
-		lightHouse:      "lightHouse",
-		test:            "test",
-		closeTunnel:     "closeTunnel",
-		testRemote:      "testRemote",
-		testRemoteReply: "testRemoteReply",
-	}, typeMap)
-
-	assert.Equal(t, map[NebulaMessageType]*map[NebulaMessageSubType]string{
-		message:     &subTypeNoneMap,
-		recvError:   &subTypeNoneMap,
-		lightHouse:  &subTypeNoneMap,
-		test:        &subTypeTestMap,
-		closeTunnel: &subTypeNoneMap,
-		handshake: {
-			handshakeIXPSK0: "ix_psk0",
-		},
-		testRemote:      &subTypeNoneMap,
-		testRemoteReply: &subTypeNoneMap,
-	}, subTypeMap)
-}
-
-func TestHeader_String(t *testing.T) {
-	assert.Equal(
-		t,
-		"ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97",
-		(&Header{100, test, testRequest, 99, 98, 97}).String(),
-	)
-}
-
-func TestHeader_MarshalJSON(t *testing.T) {
-	b, err := (&Header{100, test, testRequest, 99, 98, 97}).MarshalJSON()
-	assert.Nil(t, err)
-	assert.Equal(
-		t,
-		"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",
-		string(b),
-	)
-}

+ 64 - 93
hostmap.go

@@ -12,6 +12,10 @@ import (
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 
 //const ProbeLen = 100
@@ -28,10 +32,10 @@ type HostMap struct {
 	name            string
 	Indexes         map[uint32]*HostInfo
 	RemoteIndexes   map[uint32]*HostInfo
-	Hosts           map[uint32]*HostInfo
+	Hosts           map[iputil.VpnIp]*HostInfo
 	preferredRanges []*net.IPNet
 	vpnCIDR         *net.IPNet
-	unsafeRoutes    *CIDRTree
+	unsafeRoutes    *cidr.Tree4
 	metricsEnabled  bool
 	l               *logrus.Logger
 }
@@ -39,7 +43,7 @@ type HostMap struct {
 type HostInfo struct {
 	sync.RWMutex
 
-	remote            *udpAddr
+	remote            *udp.Addr
 	remotes           *RemoteList
 	promoteCounter    uint32
 	ConnectionState   *ConnectionState
@@ -51,9 +55,9 @@ type HostInfo struct {
 	packetStore       []*cachedPacket  //todo: this is other handshake manager entry
 	remoteIndexId     uint32
 	localIndexId      uint32
-	hostId            uint32
+	vpnIp             iputil.VpnIp
 	recvError         int
-	remoteCidr        *CIDRTree
+	remoteCidr        *cidr.Tree4
 
 	// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
 	// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
@@ -66,17 +70,17 @@ type HostInfo struct {
 	lastHandshakeTime uint64
 
 	lastRoam       time.Time
-	lastRoamRemote *udpAddr
+	lastRoamRemote *udp.Addr
 }
 
 type cachedPacket struct {
-	messageType    NebulaMessageType
-	messageSubType NebulaMessageSubType
+	messageType    header.MessageType
+	messageSubType header.MessageSubType
 	callback       packetCallback
 	packet         []byte
 }
 
-type packetCallback func(t NebulaMessageType, st NebulaMessageSubType, h *HostInfo, p, nb, out []byte)
+type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p, nb, out []byte)
 
 type cachedPacketMetrics struct {
 	sent    metrics.Counter
@@ -84,7 +88,7 @@ type cachedPacketMetrics struct {
 }
 
 func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
-	h := map[uint32]*HostInfo{}
+	h := map[iputil.VpnIp]*HostInfo{}
 	i := map[uint32]*HostInfo{}
 	r := map[uint32]*HostInfo{}
 	m := HostMap{
@@ -94,7 +98,7 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
 		Hosts:           h,
 		preferredRanges: preferredRanges,
 		vpnCIDR:         vpnCIDR,
-		unsafeRoutes:    NewCIDRTree(),
+		unsafeRoutes:    cidr.NewTree4(),
 		l:               l,
 	}
 	return &m
@@ -113,9 +117,9 @@ func (hm *HostMap) EmitStats(name string) {
 	metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
 }
 
-func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
+func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) {
 	hm.RLock()
-	if i, ok := hm.Hosts[vpnIP]; ok {
+	if i, ok := hm.Hosts[vpnIp]; ok {
 		index := i.localIndexId
 		hm.RUnlock()
 		return index, nil
@@ -124,43 +128,43 @@ func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
 	return 0, errors.New("vpn IP not found")
 }
 
-func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) {
+func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
 	hm.Lock()
 	hm.Hosts[ip] = hostinfo
 	hm.Unlock()
 }
 
-func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo {
+func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
 	h := &HostInfo{}
 	hm.RLock()
-	if _, ok := hm.Hosts[vpnIP]; !ok {
+	if _, ok := hm.Hosts[vpnIp]; !ok {
 		hm.RUnlock()
 		h = &HostInfo{
 			promoteCounter:  0,
-			hostId:          vpnIP,
+			vpnIp:           vpnIp,
 			HandshakePacket: make(map[uint8][]byte, 0),
 		}
 		hm.Lock()
-		hm.Hosts[vpnIP] = h
+		hm.Hosts[vpnIp] = h
 		hm.Unlock()
 		return h
 	} else {
-		h = hm.Hosts[vpnIP]
+		h = hm.Hosts[vpnIp]
 		hm.RUnlock()
 		return h
 	}
 }
 
-func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
+func (hm *HostMap) DeleteVpnIp(vpnIp iputil.VpnIp) {
 	hm.Lock()
-	delete(hm.Hosts, vpnIP)
+	delete(hm.Hosts, vpnIp)
 	if len(hm.Hosts) == 0 {
-		hm.Hosts = map[uint32]*HostInfo{}
+		hm.Hosts = map[iputil.VpnIp]*HostInfo{}
 	}
 	hm.Unlock()
 
 	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
+		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts)}).
 			Debug("Hostmap vpnIp deleted")
 	}
 }
@@ -174,22 +178,22 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
 
 	if hm.l.Level > logrus.DebugLevel {
 		hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
-			"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
+			"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}).
 			Debug("Hostmap remoteIndex added")
 	}
 }
 
-func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
+func (hm *HostMap) AddVpnIpHostInfo(vpnIp iputil.VpnIp, h *HostInfo) {
 	hm.Lock()
-	h.hostId = vpnIP
-	hm.Hosts[vpnIP] = h
+	h.vpnIp = vpnIp
+	hm.Hosts[vpnIp] = h
 	hm.Indexes[h.localIndexId] = h
 	hm.RemoteIndexes[h.remoteIndexId] = h
 	hm.Unlock()
 
 	if hm.l.Level > logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
-			"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
+		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": vpnIp, "mapTotalSize": len(hm.Hosts),
+			"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "vpnIp": h.vpnIp}}).
 			Debug("Hostmap vpnIp added")
 	}
 }
@@ -204,9 +208,9 @@ func (hm *HostMap) DeleteIndex(index uint32) {
 
 		// Check if we have an entry under hostId that matches the same hostinfo
 		// instance. Clean it up as well if we do.
-		hostinfo2, ok := hm.Hosts[hostinfo.hostId]
+		hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
 		if ok && hostinfo2 == hostinfo {
-			delete(hm.Hosts, hostinfo.hostId)
+			delete(hm.Hosts, hostinfo.vpnIp)
 		}
 	}
 	hm.Unlock()
@@ -228,9 +232,9 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
 		// Check if we have an entry under hostId that matches the same hostinfo
 		// instance. Clean it up as well if we do (they might not match in pendingHostmap)
 		var hostinfo2 *HostInfo
-		hostinfo2, ok = hm.Hosts[hostinfo.hostId]
+		hostinfo2, ok = hm.Hosts[hostinfo.vpnIp]
 		if ok && hostinfo2 == hostinfo {
-			delete(hm.Hosts, hostinfo.hostId)
+			delete(hm.Hosts, hostinfo.vpnIp)
 		}
 	}
 	hm.Unlock()
@@ -251,16 +255,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	// Check if this same hostId is in the hostmap with a different instance.
 	// This could happen if we have an entry in the pending hostmap with different
 	// index values than the one in the main hostmap.
-	hostinfo2, ok := hm.Hosts[hostinfo.hostId]
+	hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
 	if ok && hostinfo2 != hostinfo {
-		delete(hm.Hosts, hostinfo2.hostId)
+		delete(hm.Hosts, hostinfo2.vpnIp)
 		delete(hm.Indexes, hostinfo2.localIndexId)
 		delete(hm.RemoteIndexes, hostinfo2.remoteIndexId)
 	}
 
-	delete(hm.Hosts, hostinfo.hostId)
+	delete(hm.Hosts, hostinfo.vpnIp)
 	if len(hm.Hosts) == 0 {
-		hm.Hosts = map[uint32]*HostInfo{}
+		hm.Hosts = map[iputil.VpnIp]*HostInfo{}
 	}
 	delete(hm.Indexes, hostinfo.localIndexId)
 	if len(hm.Indexes) == 0 {
@@ -273,7 +277,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 
 	if hm.l.Level >= logrus.DebugLevel {
 		hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
-			"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Hostmap hostInfo deleted")
 	}
 }
@@ -301,17 +305,17 @@ func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
 	}
 }
 
-func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
-	return hm.queryVpnIP(vpnIp, nil)
+func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) {
+	return hm.queryVpnIp(vpnIp, nil)
 }
 
-// PromoteBestQueryVpnIP will attempt to lazily switch to the best remote every
+// PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
 // `PromoteEvery` calls to this function for a given host.
-func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostInfo, error) {
-	return hm.queryVpnIP(vpnIp, ifce)
+func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) {
+	return hm.queryVpnIp(vpnIp, ifce)
 }
 
-func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
+func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) {
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
 		hm.RUnlock()
@@ -327,10 +331,10 @@ func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo,
 	return nil, errors.New("unable to find host")
 }
 
-func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
+func (hm *HostMap) queryUnsafeRoute(ip iputil.VpnIp) iputil.VpnIp {
 	r := hm.unsafeRoutes.MostSpecificContains(ip)
 	if r != nil {
-		return r.(uint32)
+		return r.(iputil.VpnIp)
 	} else {
 		return 0
 	}
@@ -344,13 +348,13 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
 		dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
 	}
 
-	hm.Hosts[hostinfo.hostId] = hostinfo
+	hm.Hosts[hostinfo.vpnIp] = hostinfo
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 
 	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
-			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
+		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
+			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
 			Debug("Hostmap vpnIp added")
 	}
 }
@@ -370,7 +374,7 @@ func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
 }
 
 // Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
-func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
+func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) {
 	var metricsTxPunchy metrics.Counter
 	if hm.metricsEnabled {
 		metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
@@ -406,7 +410,7 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udpConn) {
 func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
 	for _, r := range *routes {
 		hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
-		hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
+		hm.unsafeRoutes.AddCIDR(r.route, iputil.Ip2VpnIp(*r.via))
 	}
 }
 
@@ -431,24 +435,24 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 			}
 		}
 
-		i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) {
+		i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) {
 			if addr == nil || !preferred {
 				return
 			}
 
 			// Try to send a test packet to that host, this should
 			// cause it to detect a roaming event and switch remotes
-			ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			ifce.send(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 		})
 	}
 
 	// Re query our lighthouses for new remotes occasionally
 	if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
-		ifce.lightHouse.QueryServer(i.hostId, ifce)
+		ifce.lightHouse.QueryServer(i.vpnIp, ifce)
 	}
 }
 
-func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
+func (i *HostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
 	//TODO: return the error so we can log with more context
 	if len(i.packetStore) < 100 {
 		tempPacket := make([]byte, len(packet))
@@ -510,17 +514,17 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
 	return nil
 }
 
-func (i *HostInfo) SetRemote(remote *udpAddr) {
+func (i *HostInfo) SetRemote(remote *udp.Addr) {
 	// We copy here because we likely got this remote from a source that reuses the object
 	if !i.remote.Equals(remote) {
 		i.remote = remote.Copy()
-		i.remotes.LearnRemote(i.hostId, remote.Copy())
+		i.remotes.LearnRemote(i.vpnIp, remote.Copy())
 	}
 }
 
 // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
 // time on the HostInfo will also be updated.
-func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udpAddr) bool {
+func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
 	currentRemote := i.remote
 	if currentRemote == nil {
 		i.SetRemote(newRemote)
@@ -572,7 +576,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
 		return
 	}
 
-	remoteCidr := NewCIDRTree()
+	remoteCidr := cidr.NewTree4()
 	for _, ip := range c.Details.Ips {
 		remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
 	}
@@ -588,8 +592,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 		return logrus.NewEntry(l)
 	}
 
-	li := l.WithField("vpnIp", IntIp(i.hostId))
-
+	li := l.WithField("vpnIp", i.vpnIp)
 	if connState := i.ConnectionState; connState != nil {
 		if peerCert := connState.peerCert; peerCert != nil {
 			li = li.WithField("certName", peerCert.Details.Name)
@@ -599,38 +602,6 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 	return li
 }
 
-//########################
-
-/*
-
-func (hm *HostMap) DebugRemotes(vpnIp uint32) string {
-	s := "\n"
-	for _, h := range hm.Hosts {
-		for _, r := range h.Remotes {
-			s += fmt.Sprintf("%s : %d ## %v\n", r.addr.IP.String(), r.addr.Port, r.probes)
-		}
-	}
-	return s
-}
-
-func (i *HostInfo) HandleReply(addr *net.UDPAddr, counter int) {
-	for _, r := range i.Remotes {
-		if r.addr.IP.Equal(addr.IP) && r.addr.Port == addr.Port {
-			r.ProbeReceived(counter)
-		}
-	}
-}
-
-func (i *HostInfo) Probes() []*Probe {
-	p := []*Probe{}
-	for _, d := range i.Remotes {
-		p = append(p, &Probe{Addr: d.addr, Counter: d.Probe()})
-	}
-	return p
-}
-
-*/
-
 // Utility functions
 
 func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {

+ 28 - 23
inside.go

@@ -5,9 +5,13 @@ import (
 
 	"github.com/flynn/noise"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 
-func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
+func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
 	err := newPacket(packet, false, fwPacket)
 	if err != nil {
 		f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
@@ -32,7 +36,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 	hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
 	if hostinfo == nil {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
+			f.l.WithField("vpnIp", fwPacket.RemoteIP).
 				WithField("fwPacket", fwPacket).
 				Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
 		}
@@ -45,7 +49,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 		// the packet queue.
 		ci.queueLock.Lock()
 		if !ci.ready {
-			hostinfo.cachePacket(f.l, message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
+			hostinfo.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
 			ci.queueLock.Unlock()
 			return
 		}
@@ -54,7 +58,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 
 	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
 	if dropReason == nil {
-		f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
+		f.sendNoMetrics(header.Message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
 
 	} else if f.l.Level >= logrus.DebugLevel {
 		hostinfo.logger(f.l).
@@ -65,20 +69,21 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 }
 
 // getOrHandshake returns nil if the vpnIp is not routable
-func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
-	if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false {
+func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
+	//TODO: we can find contains without converting back to bytes
+	if f.hostMap.vpnCIDR.Contains(vpnIp.ToIP()) == false {
 		vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
 		if vpnIp == 0 {
 			return nil
 		}
 	}
-	hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
+	hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
 
 	//if err != nil || hostinfo.ConnectionState == nil {
 	if err != nil {
-		hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIp)
+		hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
 		if err != nil {
-			hostinfo = f.handshakeManager.AddVpnIP(vpnIp)
+			hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
 		}
 	}
 	ci := hostinfo.ConnectionState
@@ -126,8 +131,8 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
 	return hostinfo
 }
 
-func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
-	fp := &FirewallPacket{}
+func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
+	fp := &firewall.Packet{}
 	err := newPacket(p, false, fp)
 	if err != nil {
 		f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
@@ -145,15 +150,15 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
 		return
 	}
 
-	f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
+	f.sendNoMetrics(header.Message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
 }
 
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
-func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
+func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
 	hostInfo := f.getOrHandshake(vpnIp)
 	if hostInfo == nil {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", IntIp(vpnIp)).
+			f.l.WithField("vpnIp", vpnIp).
 				Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
 		}
 		return
@@ -175,16 +180,16 @@ func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
 	return
 }
 
-func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
+func (f *Interface) sendMessageToVpnIp(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
 	f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
 }
 
-func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
+func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
 	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
 }
 
-func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) {
+func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) {
 	if ci.eKey == nil {
 		//TODO: log warning
 		return
@@ -196,18 +201,18 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
 	c := atomic.AddUint64(&ci.atomicMessageCounter, 1)
 
 	//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
-	out = HeaderEncode(out, Version, uint8(t), uint8(st), hostinfo.remoteIndexId, c)
-	f.connectionManager.Out(hostinfo.hostId)
+	out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
+	f.connectionManager.Out(hostinfo.vpnIp)
 
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
 	// all our IPs and enable a faster roaming.
-	if t != closeTunnel && hostinfo.lastRebindCount != f.rebindCount {
+	if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
 		//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
 		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
-		f.lightHouse.QueryServer(hostinfo.hostId, f)
+		f.lightHouse.QueryServer(hostinfo.vpnIp, f)
 		hostinfo.lastRebindCount = f.rebindCount
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
+			f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
 		}
 	}
 
@@ -230,7 +235,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
 	return
 }
 
-func isMulticast(ip uint32) bool {
+func isMulticast(ip iputil.VpnIp) bool {
 	// Class D multicast
 	if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 {
 		return true

+ 26 - 21
interface.go

@@ -12,6 +12,10 @@ import (
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 
 const mtu = 9001
@@ -27,7 +31,7 @@ type Inside interface {
 
 type InterfaceConfig struct {
 	HostMap                 *HostMap
-	Outside                 *udpConn
+	Outside                 *udp.Conn
 	Inside                  Inside
 	certState               *CertState
 	Cipher                  string
@@ -39,7 +43,6 @@ type InterfaceConfig struct {
 	pendingDeletionInterval int
 	DropLocalBroadcast      bool
 	DropMulticast           bool
-	UDPBatchSize            int
 	routines                int
 	MessageMetrics          *MessageMetrics
 	version                 string
@@ -52,7 +55,7 @@ type InterfaceConfig struct {
 
 type Interface struct {
 	hostMap            *HostMap
-	outside            *udpConn
+	outside            *udp.Conn
 	inside             Inside
 	certState          *CertState
 	cipher             string
@@ -62,11 +65,10 @@ type Interface struct {
 	serveDns           bool
 	createTime         time.Time
 	lightHouse         *LightHouse
-	localBroadcast     uint32
-	myVpnIp            uint32
+	localBroadcast     iputil.VpnIp
+	myVpnIp            iputil.VpnIp
 	dropLocalBroadcast bool
 	dropMulticast      bool
-	udpBatchSize       int
 	routines           int
 	caPool             *cert.NebulaCAPool
 	disconnectInvalid  bool
@@ -77,7 +79,7 @@ type Interface struct {
 
 	conntrackCacheTimeout time.Duration
 
-	writers []*udpConn
+	writers []*udp.Conn
 	readers []io.ReadWriteCloser
 
 	metricHandshakes    metrics.Histogram
@@ -101,6 +103,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		return nil, errors.New("no firewall rules")
 	}
 
+	myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP)
 	ifce := &Interface{
 		hostMap:            c.HostMap,
 		outside:            c.Outside,
@@ -112,17 +115,16 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		handshakeManager:   c.HandshakeManager,
 		createTime:         time.Now(),
 		lightHouse:         c.lightHouse,
-		localBroadcast:     ip2int(c.certState.certificate.Details.Ips[0].IP) | ^ip2int(c.certState.certificate.Details.Ips[0].Mask),
+		localBroadcast:     myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask),
 		dropLocalBroadcast: c.DropLocalBroadcast,
 		dropMulticast:      c.DropMulticast,
-		udpBatchSize:       c.UDPBatchSize,
 		routines:           c.routines,
 		version:            c.version,
-		writers:            make([]*udpConn, c.routines),
+		writers:            make([]*udp.Conn, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
 		caPool:             c.caPool,
 		disconnectInvalid:  c.disconnectInvalid,
-		myVpnIp:            ip2int(c.certState.certificate.Details.Ips[0].IP),
+		myVpnIp:            myVpnIp,
 
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 
@@ -190,14 +192,17 @@ func (f *Interface) run() {
 func (f *Interface) listenOut(i int) {
 	runtime.LockOSThread()
 
-	var li *udpConn
+	var li *udp.Conn
 	// TODO clean this up with a coherent interface for each outside connection
 	if i > 0 {
 		li = f.writers[i]
 	} else {
 		li = f.outside
 	}
-	li.ListenOut(f, i)
+
+	lhh := f.lightHouse.NewRequestHandler()
+	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
+	li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i)
 }
 
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@@ -205,10 +210,10 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 
 	packet := make([]byte, mtu)
 	out := make([]byte, mtu)
-	fwPacket := &FirewallPacket{}
+	fwPacket := &firewall.Packet{}
 	nb := make([]byte, 12, 12)
 
-	conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
+	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 
 	for {
 		n, err := reader.Read(packet)
@@ -222,16 +227,16 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 	}
 }
 
-func (f *Interface) RegisterConfigChangeCallbacks(c *Config) {
+func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
 	c.RegisterReloadCallback(f.reloadCA)
 	c.RegisterReloadCallback(f.reloadCertKey)
 	c.RegisterReloadCallback(f.reloadFirewall)
 	for _, udpConn := range f.writers {
-		c.RegisterReloadCallback(udpConn.reloadConfig)
+		c.RegisterReloadCallback(udpConn.ReloadConfig)
 	}
 }
 
-func (f *Interface) reloadCA(c *Config) {
+func (f *Interface) reloadCA(c *config.C) {
 	// reload and check regardless
 	// todo: need mutex?
 	newCAs, err := loadCAFromConfig(f.l, c)
@@ -244,7 +249,7 @@ func (f *Interface) reloadCA(c *Config) {
 	f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
 }
 
-func (f *Interface) reloadCertKey(c *Config) {
+func (f *Interface) reloadCertKey(c *config.C) {
 	// reload and check in all cases
 	cs, err := NewCertStateFromConfig(c)
 	if err != nil {
@@ -264,7 +269,7 @@ func (f *Interface) reloadCertKey(c *Config) {
 	f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
 }
 
-func (f *Interface) reloadFirewall(c *Config) {
+func (f *Interface) reloadFirewall(c *config.C) {
 	//TODO: need to trigger/detect if the certificate changed too
 	if c.HasChanged("firewall") == false {
 		f.l.Debug("No firewall config change detected")
@@ -307,7 +312,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 	ticker := time.NewTicker(i)
 	defer ticker.Stop()
 
-	udpStats := NewUDPStatsEmitter(f.writers)
+	udpStats := udp.NewUDPStatsEmitter(f.writers)
 
 	for {
 		select {

+ 66 - 0
iputil/util.go

@@ -0,0 +1,66 @@
+package iputil
+
+import (
+	"encoding/binary"
+	"fmt"
+	"net"
+)
+
+type VpnIp uint32
+
+const maxIPv4StringLen = len("255.255.255.255")
+
+func (ip VpnIp) String() string {
+	b := make([]byte, maxIPv4StringLen)
+
+	n := ubtoa(b, 0, byte(ip>>24))
+	b[n] = '.'
+	n++
+
+	n += ubtoa(b, n, byte(ip>>16&255))
+	b[n] = '.'
+	n++
+
+	n += ubtoa(b, n, byte(ip>>8&255))
+	b[n] = '.'
+	n++
+
+	n += ubtoa(b, n, byte(ip&255))
+	return string(b[:n])
+}
+
+func (ip VpnIp) MarshalJSON() ([]byte, error) {
+	return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil
+}
+
+func (ip VpnIp) ToIP() net.IP {
+	nip := make(net.IP, 4)
+	binary.BigEndian.PutUint32(nip, uint32(ip))
+	return nip
+}
+
+func Ip2VpnIp(ip []byte) VpnIp {
+	if len(ip) == 16 {
+		return VpnIp(binary.BigEndian.Uint32(ip[12:16]))
+	}
+	return VpnIp(binary.BigEndian.Uint32(ip))
+}
+
+// ubtoa encodes the string form of the integer v to dst[start:] and
+// returns the number of bytes written to dst. The caller must ensure
+// that dst has sufficient length.
+func ubtoa(dst []byte, start int, v byte) int {
+	if v < 10 {
+		dst[start] = v + '0'
+		return 1
+	} else if v < 100 {
+		dst[start+1] = v%10 + '0'
+		dst[start] = v/10 + '0'
+		return 2
+	}
+
+	dst[start+2] = v%10 + '0'
+	dst[start+1] = (v/10)%10 + '0'
+	dst[start] = v/100 + '0'
+	return 3
+}

+ 17 - 0
iputil/util_test.go

@@ -0,0 +1,17 @@
+package iputil
+
+import (
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestVpnIp_String(t *testing.T) {
+	assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String())
+	assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String())
+	assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String())
+	assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String())
+	assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String())
+	assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String())
+}

+ 82 - 81
lighthouse.go

@@ -12,6 +12,9 @@ import (
 	"github.com/golang/protobuf/proto"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 
 //TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake?
@@ -23,13 +26,13 @@ type LightHouse struct {
 	//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
 	sync.RWMutex //Because we concurrently read and write to our maps
 	amLighthouse bool
-	myVpnIp      uint32
-	myVpnZeros   uint32
-	punchConn    *udpConn
+	myVpnIp      iputil.VpnIp
+	myVpnZeros   iputil.VpnIp
+	punchConn    *udp.Conn
 
 	// Local cache of answers from light houses
 	// map of vpn Ip to answers
-	addrMap map[uint32]*RemoteList
+	addrMap map[iputil.VpnIp]*RemoteList
 
 	// filters remote addresses allowed for each host
 	// - When we are a lighthouse, this filters what addresses we store and
@@ -42,12 +45,12 @@ type LightHouse struct {
 	localAllowList *LocalAllowList
 
 	// used to trigger the HandshakeManager when we receive HostQueryReply
-	handshakeTrigger chan<- uint32
+	handshakeTrigger chan<- iputil.VpnIp
 
 	// staticList exists to avoid having a bool in each addrMap entry
 	// since static should be rare
-	staticList  map[uint32]struct{}
-	lighthouses map[uint32]struct{}
+	staticList  map[iputil.VpnIp]struct{}
+	lighthouses map[iputil.VpnIp]struct{}
 	interval    int
 	nebulaPort  uint32 // 32 bits because protobuf does not have a uint16
 	punchBack   bool
@@ -58,20 +61,16 @@ type LightHouse struct {
 	l                 *logrus.Logger
 }
 
-type EncWriter interface {
-	SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
-}
-
-func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
+func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, ips []iputil.VpnIp, interval int, nebulaPort uint32, pc *udp.Conn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
 	ones, _ := myVpnIpNet.Mask.Size()
 	h := LightHouse{
 		amLighthouse: amLighthouse,
-		myVpnIp:      ip2int(myVpnIpNet.IP),
-		myVpnZeros:   uint32(32 - ones),
-		addrMap:      make(map[uint32]*RemoteList),
+		myVpnIp:      iputil.Ip2VpnIp(myVpnIpNet.IP),
+		myVpnZeros:   iputil.VpnIp(32 - ones),
+		addrMap:      make(map[iputil.VpnIp]*RemoteList),
 		nebulaPort:   nebulaPort,
-		lighthouses:  make(map[uint32]struct{}),
-		staticList:   make(map[uint32]struct{}),
+		lighthouses:  make(map[iputil.VpnIp]struct{}),
+		staticList:   make(map[iputil.VpnIp]struct{}),
 		interval:     interval,
 		punchConn:    pc,
 		punchBack:    punchBack,
@@ -111,13 +110,13 @@ func (lh *LightHouse) SetLocalAllowList(allowList *LocalAllowList) {
 func (lh *LightHouse) ValidateLHStaticEntries() error {
 	for lhIP, _ := range lh.lighthouses {
 		if _, ok := lh.staticList[lhIP]; !ok {
-			return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", IntIp(lhIP))
+			return fmt.Errorf("Lighthouse %s does not have a static_host_map entry", lhIP)
 		}
 	}
 	return nil
 }
 
-func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
+func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
 	if !lh.IsLighthouseIP(ip) {
 		lh.QueryServer(ip, f)
 	}
@@ -131,7 +130,7 @@ func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
 }
 
 // This is asynchronous so no reply should be expected
-func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
+func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) {
 	if lh.amLighthouse {
 		return
 	}
@@ -143,7 +142,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
 	// Send a query to the lighthouses and hope for the best next time
 	query, err := proto.Marshal(NewLhQueryByInt(ip))
 	if err != nil {
-		lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
+		lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload")
 		return
 	}
 
@@ -151,11 +150,11 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
 	for n := range lh.lighthouses {
-		f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
+		f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
 	}
 }
 
-func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
+func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
 	lh.RLock()
 	if v, ok := lh.addrMap[ip]; ok {
 		lh.RUnlock()
@@ -172,7 +171,7 @@ func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
 // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
 // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
 // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
-func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, error)) (bool, int, error) {
+func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) {
 	lh.RLock()
 	// Do we have an entry in the main cache?
 	if v, ok := lh.addrMap[vpnIp]; ok {
@@ -195,18 +194,18 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, err
 	return false, 0, nil
 }
 
-func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
+func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
 	// First we check the static mapping
 	// and do nothing if it is there
-	if _, ok := lh.staticList[vpnIP]; ok {
+	if _, ok := lh.staticList[vpnIp]; ok {
 		return
 	}
 	lh.Lock()
 	//l.Debugln(lh.addrMap)
-	delete(lh.addrMap, vpnIP)
+	delete(lh.addrMap, vpnIp)
 
 	if lh.l.Level >= logrus.DebugLevel {
-		lh.l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
+		lh.l.Debugf("deleting %s from lighthouse.", vpnIp)
 	}
 
 	lh.Unlock()
@@ -215,7 +214,7 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
 // AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
 // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
 // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
-func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
+func (lh *LightHouse) AddStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr) {
 	lh.Lock()
 	am := lh.unlockedGetRemoteList(vpnIp)
 	am.Lock()
@@ -242,23 +241,23 @@ func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
 }
 
 // unlockedGetRemoteList assumes you have the lh lock
-func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList {
-	am, ok := lh.addrMap[vpnIP]
+func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
+	am, ok := lh.addrMap[vpnIp]
 	if !ok {
 		am = NewRemoteList()
-		lh.addrMap[vpnIP] = am
+		lh.addrMap[vpnIp] = am
 	}
 	return am
 }
 
 // unlockedShouldAddV4 checks if to is allowed by our allow list
-func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool {
-	allow := lh.remoteAllowList.AllowIpV4(vpnIp, to.Ip)
+func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
+	allow := lh.remoteAllowList.AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
 	if lh.l.Level >= logrus.TraceLevel {
-		lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow")
+		lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
 	}
 
-	if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, to.Ip) {
+	if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) {
 		return false
 	}
 
@@ -266,7 +265,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp uint32, to *Ip4AndPort) bool {
 }
 
 // unlockedShouldAddV6 checks if to is allowed by our allow list
-func (lh *LightHouse) unlockedShouldAddV6(vpnIp uint32, to *Ip6AndPort) bool {
+func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool {
 	allow := lh.remoteAllowList.AllowIpV6(vpnIp, to.Hi, to.Lo)
 	if lh.l.Level >= logrus.TraceLevel {
 		lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
@@ -287,25 +286,25 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
 	return ip
 }
 
-func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
-	if _, ok := lh.lighthouses[vpnIP]; ok {
+func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool {
+	if _, ok := lh.lighthouses[vpnIp]; ok {
 		return true
 	}
 	return false
 }
 
-func NewLhQueryByInt(VpnIp uint32) *NebulaMeta {
+func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta {
 	return &NebulaMeta{
 		Type: NebulaMeta_HostQuery,
 		Details: &NebulaMetaDetails{
-			VpnIp: VpnIp,
+			VpnIp: uint32(VpnIp),
 		},
 	}
 }
 
 func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
 	ipp := Ip4AndPort{Port: port}
-	ipp.Ip = ip2int(ip)
+	ipp.Ip = uint32(iputil.Ip2VpnIp(ip))
 	return &ipp
 }
 
@@ -317,19 +316,19 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
 	}
 }
 
-func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udpAddr {
+func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
 	ip := ipp.Ip
-	return NewUDPAddr(
+	return udp.NewAddr(
 		net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
 		uint16(ipp.Port),
 	)
 }
 
-func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udpAddr {
-	return NewUDPAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
+func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
+	return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
 }
 
-func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
+func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
 	if lh.amLighthouse || lh.interval == 0 {
 		return
 	}
@@ -349,12 +348,12 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
 	}
 }
 
-func (lh *LightHouse) SendUpdate(f EncWriter) {
+func (lh *LightHouse) SendUpdate(f udp.EncWriter) {
 	var v4 []*Ip4AndPort
 	var v6 []*Ip6AndPort
 
 	for _, e := range *localIps(lh.l, lh.localAllowList) {
-		if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip2int(ip4)) {
+		if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) {
 			continue
 		}
 
@@ -368,7 +367,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
 	m := &NebulaMeta{
 		Type: NebulaMeta_HostUpdateNotification,
 		Details: &NebulaMetaDetails{
-			VpnIp:       lh.myVpnIp,
+			VpnIp:       uint32(lh.myVpnIp),
 			Ip4AndPorts: v4,
 			Ip6AndPorts: v6,
 		},
@@ -385,7 +384,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
 	}
 
 	for vpnIp := range lh.lighthouses {
-		f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
+		f.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out)
 	}
 }
 
@@ -415,11 +414,11 @@ func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
 }
 
 func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) {
-	lh.metrics.Rx(NebulaMessageType(t), 0, i)
+	lh.metrics.Rx(header.MessageType(t), 0, i)
 }
 
 func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) {
-	lh.metrics.Tx(NebulaMessageType(t), 0, i)
+	lh.metrics.Tx(header.MessageType(t), 0, i)
 }
 
 // This method is similar to Reset(), but it re-uses the pointer structs
@@ -436,18 +435,18 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
 	return lhh.meta
 }
 
-func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) {
+func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) {
 	n := lhh.resetMeta()
 	err := n.Unmarshal(p)
 	if err != nil {
-		lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
+		lhh.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr).
 			Error("Failed to unmarshal lighthouse packet")
 		//TODO: send recv_error?
 		return
 	}
 
 	if n.Details == nil {
-		lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
+		lhh.l.WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr).
 			Error("Invalid lighthouse update")
 		//TODO: send recv_error?
 		return
@@ -471,7 +470,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
 	}
 }
 
-func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr *udpAddr, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) {
 	// Exit if we don't answer queries
 	if !lhh.lh.amLighthouse {
 		if lhh.l.Level >= logrus.DebugLevel {
@@ -481,12 +480,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
 	}
 
 	//TODO: we can DRY this further
-	reqVpnIP := n.Details.VpnIp
+	reqVpnIp := n.Details.VpnIp
 	//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
-	found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(c *cache) (int, error) {
+	found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) {
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostQueryReply
-		n.Details.VpnIp = reqVpnIP
+		n.Details.VpnIp = reqVpnIp
 
 		lhh.coalesceAnswers(c, n)
 
@@ -498,18 +497,18 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
 	}
 
 	if err != nil {
-		lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
+		lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host query reply")
 		return
 	}
 
 	lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
-	w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
+	w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 
 	// This signals the other side to punch some zero byte udp packets
 	found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostPunchNotification
-		n.Details.VpnIp = vpnIp
+		n.Details.VpnIp = uint32(vpnIp)
 
 		lhh.coalesceAnswers(c, n)
 
@@ -521,12 +520,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
 	}
 
 	if err != nil {
-		lhh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host was queried for")
+		lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host was queried for")
 		return
 	}
 
 	lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
-	w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0])
+	w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0])
 }
 
 func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
@@ -549,28 +548,29 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
 	}
 }
 
-func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) {
+func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 		return
 	}
 
 	lhh.lh.Lock()
-	am := lhh.lh.unlockedGetRemoteList(n.Details.VpnIp)
+	am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp))
 	am.Lock()
 	lhh.lh.Unlock()
 
-	am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
-	am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
+	certVpnIp := iputil.VpnIp(n.Details.VpnIp)
+	am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
+	am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
 	am.Unlock()
 
 	// Non-blocking attempt to trigger, skip if it would block
 	select {
-	case lhh.lh.handshakeTrigger <- n.Details.VpnIp:
+	case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp):
 	default:
 	}
 }
 
-func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp uint32) {
+func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp) {
 	if !lhh.lh.amLighthouse {
 		if lhh.l.Level >= logrus.DebugLevel {
 			lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
@@ -579,9 +579,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	}
 
 	//Simple check that the host sent this not someone else
-	if n.Details.VpnIp != vpnIp {
+	if n.Details.VpnIp != uint32(vpnIp) {
 		if lhh.l.Level >= logrus.DebugLevel {
-			lhh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
+			lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
 		}
 		return
 	}
@@ -591,18 +591,19 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	am.Lock()
 	lhh.lh.Unlock()
 
-	am.unlockedSetV4(vpnIp, n.Details.VpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
-	am.unlockedSetV6(vpnIp, n.Details.VpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
+	certVpnIp := iputil.VpnIp(n.Details.VpnIp)
+	am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
+	am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
 	am.Unlock()
 }
 
-func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 		return
 	}
 
 	empty := []byte{0}
-	punch := func(vpnPeer *udpAddr) {
+	punch := func(vpnPeer *udp.Addr) {
 		if vpnPeer == nil {
 			return
 		}
@@ -615,7 +616,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
 
 		if lhh.l.Level >= logrus.DebugLevel {
 			//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
-			lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, IntIp(n.Details.VpnIp))
+			lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp))
 		}
 	}
 
@@ -634,18 +635,18 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
 		go func() {
 			time.Sleep(time.Second * 5)
 			if lhh.l.Level >= logrus.DebugLevel {
-				lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
+				lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", iputil.VpnIp(n.Details.VpnIp))
 			}
 			//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
 			// for each punchBack packet. We should move this into a timerwheel or a single goroutine
 			// managed by a channel.
-			w.SendMessageToVpnIp(test, testRequest, n.Details.VpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			w.SendMessageToVpnIp(header.Test, header.TestRequest, iputil.VpnIp(n.Details.VpnIp), []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 		}()
 	}
 }
 
 // ipMaskContains checks if testIp is contained by ip after applying a cidr
 // zeros is 32 - bits from net.IPMask.Size()
-func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool {
+func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool {
 	return (testIp^ip)>>zeros == 0
 }

+ 72 - 68
lighthouse_test.go

@@ -6,6 +6,10 @@ import (
 	"testing"
 
 	"github.com/golang/protobuf/proto"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
@@ -17,12 +21,12 @@ func TestOldIPv4Only(t *testing.T) {
 	var m Ip4AndPort
 	err := proto.Unmarshal(b, &m)
 	assert.NoError(t, err)
-	assert.Equal(t, "10.1.1.1", int2ip(m.GetIp()).String())
+	assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String())
 }
 
 func TestNewLhQuery(t *testing.T) {
 	myIp := net.ParseIP("192.1.1.1")
-	myIpint := ip2int(myIp)
+	myIpint := iputil.Ip2VpnIp(myIp)
 
 	// Generating a new lh query should work
 	a := NewLhQueryByInt(myIpint)
@@ -42,37 +46,37 @@ func TestNewLhQuery(t *testing.T) {
 }
 
 func Test_lhStaticMapping(t *testing.T) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	lh1 := "10.128.0.2"
 	lh1IP := net.ParseIP(lh1)
 
-	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
+	udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
 
-	meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
-	meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
+	meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false)
+	meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242)))
 	err := meh.ValidateLHStaticEntries()
 	assert.Nil(t, err)
 
 	lh2 := "10.128.0.3"
 	lh2IP := net.ParseIP(lh2)
 
-	meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
-	meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
+	meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP), iputil.Ip2VpnIp(lh2IP)}, 10, 10003, udpServer, false, 1, false)
+	meh.AddStaticRemote(iputil.Ip2VpnIp(lh1IP), udp.NewAddr(lh1IP, uint16(4242)))
 	err = meh.ValidateLHStaticEntries()
 	assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
-	l := NewTestLogger()
+	l := util.NewTestLogger()
 	lh1 := "10.128.0.2"
 	lh1IP := net.ParseIP(lh1)
 
-	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
+	udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
 
-	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
+	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []iputil.VpnIp{iputil.Ip2VpnIp(lh1IP)}, 10, 10003, udpServer, false, 1, false)
 
-	hAddr := NewUDPAddrFromString("4.5.6.7:12345")
-	hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
+	hAddr := udp.NewAddrFromString("4.5.6.7:12345")
+	hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
 	lh.addrMap[3] = NewRemoteList()
 	lh.addrMap[3].unlockedSetV4(
 		3,
@@ -81,11 +85,11 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 			NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
 			NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
 		},
-		func(uint32, *Ip4AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
 	)
 
-	rAddr := NewUDPAddrFromString("1.2.2.3:12345")
-	rAddr2 := NewUDPAddrFromString("1.2.2.3:12346")
+	rAddr := udp.NewAddrFromString("1.2.2.3:12345")
+	rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
 	lh.addrMap[2] = NewRemoteList()
 	lh.addrMap[2].unlockedSetV4(
 		3,
@@ -94,7 +98,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 			NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
 			NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
 		},
-		func(uint32, *Ip4AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
 	)
 
 	mw := &mockEncWriter{}
@@ -133,50 +137,50 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 }
 
 func TestLighthouse_Memory(t *testing.T) {
-	l := NewTestLogger()
-
-	myUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
-	myUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
-	myUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
-	myUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
-	myUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
-	myUdpAddr5 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
-	myUdpAddr6 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
-	myUdpAddr7 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
-	myUdpAddr8 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
-	myUdpAddr9 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
-	myUdpAddr10 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
-	myUdpAddr11 := &udpAddr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
-	myVpnIp := ip2int(net.ParseIP("10.128.0.2"))
-
-	theirUdpAddr0 := &udpAddr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
-	theirUdpAddr1 := &udpAddr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
-	theirUdpAddr2 := &udpAddr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
-	theirUdpAddr3 := &udpAddr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
-	theirUdpAddr4 := &udpAddr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
-	theirVpnIp := ip2int(net.ParseIP("10.128.0.3"))
-
-	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
-	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{}, 10, 10003, udpServer, false, 1, false)
+	l := util.NewTestLogger()
+
+	myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
+	myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
+	myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
+	myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
+	myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
+	myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
+	myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
+	myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
+	myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
+	myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
+	myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
+	myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
+	myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2"))
+
+	theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
+	theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
+	theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
+	theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
+	theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
+	theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3"))
+
+	udpServer, _ := udp.NewListener(l, "0.0.0.0", 0, true, 2)
+	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []iputil.VpnIp{}, 10, 10003, udpServer, false, 1, false)
 	lhh := lh.NewRequestHandler()
 
 	// Test that my first update responds with just that
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr2}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh)
 	r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
 
 	// Ensure we don't accumulate addresses
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr3}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
 
 	// Grow it back to 2
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{myUdpAddr1, myUdpAddr4}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
 
 	// Update a different host
-	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udpAddr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
+	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
@@ -189,7 +193,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	newLHHostUpdate(
 		myUdpAddr0,
 		myVpnIp,
-		[]*udpAddr{
+		[]*udp.Addr{
 			myUdpAddr1,
 			myUdpAddr2,
 			myUdpAddr3,
@@ -212,19 +216,19 @@ func TestLighthouse_Memory(t *testing.T) {
 	)
 
 	// Make sure we won't add ips in our vpn network
-	bad1 := &udpAddr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
-	bad2 := &udpAddr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
-	good := &udpAddr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udpAddr{bad1, bad2, good}, lhh)
+	bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
+	bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
+	good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
 }
 
-func newLHHostRequest(fromAddr *udpAddr, myVpnIp, queryVpnIp uint32, lhh *LightHouseHandler) testLhReply {
+func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {
 	req := &NebulaMeta{
 		Type: NebulaMeta_HostQuery,
 		Details: &NebulaMetaDetails{
-			VpnIp: queryVpnIp,
+			VpnIp: uint32(queryVpnIp),
 		},
 	}
 
@@ -238,17 +242,17 @@ func newLHHostRequest(fromAddr *udpAddr, myVpnIp, queryVpnIp uint32, lhh *LightH
 	return w.lastReply
 }
 
-func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *LightHouseHandler) {
+func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) {
 	req := &NebulaMeta{
 		Type: NebulaMeta_HostUpdateNotification,
 		Details: &NebulaMetaDetails{
-			VpnIp:       vpnIp,
+			VpnIp:       uint32(vpnIp),
 			Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
 		},
 	}
 
 	for k, v := range addrs {
-		req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: ip2int(v.IP), Port: uint32(v.Port)}
+		req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)}
 	}
 
 	b, err := req.Marshal()
@@ -327,15 +331,15 @@ func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *Lig
 //}
 
 func Test_ipMaskContains(t *testing.T) {
-	assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255"))))
-	assert.False(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.1.1"))))
-	assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32, ip2int(net.ParseIP("10.0.1.1"))))
+	assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255"))))
+	assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
+	assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
 }
 
 type testLhReply struct {
-	nebType    NebulaMessageType
-	nebSubType NebulaMessageSubType
-	vpnIp      uint32
+	nebType    header.MessageType
+	nebSubType header.MessageSubType
+	vpnIp      iputil.VpnIp
 	msg        *NebulaMeta
 }
 
@@ -343,7 +347,7 @@ type testEncWriter struct {
 	lastReply testLhReply
 }
 
-func (tw *testEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, _, _ []byte) {
+func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
 	tw.lastReply = testLhReply{
 		nebType:    t,
 		nebSubType: st,
@@ -358,17 +362,17 @@ func (tw *testEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessag
 }
 
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
-func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udpAddr) {
+func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) {
 	assert.Len(t, have, len(want))
 	for k, w := range want {
-		if !(have[k].Ip == ip2int(w.IP) && have[k].Port == uint32(w.Port)) {
+		if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) {
 			assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
 		}
 	}
 }
 
 // assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
-func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
+func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) {
 	assert.Len(t, have, len(want))
 	for k, w := range want {
 		if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
@@ -377,8 +381,8 @@ func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
 	}
 }
 
-func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr {
-	addrs := make([]*udpAddr, len(ips))
+func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr {
+	addrs := make([]*udp.Addr, len(ips))
 	for k, v := range ips {
 		addrs[k] = NewUDPAddrFromLH4(v)
 	}

+ 39 - 0
logger.go

@@ -2,8 +2,12 @@ package nebula
 
 import (
 	"errors"
+	"fmt"
+	"strings"
+	"time"
 
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
 )
 
 type ContextualError struct {
@@ -37,3 +41,38 @@ func (ce *ContextualError) Log(lr *logrus.Logger) {
 		lr.WithFields(ce.Fields).Error(ce.Context)
 	}
 }
+
+func configLogger(l *logrus.Logger, c *config.C) error {
+	// set up our logging level
+	logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
+	if err != nil {
+		return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
+	}
+	l.SetLevel(logLevel)
+
+	disableTimestamp := c.GetBool("logging.disable_timestamp", false)
+	timestampFormat := c.GetString("logging.timestamp_format", "")
+	fullTimestamp := (timestampFormat != "")
+	if timestampFormat == "" {
+		timestampFormat = time.RFC3339
+	}
+
+	logFormat := strings.ToLower(c.GetString("logging.format", "text"))
+	switch logFormat {
+	case "text":
+		l.Formatter = &logrus.TextFormatter{
+			TimestampFormat:  timestampFormat,
+			FullTimestamp:    fullTimestamp,
+			DisableTimestamp: disableTimestamp,
+		}
+	case "json":
+		l.Formatter = &logrus.JSONFormatter{
+			TimestampFormat:  timestampFormat,
+			DisableTimestamp: disableTimestamp,
+		}
+	default:
+		return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
+	}
+
+	return nil
+}

+ 71 - 68
main.go

@@ -8,14 +8,16 @@ import (
 	"time"
 
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/sshd"
+	"github.com/slackhq/nebula/udp"
 	"gopkg.in/yaml.v2"
 )
 
 type m map[string]interface{}
 
-func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
-
+func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
 	ctx, cancel := context.WithCancel(context.Background())
 	// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
 	defer func() {
@@ -31,7 +33,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 
 	// Print the config if in test, the exit comes later
 	if configTest {
-		b, err := yaml.Marshal(config.Settings)
+		b, err := yaml.Marshal(c.Settings)
 		if err != nil {
 			return nil, err
 		}
@@ -40,33 +42,33 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		l.Println(string(b))
 	}
 
-	err := configLogger(config)
+	err := configLogger(l, c)
 	if err != nil {
 		return nil, NewContextualError("Failed to configure the logger", nil, err)
 	}
 
-	config.RegisterReloadCallback(func(c *Config) {
-		err := configLogger(c)
+	c.RegisterReloadCallback(func(c *config.C) {
+		err := configLogger(l, c)
 		if err != nil {
 			l.WithError(err).Error("Failed to configure the logger")
 		}
 	})
 
-	caPool, err := loadCAFromConfig(l, config)
+	caPool, err := loadCAFromConfig(l, c)
 	if err != nil {
 		//The errors coming out of loadCA are already nicely formatted
 		return nil, NewContextualError("Failed to load ca from config", nil, err)
 	}
 	l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
 
-	cs, err := NewCertStateFromConfig(config)
+	cs, err := NewCertStateFromConfig(c)
 	if err != nil {
 		//The errors coming out of NewCertStateFromConfig are already nicely formatted
 		return nil, NewContextualError("Failed to load certificate from config", nil, err)
 	}
 	l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
 
-	fw, err := NewFirewallFromConfig(l, cs.certificate, config)
+	fw, err := NewFirewallFromConfig(l, cs.certificate, c)
 	if err != nil {
 		return nil, NewContextualError("Error while loading firewall rules", nil, err)
 	}
@@ -74,20 +76,20 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 
 	// TODO: make sure mask is 4 bytes
 	tunCidr := cs.certificate.Details.Ips[0]
-	routes, err := parseRoutes(config, tunCidr)
+	routes, err := parseRoutes(c, tunCidr)
 	if err != nil {
 		return nil, NewContextualError("Could not parse tun.routes", nil, err)
 	}
-	unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
+	unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
 	if err != nil {
 		return nil, NewContextualError("Could not parse tun.unsafe_routes", nil, err)
 	}
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
-	wireSSHReload(l, ssh, config)
+	wireSSHReload(l, ssh, c)
 	var sshStart func()
-	if config.GetBool("sshd.enabled", false) {
-		sshStart, err = configSSH(l, ssh, config)
+	if c.GetBool("sshd.enabled", false) {
+		sshStart, err = configSSH(l, ssh, c)
 		if err != nil {
 			return nil, NewContextualError("Error while configuring the sshd", nil, err)
 		}
@@ -101,7 +103,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	var routines int
 
 	// If `routines` is set, use that and ignore the specific values
-	if routines = config.GetInt("routines", 0); routines != 0 {
+	if routines = c.GetInt("routines", 0); routines != 0 {
 		if routines < 1 {
 			routines = 1
 		}
@@ -110,8 +112,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		}
 	} else {
 		// deprecated and undocumented
-		tunQueues := config.GetInt("tun.routines", 1)
-		udpQueues := config.GetInt("listen.routines", 1)
+		tunQueues := c.GetInt("tun.routines", 1)
+		udpQueues := c.GetInt("listen.routines", 1)
 		if tunQueues > udpQueues {
 			routines = tunQueues
 		} else {
@@ -125,8 +127,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	// EXPERIMENTAL
 	// Intentionally not documented yet while we do more testing and determine
 	// a good default value.
-	conntrackCacheTimeout := config.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
-	if routines > 1 && !config.IsSet("firewall.conntrack.routine_cache_timeout") {
+	conntrackCacheTimeout := c.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
+	if routines > 1 && !c.IsSet("firewall.conntrack.routine_cache_timeout") {
 		// Use a different default if we are running with multiple routines
 		conntrackCacheTimeout = 1 * time.Second
 	}
@@ -136,30 +138,30 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 
 	var tun Inside
 	if !configTest {
-		config.CatchHUP(ctx)
+		c.CatchHUP(ctx)
 
 		switch {
-		case config.GetBool("tun.disabled", false):
-			tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l)
+		case c.GetBool("tun.disabled", false):
+			tun = newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
 		case tunFd != nil:
 			tun, err = newTunFromFd(
 				l,
 				*tunFd,
 				tunCidr,
-				config.GetInt("tun.mtu", DEFAULT_MTU),
+				c.GetInt("tun.mtu", DEFAULT_MTU),
 				routes,
 				unsafeRoutes,
-				config.GetInt("tun.tx_queue", 500),
+				c.GetInt("tun.tx_queue", 500),
 			)
 		default:
 			tun, err = newTun(
 				l,
-				config.GetString("tun.dev", ""),
+				c.GetString("tun.dev", ""),
 				tunCidr,
-				config.GetInt("tun.mtu", DEFAULT_MTU),
+				c.GetInt("tun.mtu", DEFAULT_MTU),
 				routes,
 				unsafeRoutes,
-				config.GetInt("tun.tx_queue", 500),
+				c.GetInt("tun.tx_queue", 500),
 				routines > 1,
 			)
 		}
@@ -176,16 +178,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	}()
 
 	// set up our UDP listener
-	udpConns := make([]*udpConn, routines)
-	port := config.GetInt("listen.port", 0)
+	udpConns := make([]*udp.Conn, routines)
+	port := c.GetInt("listen.port", 0)
 
 	if !configTest {
 		for i := 0; i < routines; i++ {
-			udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
+			udpServer, err := udp.NewListener(l, c.GetString("listen.host", "0.0.0.0"), port, routines > 1, c.GetInt("listen.batch", 64))
 			if err != nil {
 				return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 			}
-			udpServer.reloadConfig(config)
+			udpServer.ReloadConfig(c)
 			udpConns[i] = udpServer
 
 			// If port is dynamic, discover it
@@ -201,7 +203,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 
 	// Set up my internal host map
 	var preferredRanges []*net.IPNet
-	rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
+	rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
 	// First, check if 'preferred_ranges' is set and fallback to 'local_range'
 	if len(rawPreferredRanges) > 0 {
 		for _, rawPreferredRange := range rawPreferredRanges {
@@ -216,7 +218,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	// local_range was superseded by preferred_ranges. If it is still present,
 	// merge the local_range setting into preferred_ranges. We will probably
 	// deprecate local_range and remove in the future.
-	rawLocalRange := config.GetString("local_range", "")
+	rawLocalRange := c.GetString("local_range", "")
 	if rawLocalRange != "" {
 		_, localRange, err := net.ParseCIDR(rawLocalRange)
 		if err != nil {
@@ -240,7 +242,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
 
 	hostMap.addUnsafeRoutes(&unsafeRoutes)
-	hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
+	hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
 
 	l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
 
@@ -249,26 +251,26 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		go hostMap.Promoter(config.GetInt("promoter.interval"))
 	*/
 
-	punchy := NewPunchyFromConfig(config)
+	punchy := NewPunchyFromConfig(c)
 	if punchy.Punch && !configTest {
 		l.Info("UDP hole punching enabled")
 		go hostMap.Punchy(ctx, udpConns[0])
 	}
 
-	amLighthouse := config.GetBool("lighthouse.am_lighthouse", false)
+	amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
 
 	// fatal if am_lighthouse is enabled but we are using an ephemeral port
-	if amLighthouse && (config.GetInt("listen.port", 0) == 0) {
+	if amLighthouse && (c.GetInt("listen.port", 0) == 0) {
 		return nil, NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil)
 	}
 
 	// warn if am_lighthouse is enabled but upstream lighthouses exists
-	rawLighthouseHosts := config.GetStringSlice("lighthouse.hosts", []string{})
+	rawLighthouseHosts := c.GetStringSlice("lighthouse.hosts", []string{})
 	if amLighthouse && len(rawLighthouseHosts) != 0 {
 		l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
 	}
 
-	lighthouseHosts := make([]uint32, len(rawLighthouseHosts))
+	lighthouseHosts := make([]iputil.VpnIp, len(rawLighthouseHosts))
 	for i, host := range rawLighthouseHosts {
 		ip := net.ParseIP(host)
 		if ip == nil {
@@ -277,7 +279,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		if !tunCidr.Contains(ip) {
 			return nil, NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
 		}
-		lighthouseHosts[i] = ip2int(ip)
+		lighthouseHosts[i] = iputil.Ip2VpnIp(ip)
 	}
 
 	lightHouse := NewLightHouse(
@@ -286,47 +288,48 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		tunCidr,
 		lighthouseHosts,
 		//TODO: change to a duration
-		config.GetInt("lighthouse.interval", 10),
+		c.GetInt("lighthouse.interval", 10),
 		uint32(port),
 		udpConns[0],
 		punchy.Respond,
 		punchy.Delay,
-		config.GetBool("stats.lighthouse_metrics", false),
+		c.GetBool("stats.lighthouse_metrics", false),
 	)
 
-	remoteAllowList, err := config.GetRemoteAllowList("lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
+	remoteAllowList, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges")
 	if err != nil {
 		return nil, NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
 	}
 	lightHouse.SetRemoteAllowList(remoteAllowList)
 
-	localAllowList, err := config.GetLocalAllowList("lighthouse.local_allow_list")
+	localAllowList, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list")
 	if err != nil {
 		return nil, NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
 	}
 	lightHouse.SetLocalAllowList(localAllowList)
 
 	//TODO: Move all of this inside functions in lighthouse.go
-	for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
-		vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
-		if !tunCidr.Contains(vpnIp) {
+	for k, v := range c.GetMap("static_host_map", map[interface{}]interface{}{}) {
+		ip := net.ParseIP(fmt.Sprintf("%v", k))
+		vpnIp := iputil.Ip2VpnIp(ip)
+		if !tunCidr.Contains(ip) {
 			return nil, NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
 		}
 		vals, ok := v.([]interface{})
 		if ok {
 			for _, v := range vals {
-				ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
+				ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
 				if err != nil {
 					return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 				}
-				lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
+				lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
 			}
 		} else {
-			ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
+			ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
 			if err != nil {
 				return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 			}
-			lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
+			lightHouse.AddStaticRemote(vpnIp, udp.NewAddr(ip, port))
 		}
 	}
 
@@ -336,16 +339,16 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	}
 
 	var messageMetrics *MessageMetrics
-	if config.GetBool("stats.message_metrics", false) {
+	if c.GetBool("stats.message_metrics", false) {
 		messageMetrics = newMessageMetrics()
 	} else {
 		messageMetrics = newMessageMetricsOnlyRecvError()
 	}
 
 	handshakeConfig := HandshakeConfig{
-		tryInterval:   config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
-		retries:       config.GetInt("handshakes.retries", DefaultHandshakeRetries),
-		triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
+		tryInterval:   c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
+		retries:       c.GetInt("handshakes.retries", DefaultHandshakeRetries),
+		triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
 
 		messageMetrics: messageMetrics,
 	}
@@ -358,36 +361,35 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
 
 	serveDns := false
-	if config.GetBool("lighthouse.serve_dns", false) {
-		if config.GetBool("lighthouse.am_lighthouse", false) {
+	if c.GetBool("lighthouse.serve_dns", false) {
+		if c.GetBool("lighthouse.am_lighthouse", false) {
 			serveDns = true
 		} else {
 			l.Warn("DNS server refusing to run because this host is not a lighthouse.")
 		}
 	}
 
-	checkInterval := config.GetInt("timers.connection_alive_interval", 5)
-	pendingDeletionInterval := config.GetInt("timers.pending_deletion_interval", 10)
+	checkInterval := c.GetInt("timers.connection_alive_interval", 5)
+	pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
 	ifConfig := &InterfaceConfig{
 		HostMap:                 hostMap,
 		Inside:                  tun,
 		Outside:                 udpConns[0],
 		certState:               cs,
-		Cipher:                  config.GetString("cipher", "aes"),
+		Cipher:                  c.GetString("cipher", "aes"),
 		Firewall:                fw,
 		ServeDns:                serveDns,
 		HandshakeManager:        handshakeManager,
 		lightHouse:              lightHouse,
 		checkInterval:           checkInterval,
 		pendingDeletionInterval: pendingDeletionInterval,
-		DropLocalBroadcast:      config.GetBool("tun.drop_local_broadcast", false),
-		DropMulticast:           config.GetBool("tun.drop_multicast", false),
-		UDPBatchSize:            config.GetInt("listen.batch", 64),
+		DropLocalBroadcast:      c.GetBool("tun.drop_local_broadcast", false),
+		DropMulticast:           c.GetBool("tun.drop_multicast", false),
 		routines:                routines,
 		MessageMetrics:          messageMetrics,
 		version:                 buildVersion,
 		caPool:                  caPool,
-		disconnectInvalid:       config.GetBool("pki.disconnect_invalid", false),
+		disconnectInvalid:       c.GetBool("pki.disconnect_invalid", false),
 
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		l:                     l,
@@ -413,7 +415,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		// I don't want to make this initial commit too far-reaching though
 		ifce.writers = udpConns
 
-		ifce.RegisterConfigChangeCallbacks(config)
+		ifce.RegisterConfigChangeCallbacks(c)
 
 		go handshakeManager.Run(ctx, ifce)
 		go lightHouse.LhUpdateWorker(ctx, ifce)
@@ -421,7 +423,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 
 	// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
 	// a context so that they can exit when the context is Done.
-	statsStart, err := startStats(l, config, buildVersion, configTest)
+	statsStart, err := startStats(l, c, buildVersion, configTest)
+
 	if err != nil {
 		return nil, NewContextualError("Failed to start stats emitter", nil, err)
 	}
@@ -431,7 +434,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	}
 
 	//TODO: check if we _should_ be emitting stats
-	go ifce.emitStats(ctx, config.GetDuration("stats.interval", time.Second*10))
+	go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
 
 	attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
 
@@ -439,7 +442,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	var dnsStart func()
 	if amLighthouse && serveDns {
 		l.Debugln("Starting dns server")
-		dnsStart = dnsMain(l, hostMap, config)
+		dnsStart = dnsMain(l, hostMap, c)
 	}
 
 	return &Control{ifce, l, cancel, sshStart, statsStart, dnsStart}, nil

+ 5 - 2
message_metrics.go

@@ -4,8 +4,11 @@ import (
 	"fmt"
 
 	"github.com/rcrowley/go-metrics"
+	"github.com/slackhq/nebula/header"
 )
 
+//TODO: this can probably move into the header package
+
 type MessageMetrics struct {
 	rx [][]metrics.Counter
 	tx [][]metrics.Counter
@@ -14,7 +17,7 @@ type MessageMetrics struct {
 	txUnknown metrics.Counter
 }
 
-func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
+func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) {
 	if m != nil {
 		if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) {
 			m.rx[t][s].Inc(i)
@@ -23,7 +26,7 @@ func (m *MessageMetrics) Rx(t NebulaMessageType, s NebulaMessageSubType, i int64
 		}
 	}
 }
-func (m *MessageMetrics) Tx(t NebulaMessageType, s NebulaMessageSubType, i int64) {
+func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int64) {
 	if m != nil {
 		if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) {
 			m.tx[t][s].Inc(i)

+ 61 - 57
outside.go

@@ -10,6 +10,10 @@ import (
 	"github.com/golang/protobuf/proto"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 	"golang.org/x/net/ipv4"
 )
 
@@ -17,8 +21,8 @@ const (
 	minFwPacketLen = 4
 )
 
-func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int, localCache ConntrackCache) {
-	err := header.Parse(packet)
+func (f *Interface) readOutsidePackets(addr *udp.Addr, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
+	err := h.Parse(packet)
 	if err != nil {
 		// TODO: best if we return this and let caller log
 		// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
@@ -32,30 +36,30 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 	//l.Error("in packet ", header, packet[HeaderLen:])
 
 	// verify if we've seen this index before, otherwise respond to the handshake initiation
-	hostinfo, err := f.hostMap.QueryIndex(header.RemoteIndex)
+	hostinfo, err := f.hostMap.QueryIndex(h.RemoteIndex)
 
 	var ci *ConnectionState
 	if err == nil {
 		ci = hostinfo.ConnectionState
 	}
 
-	switch header.Type {
-	case message:
-		if !f.handleEncrypted(ci, addr, header) {
+	switch h.Type {
+	case header.Message:
+		if !f.handleEncrypted(ci, addr, h) {
 			return
 		}
 
-		f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q, localCache)
+		f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache)
 
 		// Fallthrough to the bottom to record incoming traffic
 
-	case lightHouse:
-		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-		if !f.handleEncrypted(ci, addr, header) {
+	case header.LightHouse:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		if !f.handleEncrypted(ci, addr, h) {
 			return
 		}
 
-		d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
+		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
 				WithField("packet", packet).
@@ -66,17 +70,17 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 			return
 		}
 
-		lhh.HandleRequest(addr, hostinfo.hostId, d, f)
+		lhf(addr, hostinfo.vpnIp, d, f)
 
 		// Fallthrough to the bottom to record incoming traffic
 
-	case test:
-		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-		if !f.handleEncrypted(ci, addr, header) {
+	case header.Test:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		if !f.handleEncrypted(ci, addr, h) {
 			return
 		}
 
-		d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
+		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
 				WithField("packet", packet).
@@ -87,11 +91,11 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 			return
 		}
 
-		if header.Subtype == testRequest {
+		if h.Subtype == header.TestRequest {
 			// This testRequest might be from TryPromoteBest, so we should roam
 			// to the new IP address before responding
 			f.handleHostRoaming(hostinfo, addr)
-			f.send(test, testReply, ci, hostinfo, hostinfo.remote, d, nb, out)
+			f.send(header.Test, header.TestReply, ci, hostinfo, hostinfo.remote, d, nb, out)
 		}
 
 		// Fallthrough to the bottom to record incoming traffic
@@ -99,19 +103,19 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 		// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
 		// are unauthenticated
 
-	case handshake:
-		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-		HandleIncomingHandshake(f, addr, packet, header, hostinfo)
+	case header.Handshake:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		HandleIncomingHandshake(f, addr, packet, h, hostinfo)
 		return
 
-	case recvError:
-		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-		f.handleRecvError(addr, header)
+	case header.RecvError:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		f.handleRecvError(addr, h)
 		return
 
-	case closeTunnel:
-		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-		if !f.handleEncrypted(ci, addr, header) {
+	case header.CloseTunnel:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		if !f.handleEncrypted(ci, addr, h) {
 			return
 		}
 
@@ -122,22 +126,22 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 		return
 
 	default:
-		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
 		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
 		return
 	}
 
 	f.handleHostRoaming(hostinfo, addr)
 
-	f.connectionManager.In(hostinfo.hostId)
+	f.connectionManager.In(hostinfo.vpnIp)
 }
 
 // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
 func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) {
 	//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
-	f.connectionManager.ClearIP(hostInfo.hostId)
-	f.connectionManager.ClearPendingDeletion(hostInfo.hostId)
-	f.lightHouse.DeleteVpnIP(hostInfo.hostId)
+	f.connectionManager.ClearIP(hostInfo.vpnIp)
+	f.connectionManager.ClearPendingDeletion(hostInfo.vpnIp)
+	f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
 
 	if hasHostMapLock {
 		f.hostMap.unlockedDeleteHostInfo(hostInfo)
@@ -148,12 +152,12 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo, hasHostMapLock bool) {
 
 // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
 func (f *Interface) sendCloseTunnel(h *HostInfo) {
-	f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
+	f.send(header.CloseTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 }
 
-func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
-	if hostDidRoam(hostinfo.remote, addr) {
-		if !f.lightHouse.remoteAllowList.Allow(hostinfo.hostId, addr.IP) {
+func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) {
+	if !hostinfo.remote.Equals(addr) {
+		if !f.lightHouse.remoteAllowList.Allow(hostinfo.vpnIp, addr.IP) {
 			hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
 			return
 		}
@@ -175,11 +179,11 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
 
 }
 
-func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool {
+func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool {
 	// If connectionstate exists and the replay protector allows, process packet
 	// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
-	if ci == nil || !ci.window.Check(f.l, header.MessageCounter) {
-		f.sendRecvError(addr, header.RemoteIndex)
+	if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
+		f.sendRecvError(addr, h.RemoteIndex)
 		return false
 	}
 
@@ -187,7 +191,7 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *
 }
 
 // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
-func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
+func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 	// Do we at least have an ipv4 header worth of data?
 	if len(data) < ipv4.HeaderLen {
 		return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
@@ -215,7 +219,7 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
 
 	// Accounting for a variable header length, do we have enough data for our src/dst tuples?
 	minLen := ihl
-	if !fp.Fragment && fp.Protocol != fwProtoICMP {
+	if !fp.Fragment && fp.Protocol != firewall.ProtoICMP {
 		minLen += minFwPacketLen
 	}
 	if len(data) < minLen {
@@ -224,9 +228,9 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
 
 	// Firewall packets are locally oriented
 	if incoming {
-		fp.RemoteIP = binary.BigEndian.Uint32(data[12:16])
-		fp.LocalIP = binary.BigEndian.Uint32(data[16:20])
-		if fp.Fragment || fp.Protocol == fwProtoICMP {
+		fp.RemoteIP = iputil.Ip2VpnIp(data[12:16])
+		fp.LocalIP = iputil.Ip2VpnIp(data[16:20])
+		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.LocalPort = 0
 		} else {
@@ -234,9 +238,9 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
 			fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
 		}
 	} else {
-		fp.LocalIP = binary.BigEndian.Uint32(data[12:16])
-		fp.RemoteIP = binary.BigEndian.Uint32(data[16:20])
-		if fp.Fragment || fp.Protocol == fwProtoICMP {
+		fp.LocalIP = iputil.Ip2VpnIp(data[12:16])
+		fp.RemoteIP = iputil.Ip2VpnIp(data[16:20])
+		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.LocalPort = 0
 		} else {
@@ -248,15 +252,15 @@ func newPacket(data []byte, incoming bool, fp *FirewallPacket) error {
 	return nil
 }
 
-func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, header *Header, nb []byte) ([]byte, error) {
+func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, h *header.H, nb []byte) ([]byte, error) {
 	var err error
-	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], mc, nb)
+	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], mc, nb)
 	if err != nil {
 		return nil, err
 	}
 
 	if !hostinfo.ConnectionState.window.Update(f.l, mc) {
-		hostinfo.logger(f.l).WithField("header", header).
+		hostinfo.logger(f.l).WithField("header", h).
 			Debugln("dropping out of window packet")
 		return nil, errors.New("out of window packet")
 	}
@@ -264,10 +268,10 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
 	return out, nil
 }
 
-func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int, localCache ConntrackCache) {
+func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) {
 	var err error
 
-	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
+	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
 	if err != nil {
 		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
 		//TODO: maybe after build 64 is out? 06/14/2018 - NB
@@ -298,18 +302,18 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 		return
 	}
 
-	f.connectionManager.In(hostinfo.hostId)
+	f.connectionManager.In(hostinfo.vpnIp)
 	_, err = f.readers[q].Write(out)
 	if err != nil {
 		f.l.WithError(err).Error("Failed to write to tun")
 	}
 }
 
-func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
-	f.messageMetrics.Tx(recvError, 0, 1)
+func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
+	f.messageMetrics.Tx(header.RecvError, 0, 1)
 
 	//TODO: this should be a signed message so we can trust that we should drop the index
-	b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
+	b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
 	f.outside.WriteTo(b, endpoint)
 	if f.l.Level >= logrus.DebugLevel {
 		f.l.WithField("index", index).
@@ -318,7 +322,7 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
 	}
 }
 
-func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
+func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
 	if f.l.Level >= logrus.DebugLevel {
 		f.l.WithField("index", h.RemoteIndex).
 			WithField("udpAddr", addr).

+ 9 - 7
outside_test.go

@@ -4,12 +4,14 @@ import (
 	"net"
 	"testing"
 
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/net/ipv4"
 )
 
 func Test_newPacket(t *testing.T) {
-	p := &FirewallPacket{}
+	p := &firewall.Packet{}
 
 	// length fail
 	err := newPacket([]byte{0, 1}, true, p)
@@ -44,7 +46,7 @@ func Test_newPacket(t *testing.T) {
 		Src:      net.IPv4(10, 0, 0, 1),
 		Dst:      net.IPv4(10, 0, 0, 2),
 		Options:  []byte{0, 1, 0, 2},
-		Protocol: fwProtoTCP,
+		Protocol: firewall.ProtoTCP,
 	}
 
 	b, _ = h.Marshal()
@@ -52,9 +54,9 @@ func Test_newPacket(t *testing.T) {
 	err = newPacket(b, true, p)
 
 	assert.Nil(t, err)
-	assert.Equal(t, p.Protocol, uint8(fwProtoTCP))
-	assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 2)))
-	assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 1)))
+	assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
+	assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
+	assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
 	assert.Equal(t, p.RemotePort, uint16(3))
 	assert.Equal(t, p.LocalPort, uint16(4))
 
@@ -74,8 +76,8 @@ func Test_newPacket(t *testing.T) {
 
 	assert.Nil(t, err)
 	assert.Equal(t, p.Protocol, uint8(2))
-	assert.Equal(t, p.LocalIP, ip2int(net.IPv4(10, 0, 0, 1)))
-	assert.Equal(t, p.RemoteIP, ip2int(net.IPv4(10, 0, 0, 2)))
+	assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
+	assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
 	assert.Equal(t, p.RemotePort, uint16(6))
 	assert.Equal(t, p.LocalPort, uint16(5))
 }

+ 6 - 2
punchy.go

@@ -1,6 +1,10 @@
 package nebula
 
-import "time"
+import (
+	"time"
+
+	"github.com/slackhq/nebula/config"
+)
 
 type Punchy struct {
 	Punch   bool
@@ -8,7 +12,7 @@ type Punchy struct {
 	Delay   time.Duration
 }
 
-func NewPunchyFromConfig(c *Config) *Punchy {
+func NewPunchyFromConfig(c *config.C) *Punchy {
 	p := &Punchy{}
 
 	if c.IsSet("punchy.punch") {

+ 4 - 2
punchy_test.go

@@ -4,12 +4,14 @@ import (
 	"testing"
 	"time"
 
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestNewPunchyFromConfig(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
+	l := util.NewTestLogger()
+	c := config.NewC(l)
 
 	// Test defaults
 	p := NewPunchyFromConfig(c)

+ 31 - 28
remote_list.go

@@ -5,14 +5,17 @@ import (
 	"net"
 	"sort"
 	"sync"
+
+	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/udp"
 )
 
 // forEachFunc is used to benefit folks that want to do work inside the lock
-type forEachFunc func(addr *udpAddr, preferred bool)
+type forEachFunc func(addr *udp.Addr, preferred bool)
 
 // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
-type checkFuncV4 func(vpnIp uint32, to *Ip4AndPort) bool
-type checkFuncV6 func(vpnIp uint32, to *Ip6AndPort) bool
+type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool
+type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool
 
 // CacheMap is a struct that better represents the lighthouse cache for humans
 // The string key is the owners vpnIp
@@ -21,8 +24,8 @@ type CacheMap map[string]*Cache
 // Cache is the other part of CacheMap to better represent the lighthouse cache for humans
 // We don't reason about ipv4 vs ipv6 here
 type Cache struct {
-	Learned  []*udpAddr `json:"learned,omitempty"`
-	Reported []*udpAddr `json:"reported,omitempty"`
+	Learned  []*udp.Addr `json:"learned,omitempty"`
+	Reported []*udp.Addr `json:"reported,omitempty"`
 }
 
 //TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
@@ -53,16 +56,16 @@ type RemoteList struct {
 	sync.RWMutex
 
 	// A deduplicated set of addresses. Any accessor should lock beforehand.
-	addrs []*udpAddr
+	addrs []*udp.Addr
 
 	// These are maps to store v4 and v6 addresses per lighthouse
 	// Map key is the vpnIp of the person that told us about this the cached entries underneath.
 	// For learned addresses, this is the vpnIp that sent the packet
-	cache map[uint32]*cache
+	cache map[iputil.VpnIp]*cache
 
 	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
 	// They should not be tried again during a handshake
-	badRemotes []*udpAddr
+	badRemotes []*udp.Addr
 
 	// A flag that the cache may have changed and addrs needs to be rebuilt
 	shouldRebuild bool
@@ -71,8 +74,8 @@ type RemoteList struct {
 // NewRemoteList creates a new empty RemoteList
 func NewRemoteList() *RemoteList {
 	return &RemoteList{
-		addrs: make([]*udpAddr, 0),
-		cache: make(map[uint32]*cache),
+		addrs: make([]*udp.Addr, 0),
+		cache: make(map[iputil.VpnIp]*cache),
 	}
 }
 
@@ -98,7 +101,7 @@ func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc)
 
 // CopyAddrs locks and makes a deep copy of the deduplicated address list
 // The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
+func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
 	if r == nil {
 		return nil
 	}
@@ -107,7 +110,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
 
 	r.RLock()
 	defer r.RUnlock()
-	c := make([]*udpAddr, len(r.addrs))
+	c := make([]*udp.Addr, len(r.addrs))
 	for i, v := range r.addrs {
 		c[i] = v.Copy()
 	}
@@ -118,7 +121,7 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
 // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
 // It will mark the deduplicated address list as dirty, so do not call it unless new information is available
 //TODO: this needs to support the allow list list
-func (r *RemoteList) LearnRemote(ownerVpnIp uint32, addr *udpAddr) {
+func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) {
 	r.Lock()
 	defer r.Unlock()
 	if v4 := addr.IP.To4(); v4 != nil {
@@ -139,8 +142,8 @@ func (r *RemoteList) CopyCache() *CacheMap {
 		c := cm[vpnIp]
 		if c == nil {
 			c = &Cache{
-				Learned:  make([]*udpAddr, 0),
-				Reported: make([]*udpAddr, 0),
+				Learned:  make([]*udp.Addr, 0),
+				Reported: make([]*udp.Addr, 0),
 			}
 			cm[vpnIp] = c
 		}
@@ -148,7 +151,7 @@ func (r *RemoteList) CopyCache() *CacheMap {
 	}
 
 	for owner, mc := range r.cache {
-		c := getOrMake(IntIp(owner).String())
+		c := getOrMake(owner.String())
 
 		if mc.v4 != nil {
 			if mc.v4.learned != nil {
@@ -175,7 +178,7 @@ func (r *RemoteList) CopyCache() *CacheMap {
 }
 
 // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
-func (r *RemoteList) BlockRemote(bad *udpAddr) {
+func (r *RemoteList) BlockRemote(bad *udp.Addr) {
 	r.Lock()
 	defer r.Unlock()
 
@@ -192,11 +195,11 @@ func (r *RemoteList) BlockRemote(bad *udpAddr) {
 }
 
 // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
-func (r *RemoteList) CopyBlockedRemotes() []*udpAddr {
+func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr {
 	r.RLock()
 	defer r.RUnlock()
 
-	c := make([]*udpAddr, len(r.badRemotes))
+	c := make([]*udp.Addr, len(r.badRemotes))
 	for i, v := range r.badRemotes {
 		c[i] = v.Copy()
 	}
@@ -228,7 +231,7 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
 }
 
 // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
-func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
+func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
 	for _, v := range r.badRemotes {
 		if v.Equals(remote) {
 			return true
@@ -239,14 +242,14 @@ func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
 
 // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
 // deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) {
+func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
 	r.shouldRebuild = true
 	r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
 }
 
 // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4AndPort, check checkFuncV4) {
+func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 
@@ -263,7 +266,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, vpnIp uint32, to []*Ip4And
 
 // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
 // This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
+func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 
@@ -276,14 +279,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
 
 // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
 // deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) {
+func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
 	r.shouldRebuild = true
 	r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
 }
 
 // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6AndPort, check checkFuncV6) {
+func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 
@@ -300,7 +303,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, vpnIp uint32, to []*Ip6And
 
 // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
 // This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
+func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 
@@ -313,7 +316,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
 
 // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
 // The caller must dirty the learned address cache if required
-func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
+func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
 	am := r.cache[ownerVpnIp]
 	if am == nil {
 		am = &cache{}
@@ -328,7 +331,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
 
 // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
 // The caller must dirty the learned address cache if required
-func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp uint32) *cacheV6 {
+func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 {
 	am := r.cache[ownerVpnIp]
 	if am == nil {
 		am = &cache{}

+ 33 - 32
remote_list_test.go

@@ -4,6 +4,7 @@ import (
 	"net"
 	"testing"
 
+	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 )
 
@@ -13,18 +14,18 @@ func TestRemoteList_Rebuild(t *testing.T) {
 		0,
 		0,
 		[]*Ip4AndPort{
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped
-			{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is duped
-			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is duped
-			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is a dupe
-			{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},   // this is a dupe
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},   // this is a dupe
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe
 		},
-		func(uint32, *Ip4AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
 	)
 
 	rl.unlockedSetV6(
@@ -37,7 +38,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
 			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
 			NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
 		},
-		func(uint32, *Ip6AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip6AndPort) bool { return true },
 	)
 
 	rl.Rebuild([]*net.IPNet{})
@@ -106,16 +107,16 @@ func BenchmarkFullRebuild(b *testing.B) {
 		0,
 		0,
 		[]*Ip4AndPort{
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
-			{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},   // this is a dupe
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},   // this is a dupe
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
 		},
-		func(uint32, *Ip4AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
 	)
 
 	rl.unlockedSetV6(
@@ -127,7 +128,7 @@ func BenchmarkFullRebuild(b *testing.B) {
 			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
 			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
 		},
-		func(uint32, *Ip6AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip6AndPort) bool { return true },
 	)
 
 	b.Run("no preferred", func(b *testing.B) {
@@ -171,16 +172,16 @@ func BenchmarkSortRebuild(b *testing.B) {
 		0,
 		0,
 		[]*Ip4AndPort{
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
-			{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
-			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},   // this is a dupe
-			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},   // this is a dupe
+			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
 		},
-		func(uint32, *Ip4AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
 	)
 
 	rl.unlockedSetV6(
@@ -192,7 +193,7 @@ func BenchmarkSortRebuild(b *testing.B) {
 			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
 			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
 		},
-		func(uint32, *Ip6AndPort) bool { return true },
+		func(iputil.VpnIp, *Ip6AndPort) bool { return true },
 	)
 
 	b.Run("no preferred", func(b *testing.B) {

+ 30 - 26
ssh.go

@@ -15,7 +15,11 @@ import (
 	"syscall"
 
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/sshd"
+	"github.com/slackhq/nebula/udp"
 )
 
 type sshListHostMapFlags struct {
@@ -45,8 +49,8 @@ type sshCreateTunnelFlags struct {
 	Address string
 }
 
-func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
-	c.RegisterReloadCallback(func(c *Config) {
+func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
+	c.RegisterReloadCallback(func(c *config.C) {
 		if c.GetBool("sshd.enabled", false) {
 			sshRun, err := configSSH(l, ssh, c)
 			if err != nil {
@@ -66,7 +70,7 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
 // updates the passed-in SSHServer. On success, it returns a function
 // that callers may invoke to run the configured ssh server. On
 // failure, it returns nil, error.
-func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) (func(), error) {
+func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
 	//TODO conntrack list
 	//TODO print firewall rules or hash?
 
@@ -351,7 +355,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
 
 	hm := listHostMap(hostMap)
 	sort.Slice(hm, func(i, j int) bool {
-		return bytes.Compare(hm[i].VpnIP, hm[j].VpnIP) < 0
+		return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
 	})
 
 	if fs.Json || fs.Pretty {
@@ -368,7 +372,7 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
 
 	} else {
 		for _, v := range hm {
-			err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, v.RemoteAddrs))
+			err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs))
 			if err != nil {
 				return err
 			}
@@ -386,7 +390,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 	}
 
 	type lighthouseInfo struct {
-		VpnIP net.IP    `json:"vpnIp"`
+		VpnIp string    `json:"vpnIp"`
 		Addrs *CacheMap `json:"addrs"`
 	}
 
@@ -395,7 +399,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 	x := 0
 	for k, v := range lightHouse.addrMap {
 		addrMap[x] = lighthouseInfo{
-			VpnIP: int2ip(k),
+			VpnIp: k.String(),
 			Addrs: v.CopyCache(),
 		}
 		x++
@@ -403,7 +407,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 	lightHouse.RUnlock()
 
 	sort.Slice(addrMap, func(i, j int) bool {
-		return bytes.Compare(addrMap[i].VpnIP, addrMap[j].VpnIP) < 0
+		return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0
 	})
 
 	if fs.Json || fs.Pretty {
@@ -424,7 +428,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 			if err != nil {
 				return err
 			}
-			err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, string(b)))
+			err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b)))
 			if err != nil {
 				return err
 			}
@@ -470,7 +474,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	vpnIp := ip2int(parsedIp)
+	vpnIp := iputil.Ip2VpnIp(parsedIp)
 	if vpnIp == 0 {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
@@ -499,19 +503,19 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	vpnIp := ip2int(parsedIp)
+	vpnIp := iputil.Ip2VpnIp(parsedIp)
 	if vpnIp == 0 {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
+	hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
 
 	if !flags.LocalOnly {
 		ifce.send(
-			closeTunnel,
+			header.CloseTunnel,
 			0,
 			hostInfo.ConnectionState,
 			hostInfo,
@@ -542,30 +546,30 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	vpnIp := ip2int(parsedIp)
+	vpnIp := iputil.Ip2VpnIp(parsedIp)
 	if vpnIp == 0 {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	hostInfo, _ := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
+	hostInfo, _ := ifce.hostMap.QueryVpnIp(vpnIp)
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
 	}
 
-	hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIP(uint32(vpnIp))
+	hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
 	}
 
-	var addr *udpAddr
+	var addr *udp.Addr
 	if flags.Address != "" {
-		addr = NewUDPAddrFromString(flags.Address)
+		addr = udp.NewAddrFromString(flags.Address)
 		if addr == nil {
 			return w.WriteLine("Address could not be parsed")
 		}
 	}
 
-	hostInfo = ifce.handshakeManager.AddVpnIP(vpnIp)
+	hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp)
 	if addr != nil {
 		hostInfo.SetRemote(addr)
 	}
@@ -589,7 +593,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine("No address was provided")
 	}
 
-	addr := NewUDPAddrFromString(flags.Address)
+	addr := udp.NewAddrFromString(flags.Address)
 	if addr == nil {
 		return w.WriteLine("Address could not be parsed")
 	}
@@ -599,12 +603,12 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	vpnIp := ip2int(parsedIp)
+	vpnIp := iputil.Ip2VpnIp(parsedIp)
 	if vpnIp == 0 {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
+	hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
@@ -680,12 +684,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		}
 
-		vpnIp := ip2int(parsedIp)
+		vpnIp := iputil.Ip2VpnIp(parsedIp)
 		if vpnIp == 0 {
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		}
 
-		hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
+		hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
 		if err != nil {
 			return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 		}
@@ -742,12 +746,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	vpnIp := ip2int(parsedIp)
+	vpnIp := iputil.Ip2VpnIp(parsedIp)
 	if vpnIp == 0 {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	hostInfo, err := ifce.hostMap.QueryVpnIP(vpnIp)
+	hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}

+ 4 - 3
stats.go

@@ -15,12 +15,13 @@ import (
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
 )
 
 // startStats initializes stats from config. On success, if any futher work
 // is needed to serve stats, it returns a func to handle that work. If no
 // work is needed, it'll return nil. On failure, it returns nil, error.
-func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) (func(), error) {
+func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) {
 	mType := c.GetString("stats.type", "")
 	if mType == "" || mType == "none" {
 		return nil, nil
@@ -57,7 +58,7 @@ func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest boo
 	return startFn, nil
 }
 
-func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
+func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error {
 	proto := c.GetString("stats.protocol", "tcp")
 	host := c.GetString("stats.host", "")
 	if host == "" {
@@ -77,7 +78,7 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest
 	return nil
 }
 
-func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) (func(), error) {
+func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) {
 	namespace := c.GetString("stats.namespace", "")
 	subsystem := c.GetString("stats.subsystem", "")
 

+ 7 - 5
timeout.go

@@ -2,12 +2,14 @@ package nebula
 
 import (
 	"time"
+
+	"github.com/slackhq/nebula/firewall"
 )
 
 // How many timer objects should be cached
 const timerCacheMax = 50000
 
-var emptyFWPacket = FirewallPacket{}
+var emptyFWPacket = firewall.Packet{}
 
 type TimerWheel struct {
 	// Current tick
@@ -42,7 +44,7 @@ type TimeoutList struct {
 
 // Represents an item within a tick
 type TimeoutItem struct {
-	Packet FirewallPacket
+	Packet firewall.Packet
 	Next   *TimeoutItem
 }
 
@@ -73,8 +75,8 @@ func NewTimerWheel(min, max time.Duration) *TimerWheel {
 	return &tw
 }
 
-// Add will add a FirewallPacket to the wheel in it's proper timeout
-func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem {
+// Add will add a firewall.Packet to the wheel in it's proper timeout
+func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem {
 	// Check and see if we should progress the tick
 	tw.advance(time.Now())
 
@@ -103,7 +105,7 @@ func (tw *TimerWheel) Add(v FirewallPacket, timeout time.Duration) *TimeoutItem
 	return ti
 }
 
-func (tw *TimerWheel) Purge() (FirewallPacket, bool) {
+func (tw *TimerWheel) Purge() (firewall.Packet, bool) {
 	if tw.expired.Head == nil {
 		return emptyFWPacket, false
 	}

+ 4 - 2
timeout_system.go

@@ -3,6 +3,8 @@ package nebula
 import (
 	"sync"
 	"time"
+
+	"github.com/slackhq/nebula/iputil"
 )
 
 // How many timer objects should be cached
@@ -43,7 +45,7 @@ type SystemTimeoutList struct {
 
 // Represents an item within a tick
 type SystemTimeoutItem struct {
-	Item uint32
+	Item iputil.VpnIp
 	Next *SystemTimeoutItem
 }
 
@@ -74,7 +76,7 @@ func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel {
 	return &tw
 }
 
-func (tw *SystemTimerWheel) Add(v uint32, timeout time.Duration) *SystemTimeoutItem {
+func (tw *SystemTimerWheel) Add(v iputil.VpnIp, timeout time.Duration) *SystemTimeoutItem {
 	tw.lock.Lock()
 	defer tw.lock.Unlock()
 

+ 4 - 3
timeout_system_test.go

@@ -5,6 +5,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 )
 
@@ -51,7 +52,7 @@ func TestSystemTimerWheel_findWheel(t *testing.T) {
 func TestSystemTimerWheel_Add(t *testing.T) {
 	tw := NewSystemTimerWheel(time.Second, time.Second*10)
 
-	fp1 := ip2int(net.ParseIP("1.2.3.4"))
+	fp1 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
 	tw.Add(fp1, time.Second*1)
 
 	// Make sure we set head and tail properly
@@ -62,7 +63,7 @@ func TestSystemTimerWheel_Add(t *testing.T) {
 	assert.Nil(t, tw.wheel[2].Tail.Next)
 
 	// Make sure we only modify head
-	fp2 := ip2int(net.ParseIP("1.2.3.4"))
+	fp2 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
 	tw.Add(fp2, time.Second*1)
 	assert.Equal(t, fp2, tw.wheel[2].Head.Item)
 	assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
@@ -85,7 +86,7 @@ func TestSystemTimerWheel_Purge(t *testing.T) {
 	assert.NotNil(t, tw.lastTick)
 	assert.Equal(t, 0, tw.current)
 
-	fps := []uint32{9, 10, 11, 12}
+	fps := []iputil.VpnIp{9, 10, 11, 12}
 
 	//fp1 := ip2int(net.ParseIP("1.2.3.4"))
 

+ 4 - 3
timeout_test.go

@@ -4,6 +4,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/slackhq/nebula/firewall"
 	"github.com/stretchr/testify/assert"
 )
 
@@ -50,7 +51,7 @@ func TestTimerWheel_findWheel(t *testing.T) {
 func TestTimerWheel_Add(t *testing.T) {
 	tw := NewTimerWheel(time.Second, time.Second*10)
 
-	fp1 := FirewallPacket{}
+	fp1 := firewall.Packet{}
 	tw.Add(fp1, time.Second*1)
 
 	// Make sure we set head and tail properly
@@ -61,7 +62,7 @@ func TestTimerWheel_Add(t *testing.T) {
 	assert.Nil(t, tw.wheel[2].Tail.Next)
 
 	// Make sure we only modify head
-	fp2 := FirewallPacket{}
+	fp2 := firewall.Packet{}
 	tw.Add(fp2, time.Second*1)
 	assert.Equal(t, fp2, tw.wheel[2].Head.Packet)
 	assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet)
@@ -84,7 +85,7 @@ func TestTimerWheel_Purge(t *testing.T) {
 	assert.NotNil(t, tw.lastTick)
 	assert.Equal(t, 0, tw.current)
 
-	fps := []FirewallPacket{
+	fps := []firewall.Packet{
 		{LocalIP: 1},
 		{LocalIP: 2},
 		{LocalIP: 3},

+ 7 - 5
tun_common.go

@@ -4,6 +4,8 @@ import (
 	"fmt"
 	"net"
 	"strconv"
+
+	"github.com/slackhq/nebula/config"
 )
 
 const DEFAULT_MTU = 1300
@@ -14,10 +16,10 @@ type route struct {
 	via   *net.IP
 }
 
-func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
+func parseRoutes(c *config.C, network *net.IPNet) ([]route, error) {
 	var err error
 
-	r := config.Get("tun.routes")
+	r := c.Get("tun.routes")
 	if r == nil {
 		return []route{}, nil
 	}
@@ -84,10 +86,10 @@ func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
 	return routes, nil
 }
 
-func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
+func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]route, error) {
 	var err error
 
-	r := config.Get("tun.unsafe_routes")
+	r := c.Get("tun.unsafe_routes")
 	if r == nil {
 		return []route{}, nil
 	}
@@ -110,7 +112,7 @@ func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
 
 		rMtu, ok := m["mtu"]
 		if !ok {
-			rMtu = config.GetInt("tun.mtu", DEFAULT_MTU)
+			rMtu = c.GetInt("tun.mtu", DEFAULT_MTU)
 		}
 
 		mtu, ok := rMtu.(int)

+ 6 - 4
tun_test.go

@@ -5,12 +5,14 @@ import (
 	"net"
 	"testing"
 
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
 	"github.com/stretchr/testify/assert"
 )
 
 func Test_parseRoutes(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
+	l := util.NewTestLogger()
+	c := config.NewC(l)
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 
 	// test no routes config
@@ -105,8 +107,8 @@ func Test_parseRoutes(t *testing.T) {
 }
 
 func Test_parseUnsafeRoutes(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
+	l := util.NewTestLogger()
+	c := config.NewC(l)
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 
 	// test no routes config

+ 20 - 0
udp/conn.go

@@ -0,0 +1,20 @@
+package udp
+
+import (
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
+)
+
+const MTU = 9001
+
+type EncReader func(
+	addr *Addr,
+	out []byte,
+	packet []byte,
+	header *header.H,
+	fwPacket *firewall.Packet,
+	lhh LightHouseHandlerFunc,
+	nb []byte,
+	q int,
+	localCache firewall.ConntrackCache,
+)

+ 14 - 0
udp/temp.go

@@ -0,0 +1,14 @@
+package udp
+
+import (
+	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
+)
+
+//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
+
+type EncWriter interface {
+	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
+}
+
+type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter)

+ 15 - 13
udp_all.go → udp/udp_all.go

@@ -1,4 +1,4 @@
-package nebula
+package udp
 
 import (
 	"encoding/json"
@@ -7,32 +7,34 @@ import (
 	"strconv"
 )
 
-type udpAddr struct {
+type m map[string]interface{}
+
+type Addr struct {
 	IP   net.IP
 	Port uint16
 }
 
-func NewUDPAddr(ip net.IP, port uint16) *udpAddr {
-	addr := udpAddr{IP: make([]byte, net.IPv6len), Port: port}
+func NewAddr(ip net.IP, port uint16) *Addr {
+	addr := Addr{IP: make([]byte, net.IPv6len), Port: port}
 	copy(addr.IP, ip.To16())
 	return &addr
 }
 
-func NewUDPAddrFromString(s string) *udpAddr {
-	ip, port, err := parseIPAndPort(s)
+func NewAddrFromString(s string) *Addr {
+	ip, port, err := ParseIPAndPort(s)
 	//TODO: handle err
 	_ = err
-	return &udpAddr{IP: ip.To16(), Port: port}
+	return &Addr{IP: ip.To16(), Port: port}
 }
 
-func (ua *udpAddr) Equals(t *udpAddr) bool {
+func (ua *Addr) Equals(t *Addr) bool {
 	if t == nil || ua == nil {
 		return t == nil && ua == nil
 	}
 	return ua.IP.Equal(t.IP) && ua.Port == t.Port
 }
 
-func (ua *udpAddr) String() string {
+func (ua *Addr) String() string {
 	if ua == nil {
 		return "<nil>"
 	}
@@ -40,7 +42,7 @@ func (ua *udpAddr) String() string {
 	return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
 }
 
-func (ua *udpAddr) MarshalJSON() ([]byte, error) {
+func (ua *Addr) MarshalJSON() ([]byte, error) {
 	if ua == nil {
 		return nil, nil
 	}
@@ -48,12 +50,12 @@ func (ua *udpAddr) MarshalJSON() ([]byte, error) {
 	return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
 }
 
-func (ua *udpAddr) Copy() *udpAddr {
+func (ua *Addr) Copy() *Addr {
 	if ua == nil {
 		return nil
 	}
 
-	nu := udpAddr{
+	nu := Addr{
 		Port: ua.Port,
 		IP:   make(net.IP, len(ua.IP)),
 	}
@@ -62,7 +64,7 @@ func (ua *udpAddr) Copy() *udpAddr {
 	return &nu
 }
 
-func parseIPAndPort(s string) (net.IP, uint16, error) {
+func ParseIPAndPort(s string) (net.IP, uint16, error) {
 	rIp, sPort, err := net.SplitHostPort(s)
 	if err != nil {
 		return nil, 0, err

+ 2 - 2
udp_android.go → udp/udp_android.go

@@ -1,7 +1,7 @@
 //go:build !e2e_testing
 // +build !e2e_testing
 
-package nebula
+package udp
 
 import (
 	"fmt"
@@ -34,6 +34,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 }
 
-func (u *udpConn) Rebind() error {
+func (u *Conn) Rebind() error {
 	return nil
 }

+ 2 - 2
udp_darwin.go → udp/udp_darwin.go

@@ -1,7 +1,7 @@
 //go:build !e2e_testing
 // +build !e2e_testing
 
-package nebula
+package udp
 
 // Darwin support is primarily implemented in udp_generic, besides NewListenConfig
 
@@ -37,7 +37,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 }
 
-func (u *udpConn) Rebind() error {
+func (u *Conn) Rebind() error {
 	file, err := u.File()
 	if err != nil {
 		return err

+ 2 - 2
udp_freebsd.go → udp/udp_freebsd.go

@@ -1,7 +1,7 @@
 //go:build !e2e_testing
 // +build !e2e_testing
 
-package nebula
+package udp
 
 // FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
 
@@ -36,6 +36,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 }
 
-func (u *udpConn) Rebind() error {
+func (u *Conn) Rebind() error {
 	return nil
 }

+ 20 - 25
udp_generic.go → udp/udp_generic.go

@@ -5,7 +5,7 @@
 // udp_generic implements the nebula UDP interface in pure Go stdlib. This
 // means it can be used on platforms like Darwin and Windows.
 
-package nebula
+package udp
 
 import (
 	"context"
@@ -13,36 +13,39 @@ import (
 	"net"
 
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
 )
 
-type udpConn struct {
+type Conn struct {
 	*net.UDPConn
 	l *logrus.Logger
 }
 
-func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
+func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) {
 	lc := NewListenConfig(multi)
 	pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
 	if err != nil {
 		return nil, err
 	}
 	if uc, ok := pc.(*net.UDPConn); ok {
-		return &udpConn{UDPConn: uc, l: l}, nil
+		return &Conn{UDPConn: uc, l: l}, nil
 	}
 	return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
 }
 
-func (uc *udpConn) WriteTo(b []byte, addr *udpAddr) error {
+func (uc *Conn) WriteTo(b []byte, addr *Addr) error {
 	_, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
 	return err
 }
 
-func (uc *udpConn) LocalAddr() (*udpAddr, error) {
+func (uc *Conn) LocalAddr() (*Addr, error) {
 	a := uc.UDPConn.LocalAddr()
 
 	switch v := a.(type) {
 	case *net.UDPAddr:
-		addr := &udpAddr{IP: make([]byte, len(v.IP))}
+		addr := &Addr{IP: make([]byte, len(v.IP))}
 		copy(addr.IP, v.IP)
 		addr.Port = uint16(v.Port)
 		return addr, nil
@@ -52,11 +55,11 @@ func (uc *udpConn) LocalAddr() (*udpAddr, error) {
 	}
 }
 
-func (u *udpConn) reloadConfig(c *Config) {
+func (u *Conn) ReloadConfig(c *config.C) {
 	// TODO
 }
 
-func NewUDPStatsEmitter(udpConns []*udpConn) func() {
+func NewUDPStatsEmitter(udpConns []*Conn) func() {
 	// No UDP stats for non-linux
 	return func() {}
 }
@@ -65,32 +68,24 @@ type rawMessage struct {
 	Len uint32
 }
 
-func (u *udpConn) ListenOut(f *Interface, q int) {
-	plaintext := make([]byte, mtu)
-	buffer := make([]byte, mtu)
-	header := &Header{}
-	fwPacket := &FirewallPacket{}
-	udpAddr := &udpAddr{IP: make([]byte, 16)}
+func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
+	plaintext := make([]byte, MTU)
+	buffer := make([]byte, MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	udpAddr := &Addr{IP: make([]byte, 16)}
 	nb := make([]byte, 12, 12)
 
-	lhh := f.lightHouse.NewRequestHandler()
-
-	conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
-
 	for {
 		// Just read one packet at a time
 		n, rua, err := u.ReadFromUDP(buffer)
 		if err != nil {
-			f.l.WithError(err).Error("Failed to read packets")
+			u.l.WithError(err).Error("Failed to read packets")
 			continue
 		}
 
 		udpAddr.IP = rua.IP
 		udpAddr.Port = uint16(rua.Port)
-		f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get(f.l))
+		r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
 	}
 }
-
-func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
-	return !addr.Equals(newaddr)
-}

+ 29 - 33
udp_linux.go → udp/udp_linux.go

@@ -1,7 +1,7 @@
 //go:build !android && !e2e_testing
 // +build !android,!e2e_testing
 
-package nebula
+package udp
 
 import (
 	"encoding/binary"
@@ -12,14 +12,18 @@ import (
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
 	"golang.org/x/sys/unix"
 )
 
 //TODO: make it support reload as best you can!
 
-type udpConn struct {
+type Conn struct {
 	sysFd int
 	l     *logrus.Logger
+	batch int
 }
 
 var x int
@@ -41,7 +45,7 @@ const (
 
 type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
 
-func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
+func NewListener(l *logrus.Logger, ip string, port int, multi bool, batch int) (*Conn, error) {
 	syscall.ForkLock.RLock()
 	fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
 	if err == nil {
@@ -73,36 +77,36 @@ func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, e
 	//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
 	//l.Println(v, err)
 
-	return &udpConn{sysFd: fd, l: l}, err
+	return &Conn{sysFd: fd, l: l, batch: batch}, err
 }
 
-func (u *udpConn) Rebind() error {
+func (u *Conn) Rebind() error {
 	return nil
 }
 
-func (u *udpConn) SetRecvBuffer(n int) error {
+func (u *Conn) SetRecvBuffer(n int) error {
 	return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
 }
 
-func (u *udpConn) SetSendBuffer(n int) error {
+func (u *Conn) SetSendBuffer(n int) error {
 	return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
 }
 
-func (u *udpConn) GetRecvBuffer() (int, error) {
+func (u *Conn) GetRecvBuffer() (int, error) {
 	return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
 }
 
-func (u *udpConn) GetSendBuffer() (int, error) {
+func (u *Conn) GetSendBuffer() (int, error) {
 	return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
 }
 
-func (u *udpConn) LocalAddr() (*udpAddr, error) {
+func (u *Conn) LocalAddr() (*Addr, error) {
 	sa, err := unix.Getsockname(u.sysFd)
 	if err != nil {
 		return nil, err
 	}
 
-	addr := &udpAddr{}
+	addr := &Addr{}
 	switch sa := sa.(type) {
 	case *unix.SockaddrInet4:
 		addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
@@ -115,25 +119,21 @@ func (u *udpConn) LocalAddr() (*udpAddr, error) {
 	return addr, nil
 }
 
-func (u *udpConn) ListenOut(f *Interface, q int) {
-	plaintext := make([]byte, mtu)
-	header := &Header{}
-	fwPacket := &FirewallPacket{}
-	udpAddr := &udpAddr{}
+func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
+	plaintext := make([]byte, MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	udpAddr := &Addr{}
 	nb := make([]byte, 12, 12)
 
-	lhh := f.lightHouse.NewRequestHandler()
-
 	//TODO: should we track this?
 	//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
-	msgs, buffers, names := u.PrepareRawMessages(f.udpBatchSize)
+	msgs, buffers, names := u.PrepareRawMessages(u.batch)
 	read := u.ReadMulti
-	if f.udpBatchSize == 1 {
+	if u.batch == 1 {
 		read = u.ReadSingle
 	}
 
-	conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
-
 	for {
 		n, err := read(msgs)
 		if err != nil {
@@ -145,12 +145,12 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
 		for i := 0; i < n; i++ {
 			udpAddr.IP = names[i][8:24]
 			udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
-			f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
+			r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
 		}
 	}
 }
 
-func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
+func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) {
 	for {
 		n, _, err := unix.Syscall6(
 			unix.SYS_RECVMSG,
@@ -171,7 +171,7 @@ func (u *udpConn) ReadSingle(msgs []rawMessage) (int, error) {
 	}
 }
 
-func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) {
+func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) {
 	for {
 		n, _, err := unix.Syscall6(
 			unix.SYS_RECVMMSG,
@@ -191,7 +191,7 @@ func (u *udpConn) ReadMulti(msgs []rawMessage) (int, error) {
 	}
 }
 
-func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
+func (u *Conn) WriteTo(b []byte, addr *Addr) error {
 
 	var rsa unix.RawSockaddrInet6
 	rsa.Family = unix.AF_INET6
@@ -221,7 +221,7 @@ func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
 	}
 }
 
-func (u *udpConn) reloadConfig(c *Config) {
+func (u *Conn) ReloadConfig(c *config.C) {
 	b := c.GetInt("listen.read_buffer", 0)
 	if b > 0 {
 		err := u.SetRecvBuffer(b)
@@ -253,7 +253,7 @@ func (u *udpConn) reloadConfig(c *Config) {
 	}
 }
 
-func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error {
+func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error {
 	var vallen uint32 = 4 * _SK_MEMINFO_VARS
 	_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
 	if err != 0 {
@@ -262,7 +262,7 @@ func (u *udpConn) getMemInfo(meminfo *_SK_MEMINFO) error {
 	return nil
 }
 
-func NewUDPStatsEmitter(udpConns []*udpConn) func() {
+func NewUDPStatsEmitter(udpConns []*Conn) func() {
 	// Check if our kernel supports SO_MEMINFO before registering the gauges
 	var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
 	var meminfo _SK_MEMINFO
@@ -293,7 +293,3 @@ func NewUDPStatsEmitter(udpConns []*udpConn) func() {
 		}
 	}
 }
-
-func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
-	return !addr.Equals(newaddr)
-}

+ 3 - 3
udp_linux_32.go → udp/udp_linux_32.go

@@ -4,7 +4,7 @@
 // +build !android
 // +build !e2e_testing
 
-package nebula
+package udp
 
 import (
 	"golang.org/x/sys/unix"
@@ -30,13 +30,13 @@ type rawMessage struct {
 	Len uint32
 }
 
-func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
+func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
 	msgs := make([]rawMessage, n)
 	buffers := make([][]byte, n)
 	names := make([][]byte, n)
 
 	for i := range msgs {
-		buffers[i] = make([]byte, mtu)
+		buffers[i] = make([]byte, MTU)
 		names[i] = make([]byte, unix.SizeofSockaddrInet6)
 
 		//TODO: this is still silly, no need for an array

+ 3 - 3
udp_linux_64.go → udp/udp_linux_64.go

@@ -4,7 +4,7 @@
 // +build !android
 // +build !e2e_testing
 
-package nebula
+package udp
 
 import (
 	"golang.org/x/sys/unix"
@@ -33,13 +33,13 @@ type rawMessage struct {
 	Pad0 [4]byte
 }
 
-func (u *udpConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
+func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
 	msgs := make([]rawMessage, n)
 	buffers := make([][]byte, n)
 	names := make([][]byte, n)
 
 	for i := range msgs {
-		buffers[i] = make([]byte, mtu)
+		buffers[i] = make([]byte, MTU)
 		names[i] = make([]byte, unix.SizeofSockaddrInet6)
 
 		//TODO: this is still silly, no need for an array

+ 39 - 43
udp_tester.go → udp/udp_tester.go

@@ -1,16 +1,19 @@
 //go:build e2e_testing
 // +build e2e_testing
 
-package nebula
+package udp
 
 import (
 	"fmt"
 	"net"
 
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
 )
 
-type UdpPacket struct {
+type Packet struct {
 	ToIp     net.IP
 	ToPort   uint16
 	FromIp   net.IP
@@ -18,8 +21,8 @@ type UdpPacket struct {
 	Data     []byte
 }
 
-func (u *UdpPacket) Copy() *UdpPacket {
-	n := &UdpPacket{
+func (u *Packet) Copy() *Packet {
+	n := &Packet{
 		ToIp:     make(net.IP, len(u.ToIp)),
 		ToPort:   u.ToPort,
 		FromIp:   make(net.IP, len(u.FromIp)),
@@ -33,20 +36,20 @@ func (u *UdpPacket) Copy() *UdpPacket {
 	return n
 }
 
-type udpConn struct {
-	addr *udpAddr
+type Conn struct {
+	Addr *Addr
 
-	rxPackets chan *UdpPacket // Packets to receive into nebula
-	txPackets chan *UdpPacket // Packets transmitted outside by nebula
+	RxPackets chan *Packet // Packets to receive into nebula
+	TxPackets chan *Packet // Packets transmitted outside by nebula
 
 	l *logrus.Logger
 }
 
-func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error) {
-	return &udpConn{
-		addr:      &udpAddr{net.ParseIP(ip), uint16(port)},
-		rxPackets: make(chan *UdpPacket, 1),
-		txPackets: make(chan *UdpPacket, 1),
+func NewListener(l *logrus.Logger, ip string, port int, _ bool, _ int) (*Conn, error) {
+	return &Conn{
+		Addr:      &Addr{net.ParseIP(ip), uint16(port)},
+		RxPackets: make(chan *Packet, 1),
+		TxPackets: make(chan *Packet, 1),
 		l:         l,
 	}, nil
 }
@@ -54,8 +57,8 @@ func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error
 // Send will place a UdpPacket onto the receive queue for nebula to consume
 // this is an encrypted packet or a handshake message in most cases
 // packets were transmitted from another nebula node, you can send them with Tun.Send
-func (u *udpConn) Send(packet *UdpPacket) {
-	h := &Header{}
+func (u *Conn) Send(packet *Packet) {
+	h := &header.H{}
 	if err := h.Parse(packet.Data); err != nil {
 		panic(err)
 	}
@@ -63,19 +66,19 @@ func (u *udpConn) Send(packet *UdpPacket) {
 		WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
 		WithField("dataLen", len(packet.Data)).
 		Info("UDP receiving injected packet")
-	u.rxPackets <- packet
+	u.RxPackets <- packet
 }
 
 // Get will pull a UdpPacket from the transmit queue
 // nebula meant to send this message on the network, it will be encrypted
 // packets were ingested from the tun side (in most cases), you can send them with Tun.Send
-func (u *udpConn) Get(block bool) *UdpPacket {
+func (u *Conn) Get(block bool) *Packet {
 	if block {
-		return <-u.txPackets
+		return <-u.TxPackets
 	}
 
 	select {
-	case p := <-u.txPackets:
+	case p := <-u.TxPackets:
 		return p
 	default:
 		return nil
@@ -86,56 +89,49 @@ func (u *udpConn) Get(block bool) *UdpPacket {
 // Below this is boilerplate implementation to make nebula actually work
 //********************************************************************************************************************//
 
-func (u *udpConn) WriteTo(b []byte, addr *udpAddr) error {
-	p := &UdpPacket{
+func (u *Conn) WriteTo(b []byte, addr *Addr) error {
+	p := &Packet{
 		Data:     make([]byte, len(b), len(b)),
 		FromIp:   make([]byte, 16),
-		FromPort: u.addr.Port,
+		FromPort: u.Addr.Port,
 		ToIp:     make([]byte, 16),
 		ToPort:   addr.Port,
 	}
 
 	copy(p.Data, b)
 	copy(p.ToIp, addr.IP.To16())
-	copy(p.FromIp, u.addr.IP.To16())
+	copy(p.FromIp, u.Addr.IP.To16())
 
-	u.txPackets <- p
+	u.TxPackets <- p
 	return nil
 }
 
-func (u *udpConn) ListenOut(f *Interface, q int) {
-	plaintext := make([]byte, mtu)
-	header := &Header{}
-	fwPacket := &FirewallPacket{}
-	ua := &udpAddr{IP: make([]byte, 16)}
+func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
+	plaintext := make([]byte, MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	ua := &Addr{IP: make([]byte, 16)}
 	nb := make([]byte, 12, 12)
 
-	lhh := f.lightHouse.NewRequestHandler()
-	conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
-
 	for {
-		p := <-u.rxPackets
+		p := <-u.RxPackets
 		ua.Port = p.FromPort
 		copy(ua.IP, p.FromIp.To16())
-		f.readOutsidePackets(ua, plaintext[:0], p.Data, header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
+		r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
 	}
 }
 
-func (u *udpConn) reloadConfig(*Config) {}
+func (u *Conn) ReloadConfig(*config.C) {}
 
-func NewUDPStatsEmitter(_ []*udpConn) func() {
+func NewUDPStatsEmitter(_ []*Conn) func() {
 	// No UDP stats for non-linux
 	return func() {}
 }
 
-func (u *udpConn) LocalAddr() (*udpAddr, error) {
-	return u.addr, nil
+func (u *Conn) LocalAddr() (*Addr, error) {
+	return u.Addr, nil
 }
 
-func (u *udpConn) Rebind() error {
+func (u *Conn) Rebind() error {
 	return nil
 }
-
-func hostDidRoam(addr *udpAddr, newaddr *udpAddr) bool {
-	return !addr.Equals(newaddr)
-}

+ 2 - 2
udp_windows.go → udp/udp_windows.go

@@ -1,7 +1,7 @@
 //go:build !e2e_testing
 // +build !e2e_testing
 
-package nebula
+package udp
 
 // Windows support is primarily implemented in udp_generic, besides NewListenConfig
 
@@ -24,6 +24,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
 	}
 }
 
-func (u *udpConn) Rebind() error {
+func (u *Conn) Rebind() error {
 	return nil
 }

+ 3 - 4
main_test.go → util/main.go

@@ -1,4 +1,4 @@
-package nebula
+package util
 
 import (
 	"io/ioutil"
@@ -17,13 +17,12 @@ func NewTestLogger() *logrus.Logger {
 	}
 
 	switch v {
-	case "1":
-		// This is the default level but we are being explicit
-		l.SetLevel(logrus.InfoLevel)
 	case "2":
 		l.SetLevel(logrus.DebugLevel)
 	case "3":
 		l.SetLevel(logrus.TraceLevel)
+	default:
+		l.SetLevel(logrus.InfoLevel)
 	}
 
 	return l