瀏覽代碼

fix BatchRead interface & make batch size configurable

Jay Wren 3 天之前
父節點
當前提交
8281b1699f
共有 8 個文件被更改,包括 30 次插入61 次删除
  1. 21 11
      interface.go
  2. 1 0
      main.go
  3. 0 1
      overlay/tun_darwin.go
  4. 0 1
      overlay/tun_freebsd.go
  5. 0 1
      overlay/tun_linux.go
  6. 0 1
      overlay/tun_openbsd.go
  7. 8 45
      overlay/tun_wg.go
  8. 0 1
      overlay/tun_windows.go

+ 21 - 11
interface.go

@@ -47,6 +47,7 @@ type InterfaceConfig struct {
 	reQueryWait     time.Duration
 
 	ConntrackCacheTimeout time.Duration
+	batchSize             int
 	l                     *logrus.Logger
 }
 
@@ -84,6 +85,7 @@ type Interface struct {
 	version     string
 
 	conntrackCacheTimeout time.Duration
+	batchSize             int
 
 	writers []udp.Conn
 	readers []io.ReadWriteCloser
@@ -112,7 +114,7 @@ type EncWriter interface {
 
 // BatchReader is an interface for readers that support vectorized packet reading
 type BatchReader interface {
-	BatchRead() ([][]byte, []int, error)
+	BatchRead(buffers [][]byte, sizes []int) (int, error)
 }
 
 // BatchWriter is an interface for writers that support vectorized packet writing
@@ -196,6 +198,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		relayManager:          c.relayManager,
 		connectionManager:     c.connectionManager,
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
+		batchSize:             c.batchSize,
 
 		metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
 		messageMetrics:   c.MessageMetrics,
@@ -323,21 +326,28 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 
 // listenInBatch handles vectorized packet reading for improved performance
 func (f *Interface) listenInBatch(reader BatchReader, i int) error {
-	// Allocate per-packet state
-	fwPackets := make([]*firewall.Packet, 64) // Match batch size
-	outBuffers := make([][]byte, 64)
-	nbBuffers := make([][]byte, 64)
-
-	for j := 0; j < 64; j++ {
+	// Allocate per-packet state and buffers for batch reading
+	batchSize := f.batchSize
+	if batchSize <= 0 {
+		batchSize = 64 // Fallback to default if not configured
+	}
+	fwPackets := make([]*firewall.Packet, batchSize)
+	outBuffers := make([][]byte, batchSize)
+	nbBuffers := make([][]byte, batchSize)
+	packets := make([][]byte, batchSize)
+	sizes := make([]int, batchSize)
+
+	for j := 0; j < batchSize; j++ {
 		fwPackets[j] = &firewall.Packet{}
 		outBuffers[j] = make([]byte, mtu)
-		nbBuffers[j] = make([]byte, 12, 12)
+		nbBuffers[j] = make([]byte, 12)
+		packets[j] = make([]byte, mtu)
 	}
 
 	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 
 	for {
-		packets, sizes, err := reader.BatchRead()
+		n, err := reader.BatchRead(packets, sizes)
 		if err != nil {
 			if errors.Is(err, os.ErrClosed) && f.closed.Load() {
 				return nil
@@ -348,8 +358,8 @@ func (f *Interface) listenInBatch(reader BatchReader, i int) error {
 
 		// Process each packet in the batch
 		cache := conntrackCache.Get(f.l)
-		for idx := 0; idx < len(packets); idx++ {
-			if idx < len(sizes) && sizes[idx] > 0 {
+		for idx := 0; idx < n; idx++ {
+			if sizes[idx] > 0 {
 				// Use modulo to reuse fw packet state if batch is larger than our pre-allocated state
 				stateIdx := idx % len(fwPackets)
 				f.consumeInsidePacket(packets[idx][:sizes[idx]], fwPackets[stateIdx], nbBuffers[stateIdx], outBuffers[stateIdx], i, cache)

+ 1 - 0
main.go

@@ -242,6 +242,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		relayManager:          NewRelayManager(ctx, l, hostMap, c),
 		punchy:                punchy,
 		ConntrackCacheTimeout: conntrackCacheTimeout,
+		batchSize:             c.GetInt("tun.batch_size", 64),
 		l:                     l,
 	}
 

+ 0 - 1
overlay/tun_darwin.go

@@ -258,7 +258,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
 	return &wgTunReader{
 		parent:    t,
 		tunDevice: t.tunDevice,
-		batchSize: 64,
 		offset:    0,
 		l:         t.l,
 	}, nil

+ 0 - 1
overlay/tun_freebsd.go

@@ -214,7 +214,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
 	return &wgTunReader{
 		parent:    t,
 		tunDevice: t.tunDevice,
-		batchSize: 64,
 		offset:    0,
 		l:         t.l,
 	}, nil

+ 0 - 1
overlay/tun_linux.go

@@ -323,7 +323,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
 	return &wgTunReader{
 		parent:    t,
 		tunDevice: t.tunDevice,
-		batchSize: 64, // Default batch size
 		offset:    0,
 		l:         t.l,
 	}, nil

+ 0 - 1
overlay/tun_openbsd.go

@@ -196,7 +196,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
 	return &wgTunReader{
 		parent:    t,
 		tunDevice: t.tunDevice,
-		batchSize: 64,
 		offset:    0,
 		l:         t.l,
 	}, nil

+ 8 - 45
overlay/tun_wg.go

@@ -7,7 +7,6 @@ import (
 	"fmt"
 	"io"
 	"net/netip"
-	"sync"
 	"sync/atomic"
 
 	"github.com/gaissmai/bart"
@@ -37,7 +36,7 @@ type wgTun struct {
 
 // BatchReader interface for readers that support vectorized I/O
 type BatchReader interface {
-	BatchRead() ([][]byte, []int, error)
+	BatchRead(buffers [][]byte, sizes []int) (int, error)
 }
 
 // BatchWriter interface for writers that support vectorized I/O
@@ -47,24 +46,12 @@ type BatchWriter interface {
 
 // wgTunReader wraps a single TUN queue for multi-queue support
 type wgTunReader struct {
-	parent     *wgTun
-	tunDevice  wgtun.Device
-	buffers    [][]byte
-	sizes      []int
-	offset     int
-	batchSize  int
-	l          *logrus.Logger
+	parent    *wgTun
+	tunDevice wgtun.Device
+	offset    int
+	l         *logrus.Logger
 }
 
-var (
-	bufferPool = sync.Pool{
-		New: func() interface{} {
-			buf := make([]byte, 9001) // MTU size
-			return &buf
-		},
-	}
-)
-
 func (t *wgTun) Networks() []netip.Prefix {
 	return t.vpnNetworks
 }
@@ -210,23 +197,9 @@ func (t *wgTun) reload(c *config.C, initial bool) error {
 }
 
 // BatchRead reads multiple packets from the TUN device using vectorized I/O
-func (r *wgTunReader) BatchRead() ([][]byte, []int, error) {
-	// Reuse buffers from pool
-	if len(r.buffers) == 0 {
-		r.buffers = make([][]byte, r.batchSize)
-		r.sizes = make([]int, r.batchSize)
-		for i := 0; i < r.batchSize; i++ {
-			buf := bufferPool.Get().(*[]byte)
-			r.buffers[i] = (*buf)[:cap(*buf)]
-		}
-	}
-
-	n, err := r.tunDevice.Read(r.buffers, r.sizes, r.offset)
-	if err != nil {
-		return nil, nil, err
-	}
-
-	return r.buffers[:n], r.sizes[:n], nil
+// The caller provides buffers and sizes slices, and this function returns the number of packets read.
+func (r *wgTunReader) BatchRead(buffers [][]byte, sizes []int) (int, error) {
+	return r.tunDevice.Read(buffers, sizes, r.offset)
 }
 
 // Read implements io.Reader for wgTunReader (single packet for compatibility)
@@ -262,16 +235,6 @@ func (r *wgTunReader) BatchWrite(packets [][]byte) (int, error) {
 }
 
 func (r *wgTunReader) Close() error {
-	// Return buffers to pool
-	for i := range r.buffers {
-		if r.buffers[i] != nil {
-			bufferPool.Put(&r.buffers[i])
-			r.buffers[i] = nil
-		}
-	}
-	r.buffers = nil
-	r.sizes = nil
-
 	if r.tunDevice != nil {
 		return r.tunDevice.Close()
 	}

+ 0 - 1
overlay/tun_windows.go

@@ -175,7 +175,6 @@ func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
 	return &wgTunReader{
 		parent:    t,
 		tunDevice: t.tunDevice,
-		batchSize: 64,
 		offset:    0,
 		l:         t.l,
 	}, nil