瀏覽代碼

subnet support

Ryan Huber 5 年之前
父節點
當前提交
9333a8e3b7
共有 16 個文件被更改,包括 225 次插入66 次删除
  1. 2 1
      connection_manager.go
  2. 2 1
      connection_manager_test.go
  3. 7 0
      examples/config.yml
  4. 7 2
      firewall.go
  5. 40 15
      firewall_test.go
  6. 2 0
      handshake_ix.go
  7. 10 7
      handshake_manager_test.go
  8. 31 7
      hostmap.go
  9. 13 13
      hostmap_test.go
  10. 5 2
      inside.go
  11. 4 0
      main.go
  12. 1 8
      outside.go
  13. 69 0
      tun_common.go
  14. 4 1
      tun_darwin.go
  15. 24 8
      tun_linux.go
  16. 4 1
      tun_windows.go

+ 2 - 1
connection_manager.go

@@ -1,9 +1,10 @@
 package nebula
 package nebula
 
 
 import (
 import (
-	"github.com/sirupsen/logrus"
 	"sync"
 	"sync"
 	"time"
 	"time"
+
+	"github.com/sirupsen/logrus"
 )
 )
 
 
 // TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
 // TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet

+ 2 - 1
connection_manager_test.go

@@ -10,12 +10,13 @@ import (
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
-var vpnIP uint32 = uint32(12341234)
+var vpnIP uint32
 
 
 func Test_NewConnectionManagerTest(t *testing.T) {
 func Test_NewConnectionManagerTest(t *testing.T) {
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
+	vpnIP = ip2int(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
 	preferredRanges := []*net.IPNet{localrange}
 
 
 	// Very incomplete mock objects
 	// Very incomplete mock objects

+ 7 - 0
examples/config.yml

@@ -100,6 +100,13 @@ tun:
   routes:
   routes:
     #- mtu: 8800
     #- mtu: 8800
     #  route: 10.0.0.0/16
     #  route: 10.0.0.0/16
+  # Unsafe routes allows you to route traffic over nebula to non-nebula nodes
+  # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
+  # NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate
+  unsafe_routes:
+    - route: 172.16.1.0/24
+      via: 192.168.100.99
+
 
 
 # TODO
 # TODO
 # Configure logging level
 # Configure logging level

+ 7 - 2
firewall.go

@@ -343,12 +343,17 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
 	return nil
 	return nil
 }
 }
 
 
-func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
 	// Check if we spoke to this tuple, if we did then allow this packet
 	// Check if we spoke to this tuple, if we did then allow this packet
 	if f.inConns(packet, fp, incoming) {
 	if f.inConns(packet, fp, incoming) {
 		return false
 		return false
 	}
 	}
 
 
+	// Make sure remote address matches nebula certificate
+	if h.remoteCidr.Contains(fp.RemoteIP) == nil {
+		return true
+	}
+
 	// Make sure we are supposed to be handling this local ip address
 	// Make sure we are supposed to be handling this local ip address
 	if f.localIps.Contains(fp.LocalIP) == nil {
 	if f.localIps.Contains(fp.LocalIP) == nil {
 		return true
 		return true
@@ -360,7 +365,7 @@ func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, c *cert
 	}
 	}
 
 
 	// We now know which firewall table to check against
 	// We now know which firewall table to check against
-	if !table.match(fp, incoming, c, caPool) {
+	if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
 		return true
 		return true
 	}
 	}
 
 

+ 40 - 15
firewall_test.go

@@ -3,13 +3,14 @@ package nebula
 import (
 import (
 	"encoding/binary"
 	"encoding/binary"
 	"errors"
 	"errors"
-	"github.com/rcrowley/go-metrics"
-	"github.com/slackhq/nebula/cert"
-	"github.com/stretchr/testify/assert"
 	"math"
 	"math"
 	"net"
 	"net"
 	"testing"
 	"testing"
 	"time"
 	"time"
+
+	"github.com/rcrowley/go-metrics"
+	"github.com/slackhq/nebula/cert"
+	"github.com/stretchr/testify/assert"
 )
 )
 
 
 func TestNewFirewall(t *testing.T) {
 func TestNewFirewall(t *testing.T) {
@@ -134,7 +135,7 @@ func TestFirewall_AddRule(t *testing.T) {
 func TestFirewall_Drop(t *testing.T) {
 func TestFirewall_Drop(t *testing.T) {
 	p := FirewallPacket{
 	p := FirewallPacket{
 		ip2int(net.IPv4(1, 2, 3, 4)),
 		ip2int(net.IPv4(1, 2, 3, 4)),
-		101,
+		ip2int(net.IPv4(1, 2, 3, 4)),
 		10,
 		10,
 		90,
 		90,
 		fwProtoUDP,
 		fwProtoUDP,
@@ -154,39 +155,51 @@ func TestFirewall_Drop(t *testing.T) {
 			Issuer: "signer-shasum",
 			Issuer: "signer-shasum",
 		},
 		},
 	}
 	}
+	h := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c,
+		},
+	}
+	h.CreateRemoteCIDR(&c)
 
 
 	fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	// Drop outbound
 	// Drop outbound
-	assert.True(t, fw.Drop([]byte{}, p, false, &c, cp))
+	assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
 	// Allow inbound
 	// Allow inbound
-	assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
+	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
 	// Allow outbound because conntrack
 	// Allow outbound because conntrack
-	assert.False(t, fw.Drop([]byte{}, p, false, &c, cp))
+	assert.False(t, fw.Drop([]byte{}, p, false, &h, cp))
+
+	// test remote mismatch
+	oldRemote := p.RemoteIP
+	p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
+	assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
+	p.RemoteIP = oldRemote
 
 
 	// test caSha assertions true
 	// test caSha assertions true
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum"))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum"))
-	assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
+	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
 
 
 	// test caSha assertions false
 	// test caSha assertions false
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope"))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope"))
-	assert.True(t, fw.Drop([]byte{}, p, true, &c, cp))
+	assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
 
 
 	// test caName true
 	// test caName true
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", ""))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", ""))
-	assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
+	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
 
 
 	// test caName false
 	// test caName false
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", ""))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", ""))
-	assert.True(t, fw.Drop([]byte{}, p, true, &c, cp))
+	assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
 }
 }
 
 
 func BenchmarkFirewallTable_match(b *testing.B) {
 func BenchmarkFirewallTable_match(b *testing.B) {
@@ -286,7 +299,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 func TestFirewall_Drop2(t *testing.T) {
 func TestFirewall_Drop2(t *testing.T) {
 	p := FirewallPacket{
 	p := FirewallPacket{
 		ip2int(net.IPv4(1, 2, 3, 4)),
 		ip2int(net.IPv4(1, 2, 3, 4)),
-		101,
+		ip2int(net.IPv4(1, 2, 3, 4)),
 		10,
 		10,
 		90,
 		90,
 		fwProtoUDP,
 		fwProtoUDP,
@@ -305,6 +318,12 @@ func TestFirewall_Drop2(t *testing.T) {
 			InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
 			InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
 		},
 		},
 	}
 	}
+	h := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c,
+		},
+	}
+	h.CreateRemoteCIDR(&c)
 
 
 	c1 := cert.NebulaCertificate{
 	c1 := cert.NebulaCertificate{
 		Details: cert.NebulaCertificateDetails{
 		Details: cert.NebulaCertificateDetails{
@@ -313,15 +332,21 @@ func TestFirewall_Drop2(t *testing.T) {
 			InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
 			InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
 		},
 		},
 	}
 	}
+	h1 := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c1,
+		},
+	}
+	h1.CreateRemoteCIDR(&c1)
 
 
 	fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
