瀏覽代碼

add gvisor based service library (#965)

* add service/ library
Tristan Rice 1 年之前
父節點
當前提交
1083279a45
共有 13 個文件被更改,包括 808 次插入143 次删除
  1. 10 0
      control.go
  2. 18 18
      e2e/handshakes_test.go
  3. 118 0
      e2e/helpers.go
  4. 1 110
      e2e/helpers_test.go
  5. 100 0
      examples/go_service/main.go
  6. 5 1
      go.mod
  7. 12 4
      go.sum
  8. 8 2
      main.go
  9. 24 8
      overlay/tun.go
  10. 63 0
      overlay/user.go
  11. 36 0
      service/listener.go
  12. 248 0
      service/service.go
  13. 165 0
      service/service_test.go

+ 10 - 0
control.go

@@ -11,6 +11,7 @@ import (
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/udp"
 )
 
@@ -29,6 +30,7 @@ type controlHostLister interface {
 type Control struct {
 	f               *Interface
 	l               *logrus.Logger
+	ctx             context.Context
 	cancel          context.CancelFunc
 	sshStart        func()
 	statsStart      func()
@@ -71,6 +73,10 @@ func (c *Control) Start() {
 	c.f.run()
 }
 
+func (c *Control) Context() context.Context {
+	return c.ctx
+}
+
 // Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
 func (c *Control) Stop() {
 	// Stop the handshakeManager (and other services), to prevent new tunnels from
@@ -226,6 +232,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	return
 }
 
+func (c *Control) Device() overlay.Device {
+	return c.f.inside
+}
+
 func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 
 	chi := ControlHostInfo{

+ 18 - 18
e2e/handshakes_test.go

@@ -20,7 +20,7 @@ import (
 )
 
 func BenchmarkHotPath(b *testing.B) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -44,7 +44,7 @@ func BenchmarkHotPath(b *testing.B) {
 }
 
 func TestGoodHandshake(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -95,7 +95,7 @@ func TestGoodHandshake(t *testing.T) {
 }
 
 func TestWrongResponderHandshake(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 
 	// The IPs here are chosen on purpose:
 	// The current remote handling will sort by preference, public, and then lexically.
@@ -164,7 +164,7 @@ func TestStage1Race(t *testing.T) {
 	// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
 	// But will eventually collapse down to a single tunnel
 
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -241,7 +241,7 @@ func TestStage1Race(t *testing.T) {
 }
 
 func TestUncleanShutdownRaceLoser(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -290,7 +290,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
 }
 
 func TestUncleanShutdownRaceWinner(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -341,7 +341,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
 }
 
 func TestRelays(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -372,7 +372,7 @@ func TestRelays(t *testing.T) {
 
 func TestStage1RaceRelays(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -421,7 +421,7 @@ func TestStage1RaceRelays(t *testing.T) {
 
 func TestStage1RaceRelays2(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -508,7 +508,7 @@ func TestStage1RaceRelays2(t *testing.T) {
 	////TODO: assert hostmaps
 }
 func TestRehandshakingRelays(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -538,7 +538,7 @@ func TestRehandshakingRelays(t *testing.T) {
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
 	// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
 	r.Log("Renew relay certificate and spin until me and them sees it")
-	_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
+	_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
 
 	caB, err := ca.MarshalToPEM()
 	if err != nil {
@@ -612,7 +612,7 @@ func TestRehandshakingRelays(t *testing.T) {
 
 func TestRehandshakingRelaysPrimary(t *testing.T) {
 	// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -642,7 +642,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
 	// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
 	r.Log("Renew relay certificate and spin until me and them sees it")
-	_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
+	_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
 
 	caB, err := ca.MarshalToPEM()
 	if err != nil {
@@ -715,7 +715,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 }
 
 func TestRehandshaking(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
 
@@ -737,7 +737,7 @@ func TestRehandshaking(t *testing.T) {
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 	r.Log("Renew my certificate and spin until their sees it")
-	_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"})
+	_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"})
 
 	caB, err := ca.MarshalToPEM()
 	if err != nil {
@@ -811,7 +811,7 @@ func TestRehandshaking(t *testing.T) {
 func TestRehandshakingLoser(t *testing.T) {
 	// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
 	// Should be the one with the new certificate
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
 
@@ -837,7 +837,7 @@ func TestRehandshakingLoser(t *testing.T) {
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 	r.Log("Renew their certificate and spin until mine sees it")
-	_, _, theirNextPrivKey, theirNextPEM := newTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"})
+	_, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"})
 
 	caB, err := ca.MarshalToPEM()
 	if err != nil {
@@ -912,7 +912,7 @@ func TestRaceRegression(t *testing.T) {
 	// This test forces stage 1, stage 2, stage 1 to be received by me from them
 	// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
 	// caused a cross-linked hostinfo
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 

+ 118 - 0
e2e/helpers.go

@@ -0,0 +1,118 @@
+package e2e
+
+import (
+	"crypto/rand"
+	"io"
+	"net"
+	"time"
+
+	"github.com/slackhq/nebula/cert"
+	"golang.org/x/crypto/curve25519"
+	"golang.org/x/crypto/ed25519"
+)
+
+// NewTestCaCert will generate a CA cert
+func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+	pub, priv, err := ed25519.GenerateKey(rand.Reader)
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	nc := &cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:           "test ca",
+			NotBefore:      time.Unix(before.Unix(), 0),
+			NotAfter:       time.Unix(after.Unix(), 0),
+			PublicKey:      pub,
+			IsCA:           true,
+			InvertedGroups: make(map[string]struct{}),
+		},
+	}
+
+	if len(ips) > 0 {
+		nc.Details.Ips = ips
+	}
+
+	if len(subnets) > 0 {
+		nc.Details.Subnets = subnets
+	}
+
+	if len(groups) > 0 {
+		nc.Details.Groups = groups
+	}
+
+	err = nc.Sign(cert.Curve_CURVE25519, priv)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := nc.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return nc, pub, priv, pem
+}
+
+// NewTestCert will generate a signed certificate with the provided details.
+// Expiry times are defaulted if you do not pass them in
+func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+	issuer, err := ca.Sha256Sum()
+	if err != nil {
+		panic(err)
+	}
+
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	pub, rawPriv := x25519Keypair()
+
+	nc := &cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:           name,
+			Ips:            []*net.IPNet{ip},
+			Subnets:        subnets,
+			Groups:         groups,
+			NotBefore:      time.Unix(before.Unix(), 0),
+			NotAfter:       time.Unix(after.Unix(), 0),
+			PublicKey:      pub,
+			IsCA:           false,
+			Issuer:         issuer,
+			InvertedGroups: make(map[string]struct{}),
+		},
+	}
+
+	err = nc.Sign(ca.Details.Curve, key)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := nc.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
+}
+
+func x25519Keypair() ([]byte, []byte) {
+	privkey := make([]byte, 32)
+	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
+		panic(err)
+	}
+
+	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
+	if err != nil {
+		panic(err)
+	}
+
+	return pubkey, privkey
+}

+ 1 - 110
e2e/helpers_test.go

@@ -4,7 +4,6 @@
 package e2e
 
 import (
-	"crypto/rand"
 	"fmt"
 	"io"
 	"net"
@@ -22,8 +21,6 @@ import (
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
-	"golang.org/x/crypto/curve25519"
-	"golang.org/x/crypto/ed25519"
 	"gopkg.in/yaml.v2"
 )
 
@@ -40,7 +37,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		IP:   udpIp,
 		Port: 4242,
 	}
-	_, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+	_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
 
 	caB, err := caCrt.MarshalToPEM()
 	if err != nil {
@@ -108,112 +105,6 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 	return control, vpnIpNet, &udpAddr, c
 }
 
-// newTestCaCert will generate a CA cert
-func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	nc := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "test ca",
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           true,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	if len(ips) > 0 {
-		nc.Details.Ips = ips
-	}
-
-	if len(subnets) > 0 {
-		nc.Details.Subnets = subnets
-	}
-
-	if len(groups) > 0 {
-		nc.Details.Groups = groups
-	}
-
-	err = nc.Sign(cert.Curve_CURVE25519, priv)
-	if err != nil {
-		panic(err)
-	}
-
-	pem, err := nc.MarshalToPEM()
-	if err != nil {
-		panic(err)
-	}
-
-	return nc, pub, priv, pem
-}
-
-// newTestCert will generate a signed certificate with the provided details.
-// Expiry times are defaulted if you do not pass them in
-func newTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
-	issuer, err := ca.Sha256Sum()
-	if err != nil {
-		panic(err)
-	}
-
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	pub, rawPriv := x25519Keypair()
-
-	nc := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           name,
-			Ips:            []*net.IPNet{ip},
-			Subnets:        subnets,
-			Groups:         groups,
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           false,
-			Issuer:         issuer,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	err = nc.Sign(ca.Details.Curve, key)
-	if err != nil {
-		panic(err)
-	}
-
-	pem, err := nc.MarshalToPEM()
-	if err != nil {
-		panic(err)
-	}
-
-	return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
-}
-
-func x25519Keypair() ([]byte, []byte) {
-	privkey := make([]byte, 32)
-	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
-		panic(err)
-	}
-
-	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
-	if err != nil {
-		panic(err)
-	}
-
-	return pubkey, privkey
-}
-
 type doneCb func()
 
 func deadline(t *testing.T, seconds time.Duration) doneCb {

+ 100 - 0
examples/go_service/main.go

@@ -0,0 +1,100 @@
+package main
+
+import (
+	"bufio"
+	"fmt"
+	"log"
+
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/service"
+)
+
+func main() {
+	if err := run(); err != nil {
+		log.Fatalf("%+v", err)
+	}
+}
+
+func run() error {
+	configStr := `
+tun:
+  user: true
+
+static_host_map:
+  '192.168.100.1': ['localhost:4242']
+
+listen:
+  host: 0.0.0.0
+  port: 4241
+
+lighthouse:
+  am_lighthouse: false
+  interval: 60
+  hosts:
+    - '192.168.100.1'
+
+firewall:
+  outbound:
+    # Allow all outbound traffic from this node
+    - port: any
+      proto: any
+      host: any
+
+  inbound:
+    # Allow icmp between any nebula hosts
+    - port: any
+      proto: icmp
+      host: any
+    - port: any
+      proto: any
+      host: any
+
+pki:
+  ca: /home/rice/Developer/nebula-config/ca.crt
+  cert: /home/rice/Developer/nebula-config/app.crt
+  key: /home/rice/Developer/nebula-config/app.key
+`
+	var config config.C
+	if err := config.LoadString(configStr); err != nil {
+		return err
+	}
+	service, err := service.New(&config)
+	if err != nil {
+		return err
+	}
+
+	ln, err := service.Listen("tcp", ":1234")
+	if err != nil {
+		return err
+	}
+	for {
+		conn, err := ln.Accept()
+		if err != nil {
+			log.Printf("accept error: %s", err)
+			break
+		}
+		defer conn.Close()
+
+		log.Printf("got connection")
+
+		conn.Write([]byte("hello world\n"))
+
+		scanner := bufio.NewScanner(conn)
+		for scanner.Scan() {
+			message := scanner.Text()
+			fmt.Fprintf(conn, "echo: %q\n", message)
+			log.Printf("got message %q", message)
+		}
+
+		if err := scanner.Err(); err != nil {
+			log.Printf("scanner error: %s", err)
+			break
+		}
+	}
+
+	service.Close()
+	if err := service.Wait(); err != nil {
+		return err
+	}
+	return nil
+}

+ 5 - 1
go.mod

@@ -19,10 +19,11 @@ require (
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
 	github.com/stretchr/testify v1.8.4
-	github.com/vishvananda/netlink v1.1.0
+	github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54
 	golang.org/x/crypto v0.14.0
 	golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53
 	golang.org/x/net v0.17.0
+	golang.org/x/sync v0.3.0
 	golang.org/x/sys v0.14.0
 	golang.org/x/term v0.13.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
@@ -30,6 +31,7 @@ require (
 	golang.zx2c4.com/wireguard/windows v0.5.3
 	google.golang.org/protobuf v1.31.0
 	gopkg.in/yaml.v2 v2.4.0
+	gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f
 )
 
 require (
@@ -37,6 +39,7 @@ require (
 	github.com/cespare/xxhash/v2 v2.2.0 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/golang/protobuf v1.5.3 // indirect
+	github.com/google/btree v1.0.1 // indirect
 	github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect
@@ -44,6 +47,7 @@ require (
 	github.com/prometheus/procfs v0.11.1 // indirect
 	github.com/vishvananda/netns v0.0.4 // indirect
 	golang.org/x/mod v0.12.0 // indirect
+	golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
 	golang.org/x/tools v0.13.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )

+ 12 - 4
go.sum

@@ -47,6 +47,8 @@ github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw
 github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
 github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
 github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
+github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
+github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
 github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
 github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
 github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
@@ -135,9 +137,9 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
 github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0=
-github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
-github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
+github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 h1:8mhqcHPqTMhSPoslhGYihEgSfc77+7La1P6kiB6+9So=
+github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
+github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
 github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
 github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
 github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -177,16 +179,18 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
+golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
 golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -201,6 +205,8 @@ golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
 golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
@@ -244,3 +250,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f h1:8GE2MRjGiFmfpon8dekPI08jEuNMQzSffVHgdupcO4E=
+gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f/go.mod h1:pzr6sy8gDLfVmDAg8OYrlKvGEHw5C3PGTiBXBTCx76Q=

+ 8 - 2
main.go

@@ -18,7 +18,7 @@ import (
 
 type m map[string]interface{}
 
-func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
+func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
 	ctx, cancel := context.WithCancel(context.Background())
 	// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
 	defer func() {
@@ -128,7 +128,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	if !configTest {
 		c.CatchHUP(ctx)
 
-		tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines)
+		if deviceFactory == nil {
+			deviceFactory = overlay.NewDeviceFromConfig
+		}
+
+		tun, err = deviceFactory(c, l, tunCidr, routines)
 		if err != nil {
 			return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
 		}
@@ -159,6 +163,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 
 		for i := 0; i < routines; i++ {
+			l.Infof("listening %q %d", listenHost.IP, port)
 			udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64))
 			if err != nil {
 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
@@ -335,6 +340,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	return &Control{
 		ifce,
 		l,
+		ctx,
 		cancel,
 		sshStart,
 		statsStart,

+ 24 - 8
overlay/tun.go

@@ -10,7 +10,9 @@ import (
 
 const DefaultMTU = 1300
 
-func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *int, routines int) (Device, error) {
+type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error)
+
+func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
 	routes, err := parseRoutes(c, tunCidr)
 	if err != nil {
 		return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
@@ -27,27 +29,41 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
 		tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
 		return tun, nil
 
-	case fd != nil:
-		return newTunFromFd(
+	default:
+		return newTun(
 			l,
-			*fd,
+			c.GetString("tun.dev", ""),
 			tunCidr,
 			c.GetInt("tun.mtu", DefaultMTU),
 			routes,
 			c.GetInt("tun.tx_queue", 500),
+			routines > 1,
 			c.GetBool("tun.use_system_route_table", false),
 		)
+	}
+}
 
-	default:
-		return newTun(
+func NewFdDeviceFromConfig(fd *int) DeviceFactory {
+	return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+		routes, err := parseRoutes(c, tunCidr)
+		if err != nil {
+			return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
+		}
+
+		unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
+		if err != nil {
+			return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
+		}
+		routes = append(routes, unsafeRoutes...)
+		return newTunFromFd(
 			l,
-			c.GetString("tun.dev", ""),
+			*fd,
 			tunCidr,
 			c.GetInt("tun.mtu", DefaultMTU),
 			routes,
 			c.GetInt("tun.tx_queue", 500),
-			routines > 1,
 			c.GetBool("tun.use_system_route_table", false),
 		)
+
 	}
 }

+ 63 - 0
overlay/user.go

@@ -0,0 +1,63 @@
+package overlay
+
+import (
+	"io"
+	"net"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/iputil"
+)
+
+func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+	return NewUserDevice(tunCidr)
+}
+
+func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
+	// these pipes guarantee each write/read will match 1:1
+	or, ow := io.Pipe()
+	ir, iw := io.Pipe()
+	return &UserDevice{
+		tunCidr:        tunCidr,
+		outboundReader: or,
+		outboundWriter: ow,
+		inboundReader:  ir,
+		inboundWriter:  iw,
+	}, nil
+}
+
+type UserDevice struct {
+	tunCidr *net.IPNet
+
+	outboundReader *io.PipeReader
+	outboundWriter *io.PipeWriter
+
+	inboundReader *io.PipeReader
+	inboundWriter *io.PipeWriter
+}
+
+func (d *UserDevice) Activate() error {
+	return nil
+}
+func (d *UserDevice) Cidr() *net.IPNet                      { return d.tunCidr }
+func (d *UserDevice) Name() string                          { return "faketun0" }
+func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip }
+func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
+	return d, nil
+}
+
+func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
+	return d.inboundReader, d.outboundWriter
+}
+
+func (d *UserDevice) Read(p []byte) (n int, err error) {
+	return d.outboundReader.Read(p)
+}
+func (d *UserDevice) Write(p []byte) (n int, err error) {
+	return d.inboundWriter.Write(p)
+}
+func (d *UserDevice) Close() error {
+	d.inboundWriter.Close()
+	d.outboundWriter.Close()
+	return nil
+}

+ 36 - 0
service/listener.go

@@ -0,0 +1,36 @@
+package service
+
+import (
+	"io"
+	"net"
+)
+
+type tcpListener struct {
+	port   uint16
+	s      *Service
+	addr   *net.TCPAddr
+	accept chan net.Conn
+}
+
+func (l *tcpListener) Accept() (net.Conn, error) {
+	conn, ok := <-l.accept
+	if !ok {
+		return nil, io.EOF
+	}
+	return conn, nil
+}
+
+func (l *tcpListener) Close() error {
+	l.s.mu.Lock()
+	defer l.s.mu.Unlock()
+	delete(l.s.mu.listeners, uint16(l.addr.Port))
+
+	close(l.accept)
+
+	return nil
+}
+
+// Addr returns the listener's network address.
+func (l *tcpListener) Addr() net.Addr {
+	return l.addr
+}

+ 248 - 0
service/service.go

@@ -0,0 +1,248 @@
+package service
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"fmt"
+	"log"
+	"math"
+	"net"
+	"os"
+	"strings"
+	"sync"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/overlay"
+	"golang.org/x/sync/errgroup"
+	"gvisor.dev/gvisor/pkg/bufferv2"
+	"gvisor.dev/gvisor/pkg/tcpip"
+	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+	"gvisor.dev/gvisor/pkg/tcpip/header"
+	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+	"gvisor.dev/gvisor/pkg/tcpip/stack"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+	"gvisor.dev/gvisor/pkg/waiter"
+)
+
+const nicID = 1
+
+type Service struct {
+	eg      *errgroup.Group
+	control *nebula.Control
+	ipstack *stack.Stack
+
+	mu struct {
+		sync.Mutex
+
+		listeners map[uint16]*tcpListener
+	}
+}
+
+func New(config *config.C) (*Service, error) {
+	logger := logrus.New()
+	logger.Out = os.Stdout
+
+	control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
+	if err != nil {
+		return nil, err
+	}
+	control.Start()
+
+	ctx := control.Context()
+	eg, ctx := errgroup.WithContext(ctx)
+	s := Service{
+		eg:      eg,
+		control: control,
+	}
+	s.mu.listeners = map[uint16]*tcpListener{}
+
+	device, ok := control.Device().(*overlay.UserDevice)
+	if !ok {
+		return nil, errors.New("must be using user device")
+	}
+
+	s.ipstack = stack.New(stack.Options{
+		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
+	})
+	sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
+	tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
+	if tcpipErr != nil {
+		return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
+	}
+	linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "")
+	if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
+		return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
+	}
+	ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4)))
+	s.ipstack.SetRouteTable([]tcpip.Route{
+		{
+			Destination: ipv4Subnet,
+			NIC:         nicID,
+		},
+	})
+
+	ipNet := device.Cidr()
+	pa := tcpip.ProtocolAddress{
+		AddressWithPrefix: tcpip.Address(ipNet.IP).WithPrefix(),
+		Protocol:          ipv4.ProtocolNumber,
+	}
+	if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
+		PEB:        stack.CanBePrimaryEndpoint, // zero value default
+		ConfigType: stack.AddressConfigStatic,  // zero value default
+	}); err != nil {
+		return nil, fmt.Errorf("error creating IP: %s", err)
+	}
+
+	const tcpReceiveBufferSize = 0
+	const maxInFlightConnectionAttempts = 1024
+	tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler)
+	s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
+
+	reader, writer := device.Pipe()
+
+	go func() {
+		<-ctx.Done()
+		reader.Close()
+		writer.Close()
+	}()
+
+	// create Goroutines to forward packets between Nebula and Gvisor
+	eg.Go(func() error {
+		buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize)
+		for {
+			// this will read exactly one packet
+			n, err := reader.Read(buf)
+			if err != nil {
+				return err
+			}
+			packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
+				Payload: bufferv2.MakeWithData(bytes.Clone(buf[:n])),
+			})
+			linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
+
+			if err := ctx.Err(); err != nil {
+				return err
+			}
+		}
+	})
+	eg.Go(func() error {
+		for {
+			packet := linkEP.ReadContext(ctx)
+			if packet.IsNil() {
+				if err := ctx.Err(); err != nil {
+					return err
+				}
+				continue
+			}
+			bufView := packet.ToView()
+			if _, err := bufView.WriteTo(writer); err != nil {
+				return err
+			}
+			bufView.Release()
+		}
+	})
+
+	return &s, nil
+}
+
+// DialContext dials the provided address. Currently only TCP is supported.
+func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+	if network != "tcp" && network != "tcp4" {
+		return nil, errors.New("only tcp is supported")
+	}
+
+	addr, err := net.ResolveTCPAddr(network, address)
+	if err != nil {
+		return nil, err
+	}
+
+	fullAddr := tcpip.FullAddress{
+		NIC:  nicID,
+		Addr: tcpip.Address(addr.IP),
+		Port: uint16(addr.Port),
+	}
+
+	return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber)
+}
+
+// Listen listens on the provided address. Currently only TCP with wildcard
+// addresses are supported.
+func (s *Service) Listen(network, address string) (net.Listener, error) {
+	if network != "tcp" && network != "tcp4" {
+		return nil, errors.New("only tcp is supported")
+	}
+	addr, err := net.ResolveTCPAddr(network, address)
+	if err != nil {
+		return nil, err
+	}
+	if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) {
+		return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP)
+	}
+	if addr.Port == 0 {
+		return nil, errors.New("specific port required, got 0")
+	}
+	if addr.Port < 0 || addr.Port >= math.MaxUint16 {
+		return nil, fmt.Errorf("invalid port %d", addr.Port)
+	}
+	port := uint16(addr.Port)
+
+	l := &tcpListener{
+		port:   port,
+		s:      s,
+		addr:   addr,
+		accept: make(chan net.Conn),
+	}
+
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	if _, ok := s.mu.listeners[port]; ok {
+		return nil, fmt.Errorf("already listening on port %d", port)
+	}
+	s.mu.listeners[port] = l
+
+	return l, nil
+}
+
+func (s *Service) Wait() error {
+	return s.eg.Wait()
+}
+
+func (s *Service) Close() error {
+	s.control.Stop()
+	return nil
+}
+
+func (s *Service) tcpHandler(r *tcp.ForwarderRequest) {
+	endpointID := r.ID()
+
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	l, ok := s.mu.listeners[endpointID.LocalPort]
+	if !ok {
+		r.Complete(true)
+		return
+	}
+
+	var wq waiter.Queue
+	ep, err := r.CreateEndpoint(&wq)
+	if err != nil {
+		log.Printf("got error creating endpoint %q", err)
+		r.Complete(true)
+		return
+	}
+	r.Complete(false)
+	ep.SocketOptions().SetKeepAlive(true)
+
+	conn := gonet.NewTCPConn(&wq, ep)
+	l.accept <- conn
+}

+ 165 - 0
service/service_test.go

@@ -0,0 +1,165 @@
+package service
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"net"
+	"testing"
+	"time"
+
+	"dario.cat/mergo"
+	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/e2e"
+	"golang.org/x/sync/errgroup"
+	"gopkg.in/yaml.v2"
+)
+
+type m map[string]interface{}
+
+func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service {
+
+	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
+	copy(vpnIpNet.IP, udpIp)
+
+	_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+	caB, err := caCrt.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	mc := m{
+		"pki": m{
+			"ca":   string(caB),
+			"cert": string(myPEM),
+			"key":  string(myPrivKey),
+		},
+		//"tun": m{"disabled": true},
+		"firewall": m{
+			"outbound": []m{{
+				"proto": "any",
+				"port":  "any",
+				"host":  "any",
+			}},
+			"inbound": []m{{
+				"proto": "any",
+				"port":  "any",
+				"host":  "any",
+			}},
+		},
+		"timers": m{
+			"pending_deletion_interval": 2,
+			"connection_alive_interval": 2,
+		},
+		"handshakes": m{
+			"try_interval": "200ms",
+		},
+	}
+
+	if overrides != nil {
+		err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
+		if err != nil {
+			panic(err)
+		}
+		mc = overrides
+	}
+
+	cb, err := yaml.Marshal(mc)
+	if err != nil {
+		panic(err)
+	}
+
+	var c config.C
+	if err := c.LoadString(string(cb)); err != nil {
+		panic(err)
+	}
+
+	s, err := New(&c)
+	if err != nil {
+		panic(err)
+	}
+	return s
+}
+
+func TestService(t *testing.T) {
+	ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{
+		"static_host_map": m{},
+		"lighthouse": m{
+			"am_lighthouse": true,
+		},
+		"listen": m{
+			"host": "0.0.0.0",
+			"port": 4243,
+		},
+	})
+	b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{
+		"static_host_map": m{
+			"10.0.0.1": []string{"localhost:4243"},
+		},
+		"lighthouse": m{
+			"hosts":    []string{"10.0.0.1"},
+			"interval": 1,
+		},
+	})
+
+	ln, err := a.Listen("tcp", ":1234")
+	if err != nil {
+		t.Fatal(err)
+	}
+	var eg errgroup.Group
+	eg.Go(func() error {
+		conn, err := ln.Accept()
+		if err != nil {
+			return err
+		}
+		defer conn.Close()
+
+		t.Log("accepted connection")
+
+		if _, err := conn.Write([]byte("server msg")); err != nil {
+			return err
+		}
+
+		t.Log("server: wrote message")
+
+		data := make([]byte, 100)
+		n, err := conn.Read(data)
+		if err != nil {
+			return err
+		}
+		data = data[:n]
+		if !bytes.Equal(data, []byte("client msg")) {
+			return errors.New("got invalid message from client")
+		}
+		t.Log("server: read message")
+		return conn.Close()
+	})
+
+	c, err := b.DialContext(context.Background(), "tcp", "10.0.0.1:1234")
+	if err != nil {
+		t.Fatal(err)
+	}
+	if _, err := c.Write([]byte("client msg")); err != nil {
+		t.Fatal(err)
+	}
+
+	data := make([]byte, 100)
+	n, err := c.Read(data)
+	if err != nil {
+		t.Fatal(err)
+	}
+	data = data[:n]
+	if !bytes.Equal(data, []byte("server msg")) {
+		t.Fatal("got invalid message from client")
+	}
+
+	if err := c.Close(); err != nil {
+		t.Fatal(err)
+	}
+
+	if err := eg.Wait(); err != nil {
+		t.Fatal(err)
+	}
+}