Nate Brown 5 months ago
parent
commit
2ea8a72d5c
4 changed files with 57 additions and 28 deletions
  1. 2 1
      bits.go
  2. 2 0
      connection_state.go
  3. 51 26
      interface.go
  4. 2 1
      outside.go

+ 2 - 1
bits.go

@@ -5,6 +5,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 )
 )
 
 
+// TODO: Pretty sure this is just all sorts of racy now, we need it to be atomic
 type Bits struct {
 type Bits struct {
 	length             uint64
 	length             uint64
 	current            uint64
 	current            uint64
@@ -43,7 +44,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
 	}
 	}
 
 
 	// Not within the window
 	// Not within the window
-	l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
+	l.Error("rejected a packet (top) %d %d\n", b.current, i)
 	return false
 	return false
 }
 }
 
 

+ 2 - 0
connection_state.go

@@ -13,6 +13,8 @@ import (
 	"github.com/slackhq/nebula/noiseutil"
 	"github.com/slackhq/nebula/noiseutil"
 )
 )
 
 
+// TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
+// 4092 should be sufficient for 5Gbps
 const ReplayWindow = 1024
 const ReplayWindow = 1024
 
 
 type ConnectionState struct {
 type ConnectionState struct {

+ 51 - 26
interface.go

@@ -202,9 +202,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 			dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
 			dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
 		},
 		},
 
 
