Sfoglia il codice sorgente

Remove more os.Exit calls and give a more reliable wait for stop function

Nate Brown 3 mesi fa
parent
commit
c9ff55d586
6 ha cambiato i file con 109 aggiunte e 31 eliminazioni
  1. 10 2
      cmd/nebula-service/main.go
  2. 10 2
      cmd/nebula/main.go
  3. 50 6
      control.go
  4. 20 12
      interface.go
  5. 8 8
      main.go
  6. 11 1
      service/service.go

+ 10 - 2
cmd/nebula-service/main.go

@@ -65,8 +65,16 @@ func main() {
 	}
 
 	if !*configTest {
-		ctrl.Start()
-		ctrl.ShutdownBlock()
+		wait, err := ctrl.Start()
+		if err != nil {
+			util.LogWithContextIfNeeded("Error while running", err, l)
+			os.Exit(1)
+		}
+
+		go ctrl.ShutdownBlock()
+		wait()
+
+		l.Info("Goodbye")
 	}
 
 	os.Exit(0)

+ 10 - 2
cmd/nebula/main.go

@@ -59,9 +59,17 @@ func main() {
 	}
 
 	if !*configTest {
-		ctrl.Start()
+		wait, err := ctrl.Start()
+		if err != nil {
+			util.LogWithContextIfNeeded("Error while running", err, l)
+			os.Exit(1)
+		}
+
+		go ctrl.ShutdownBlock()
 		notifyReady(l)
-		ctrl.ShutdownBlock()
+		wait()
+
+		l.Info("Goodbye")
 	}
 
 	os.Exit(0)

+ 50 - 6
control.go