-	// c1 lacks the proper groups
-	assert.True(t, fw.Drop([]byte{}, p, true, &c1, cp))
+	// h1/c1 lacks the proper groups
+	assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp))
 	// c has the proper groups
 	// c has the proper groups
-	assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
+	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
 }
 }
 
 
 func BenchmarkLookup(b *testing.B) {
 func BenchmarkLookup(b *testing.B) {

+ 2 - 0
handshake_ix.go

@@ -205,6 +205,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 
 
 			//hostinfo.ClearRemotes()
 			//hostinfo.ClearRemotes()
 			hostinfo.AddRemote(*addr)
 			hostinfo.AddRemote(*addr)
+			hostinfo.CreateRemoteCIDR(remoteCert)
 			f.lightHouse.AddRemoteAndReset(ip, addr)
 			f.lightHouse.AddRemoteAndReset(ip, addr)
 			if f.serveDns {
 			if f.serveDns {
 				dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
 				dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
@@ -314,6 +315,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 
 
 		//hostinfo.ClearRemotes()
 		//hostinfo.ClearRemotes()
 		f.hostMap.AddRemote(ip, addr)
 		f.hostMap.AddRemote(ip, addr)
+		hostinfo.CreateRemoteCIDR(remoteCert)
 		f.lightHouse.AddRemoteAndReset(ip, addr)
 		f.lightHouse.AddRemoteAndReset(ip, addr)
 		if f.serveDns {
 		if f.serveDns {
 			dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
 			dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())

+ 10 - 7
handshake_manager_test.go

@@ -11,12 +11,13 @@ import (
 var indexes []uint32 = []uint32{1000, 2000, 3000, 4000}
 var indexes []uint32 = []uint32{1000, 2000, 3000, 4000}
 
 
 //var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
 //var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
-var ips []uint32 = []uint32{9000}
+var ips []uint32
 
 
 func Test_NewHandshakeManagerIndex(t *testing.T) {
 func Test_NewHandshakeManagerIndex(t *testing.T) {
-	_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
+	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
+	ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
 	preferredRanges := []*net.IPNet{localrange}
 	preferredRanges := []*net.IPNet{localrange}
 	mainHM := NewHostMap("test", vpncidr, preferredRanges)
 	mainHM := NewHostMap("test", vpncidr, preferredRanges)
 
 
@@ -54,9 +55,10 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
 }
 }
 
 
 func Test_NewHandshakeManagerVpnIP(t *testing.T) {
 func Test_NewHandshakeManagerVpnIP(t *testing.T) {
-	_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
+	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
+	ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
 	preferredRanges := []*net.IPNet{localrange}
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
 	mw := &mockEncWriter{}
 	mainHM := NewHostMap("test", vpncidr, preferredRanges)
 	mainHM := NewHostMap("test", vpncidr, preferredRanges)
@@ -102,9 +104,10 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
 }
 }
 
 
 func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
 func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
-	_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
+	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
+	vpnIP = ip2int(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
 	mw := &mockEncWriter{}
 	mainHM := NewHostMap("test", vpncidr, preferredRanges)
 	mainHM := NewHostMap("test", vpncidr, preferredRanges)
@@ -114,7 +117,7 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
 	now := time.Now()
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 
 
-	hostinfo := blah.AddVpnIP(101010)
+	hostinfo := blah.AddVpnIP(vpnIP)
 	// Pretned we have an index too
 	// Pretned we have an index too
 	blah.AddIndexHostInfo(12341234, hostinfo)
 	blah.AddIndexHostInfo(12341234, hostinfo)
 	assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234))
 	assert.Contains(t, blah.pendingHostMap.Indexes, uint32(12341234))
@@ -147,12 +150,12 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
 		l.Infoln(cumulative, next_tick)
 		l.Infoln(cumulative, next_tick)
 		blah.NextOutboundHandshakeTimerTick(next_tick)
 		blah.NextOutboundHandshakeTimerTick(next_tick)
 	*/
 	*/
-	assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
+	assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(vpnIP))
 	assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
 	assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
 }
 }
 
 
 func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
 func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
-	_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
+	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	preferredRanges := []*net.IPNet{localrange}
 	preferredRanges := []*net.IPNet{localrange}

+ 31 - 7
hostmap.go

@@ -29,6 +29,7 @@ type HostMap struct {
 	preferredRanges []*net.IPNet
 	preferredRanges []*net.IPNet
 	vpnCIDR         *net.IPNet
 	vpnCIDR         *net.IPNet
 	defaultRoute    uint32
 	defaultRoute    uint32
+	unsafeRoutes    *CIDRTree
 }
 }
 
 
 type HostInfo struct {
 type HostInfo struct {
@@ -46,6 +47,7 @@ type HostInfo struct {
 	localIndexId      uint32
 	localIndexId      uint32
 	hostId            uint32
 	hostId            uint32
 	recvError         int
 	recvError         int
+	remoteCidr        *CIDRTree
 
 
 	lastRoam       time.Time
 	lastRoam       time.Time
 	lastRoamRemote *udpAddr
 	lastRoamRemote *udpAddr
@@ -82,6 +84,7 @@ func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *
 		preferredRanges: preferredRanges,
 		preferredRanges: preferredRanges,
 		vpnCIDR:         vpnCIDR,
 		vpnCIDR:         vpnCIDR,
 		defaultRoute:    0,
 		defaultRoute:    0,
+		unsafeRoutes:    NewCIDRTree(),
 	}
 	}
 	return &m
 	return &m
 }
 }