-		inbound:  make(chan *packet.Packet, 1024),
-		outbound: make(chan *[]byte, 1024),
-		l:        c.l,
+		//TODO: configurable size
+		inbound:  make(chan *packet.Packet, 1028),
+		outbound: make(chan *[]byte, 1028),
+
+		l: c.l,
 	}
 	}
 
 
 	ifce.inPool = sync.Pool{New: func() any {
 	ifce.inPool = sync.Pool{New: func() any {
@@ -264,22 +266,22 @@ func (f *Interface) activate() error {
 }
 }
 
 
 func (f *Interface) run(c context.Context) (func(), error) {
 func (f *Interface) run(c context.Context) (func(), error) {
-	// Launch n queues to read packets from udp
 	for i := 0; i < f.routines; i++ {
 	for i := 0; i < f.routines; i++ {
+		// Launch n queues to read packets from udp
 		f.wg.Add(1)
 		f.wg.Add(1)
 		go f.listenOut(i)
 		go f.listenOut(i)
-	}
 
 
-	// Launch n queues to read packets from tun dev
-	for i := 0; i < f.routines; i++ {
+		// Launch n queues to read packets from tun dev
 		f.wg.Add(1)
 		f.wg.Add(1)
 		go f.listenIn(f.readers[i], i)
 		go f.listenIn(f.readers[i], i)
-	}
 
 
-	// Launch n queues to read packets from tun dev
-	for i := 0; i < f.routines; i++ {
+		// Launch n queues to read packets from tun dev
+		f.wg.Add(1)
+		go f.workerIn(i, c)
+
+		// Launch n queues to read packets from tun dev
 		f.wg.Add(1)
 		f.wg.Add(1)
-		go f.worker(i, c)
+		go f.workerOut(i, c)
 	}
 	}
 
 
 	return f.wg.Wait, nil
 	return f.wg.Wait, nil
@@ -298,12 +300,16 @@ func (f *Interface) listenOut(i int) {
 		p := f.inPool.Get().(*packet.Packet)
 		p := f.inPool.Get().(*packet.Packet)
 		//TODO: have the listener store this in the msgs array after a read instead of doing a copy
 		//TODO: have the listener store this in the msgs array after a read instead of doing a copy
 
 
+		p.Payload = p.Payload[:mtu]
 		copy(p.Payload, payload)
 		copy(p.Payload, payload)
 		p.Payload = p.Payload[:len(payload)]
 		p.Payload = p.Payload[:len(payload)]
 		p.Addr = fromUdpAddr
 		p.Addr = fromUdpAddr
-		select {
-		case f.inbound <- p:
-		}
+		f.inbound <- p
+		//select {
+		//case f.inbound <- p:
+		//default:
+		//	f.l.Error("Dropped packet from inbound channel")
+		//}
 	})
 	})
 
 
 	if err != nil && !f.closed.Load() {
 	if err != nil && !f.closed.Load() {
@@ -320,6 +326,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 
 
 	for {
 	for {
 		p := f.outPool.Get().(*[]byte)
 		p := f.outPool.Get().(*[]byte)
+		*p = (*p)[:mtu]
 		n, err := reader.Read(*p)
 		n, err := reader.Read(*p)
 		if err != nil {
 		if err != nil {
 			if !f.closed.Load() {
 			if !f.closed.Load() {
@@ -331,31 +338,30 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 
 
 		*p = (*p)[:n]
 		*p = (*p)[:n]
 		//TODO: nonblocking channel write
 		//TODO: nonblocking channel write
-		select {
-		case f.outbound <- p:
-		}
+		f.outbound <- p
+		//select {
+		//case f.outbound <- p:
+		//default:
+		//	f.l.Error("Dropped packet from outbound channel")
+		//}
 	}
 	}
 
 
 	f.l.Debugf("overlay reader %v is done", i)
 	f.l.Debugf("overlay reader %v is done", i)
 	f.wg.Done()
 	f.wg.Done()
 }
 }
 
 
-func (f *Interface) worker(i int, ctx context.Context) {
+func (f *Interface) workerIn(i int, ctx context.Context) {
 	lhh := f.lightHouse.NewRequestHandler()
 	lhh := f.lightHouse.NewRequestHandler()
 	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
-	fwPacket := &firewall.Packet{}
-	nb := make([]byte, 12, 12)
-	result := make([]byte, mtu)
+	fwPacket2 := &firewall.Packet{}
+	nb2 := make([]byte, 12, 12)
+	result2 := make([]byte, mtu)
 	h := &header.H{}
 	h := &header.H{}
 
 
 	for {
 	for {
 		select {
 		select {
-		case data := <-f.outbound:
-			f.consumeInsidePacket(*data, fwPacket, nb, result, i, conntrackCache.Get(f.l))
-			*data = (*data)[:mtu]
-			f.outPool.Put(data)
 		case p := <-f.inbound:
 		case p := <-f.inbound:
-			f.readOutsidePackets(p.Addr, nil, result[:0], p.Payload, h, fwPacket, lhh, nb, i, conntrackCache.Get(f.l))
+			f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
 			p.Payload = p.Payload[:mtu]
 			p.Payload = p.Payload[:mtu]
 			f.inPool.Put(p)
 			f.inPool.Put(p)
 		case <-ctx.Done():
 		case <-ctx.Done():
@@ -365,6 +371,25 @@ func (f *Interface) worker(i int, ctx context.Context) {
 	}
 	}
 }
 }
 
 
+func (f *Interface) workerOut(i int, ctx context.Context) {
+	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
+	fwPacket1 := &firewall.Packet{}
+	nb1 := make([]byte, 12, 12)
+	result1 := make([]byte, mtu)
+
+	for {
+		select {
+		case data := <-f.outbound:
+			f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
+			*data = (*data)[:mtu]
+			f.outPool.Put(data)
+		case <-ctx.Done():
+			f.wg.Done()
+			return
+		}
+	}
+}
+
 func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
 func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
 	c.RegisterReloadCallback(f.reloadFirewall)
 	c.RegisterReloadCallback(f.reloadFirewall)
 	c.RegisterReloadCallback(f.reloadSendRecvError)
 	c.RegisterReloadCallback(f.reloadSendRecvError)

+ 2 - 1
outside.go

@@ -245,6 +245,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
 			return
 			return
 		}
 		}
 
 
+		//TODO: Seems we have a bunch of stuff racing here, since we don't have a lock on hostinfo anymore we announce roaming in bursts
 		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
 		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
 			Info("Host roamed to new udp ip/port.")
 			Info("Host roamed to new udp ip/port.")
 		hostinfo.lastRoam = time.Now()
 		hostinfo.lastRoam = time.Now()
@@ -470,7 +471,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 
 
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
 	if err != nil {
 	if err != nil {
-		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
+		hostinfo.logger(f.l).WithError(err).WithField("fwPacket", fwPacket).Error("Failed to decrypt packet")
 		return false
 		return false
 	}
 	}