Browse Source

Testing the concept

Nate Brown 3 months ago
parent
commit
2b5d51bcac
3 changed files with 80 additions and 19 deletions
  1. 1 1
      control.go
  2. 67 18
      interface.go
  3. 12 0
      packet/packet.go

+ 1 - 1
control.go

@@ -96,7 +96,7 @@ func (c *Control) Start() (func(), error) {
 	// Start reading packets.
 	c.state = Started
 	c.stateLock.Unlock()
-	return c.f.run()
+	return c.f.run(c.ctx)
 }
 
 func (c *Control) State() RunState {

+ 67 - 18
interface.go

@@ -18,6 +18,7 @@ import (
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/overlay"
+	"github.com/slackhq/nebula/packet"
 	"github.com/slackhq/nebula/udp"
 )
 
@@ -94,6 +95,12 @@ type Interface struct {
 	cachedPacketMetrics *cachedPacketMetrics
 
 	l *logrus.Logger
+
+	inPool  sync.Pool
+	inbound chan *packet.Packet
+
+	outPool  sync.Pool
+	outbound chan *[]byte
 }
 
 type EncWriter interface {
@@ -192,9 +199,20 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 			dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
 		},
 
-		l: c.l,
+		inbound:  make(chan *packet.Packet, 1024),
+		outbound: make(chan *[]byte, 1024),
+		l:        c.l,
 	}
 
+	ifce.inPool = sync.Pool{New: func() any {
+		return packet.New()
+	}}
+
+	ifce.outPool = sync.Pool{New: func() any {
+		t := make([]byte, mtu)
+		return &t
+	}}
+
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -242,7 +260,7 @@ func (f *Interface) activate() error {
 	return nil
 }
 
-func (f *Interface) run() (func(), error) {
+func (f *Interface) run(c context.Context) (func(), error) {
 	// Launch n queues to read packets from udp
 	for i := 0; i < f.routines; i++ {
 		f.wg.Add(1)
@@ -255,6 +273,12 @@ func (f *Interface) run() (func(), error) {
 		go f.listenIn(f.readers[i], i)
 	}
 
+	// Launch n queues to read packets from tun dev
+	for i := 0; i < f.routines; i++ {
+		f.wg.Add(1)
+		go f.worker(i, c)
+	}
+
 	return f.wg.Wait, nil
 }
 
@@ -267,15 +291,16 @@ func (f *Interface) listenOut(i int) {
 		li = f.outside
 	}
 
-	ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
-	lhh := f.lightHouse.NewRequestHandler()
-	plaintext := make([]byte, udp.MTU)
-	h := &header.H{}
-	fwPacket := &firewall.Packet{}
-	nb := make([]byte, 12, 12)
-
 	err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
-		f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
+		p := f.inPool.Get().(*packet.Packet)
+		//TODO: have the listener store this in the msgs array after a read instead of doing a copy
+
+		copy(p.Payload, payload)
+		p.Payload = p.Payload[:len(payload)]
+		p.Addr = fromUdpAddr
+		select {
+		case f.inbound <- p:
+		}
 	})
 
 	if err != nil && !f.closed.Load() {
@@ -289,15 +314,10 @@ func (f *Interface) listenOut(i int) {
 
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 	runtime.LockOSThread()
-	packet := make([]byte, mtu)
-	out := make([]byte, mtu)
-	fwPacket := &firewall.Packet{}
-	nb := make([]byte, 12, 12)
-
-	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 
 	for {
-		n, err := reader.Read(packet)
+		p := f.outPool.Get().(*[]byte)
+		n, err := reader.Read(*p)
 		if err != nil {
 			if !f.closed.Load() {
 				f.l.WithError(err).Error("Error while reading outbound packet, closing")
@@ -306,13 +326,42 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 			break
 		}
 
-		f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
+		*p = (*p)[:n]
+		//TODO: nonblocking channel write
+		select {
+		case f.outbound <- p:
+		}
 	}
 
 	f.l.Debugf("overlay reader %v is done", i)
 	f.wg.Done()
 }
 
+func (f *Interface) worker(i int, ctx context.Context) {
+	lhh := f.lightHouse.NewRequestHandler()
+	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
+	fwPacket := &firewall.Packet{}
+	nb := make([]byte, 12, 12)
+	result := make([]byte, mtu)
+	h := &header.H{}
+
+	for {
+		select {
+		case data := <-f.outbound:
+			f.consumeInsidePacket(*data, fwPacket, nb, result, i, conntrackCache.Get(f.l))
+			*data = (*data)[:mtu]
+			f.outPool.Put(data)
+		case p := <-f.inbound:
+			f.readOutsidePackets(p.Addr, nil, result[:0], p.Payload, h, fwPacket, lhh, nb, i, conntrackCache.Get(f.l))
+			p.Payload = p.Payload[:mtu]
+			f.inPool.Put(p)
+		case <-ctx.Done():
+			f.wg.Done()
+			return
+		}
+	}
+}
+
 func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
 	c.RegisterReloadCallback(f.reloadFirewall)
 	c.RegisterReloadCallback(f.reloadSendRecvError)

+ 12 - 0
packet/packet.go

@@ -0,0 +1,12 @@
+package packet
+
+import "net/netip"
+
+type Packet struct {
+	Payload []byte
+	Addr    netip.AddrPort
+}
+
+func New() *Packet {
+	return &Packet{Payload: make([]byte, 9001)}
+}