فهرست منبع

add batching of packets

Ryan 1 ماه پیش
والد
کامیت
ad37749c5e
3فایلهای تغییر یافته به همراه167 افزوده شده و 42 حذف شده
  1. 164 28
      interface.go
  2. 1 12
      service/service.go
  3. 2 2
      udp/conn.go

+ 164 - 28
interface.go

@@ -22,7 +22,14 @@ import (
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 )
 )
 
 
-const mtu = 9001
+const (
+	mtu = 9001
+
+	inboundBatchSize      = 32
+	outboundBatchSize     = 32
+	batchFlushInterval    = 50 * time.Microsecond
+	maxOutstandingBatches = 1028
+)
 
 
 type InterfaceConfig struct {
 type InterfaceConfig struct {
 	HostMap            *HostMap
 	HostMap            *HostMap
@@ -97,10 +104,81 @@ type Interface struct {
 	l *logrus.Logger
 	l *logrus.Logger
 
 
 	inPool  sync.Pool
 	inPool  sync.Pool
-	inbound chan *packet.Packet
+	inbound []chan *packetBatch
 
 
 	outPool  sync.Pool
 	outPool  sync.Pool
-	outbound chan *[]byte
+	outbound []chan *outboundBatch
+
+	packetBatchPool   sync.Pool
+	outboundBatchPool sync.Pool
+}
+
+type packetBatch struct {
+	packets []*packet.Packet
+}
+
+func newPacketBatch() *packetBatch {
+	return &packetBatch{
+		packets: make([]*packet.Packet, 0, inboundBatchSize),
+	}
+}
+
+func (b *packetBatch) add(p *packet.Packet) {
+	b.packets = append(b.packets, p)
+}
+
+func (b *packetBatch) reset() {
+	for i := range b.packets {
+		b.packets[i] = nil
+	}
+	b.packets = b.packets[:0]
+}
+
+func (f *Interface) getPacketBatch() *packetBatch {
+	if v := f.packetBatchPool.Get(); v != nil {
+		b := v.(*packetBatch)
+		b.reset()
+		return b
+	}
+	return newPacketBatch()
+}
+
+func (f *Interface) releasePacketBatch(b *packetBatch) {
+	b.reset()
+	f.packetBatchPool.Put(b)
+}
+
+type outboundBatch struct {
+	payloads []*[]byte
+}
+
+func newOutboundBatch() *outboundBatch {
+	return &outboundBatch{payloads: make([]*[]byte, 0, outboundBatchSize)}
+}
+
+func (b *outboundBatch) add(buf *[]byte) {
+	b.payloads = append(b.payloads, buf)
+}
+
+func (b *outboundBatch) reset() {
+	for i := range b.payloads {
+		b.payloads[i] = nil
+	}
+	b.payloads = b.payloads[:0]
+}
+
+func (f *Interface) getOutboundBatch() *outboundBatch {
+	if v := f.outboundBatchPool.Get(); v != nil {
+		b := v.(*outboundBatch)
+		b.reset()
+		return b
+	}
+	return newOutboundBatch()
+}
+
+func (f *Interface) releaseOutboundBatch(b *outboundBatch) {
+	b.reset()
+	f.outboundBatchPool.Put(b)
 }
 }
 
 
 type EncWriter interface {
 type EncWriter interface {
@@ -203,12 +281,17 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		},
 		},
 
 
 		//TODO: configurable size
 		//TODO: configurable size
-		inbound:  make(chan *packet.Packet, 1028),
-		outbound: make(chan *[]byte, 1028),
+		inbound:  make([]chan *packetBatch, c.routines),
+		outbound: make([]chan *outboundBatch, c.routines),
 
 
 		l: c.l,
 		l: c.l,
 	}
 	}
 
 
+	for i := 0; i < c.routines; i++ {
+		ifce.inbound[i] = make(chan *packetBatch, maxOutstandingBatches)
+		ifce.outbound[i] = make(chan *outboundBatch, maxOutstandingBatches)
+	}
+
 	ifce.inPool = sync.Pool{New: func() any {
 	ifce.inPool = sync.Pool{New: func() any {
 		return packet.New()
 		return packet.New()
 	}}
 	}}
@@ -218,6 +301,14 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		return &t
 		return &t
 	}}
 	}}
 
 
+	ifce.packetBatchPool = sync.Pool{New: func() any {
+		return newPacketBatch()
+	}}
+
+	ifce.outboundBatchPool = sync.Pool{New: func() any {
+		return newOutboundBatch()
+	}}
+
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -296,22 +387,41 @@ func (f *Interface) listenOut(i int) {
 		li = f.outside
 		li = f.outside
 	}
 	}
 
 
+	batch := f.getPacketBatch()
+	lastFlush := time.Now()
+
+	flush := func(force bool) {
+		if len(batch.packets) == 0 {
+			if force {
+				f.releasePacketBatch(batch)
+			}
+			return
+		}
+
+		f.inbound[i] <- batch
+		batch = f.getPacketBatch()
+		lastFlush = time.Now()
+	}
+
 	err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
 	err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
 		p := f.inPool.Get().(*packet.Packet)
 		p := f.inPool.Get().(*packet.Packet)
