فهرست منبع

Ensure the Nebula device exists before attempting to bind to the Nebula IP (#375)

brad-defined 4 سال پیش
والد
کامیت
17106f83a0
7فایلهای تغییر یافته به همراه97 افزوده شده و 44 حذف شده
  1. 20 2
      control.go
  2. 5 2
      dns_server.go
  3. 6 1
      interface.go
  4. 6 4
      main.go
  5. 22 10
      ssh.go
  6. 12 11
      sshd/server.go
  7. 26 14
      stats.go

+ 20 - 2
control.go

@@ -15,8 +15,11 @@ import (
 // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
 
 type Control struct {
-	f *Interface
-	l *logrus.Logger
+	f          *Interface
+	l          *logrus.Logger
+	sshStart   func()
+	statsStart func()
+	dnsStart   func()
 }
 
 type ControlHostInfo struct {
@@ -32,6 +35,21 @@ type ControlHostInfo struct {
 
 // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
 func (c *Control) Start() {
+	// Activate the interface
+	c.f.activate()
+
+	// Call all the delayed funcs that waited patiently for the interface to be created.
+	if c.sshStart != nil {
+		go c.sshStart()
+	}
+	if c.statsStart != nil {
+		go c.statsStart()
+	}
+	if c.dnsStart != nil {
+		go c.dnsStart()
+	}
+
+	// Start reading packets.
 	c.f.run()
 }
 

+ 5 - 2
dns_server.go

@@ -109,7 +109,7 @@ func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
 	w.WriteMsg(m)
 }
 
-func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) {
+func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) func() {
 	dnsR = newDnsRecords(hostMap)
 
 	// attach request handler func
@@ -120,7 +120,10 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) {
 	c.RegisterReloadCallback(func(c *Config) {
 		reloadDns(l, c)
 	})
-	startDns(l, c)
+
+	return func() {
+		startDns(l, c)
+	}
 }
 
 func getDnsServerAddr(c *Config) string {

+ 6 - 1
interface.go

@@ -130,7 +130,10 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
 	return ifce, nil
 }
 
-func (f *Interface) run() {
+// activate creates the interface on the host. After the interface is created, any
+// other services that want to bind listeners to its IP may do so successfully. However,
+// the interface isn't going to process anything until run() is called.
+func (f *Interface) activate() {
 	// actually turn on tun dev
 
 	addr, err := f.outside.LocalAddr()
@@ -159,7 +162,9 @@ func (f *Interface) run() {
 	if err := f.inside.Activate(); err != nil {
 		f.l.Fatal(err)
 	}
+}
 
+func (f *Interface) run() {
 	// Launch n queues to read packets from udp
 	for i := 0; i < f.routines; i++ {
 		go f.listenOut(i)

+ 6 - 4
main.go

@@ -75,8 +75,9 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	wireSSHReload(l, ssh, config)
+	var sshStart func()
 	if config.GetBool("sshd.enabled", false) {
-		err = configSSH(l, ssh, config)
+		sshStart, err = configSSH(l, ssh, config)
 		if err != nil {
 			return nil, NewContextualError("Error while configuring the sshd", nil, err)
 		}
@@ -393,7 +394,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		go lightHouse.LhUpdateWorker(ifce)
 	}
 
-	err = startStats(l, config, buildVersion, configTest)
+	statsStart, err := startStats(l, config, buildVersion, configTest)
 	if err != nil {
 		return nil, NewContextualError("Failed to start stats emitter", nil, err)
 	}
@@ -408,10 +409,11 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
 
 	// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
+	var dnsStart func()
 	if amLighthouse && serveDns {
 		l.Debugln("Starting dns server")
-		go dnsMain(l, hostMap, config)
+		dnsStart = dnsMain(l, hostMap, config)
 	}
 
-	return &Control{ifce, l}, nil
+	return &Control{ifce, l, sshStart, statsStart, dnsStart}, nil
 }

+ 22 - 10
ssh.go

@@ -47,48 +47,55 @@ type sshCreateTunnelFlags struct {
 func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
 	c.RegisterReloadCallback(func(c *Config) {
 		if c.GetBool("sshd.enabled", false) {
-			err := configSSH(l, ssh, c)
+			sshRun, err := configSSH(l, ssh, c)
 			if err != nil {
 				l.WithError(err).Error("Failed to reconfigure the sshd")
 				ssh.Stop()
 			}
+			if sshRun != nil {
+				go sshRun()
+			}
 		} else {
 			ssh.Stop()
 		}
 	})
 }
 
-func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error {
+// configSSH reads the ssh info out of the passed-in Config and
+// updates the passed-in SSHServer. On success, it returns a function
+// that callers may invoke to run the configured ssh server. On
+// failure, it returns nil, error.
+func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) (func(), error) {
 	//TODO conntrack list
 	//TODO print firewall rules or hash?
 
 	listen := c.GetString("sshd.listen", "")
 	if listen == "" {
-		return fmt.Errorf("sshd.listen must be provided")
+		return nil, fmt.Errorf("sshd.listen must be provided")
 	}
 
 	_, port, err := net.SplitHostPort(listen)
 	if err != nil {
-		return fmt.Errorf("invalid sshd.listen address: %s", err)
+		return nil, fmt.Errorf("invalid sshd.listen address: %s", err)
 	}
 	if port == "22" {
-		return fmt.Errorf("sshd.listen can not use port 22")
+		return nil, fmt.Errorf("sshd.listen can not use port 22")
 	}
 
 	//TODO: no good way to reload this right now
 	hostKeyFile := c.GetString("sshd.host_key", "")
 	if hostKeyFile == "" {
-		return fmt.Errorf("sshd.host_key must be provided")
+		return nil, fmt.Errorf("sshd.host_key must be provided")
 	}
 
 	hostKeyBytes, err := ioutil.ReadFile(hostKeyFile)
 	if err != nil {
-		return fmt.Errorf("error while loading sshd.host_key file: %s", err)
+		return nil, fmt.Errorf("error while loading sshd.host_key file: %s", err)
 	}
 
 	err = ssh.SetHostKey(hostKeyBytes)
 	if err != nil {
-		return fmt.Errorf("error while adding sshd.host_key: %s", err)
+		return nil, fmt.Errorf("error while adding sshd.host_key: %s", err)
 	}
 
 	rawKeys := c.Get("sshd.authorized_users")
@@ -139,14 +146,19 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error {
 		l.Info("no ssh users to authorize")
 	}
 
+	var runner func()
 	if c.GetBool("sshd.enabled", false) {
 		ssh.Stop()
-		go ssh.Run(listen)
+		runner = func() {
+			if err := ssh.Run(listen); err != nil {
+				l.WithField("err", err).Warn("Failed to run the SSH server")
+			}
+		}
 	} else {
 		ssh.Stop()
 	}
 
-	return nil
+	return runner, nil
 }
 
 func attachCommands(l *logrus.Logger, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {

+ 12 - 11
sshd/server.go

@@ -141,21 +141,22 @@ func (s *SSHServer) Run(addr string) error {
 }
 
 func (s *SSHServer) Stop() {
-	for _, c := range s.conns {
-		c.Close()
-	}
-
-	if s.listener == nil {
-		return
+	// Close the listener first, to prevent any new connections being accepted.
+	if s.listener != nil {
+		if err := s.listener.Close(); err != nil {
+			s.l.WithError(err).Warn("Failed to close the sshd listener")
+		} else {
+			s.l.Info("SSH server stopped listening")
+		}
 	}
 
-	err := s.listener.Close()
-	if err != nil {
-		s.l.WithError(err).Warn("Failed to close the sshd listener")
-		return
+	// Force close all existing connections.
+	// TODO I believe this has a slight race if the listener has just accepted
+	// a connection. Can fix by moving this to the goroutine that's accepting.
+	for _, c := range s.conns {
+		c.Close()
 	}
 
-	s.l.Info("SSH server stopped listening")
 	return
 }
 

+ 26 - 14
stats.go

@@ -17,24 +17,35 @@ import (
 	"github.com/sirupsen/logrus"
 )
 
-func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) error {
+// startStats initializes stats from config. On success, if any futher work
+// is needed to serve stats, it returns a func to handle that work. If no
+// work is needed, it'll return nil. On failure, it returns nil, error.
+func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest bool) (func(), error) {
 	mType := c.GetString("stats.type", "")
 	if mType == "" || mType == "none" {
-		return nil
+		return nil, nil
 	}
 
 	interval := c.GetDuration("stats.interval", 0)
 	if interval == 0 {
-		return fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", ""))
+		return nil, fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", ""))
 	}
 
+	var startFn func()
 	switch mType {
 	case "graphite":
-		startGraphiteStats(l, interval, c, configTest)
+		err := startGraphiteStats(l, interval, c, configTest)
+		if err != nil {
+			return nil, err
+		}
 	case "prometheus":
-		startPrometheusStats(l, interval, c, buildVersion, configTest)
+		var err error
+		startFn, err = startPrometheusStats(l, interval, c, buildVersion, configTest)
+		if err != nil {
+			return nil, err
+		}
 	default:
-		return fmt.Errorf("stats.type was not understood: %s", mType)
+		return nil, fmt.Errorf("stats.type was not understood: %s", mType)
 	}
 
 	metrics.RegisterDebugGCStats(metrics.DefaultRegistry)
@@ -43,7 +54,7 @@ func startStats(l *logrus.Logger, c *Config, buildVersion string, configTest boo
 	go metrics.CaptureDebugGCStats(metrics.DefaultRegistry, interval)
 	go metrics.CaptureRuntimeMemStats(metrics.DefaultRegistry, interval)
 
-	return nil
+	return startFn, nil
 }
 
 func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
@@ -59,25 +70,25 @@ func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest
 		return fmt.Errorf("error while setting up graphite sink: %s", err)
 	}
 
-	l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
 	if !configTest {
+		l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr)
 		go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr)
 	}
 	return nil
 }
 
-func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) error {
+func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVersion string, configTest bool) (func(), error) {
 	namespace := c.GetString("stats.namespace", "")
 	subsystem := c.GetString("stats.subsystem", "")
 
 	listen := c.GetString("stats.listen", "")
 	if listen == "" {
-		return fmt.Errorf("stats.listen should not be empty")
+		return nil, fmt.Errorf("stats.listen should not be empty")
 	}
 
 	path := c.GetString("stats.path", "")
 	if path == "" {
-		return fmt.Errorf("stats.path should not be empty")
+		return nil, fmt.Errorf("stats.path should not be empty")
 	}
 
 	pr := prometheus.NewRegistry()
@@ -98,13 +109,14 @@ func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, buildVer
 	pr.MustRegister(g)
 	g.Set(1)
 
+	var startFn func()
 	if !configTest {
-		go func() {
+		startFn = func() {
 			l.Infof("Prometheus stats listening on %s at %s", listen, path)
 			http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l}))
 			log.Fatal(http.ListenAndServe(listen, nil))
-		}()
+		}
 	}
 
-	return nil
+	return startFn, nil
 }