@@ -2,9 +2,11 @@ package nebula
 
 import (
 	"context"
+	"errors"
 	"net/netip"
 	"os"
 	"os/signal"
+	"sync"
 	"syscall"
 
 	"github.com/sirupsen/logrus"
@@ -13,6 +15,16 @@ import (
 	"github.com/slackhq/nebula/overlay"
 )
 
+type RunState int
+
+const (
+	Stopped  RunState = 0 // The control has yet to be started
+	Started  RunState = 1 // The control has been started
+	Stopping RunState = 2 // The control is stopping
+)
+
+var ErrAlreadyStarted = errors.New("nebula is already started")
+
 // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
 // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
 
@@ -26,6 +38,9 @@ type controlHostLister interface {
 }
 
 type Control struct {
+	stateLock sync.Mutex
+	state     RunState
+
 	f               *Interface
 	l               *logrus.Logger
 	ctx             context.Context
@@ -48,10 +63,21 @@ type ControlHostInfo struct {
 	CurrentRelaysThroughMe []netip.Addr     `json:"currentRelaysThroughMe"`
 }
 
-// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
-func (c *Control) Start() {
+// Start actually runs nebula, this is a nonblocking call.
+// The returned function can be used to wait for nebula to fully stop.
+func (c *Control) Start() (func(), error) {
+	c.stateLock.Lock()
+	if c.state != Stopped {
+		c.stateLock.Unlock()
+		return nil, ErrAlreadyStarted
+	}
+
 	// Activate the interface
-	c.f.activate()
+	err := c.f.activate()
+	if err != nil {
+		c.stateLock.Unlock()
+		return nil, err
+	}
 
 	// Call all the delayed funcs that waited patiently for the interface to be created.
 	if c.sshStart != nil {
@@ -68,15 +94,33 @@ func (c *Control) Start() {
 	}
 
 	// Start reading packets.
-	c.f.run()
+	c.state = Started
+	c.stateLock.Unlock()
+	return c.f.run()
+}
+
+func (c *Control) State() RunState {
+	c.stateLock.Lock()
+	defer c.stateLock.Unlock()
+	return c.state
 }
 
 func (c *Control) Context() context.Context {
 	return c.ctx
 }
 
-// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
+// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
 func (c *Control) Stop() {
+	c.stateLock.Lock()
+	if c.state != Started {
+		c.stateLock.Unlock()
+		// We are stopping or stopped already
+		return
+	}
+
+	c.state = Stopping
+	c.stateLock.Unlock()
+
 	// Stop the handshakeManager (and other services), to prevent new tunnels from
 	// being created while we're shutting them all down.
 	c.cancel()
@@ -85,7 +129,7 @@ func (c *Control) Stop() {
 	if err := c.f.Close(); err != nil {
 		c.l.WithError(err).Error("Close interface failed")
 	}
-	c.l.Info("Goodbye")
+	c.state = Stopped
 }
 
 // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled

+ 20 - 12
interface.go

@@ -6,8 +6,8 @@ import (
 	"fmt"
 	"io"
 	"net/netip"
-	"os"
 	"runtime"
+	"sync"
 	"sync/atomic"
 	"time"
 
@@ -87,6 +87,7 @@ type Interface struct {
 
 	writers []udp.Conn
 	readers []io.ReadWriteCloser
+	wg      sync.WaitGroup
 
 	metricHandshakes    metrics.Histogram
 	messageMetrics      *MessageMetrics
@@ -206,7 +207,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 // activate creates the interface on the host. After the interface is created, any
 // other services that want to bind listeners to its IP may do so successfully. However,
 // the interface isn't going to process anything until run() is called.
-func (f *Interface) activate() {
+func (f *Interface) activate() error {
 	// actually turn on tun dev
 
 	addr, err := f.outside.LocalAddr()
@@ -227,28 +228,34 @@ func (f *Interface) activate() {
 		if i > 0 {
 			reader, err = f.inside.NewMultiQueueReader()
 			if err != nil {
-				f.l.Fatal(err)
+				return err
 			}
 		}
 		f.readers[i] = reader
 	}
 
-	if err := f.inside.Activate(); err != nil {
+	if err = f.inside.Activate(); err != nil {
 		f.inside.Close()
-		f.l.Fatal(err)
+		return err
 	}
+
+	return nil
 }
 
-func (f *Interface) run() {
+func (f *Interface) run() (func(), error) {
 	// Launch n queues to read packets from udp
 	for i := 0; i < f.routines; i++ {
 		go f.listenOut(i)
+		f.wg.Add(1)
 	}
 
 	// Launch n queues to read packets from tun dev
 	for i := 0; i < f.routines; i++ {
 		go f.listenIn(f.readers[i], i)
+		f.wg.Add(1)
 	}
+
+	return f.wg.Wait, nil
 }
 
 func (f *Interface) listenOut(i int) {
@@ -271,6 +278,8 @@ func (f *Interface) listenOut(i int) {
 	li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
 		f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
 	})
+
+	f.wg.Done()
 }
 
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@@ -286,17 +295,16 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 	for {
 		n, err := reader.Read(packet)
 		if err != nil {
-			if errors.Is(err, os.ErrClosed) && f.closed.Load() {
-				return
+			if !f.closed.Load() {
+				f.l.WithError(err).Error("Error while reading outbound packet")
 			}
-
-			f.l.WithError(err).Error("Error while reading outbound packet")
-			// This only seems to happen when something fatal happens to the fd, so exit.
-			os.Exit(2)
+			break
 		}
 
 		f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
 	}
+
+	f.wg.Done()
 }
 
 func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {

+ 8 - 8
main.go

@@ -288,13 +288,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	}
 
 	return &Control{
-		ifce,
-		l,
-		ctx,
-		cancel,
-		sshStart,
-		statsStart,
-		dnsStart,
-		lightHouse.StartUpdateWorker,
+		f:               ifce,
+		l:               l,
+		ctx:             ctx,
+		cancel:          cancel,
+		sshStart:        sshStart,
+		statsStart:      statsStart,
+		dnsStart:        dnsStart,
+		lighthouseStart: lightHouse.StartUpdateWorker,
 	}, nil
 }

+ 11 - 1
service/service.go

@@ -54,7 +54,11 @@ func New(config *config.C) (*Service, error) {
 	if err != nil {
 		return nil, err
 	}
-	control.Start()
+
+	wait, err := control.Start()
+	if err != nil {
+		return nil, err
+	}
 
 	ctx := control.Context()
 	eg, ctx := errgroup.WithContext(ctx)
@@ -151,6 +155,12 @@ func New(config *config.C) (*Service, error) {
 		}
 	})
 
+	// Add the nebula wait function to the group
+	eg.Go(func() error {
+		wait()
+		return nil
+	})
+
 	return &s, nil
 }