-		//TODO: have the listener store this in the msgs array after a read instead of doing a copy
-
 		p.Payload = p.Payload[:mtu]
 		p.Payload = p.Payload[:mtu]
 		copy(p.Payload, payload)
 		copy(p.Payload, payload)
 		p.Payload = p.Payload[:len(payload)]
 		p.Payload = p.Payload[:len(payload)]
 		p.Addr = fromUdpAddr
 		p.Addr = fromUdpAddr
-		f.inbound <- p
-		//select {
-		//case f.inbound <- p:
-		//default:
-		//	f.l.Error("Dropped packet from inbound channel")
-		//}
+		batch.add(p)
+
+		if len(batch.packets) >= inboundBatchSize || time.Since(lastFlush) >= batchFlushInterval {
+			flush(false)
+		}
 	})
 	})
 
 
+	if len(batch.packets) > 0 {
+		f.inbound[i] <- batch
+	} else {
+		f.releasePacketBatch(batch)
+	}
+
 	if err != nil && !f.closed.Load() {
 	if err != nil && !f.closed.Load() {
 		f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
 		f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
 		//TODO: Trigger Control to close
 		//TODO: Trigger Control to close
@@ -324,6 +434,22 @@ func (f *Interface) listenOut(i int) {
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 	runtime.LockOSThread()
 	runtime.LockOSThread()
 
 
+	batch := f.getOutboundBatch()
+	lastFlush := time.Now()
+
+	flush := func(force bool) {
+		if len(batch.payloads) == 0 {
+			if force {
+				f.releaseOutboundBatch(batch)
+			}
+			return
+		}
+
+		f.outbound[i] <- batch
+		batch = f.getOutboundBatch()
+		lastFlush = time.Now()
+	}
+
 	for {
 	for {
 		p := f.outPool.Get().(*[]byte)
 		p := f.outPool.Get().(*[]byte)
 		*p = (*p)[:mtu]
 		*p = (*p)[:mtu]
@@ -337,13 +463,17 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 		}
 		}
 
 
 		*p = (*p)[:n]
 		*p = (*p)[:n]
-		//TODO: nonblocking channel write
-		f.outbound <- p
-		//select {
-		//case f.outbound <- p:
-		//default:
-		//	f.l.Error("Dropped packet from outbound channel")
-		//}
+		batch.add(p)
+
+		if len(batch.payloads) >= outboundBatchSize || time.Since(lastFlush) >= batchFlushInterval {
+			flush(false)
+		}
+	}
+
+	if len(batch.payloads) > 0 {
+		f.outbound[i] <- batch
+	} else {
+		f.releaseOutboundBatch(batch)
 	}
 	}
 
 
 	f.l.Debugf("overlay reader %v is done", i)
 	f.l.Debugf("overlay reader %v is done", i)
@@ -360,10 +490,13 @@ func (f *Interface) workerIn(i int, ctx context.Context) {
 
 
 	for {
 	for {
 		select {
 		select {
-		case p := <-f.inbound:
-			f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
-			p.Payload = p.Payload[:mtu]
-			f.inPool.Put(p)
+		case batch := <-f.inbound[i]:
+			for _, p := range batch.packets {
+				f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
+				p.Payload = p.Payload[:mtu]
+				f.inPool.Put(p)
+			}
+			f.releasePacketBatch(batch)
 		case <-ctx.Done():
 		case <-ctx.Done():
 			f.wg.Done()
 			f.wg.Done()
 			return
 			return
@@ -379,10 +512,13 @@ func (f *Interface) workerOut(i int, ctx context.Context) {
 
 
 	for {
 	for {
 		select {
 		select {
-		case data := <-f.outbound:
-			f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
-			*data = (*data)[:mtu]
-			f.outPool.Put(data)
+		case batch := <-f.outbound[i]:
+			for _, data := range batch.payloads {
+				f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
+				*data = (*data)[:mtu]
+				f.outPool.Put(data)
+			}
+			f.releaseOutboundBatch(batch)
 		case <-ctx.Done():
 		case <-ctx.Done():
 			f.wg.Done()
 			f.wg.Done()
 			return
 			return

+ 1 - 12
service/service.go

@@ -9,13 +9,10 @@ import (
 	"math"
 	"math"
 	"net"
 	"net"
 	"net/netip"
 	"net/netip"
-	"os"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 
 
-	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
-	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/overlay"
 	"golang.org/x/sync/errgroup"
 	"golang.org/x/sync/errgroup"
 	"gvisor.dev/gvisor/pkg/buffer"
 	"gvisor.dev/gvisor/pkg/buffer"
@@ -46,15 +43,7 @@ type Service struct {
 	}
 	}
 }
 }
 
 
-func New(config *config.C) (*Service, error) {
-	logger := logrus.New()
-	logger.Out = os.Stdout
-
-	control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
-	if err != nil {
-		return nil, err
-	}
-
+func New(control *nebula.Control) (*Service, error) {
 	wait, err := control.Start()
 	wait, err := control.Start()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err

+ 2 - 2
udp/conn.go

@@ -30,8 +30,8 @@ func (NoopConn) Rebind() error {
 func (NoopConn) LocalAddr() (netip.AddrPort, error) {
 func (NoopConn) LocalAddr() (netip.AddrPort, error) {
 	return netip.AddrPort{}, nil
 	return netip.AddrPort{}, nil
 }
 }
-func (NoopConn) ListenOut(_ EncReader) {
-	return
+func (NoopConn) ListenOut(_ EncReader) error {
+	return nil
 }
 }
 func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
 func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
 	return nil
 	return nil