@@ -286,13 +289,6 @@ func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostIn
 }
 }
 
 
 func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
 func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
-	if hm.vpnCIDR.Contains(int2ip(vpnIp)) == false && hm.defaultRoute != 0 {
-		// FIXME: this shouldn't ship
-		d := hm.Hosts[hm.defaultRoute]
-		if d != nil {
-			return hm.Hosts[hm.defaultRoute], nil
-		}
-	}
 	hm.RLock()
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
 	if h, ok := hm.Hosts[vpnIp]; ok {
 		if promoteIfce != nil {
 		if promoteIfce != nil {
@@ -314,6 +310,15 @@ func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo,
 	}
 	}
 }
 }
 
 
+func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
+	r := hm.unsafeRoutes.MostSpecificContains(ip)
+	if r != nil {
+		return r.(uint32)
+	} else {
+		return 0
+	}
+}
+
 func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool {
 func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool {
 	hm.RLock()
 	hm.RLock()
 	if i, ok := hm.Hosts[vpnIP]; ok {
 	if i, ok := hm.Hosts[vpnIP]; ok {
@@ -387,6 +392,13 @@ func (hm *HostMap) Punchy(conn *udpConn) {
 	}
 	}
 }
 }
 
 
+func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
+	for _, r := range *routes {
+		l.WithField("route", r.route).WithField("via", r.via).Error("Adding UNSAFE Route")
+		hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
+	}
+}
+
 func (i *HostInfo) MarshalJSON() ([]byte, error) {
 func (i *HostInfo) MarshalJSON() ([]byte, error) {
 	return json.Marshal(m{
 	return json.Marshal(m{
 		"remote":             i.remote,
 		"remote":             i.remote,
@@ -610,6 +622,18 @@ func (i *HostInfo) RecvErrorExceeded() bool {
 	return true
 	return true
 }
 }
 
 
+func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
+	remoteCidr := NewCIDRTree()
+	for _, ip := range c.Details.Ips {
+		remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
+	}
+
+	for _, n := range c.Details.Subnets {
+		remoteCidr.AddCIDR(n, struct{}{})
+	}
+	i.remoteCidr = remoteCidr
+}
+
 //########################
 //########################
 
 
 func NewHostInfoDest(addr *udpAddr) *HostInfoDest {
 func NewHostInfoDest(addr *udpAddr) *HostInfoDest {

+ 13 - 13
hostmap_test.go

@@ -74,26 +74,26 @@ func TestHostmap(t *testing.T) {
 	a := NewUDPAddrFromString("10.127.0.3:11111")
 	a := NewUDPAddrFromString("10.127.0.3:11111")
 	b := NewUDPAddrFromString("1.0.0.1:22222")
 	b := NewUDPAddrFromString("1.0.0.1:22222")
 	y := NewUDPAddrFromString("10.128.0.3:11111")
 	y := NewUDPAddrFromString("10.128.0.3:11111")
-	m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a)
-	m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), b)
-	m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
+	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
+	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
+	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
 
 
-	info, _ := m.QueryVpnIP(ip2int(net.ParseIP("127.0.0.1")))
+	info, _ := m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
 
 
 	// There should be three remotes in the host map
 	// There should be three remotes in the host map
 	assert.Equal(t, 3, len(info.Remotes))
 	assert.Equal(t, 3, len(info.Remotes))
 
 
 	// Adding an identical remote should not change the count
 	// Adding an identical remote should not change the count
-	m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
+	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
 	assert.Equal(t, 3, len(info.Remotes))
 	assert.Equal(t, 3, len(info.Remotes))
 
 
 	// Adding a fresh remote should add one
 	// Adding a fresh remote should add one
 	y = NewUDPAddrFromString("10.18.0.3:11111")
 	y = NewUDPAddrFromString("10.18.0.3:11111")
-	m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
+	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
 	assert.Equal(t, 4, len(info.Remotes))
 	assert.Equal(t, 4, len(info.Remotes))
 
 
 	// Query and reference remote should get the first one (and not nil)
 	// Query and reference remote should get the first one (and not nil)
-	info, _ = m.QueryVpnIP(ip2int(net.ParseIP("127.0.0.1")))
+	info, _ = m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
 	assert.NotNil(t, info.remote)
 	assert.NotNil(t, info.remote)
 
 
 	// Promotion should ensure that the best remote is chosen (y)
 	// Promotion should ensure that the best remote is chosen (y)
@@ -111,9 +111,9 @@ func TestHostmapdebug(t *testing.T) {
 	a := NewUDPAddrFromString("10.127.0.3:11111")
 	a := NewUDPAddrFromString("10.127.0.3:11111")
 	b := NewUDPAddrFromString("1.0.0.1:22222")
 	b := NewUDPAddrFromString("1.0.0.1:22222")
 	y := NewUDPAddrFromString("10.128.0.3:11111")
 	y := NewUDPAddrFromString("10.128.0.3:11111")
-	m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a)
-	m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), b)
-	m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
+	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
+	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
+	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
 
 
 	//t.Errorf("%s", m.DebugRemotes(1))
 	//t.Errorf("%s", m.DebugRemotes(1))
 }
 }
