3
0
Эх сурвалжийг харах

Be more like a library to support mobile (#247)

Nathan Brown 5 жил өмнө
parent
commit
41578ca971

+ 2 - 2
bits_test.go

@@ -212,10 +212,10 @@ func TestBitsLostCounter(t *testing.T) {
 func BenchmarkBits(b *testing.B) {
 	z := NewBits(10)
 	for n := 0; n < b.N; n++ {
-		for i, _ := range z.bits {
+		for i := range z.bits {
 			z.bits[i] = true
 		}
-		for i, _ := range z.bits {
+		for i := range z.bits {
 			z.bits[i] = false
 		}
 

+ 23 - 3
cmd/nebula-service/main.go

@@ -3,9 +3,9 @@ package main
 import (
 	"flag"
 	"fmt"
-	"os"
-
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
+	"os"
 )
 
 // A version string that can be set with
@@ -45,5 +45,25 @@ func main() {
 		os.Exit(1)
 	}
 
-	nebula.Main(*configPath, *configTest, Build)
+	config := nebula.NewConfig()
+	err := config.Load(*configPath)
+	if err != nil {
+		fmt.Printf("failed to load config: %s", err)
+		os.Exit(1)
+	}
+
+	l := logrus.New()
+	l.Out = os.Stdout
+	err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
+
+	switch v := err.(type) {
+	case nebula.ContextualError:
+		v.Log(l)
+		os.Exit(1)
+	case error:
+		l.WithError(err).Error("Failed to start")
+		os.Exit(1)
+	}
+
+	os.Exit(0)
 }

+ 11 - 2
cmd/nebula-service/service.go

@@ -1,6 +1,8 @@
 package main
 
 import (
+	"fmt"
+	"github.com/sirupsen/logrus"
 	"log"
 	"os"
 	"path/filepath"
@@ -27,8 +29,15 @@ func (p *program) Start(s service.Service) error {
 }
 
 func (p *program) run() error {
-	nebula.Main(*p.configPath, *p.configTest, Build)
-	return nil
+	config := nebula.NewConfig()
+	err := config.Load(*p.configPath)
+	if err != nil {
+		return fmt.Errorf("failed to load config: %s", err)
+	}
+
+	l := logrus.New()
+	l.Out = os.Stdout
+	return nebula.Main(config, *p.configTest, true, Build, l, nil, nil)
 }
 
 func (p *program) Stop(s service.Service) error {

+ 22 - 1
cmd/nebula/main.go

@@ -3,6 +3,7 @@ package main
 import (
 	"flag"
 	"fmt"
+	"github.com/sirupsen/logrus"
 	"os"
 
 	"github.com/slackhq/nebula"
@@ -39,5 +40,25 @@ func main() {
 		os.Exit(1)
 	}
 
-	nebula.Main(*configPath, *configTest, Build)
+	config := nebula.NewConfig()
+	err := config.Load(*configPath)
+	if err != nil {
+		fmt.Printf("failed to load config: %s", err)
+		os.Exit(1)
+	}
+
+	l := logrus.New()
+	l.Out = os.Stdout
+	err = nebula.Main(config, *configTest, true, Build, l, nil, nil)
+
+	switch v := err.(type) {
+	case nebula.ContextualError:
+		v.Log(l)
+		os.Exit(1)
+	case error:
+		l.WithError(err).Error("Failed to start")
+		os.Exit(1)
+	}
+
+	os.Exit(0)
 }

+ 20 - 0
config.go

@@ -1,6 +1,7 @@
 package nebula
 
 import (
+	"errors"
 	"fmt"
 	"github.com/imdario/mergo"
 	"github.com/sirupsen/logrus"
@@ -56,6 +57,13 @@ func (c *Config) Load(path string) error {
 	return nil
 }
 
+func (c *Config) LoadString(raw string) error {
+	if raw == "" {
+		return errors.New("Empty configuration")
+	}
+	return c.parseRaw([]byte(raw))
+}
+
 // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
 // here should decide if they need to make a change to the current process before making the change. HasChanged can be
 // used to help decide if a change is necessary.
@@ -407,6 +415,18 @@ func (c *Config) addFile(path string, direct bool) error {
 	return nil
 }
 
+func (c *Config) parseRaw(b []byte) error {
+	var m map[interface{}]interface{}
+
+	err := yaml.Unmarshal(b, &m)
+	if err != nil {
+		return err
+	}
+
+	c.Settings = m
+	return nil
+}
+
 func (c *Config) parse() error {
 	var m map[interface{}]interface{}
 

+ 31 - 0
logger.go

@@ -0,0 +1,31 @@
+package nebula
+
+import (
+	"github.com/sirupsen/logrus"
+)
+
+type ContextualError struct {
+	RealError error
+	Fields    map[string]interface{}
+	Context   string
+}
+
+func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
+	return ContextualError{Context: msg, Fields: fields, RealError: realError}
+}
+
+func (ce ContextualError) Error() string {
+	return ce.RealError.Error()
+}
+
+func (ce ContextualError) Unwrap() error {
+	return ce.RealError
+}
+
+func (ce *ContextualError) Log(lr *logrus.Logger) {
+	if ce.RealError != nil {
+		lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context)
+	} else {
+		lr.WithFields(ce.Fields).Error(ce.Context)
+	}
+}

+ 66 - 0
logger_test.go

@@ -0,0 +1,66 @@
+package nebula
+
+import (
+	"errors"
+	"github.com/sirupsen/logrus"
+	"github.com/stretchr/testify/assert"
+	"testing"
+)
+
+type TestLogWriter struct {
+	Logs []string
+}
+
+func NewTestLogWriter() *TestLogWriter {
+	return &TestLogWriter{Logs: make([]string, 0)}
+}
+
+func (tl *TestLogWriter) Write(p []byte) (n int, err error) {
+	tl.Logs = append(tl.Logs, string(p))
+	return len(p), nil
+}
+
+func (tl *TestLogWriter) Reset() {
+	tl.Logs = tl.Logs[:0]
+}
+
+func TestContextualError_Log(t *testing.T) {
+	l := logrus.New()
+	l.Formatter = &logrus.TextFormatter{
+		DisableTimestamp: true,
+		DisableColors:    true,
+	}
+
+	tl := NewTestLogWriter()
+	l.Out = tl
+
+	// Test a full context line
+	tl.Reset()
+	e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
+	e.Log(l)
+	assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
+
+	// Test a line with an error and msg but no fields
+	tl.Reset()
+	e = NewContextualError("test message", nil, errors.New("error"))
+	e.Log(l)
+	assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs)
+
+	// Test just a context and fields
+	tl.Reset()
+	e = NewContextualError("test message", m{"field": "1"}, nil)
+	e.Log(l)
+	assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs)
+
+	// Test just a context
+	tl.Reset()
+	e = NewContextualError("test message", nil, nil)
+	e.Log(l)
+	assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs)
+
+	// Test just an error
+	tl.Reset()
+	e = NewContextualError("", nil, errors.New("error"))
+	e.Log(l)
+	assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
+}

+ 99 - 55
main.go

@@ -3,6 +3,9 @@ package nebula
 import (
 	"encoding/binary"
 	"fmt"
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/sshd"
+	"gopkg.in/yaml.v2"
 	"net"
 	"os"
 	"os/signal"
@@ -10,42 +13,38 @@ import (
 	"strings"
 	"syscall"
 	"time"
-
-	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/sshd"
-	"gopkg.in/yaml.v2"
 )
 
+// The caller should provide a real logger, we have one just in case
 var l = logrus.New()
 
 type m map[string]interface{}
 
-func Main(configPath string, configTest bool, buildVersion string) {
-	l.Out = os.Stdout
+type CommandRequest struct {
+	Command  string
+	Callback chan error
+}
+
+func Main(config *Config, configTest bool, block bool, buildVersion string, logger *logrus.Logger, tunFd *int, commandChan <-chan CommandRequest) error {
+	l = logger
 	l.Formatter = &logrus.TextFormatter{
 		FullTimestamp: true,
 	}
 
-	config := NewConfig()
-	err := config.Load(configPath)
-	if err != nil {
-		l.WithError(err).Error("Failed to load config")
-		os.Exit(1)
-	}
-
 	// Print the config if in test, the exit comes later
 	if configTest {
 		b, err := yaml.Marshal(config.Settings)
 		if err != nil {
-			l.Println(err)
-			os.Exit(1)
+			return err
 		}
+
+		// Print the final config
 		l.Println(string(b))
 	}
 
-	err = configLogger(config)
+	err := configLogger(config)
 	if err != nil {
-		l.WithError(err).Error("Failed to configure the logger")
+		return NewContextualError("Failed to configure the logger", nil, err)
 	}
 
 	config.RegisterReloadCallback(func(c *Config) {
@@ -59,20 +58,20 @@ func Main(configPath string, configTest bool, buildVersion string) {
 	trustedCAs, err = loadCAFromConfig(config)
 	if err != nil {
 		//The errors coming out of loadCA are already nicely formatted
-		l.WithError(err).Fatal("Failed to load ca from config")
+		return NewContextualError("Failed to load ca from config", nil, err)
 	}
 	l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
 
 	cs, err := NewCertStateFromConfig(config)
 	if err != nil {
 		//The errors coming out of NewCertStateFromConfig are already nicely formatted
-		l.WithError(err).Fatal("Failed to load certificate from config")
+		return NewContextualError("Failed to load certificate from config", nil, err)
 	}
 	l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
 
 	fw, err := NewFirewallFromConfig(cs.certificate, config)
 	if err != nil {
-		l.WithError(err).Fatal("Error while loading firewall rules")
+		return NewContextualError("Error while loading firewall rules", nil, err)
 	}
 	l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
 
@@ -80,11 +79,11 @@ func Main(configPath string, configTest bool, buildVersion string) {
 	tunCidr := cs.certificate.Details.Ips[0]
 	routes, err := parseRoutes(config, tunCidr)
 	if err != nil {
-		l.WithError(err).Fatal("Could not parse tun.routes")
+		return NewContextualError("Could not parse tun.routes", nil, err)
 	}
 	unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr)
 	if err != nil {
-		l.WithError(err).Fatal("Could not parse tun.unsafe_routes")
+		return NewContextualError("Could not parse tun.unsafe_routes", nil, err)
 	}
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
@@ -92,7 +91,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
 	if config.GetBool("sshd.enabled", false) {
 		err = configSSH(ssh, config)
 		if err != nil {
-			l.WithError(err).Fatal("Error while configuring the sshd")
+			return NewContextualError("Error while configuring the sshd", nil, err)
 		}
 	}
 
@@ -105,17 +104,28 @@ func Main(configPath string, configTest bool, buildVersion string) {
 	if !configTest {
 		config.CatchHUP()
 
-		// set up our tun dev
-		tun, err = newTun(
-			config.GetString("tun.dev", ""),
-			tunCidr,
-			config.GetInt("tun.mtu", DEFAULT_MTU),
-			routes,
-			unsafeRoutes,
-			config.GetInt("tun.tx_queue", 500),
-		)
+		if tunFd != nil {
+			tun, err = newTunFromFd(
+				*tunFd,
+				tunCidr,
+				config.GetInt("tun.mtu", DEFAULT_MTU),
+				routes,
+				unsafeRoutes,
+				config.GetInt("tun.tx_queue", 500),
+			)
+		} else {
+			tun, err = newTun(
+				config.GetString("tun.dev", ""),
+				tunCidr,
+				config.GetInt("tun.mtu", DEFAULT_MTU),
+				routes,
+				unsafeRoutes,
+				config.GetInt("tun.tx_queue", 500),
+			)
+		}
+
 		if err != nil {
-			l.WithError(err).Fatal("Failed to get a tun/tap device")
+			return NewContextualError("Failed to get a tun/tap device", nil, err)
 		}
 	}
 
@@ -126,11 +136,28 @@ func Main(configPath string, configTest bool, buildVersion string) {
 	if !configTest {
 		udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1)
 		if err != nil {
-			l.WithError(err).Fatal("Failed to open udp listener")
+			return NewContextualError("Failed to open udp listener", nil, err)
 		}
 		udpServer.reloadConfig(config)
 	}
 
+	sigChan := make(chan os.Signal)
+	killChan := make(chan CommandRequest)
+	if commandChan != nil {
+		go func() {
+			cmd := CommandRequest{}
+			for {
+				cmd = <-commandChan
+				switch cmd.Command {
+				case "rebind":
+					udpServer.Rebind()
+				case "exit":
+					killChan <- cmd
+				}
+			}
+		}()
+	}
+
 	// Set up my internal host map
 	var preferredRanges []*net.IPNet
 	rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{})
@@ -139,7 +166,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
 		for _, rawPreferredRange := range rawPreferredRanges {
 			_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
 			if err != nil {
-				l.WithError(err).Fatal("Failed to parse preferred ranges")
+				return NewContextualError("Failed to parse preferred ranges", nil, err)
 			}
 			preferredRanges = append(preferredRanges, preferredRange)
 		}
@@ -152,7 +179,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
 	if rawLocalRange != "" {
 		_, localRange, err := net.ParseCIDR(rawLocalRange)
 		if err != nil {
-			l.WithError(err).Fatal("Failed to parse local range")
+			return NewContextualError("Failed to parse local_range", nil, err)
 		}
 
 		// Check if the entry for local_range was already specified in
@@ -192,7 +219,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
 	if port == 0 && !configTest {
 		uPort, err := udpServer.LocalAddr()
 		if err != nil {
-			l.WithError(err).Fatal("Failed to get listening port")
+			return NewContextualError("Failed to get listening port", nil, err)
 		}
 		port = int(uPort.Port)
 	}
@@ -209,10 +236,10 @@ func Main(configPath string, configTest bool, buildVersion string) {
 	for i, host := range rawLighthouseHosts {
 		ip := net.ParseIP(host)
 		if ip == nil {
-			l.WithField("host", host).Fatalf("Unable to parse lighthouse host entry %v", i+1)
+			return NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
 		}
 		if !tunCidr.Contains(ip) {
-			l.WithField("vpnIp", ip).WithField("network", tunCidr.String()).Fatalf("lighthouse host is not in our subnet, invalid")
+			return NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
 		}
 		lighthouseHosts[i] = ip2int(ip)
 	}
@@ -232,13 +259,13 @@ func Main(configPath string, configTest bool, buildVersion string) {
 
 	remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false)
 	if err != nil {
-		l.WithError(err).Fatal("Invalid lighthouse.remote_allow_list")
+		return NewContextualError("Invalid lighthouse.remote_allow_list", nil, err)
 	}
 	lightHouse.SetRemoteAllowList(remoteAllowList)
 
 	localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true)
 	if err != nil {
-		l.WithError(err).Fatal("Invalid lighthouse.local_allow_list")
+		return NewContextualError("Invalid lighthouse.local_allow_list", nil, err)
 	}
 	lightHouse.SetLocalAllowList(localAllowList)
 
@@ -246,7 +273,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
 	for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) {
 		vpnIp := net.ParseIP(fmt.Sprintf("%v", k))
 		if !tunCidr.Contains(vpnIp) {
-			l.WithField("vpnIp", vpnIp).WithField("network", tunCidr.String()).Fatalf("static_host_map key is not in our subnet, invalid")
+			return NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil)
 		}
 		vals, ok := v.([]interface{})
 		if ok {
@@ -257,7 +284,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
 					ip := addr.IP
 					port, err := strconv.Atoi(parts[1])
 					if err != nil {
-						l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v)
+						return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 					}
 					lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
 				}
@@ -270,7 +297,7 @@ func Main(configPath string, configTest bool, buildVersion string) {
 				ip := addr.IP
 				port, err := strconv.Atoi(parts[1])
 				if err != nil {
-					l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v)
+					return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 				}
 				lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true)
 			}
@@ -330,14 +357,14 @@ func Main(configPath string, configTest bool, buildVersion string) {
 	case "chachapoly":
 		noiseEndianness = binary.LittleEndian
 	default:
-		l.Fatalf("Unknown cipher: %v", ifConfig.Cipher)
+		return fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
 	}
 
 	var ifce *Interface
 	if !configTest {
 		ifce, err = NewInterface(ifConfig)
 		if err != nil {
-			l.WithError(err).Fatal("Failed to initialize interface")
+			return fmt.Errorf("failed to initialize interface: %s", err)
 		}
 
 		ifce.RegisterConfigChangeCallbacks(config)
@@ -348,11 +375,11 @@ func Main(configPath string, configTest bool, buildVersion string) {
 
 	err = startStats(config, configTest)
 	if err != nil {
-		l.WithError(err).Fatal("Failed to start stats emitter")
+		return NewContextualError("Failed to start stats emitter", nil, err)
 	}
 
 	if configTest {
-		os.Exit(0)
+		return nil
 	}
 
 	//TODO: check if we _should_ be emitting stats
@@ -367,19 +394,33 @@ func Main(configPath string, configTest bool, buildVersion string) {
 		go dnsMain(hostMap, config)
 	}
 
-	// Just sit here and be friendly, main thread.
-	shutdownBlock(ifce)
+	if block {
+		// Just sit here and be friendly, main thread.
+		shutdownBlock(ifce, sigChan, killChan)
+	} else {
+		// Even though we aren't blocking we still want to shutdown gracefully
+		go shutdownBlock(ifce, sigChan, killChan)
+	}
+	return nil
 }
 
-func shutdownBlock(ifce *Interface) {
-	var sigChan = make(chan os.Signal)
+func shutdownBlock(ifce *Interface, sigChan chan os.Signal, killChan chan CommandRequest) {
+	var cmd CommandRequest
+	var sig string
+
 	signal.Notify(sigChan, syscall.SIGTERM)
 	signal.Notify(sigChan, syscall.SIGINT)
 
-	sig := <-sigChan
+	select {
+	case rawSig := <-sigChan:
+		sig = rawSig.String()
+	case cmd = <-killChan:
+		sig = "controlling app"
+	}
+
 	l.WithField("signal", sig).Info("Caught signal, shutting down")
 
-	//TODO: stop tun and udp routines, the lock on hostMap does effectively does that though
+	//TODO: stop tun and udp routines, the lock on hostMap effectively does that though
 	//TODO: this is probably better as a function in ConnectionManager or HostMap directly
 	ifce.hostMap.Lock()
 	for _, h := range ifce.hostMap.Hosts {
@@ -392,5 +433,8 @@ func shutdownBlock(ifce *Interface) {
 	ifce.hostMap.Unlock()
 
 	l.WithField("signal", sig).Info("Goodbye")
-	os.Exit(0)
+	select {
+	case cmd.Callback <- nil:
+	default:
+	}
 }

+ 10 - 4
tun_darwin.go

@@ -1,12 +1,13 @@
+// +build !ios
+
 package nebula
 
 import (
 	"fmt"
+	"github.com/songgao/water"
 	"net"
 	"os/exec"
 	"strconv"
-
-	"github.com/songgao/water"
 )
 
 type Tun struct {
@@ -20,8 +21,9 @@ type Tun struct {
 
 func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
 	if len(routes) > 0 {
-		return nil, fmt.Errorf("Route MTU not supported in Darwin")
+		return nil, fmt.Errorf("route MTU not supported in Darwin")
 	}
+
 	// NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate()
 	return &Tun{
 		Cidr:         cidr,
@@ -30,13 +32,17 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
 	}, nil
 }
 
+func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
+	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
+}
+
 func (c *Tun) Activate() error {
 	var err error
 	c.Interface, err = water.New(water.Config{
 		DeviceType: water.TUN,
 	})
 	if err != nil {
-		return fmt.Errorf("Activate failed: %v", err)
+		return fmt.Errorf("activate failed: %v", err)
 	}
 
 	c.Device = c.Interface.Name()

+ 4 - 0
tun_freebsd.go

@@ -22,6 +22,10 @@ type Tun struct {
 	io.ReadWriteCloser
 }
 
+func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
+	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
+}
+
 func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
 	if len(routes) > 0 {
 		return nil, fmt.Errorf("Route MTU not supported in FreeBSD")

+ 105 - 0
tun_ios.go

@@ -0,0 +1,105 @@
+// +build ios
+
+package nebula
+
+import (
+	"errors"
+	"fmt"
+	"io"
+	"net"
+	"os"
+	"sync"
+	"syscall"
+)
+
+type Tun struct {
+	io.ReadWriteCloser
+	Device string
+	Cidr   *net.IPNet
+}
+
+func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
+	return nil, fmt.Errorf("newTun not supported in iOS")
+}
+
+func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
+	if len(routes) > 0 {
+		return nil, fmt.Errorf("route MTU not supported in Darwin")
+	}
+
+	file := os.NewFile(uintptr(deviceFd), "/dev/tun")
+	ifce = &Tun{
+		Cidr:            cidr,
+		ReadWriteCloser: &tunReadCloser{f: file},
+	}
+	return
+}
+
+func (c *Tun) Activate() error {
+	c.Device = "iOS"
+	return nil
+}
+
+func (c *Tun) WriteRaw(b []byte) error {
+	_, err := c.Write(b)
+	return err
+}
+
+// The following is hoisted up from water, we do this so we can inject our own fd on iOS
+type tunReadCloser struct {
+	f io.ReadWriteCloser
+
+	rMu  sync.Mutex
+	rBuf []byte
+
+	wMu  sync.Mutex
+	wBuf []byte
+}
+
+func (t *tunReadCloser) Read(to []byte) (int, error) {
+	t.rMu.Lock()
+	defer t.rMu.Unlock()
+
+	if cap(t.rBuf) < len(to)+4 {
+		t.rBuf = make([]byte, len(to)+4)
+	}
+	t.rBuf = t.rBuf[:len(to)+4]
+
+	n, err := t.f.Read(t.rBuf)
+	copy(to, t.rBuf[4:])
+	return n - 4, err
+}
+
+func (t *tunReadCloser) Write(from []byte) (int, error) {
+
+	if len(from) == 0 {
+		return 0, syscall.EIO
+	}
+
+	t.wMu.Lock()
+	defer t.wMu.Unlock()
+
+	if cap(t.wBuf) < len(from)+4 {
+		t.wBuf = make([]byte, len(from)+4)
+	}
+	t.wBuf = t.wBuf[:len(from)+4]
+
+	// Determine the IP Family for the NULL L2 Header
+	ipVer := from[0] >> 4
+	if ipVer == 4 {
+		t.wBuf[3] = syscall.AF_INET
+	} else if ipVer == 6 {
+		t.wBuf[3] = syscall.AF_INET6
+	} else {
+		return 0, errors.New("unable to determine IP version from packet")
+	}
+
+	copy(t.wBuf[4:], from)
+
+	n, err := t.f.Write(t.wBuf)
+	return n - 4, err
+}
+
+func (t *tunReadCloser) Close() error {
+	return t.f.Close()
+}

+ 17 - 0
tun_linux.go

@@ -75,6 +75,23 @@ type ifreqQLEN struct {
 	pad   [8]byte
 }
 
+func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
+
+	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
+
+	ifce = &Tun{
+		ReadWriteCloser: file,
+		fd:              int(file.Fd()),
+		Device:          "tun0",
+		Cidr:            cidr,
+		DefaultMTU:      defaultMTU,
+		TXQueueLen:      txQueueLen,
+		Routes:          routes,
+		UnsafeRoutes:    unsafeRoutes,
+	}
+	return
+}
+
 func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {

+ 5 - 1
tun_windows.go

@@ -18,9 +18,13 @@ type Tun struct {
 	*water.Interface
 }
 
+func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
+	return nil, fmt.Errorf("newTunFromFd not supported in Windows")
+}
+
 func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
 	if len(routes) > 0 {
-		return nil, fmt.Errorf("Route MTU not supported in Windows")
+		return nil, fmt.Errorf("route MTU not supported in Windows")
 	}
 
 	// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()

+ 36 - 0
udp_android.go

@@ -0,0 +1,36 @@
+package nebula
+
+import (
+	"fmt"
+	"net"
+	"syscall"
+
+	"golang.org/x/sys/unix"
+)
+
+func NewListenConfig(multi bool) net.ListenConfig {
+	return net.ListenConfig{
+		Control: func(network, address string, c syscall.RawConn) error {
+			if multi {
+				var controlErr error
+				err := c.Control(func(fd uintptr) {
+					if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
+						controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
+						return
+					}
+				})
+				if err != nil {
+					return err
+				}
+				if controlErr != nil {
+					return controlErr
+				}
+			}
+			return nil
+		},
+	}
+}
+
+func (u *udpConn) Rebind() {
+	return
+}

+ 9 - 0
udp_darwin.go

@@ -32,3 +32,12 @@ func NewListenConfig(multi bool) net.ListenConfig {
 		},
 	}
 }
+
+func (u *udpConn) Rebind() error {
+	file, err := u.File()
+	if err != nil {
+		return err
+	}
+
+	return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IP, unix.IP_BOUND_IF, 0)
+}

+ 4 - 0
udp_freebsd.go

@@ -32,3 +32,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
 		},
 	}
 }
+
+func (u *udpConn) Rebind() {
+	return
+}

+ 1 - 1
udp_generic.go

@@ -1,4 +1,4 @@
-// +build !linux
+// +build !linux android
 
 // udp_generic implements the nebula UDP interface in pure Go stdlib. This
 // means it can be used on platforms like Darwin and Windows.

+ 6 - 0
udp_linux.go

@@ -1,3 +1,5 @@
+// +build !android
+
 package nebula
 
 import (
@@ -85,6 +87,10 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
 	return &udpConn{sysFd: fd}, err
 }
 
+func (u *udpConn) Rebind() {
+	return
+}
+
 func (u *udpConn) SetRecvBuffer(n int) error {
 	return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
 }

+ 1 - 0
udp_linux_32.go

@@ -1,5 +1,6 @@
 // +build linux
 // +build 386 amd64p32 arm mips mipsle
+// +build !android
 
 package nebula
 

+ 1 - 0
udp_linux_64.go

@@ -1,5 +1,6 @@
 // +build linux
 // +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
+// +build !android
 
 package nebula
 

+ 4 - 0
udp_windows.go

@@ -20,3 +20,7 @@ func NewListenConfig(multi bool) net.ListenConfig {
 		},
 	}
 }
+
+func (u *udpConn) Rebind() {
+	return
+}