瀏覽代碼

Guard e2e udp and tun channels when closed (#934)

Nate Brown 2 年之前
父節點
當前提交
9c6592b159
共有 3 個文件被更改,包括 32 次插入4 次删除
  1. 2 0
      e2e/handshakes_test.go
  2. 14 1
      overlay/tun_tester.go
  3. 16 3
      udp/udp_tester.go

+ 2 - 0
e2e/handshakes_test.go

@@ -410,6 +410,8 @@ func TestStage1RaceRelays(t *testing.T) {
 	p := r.RouteForAllUntilTxTun(myControl)
 	_ = p
 
+	r.FlushAll()
+
 	myControl.Stop()
 	theirControl.Stop()
 	relayControl.Stop()

+ 14 - 1
overlay/tun_tester.go

@@ -8,6 +8,7 @@ import (
 	"io"
 	"net"
 	"os"
+	"sync/atomic"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cidr"
@@ -21,6 +22,7 @@ type TestTun struct {
 	routeTree *cidr.Tree4
 	l         *logrus.Logger
 
+	closed    atomic.Bool
 	rxPackets chan []byte // Packets to receive into nebula
 	TxPackets chan []byte // Packets transmitted outside by nebula
 }
@@ -50,6 +52,10 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int
 // These are unencrypted ip layer frames destined for another nebula node.
 // packets should exit the udp side, capture them with udpConn.Get
 func (t *TestTun) Send(packet []byte) {
+	if t.closed.Load() {
+		return
+	}
+
 	if t.l.Level >= logrus.DebugLevel {
 		t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet")
 	}
@@ -98,6 +104,10 @@ func (t *TestTun) Name() string {
 }
 
 func (t *TestTun) Write(b []byte) (n int, err error) {
+	if t.closed.Load() {
+		return 0, io.ErrClosedPipe
+	}
+
 	packet := make([]byte, len(b), len(b))
 	copy(packet, b)
 	t.TxPackets <- packet
@@ -105,7 +115,10 @@ func (t *TestTun) Write(b []byte) (n int, err error) {
 }
 
 func (t *TestTun) Close() error {
-	close(t.rxPackets)
+	if t.closed.CompareAndSwap(false, true) {
+		close(t.rxPackets)
+		close(t.TxPackets)
+	}
 	return nil
 }
 

+ 16 - 3
udp/udp_tester.go

@@ -5,7 +5,9 @@ package udp
 
 import (
 	"fmt"
+	"io"
 	"net"
+	"sync/atomic"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
@@ -42,7 +44,8 @@ type TesterConn struct {
 	RxPackets chan *Packet // Packets to receive into nebula
 	TxPackets chan *Packet // Packets transmitted outside by nebula
 
-	l *logrus.Logger
+	closed atomic.Bool
+	l      *logrus.Logger
 }
 
 func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) {
@@ -58,6 +61,10 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, er
 // this is an encrypted packet or a handshake message in most cases
 // packets were transmitted from another nebula node, you can send them with Tun.Send
 func (u *TesterConn) Send(packet *Packet) {
+	if u.closed.Load() {
+		return
+	}
+
 	h := &header.H{}
 	if err := h.Parse(packet.Data); err != nil {
 		panic(err)
@@ -92,6 +99,10 @@ func (u *TesterConn) Get(block bool) *Packet {
 //********************************************************************************************************************//
 
 func (u *TesterConn) WriteTo(b []byte, addr *Addr) error {
+	if u.closed.Load() {
+		return io.ErrClosedPipe
+	}
+
 	p := &Packet{
 		Data:     make([]byte, len(b), len(b)),
 		FromIp:   make([]byte, 16),
@@ -142,7 +153,9 @@ func (u *TesterConn) Rebind() error {
 }
 
 func (u *TesterConn) Close() error {
-	close(u.RxPackets)
-	close(u.TxPackets)
+	if u.closed.CompareAndSwap(false, true) {
+		close(u.RxPackets)
+		close(u.TxPackets)
+	}
 	return nil
 }