@@ -157,9 +157,9 @@ func BenchmarkHostmappromote2(b *testing.B) {
 		y := NewUDPAddrFromString("10.128.0.3:11111")
 		y := NewUDPAddrFromString("10.128.0.3:11111")
 		a := NewUDPAddrFromString("10.127.0.3:11111")
 		a := NewUDPAddrFromString("10.127.0.3:11111")
 		g := NewUDPAddrFromString("1.0.0.1:22222")
 		g := NewUDPAddrFromString("1.0.0.1:22222")
-		m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), a)
-		m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), g)
-		m.AddRemote(ip2int(net.ParseIP("127.0.0.1")), y)
+		m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
+		m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
+		m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
 	}
 	}
 	b.Errorf("hi")
 	b.Errorf("hi")
 
 

+ 5 - 2
inside.go

@@ -39,7 +39,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 		ci.queueLock.Unlock()
 		ci.queueLock.Unlock()
 	}
 	}
 
 
-	if !f.firewall.Drop(packet, *fwPacket, false, ci.peerCert, trustedCAs) {
+	if !f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs) {
 		f.send(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out)
 		f.send(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out)
 		if f.lightHouse != nil && *ci.messageCounter%5000 == 0 {
 		if f.lightHouse != nil && *ci.messageCounter%5000 == 0 {
 			f.lightHouse.Query(fwPacket.RemoteIP, f)
 			f.lightHouse.Query(fwPacket.RemoteIP, f)
@@ -52,6 +52,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 }
 }
 
 
 func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
 func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
+	if f.hostMap.vpnCIDR.Contains(int2ip(vpnIp)) == false {
+		vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
+	}
 	hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
 	hostinfo, err := f.hostMap.PromoteBestQueryVpnIP(vpnIp, f)
 
 
 	//if err != nil || hostinfo.ConnectionState == nil {
 	//if err != nil || hostinfo.ConnectionState == nil {
@@ -97,7 +100,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
 	}
 	}
 
 
 	// check if packet is in outbound fw rules
 	// check if packet is in outbound fw rules
-	if f.firewall.Drop(p, *fp, false, hostInfo.ConnectionState.peerCert, trustedCAs) {
+	if f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs) {
 		l.WithField("fwPacket", fp).Debugln("dropping cached packet")
 		l.WithField("fwPacket", fp).Debugln("dropping cached packet")
 		return
 		return
 	}
 	}

+ 4 - 0
main.go

@@ -79,6 +79,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
 	// TODO: make sure mask is 4 bytes
 	// TODO: make sure mask is 4 bytes
 	tunCidr := cs.certificate.Details.Ips[0]
 	tunCidr := cs.certificate.Details.Ips[0]
 	routes, err := parseRoutes(config, tunCidr)
 	routes, err := parseRoutes(config, tunCidr)
+	unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
 	if err != nil {
 	if err != nil {
 		l.WithError(err).Fatal("Could not parse tun.routes")
 		l.WithError(err).Fatal("Could not parse tun.routes")
 	}
 	}
@@ -109,6 +110,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
 		tunCidr,
 		tunCidr,
 		config.GetInt("tun.mtu", 1300),
 		config.GetInt("tun.mtu", 1300),
 		routes,
 		routes,
+		unsafeRoutes,
 		config.GetInt("tun.tx_queue", 500),
 		config.GetInt("tun.tx_queue", 500),
 	)
 	)
 	if err != nil {
 	if err != nil {
@@ -163,6 +165,8 @@ func Main(configPath string, configTest bool, buildVersion string) {
 
 
 	hostMap := NewHostMap("main", tunCidr, preferredRanges)
 	hostMap := NewHostMap("main", tunCidr, preferredRanges)
 	hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
 	hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
+	hostMap.addUnsafeRoutes(&unsafeRoutes)
+
 	l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
 	l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
 
 
 	/*
 	/*

+ 1 - 8
outside.go

@@ -255,13 +255,6 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
 func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
 func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
 	var err error
 	var err error
 
 
-	// TODO: This breaks subnet routing and needs to also check range of ip subnet
-	/*
-		if len(res) > 16 && binary.BigEndian.Uint32(res[12:16]) != ip2int(ci.peerCert.Details.Ips[0].IP) {
-			l.Debugf("Host %s tried to spoof packet as %s.", ci.peerCert.Details.Ips[0].IP, IntIp(binary.BigEndian.Uint32(res[12:16])))
-		}
-	*/
-
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
 	if err != nil {
 	if err != nil {
 		l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).Error("Failed to decrypt packet")
 		l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).Error("Failed to decrypt packet")
@@ -283,7 +276,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 		return
 		return
 	}
 	}
 
 
-	if f.firewall.Drop(out, *fwPacket, true, hostinfo.ConnectionState.peerCert, trustedCAs) {
+	if f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs) {
 		l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
 		l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("fwPacket", fwPacket).
 			Debugln("dropping inbound packet")
 			Debugln("dropping inbound packet")
 		return
 		return

+ 69 - 0
tun_common.go

@@ -9,6 +9,7 @@ import (
 type route struct {
 type route struct {
 	mtu   int
 	mtu   int
 	route *net.IPNet
 	route *net.IPNet
+	via   *net.IP
 }
 }
 
 
 func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
 func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
@@ -81,6 +82,74 @@ func parseRoutes(config *Config, network *net.IPNet) ([]route, error) {
 	return routes, nil
 	return routes, nil
 }
 }
 
 
+func parseUnsafeRoutes(config *Config, network *net.IPNet) ([]route, error) {
+	var err error
+
+	r := config.Get("tun.unsafe_routes")
+	if r == nil {
+		return []route{}, nil
+	}
+
+	rawRoutes, ok := r.([]interface{})
+	if !ok {
+		return nil, fmt.Errorf("tun.unsafe_routes is not an array")
+	}
+
+	if len(rawRoutes) < 1 {
+		return []route{}, nil
+	}
+
+	routes := make([]route, len(rawRoutes))
+	for i, r := range rawRoutes {
+		m, ok := r.(map[interface{}]interface{})
+		if !ok {
+			return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
+		}
+
+		rVia, ok := m["via"]
+		if !ok {
+			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1)
+		}
+
+		via, ok := rVia.(string)
+		if !ok {
+			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: %v", i+1, err)
+		}
+
+		nVia := net.ParseIP(via)
+		if nVia == nil {
+			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via)
+		}
+
+		rRoute, ok := m["route"]
+		if !ok {
+			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1)
+		}
+
+		r := route{
+			via: &nVia,
+		}
+
+		_, r.route, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
+		if err != nil {
+			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
+		}
+
+		if ipWithin(network, r.route) {
+			return nil, fmt.Errorf(
+				"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
+				i+1,
+				r.route.String(),
+				network.String(),
+			)
+		}
+
+		routes[i] = r
+	}
+
+	return routes, nil
+}
+
 func ipWithin(o *net.IPNet, i *net.IPNet) bool {
 func ipWithin(o *net.IPNet, i *net.IPNet) bool {
 	// Make sure o contains the lowest form of i
 	// Make sure o contains the lowest form of i
 	if !o.Contains(i.IP.Mask(i.Mask)) {
 	if !o.Contains(i.IP.Mask(i.Mask)) {

+ 4 - 1
tun_darwin.go

@@ -17,10 +17,13 @@ type Tun struct {
 	*water.Interface
 	*water.Interface
 }
 }
 
 
-func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) {
+func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
 	if len(routes) > 0 {
 	if len(routes) > 0 {
 		return nil, fmt.Errorf("Route MTU not supported in Darwin")
 		return nil, fmt.Errorf("Route MTU not supported in Darwin")
 	}
 	}
+	if len(unsafeRoutes) > 0 {
+		return nil, fmt.Errorf("unsafeRoutes not supported in Darwin")
+	}
 	// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
 	// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
 	return &Tun{
 	return &Tun{
 		Cidr: cidr,
 		Cidr: cidr,

+ 24 - 8
tun_linux.go

@@ -14,13 +14,14 @@ import (
 
 
 type Tun struct {
 type Tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
-	fd         int
-	Device     string
-	Cidr       *net.IPNet
-	MaxMTU     int
-	DefaultMTU int
-	TXQueueLen int
-	Routes     []route
+	fd           int
+	Device       string
+	Cidr         *net.IPNet
+	MaxMTU       int
+	DefaultMTU   int
+	TXQueueLen   int
+	Routes       []route
+	UnsafeRoutes []route
 }
 }
 
 
 type ifReq struct {
 type ifReq struct {
@@ -74,7 +75,7 @@ type ifreqQLEN struct {
 	pad   [8]byte
 	pad   [8]byte
 }
 }
 
 
-func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) {
+func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -106,6 +107,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
 		DefaultMTU:      defaultMTU,
 		DefaultMTU:      defaultMTU,
 		TXQueueLen:      txQueueLen,
 		TXQueueLen:      txQueueLen,
 		Routes:          routes,
 		Routes:          routes,
+		UnsafeRoutes:    unsafeRoutes,
 	}
 	}
 	return
 	return
 }
 }
@@ -238,6 +240,20 @@ func (c Tun) Activate() error {
 		}
 		}
 	}
 	}
 
 
+	// Unsafe path routes
+	for _, r := range c.UnsafeRoutes {
+		nr := netlink.Route{
+			LinkIndex: link.Attrs().Index,
+			Dst:       r.route,
+			Scope:     unix.RT_SCOPE_LINK,
+		}
+
+		err = netlink.RouteAdd(&nr)
+		if err != nil {
+			return fmt.Errorf("failed to set mtu %v on route %v; %v", r.mtu, r.route, err)
+		}
+	}
+
 	// Run the interface
 	// Run the interface
 	ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
 	ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
 	if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
 	if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {

+ 4 - 1
tun_windows.go

@@ -16,10 +16,13 @@ type Tun struct {
 	*water.Interface
 	*water.Interface
 }
 }
 
 
-func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, txQueueLen int) (ifce *Tun, err error) {
+func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
 	if len(routes) > 0 {
 	if len(routes) > 0 {
 		return nil, fmt.Errorf("Route MTU not supported in Windows")
 		return nil, fmt.Errorf("Route MTU not supported in Windows")
 	}
 	}
+	if len(unsafeRoutes) > 0 {
+		return nil, fmt.Errorf("unsafeRoutes not supported in Windows")
+	}
 	// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
 	// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
 	return &Tun{
 	return &Tun{
 		Cidr: cidr,
 		Cidr: cidr,