Ryan 1 month ago
parent
commit
d18d1aea67
3 changed files with 403 additions and 3 deletions
  1. 367 3
      udp/udp_linux.go
  2. 18 0
      udp/udp_linux_32.go
  3. 18 0
      udp/udp_linux_64.go

+ 367 - 3
udp/udp_linux.go

@@ -5,9 +5,11 @@ package udp
 
 import (
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"net"
 	"net/netip"
+	"sync"
 	"syscall"
 	"time"
 	"unsafe"
@@ -20,11 +22,35 @@ import (
 
 var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
 
+const (
+	defaultGSOMaxSegments  = 8
+	defaultGSOFlushTimeout = 150 * time.Microsecond
+	maxGSOBatchBytes       = 0xFFFF
+)
+
+var (
+	errGSOFallback = errors.New("udp gso fallback")
+	errGSODisabled = errors.New("udp gso disabled")
+)
+
 type StdConn struct {
 	sysFd int
 	isV4  bool
 	l     *logrus.Logger
 	batch int
+
+	enableGRO bool
+	enableGSO bool
+
+	gsoMu           sync.Mutex
+	gsoBuf          []byte
+	gsoAddr         netip.AddrPort
+	gsoSegSize      int
+	gsoSegments     int
+	gsoMaxSegments  int
+	gsoMaxBytes     int
+	gsoFlushTimeout time.Duration
+	gsoTimer        *time.Timer
 }
 
 func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
@@ -69,7 +95,15 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
 		return nil, fmt.Errorf("unable to bind to socket: %s", err)
 	}
 
-	return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
+	return &StdConn{
+		sysFd:           fd,
+		isV4:            ip.Is4(),
+		l:               l,
+		batch:           batch,
+		gsoMaxSegments:  defaultGSOMaxSegments,
+		gsoMaxBytes:     MTU * defaultGSOMaxSegments,
+		gsoFlushTimeout: defaultGSOFlushTimeout,
+	}, err
 }
 
 func (u *StdConn) Rebind() error {
@@ -119,7 +153,10 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
 }
 
 func (u *StdConn) ListenOut(r EncReader) error {
-	var ip netip.Addr
+	var (
+		ip       netip.Addr
+		controls [][]byte
+	)
 
 	msgs, buffers, names := u.PrepareRawMessages(u.batch)
 	read := u.ReadMulti
@@ -128,6 +165,23 @@ func (u *StdConn) ListenOut(r EncReader) error {
 	}
 
 	for {
+		if u.enableGRO {
+			if controls == nil {
+				controls = make([][]byte, len(msgs))
+				for i := range controls {
+					controls[i] = make([]byte, unix.CmsgSpace(4))
+				}
+			}
+			for i := range msgs {
+				setRawMessageControl(&msgs[i], controls[i])
+			}
+		} else if controls != nil {
+			for i := range msgs {
+				setRawMessageControl(&msgs[i], nil)
+			}
+			controls = nil
+		}
+
 		n, err := read(msgs)
 		if err != nil {
 			return err
@@ -140,7 +194,23 @@ func (u *StdConn) ListenOut(r EncReader) error {
 			} else {
 				ip, _ = netip.AddrFromSlice(names[i][8:24])
 			}
-			r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
+			addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
+			payload := buffers[i][:msgs[i].Len]
+
+			if controls != nil {
+				if ctrlLen := getRawMessageControlLen(&msgs[i]); ctrlLen > 0 {
+					if segSize, segCount := parseGROControl(controls[i][:ctrlLen]); segCount > 1 && segSize > 0 {
+						segSize = normalizeGROSegSize(segSize, segCount, len(payload))
+						if segSize > 0 && segSize < len(payload) {
+							if u.emitGROSegments(r, addr, payload, segSize) {
+								continue
+							}
+						}
+					}
+				}
+			}
+
+			r(addr, payload)
 		}
 	}
 }
@@ -193,6 +263,14 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
 }
 
 func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
+	if u.enableGSO && ip.IsValid() {
+		if err := u.queueGSOPacket(b, ip); err == nil {
+			return nil
+		} else if !errors.Is(err, errGSOFallback) {
+			return err
+		}
+	}
+
 	if u.isV4 {
 		return u.writeTo4(b, ip)
 	}
@@ -299,6 +377,72 @@ func (u *StdConn) ReloadConfig(c *config.C) {
 			u.l.WithError(err).Error("Failed to set listen.so_mark")
 		}
 	}
+
+	u.configureGRO(c.GetBool("listen.enable_gro", false))
+	u.configureGSO(c)
+}
+
+func (u *StdConn) configureGRO(enable bool) {
+	if enable == u.enableGRO {
+		return
+	}
+
+	if enable {
+		if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
+			u.l.WithError(err).Warn("Failed to enable UDP GRO")
+			return
+		}
+		u.enableGRO = true
+		u.l.Info("UDP GRO enabled")
+		return
+	}
+
+	if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
+		u.l.WithError(err).Warn("Failed to disable UDP GRO")
+	}
+	u.enableGRO = false
+}
+
+func (u *StdConn) configureGSO(c *config.C) {
+	enable := c.GetBool("listen.enable_gso", false)
+	if !enable {
+		u.disableGSO()
+	} else {
+		u.enableGSO = true
+	}
+
+	segments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
+	if segments < 1 {
+		segments = 1
+	}
+	u.gsoMaxSegments = segments
+
+	maxBytes := c.GetInt("listen.gso_max_bytes", 0)
+	if maxBytes <= 0 {
+		maxBytes = MTU * segments
+	}
+	if maxBytes > maxGSOBatchBytes {
+		u.l.WithField("requested", maxBytes).Warn("listen.gso_max_bytes larger than UDP limit; clamping")
+		maxBytes = maxGSOBatchBytes
+	}
+	u.gsoMaxBytes = maxBytes
+
+	timeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout)
+	if timeout < 0 {
+		timeout = 0
+	}
+	u.gsoFlushTimeout = timeout
+}
+
+func (u *StdConn) disableGSO() {
+	u.gsoMu.Lock()
+	defer u.gsoMu.Unlock()
+	u.enableGSO = false
+	_ = u.flushGSOlocked()
+	u.gsoBuf = nil
+	u.gsoSegments = 0
+	u.gsoSegSize = 0
+	u.stopGSOTimerLocked()
 }
 
 func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
@@ -310,7 +454,227 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
 	return nil
 }
 
+func (u *StdConn) queueGSOPacket(b []byte, addr netip.AddrPort) error {
+	if len(b) == 0 {
+		return nil
+	}
+
+	u.gsoMu.Lock()
+	defer u.gsoMu.Unlock()
+
+	if !u.enableGSO || !addr.IsValid() || len(b) > u.gsoMaxBytes {
+		if err := u.flushGSOlocked(); err != nil {
+			return err
+		}
+		return errGSOFallback
+	}
+
+	if u.gsoSegments == 0 {
+		if cap(u.gsoBuf) < u.gsoMaxBytes {
+			u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
+		}
+		u.gsoAddr = addr
+		u.gsoSegSize = len(b)
+	} else if addr != u.gsoAddr || len(b) != u.gsoSegSize {
+		if err := u.flushGSOlocked(); err != nil {
+			return err
+		}
+		if cap(u.gsoBuf) < u.gsoMaxBytes {
+			u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
+		}
+		u.gsoAddr = addr
+		u.gsoSegSize = len(b)
+	}
+
+	if len(u.gsoBuf)+len(b) > u.gsoMaxBytes {
+		if err := u.flushGSOlocked(); err != nil {
+			return err
+		}
+		if cap(u.gsoBuf) < u.gsoMaxBytes {
+			u.gsoBuf = make([]byte, 0, u.gsoMaxBytes)
+		}
+		u.gsoAddr = addr
+		u.gsoSegSize = len(b)
+	}
+
+	u.gsoBuf = append(u.gsoBuf, b...)
+	u.gsoSegments++
+
+	if u.gsoSegments >= u.gsoMaxSegments || u.gsoFlushTimeout <= 0 {
+		return u.flushGSOlocked()
+	}
+
+	u.scheduleGSOFlushLocked()
+	return nil
+}
+
+func (u *StdConn) flushGSOlocked() error {
+	if u.gsoSegments == 0 {
+		u.stopGSOTimerLocked()
+		return nil
+	}
+
+	payload := append([]byte(nil), u.gsoBuf...)
+	addr := u.gsoAddr
+	segSize := u.gsoSegSize
+
+	u.gsoBuf = u.gsoBuf[:0]
+	u.gsoSegments = 0
+	u.gsoSegSize = 0
+	u.stopGSOTimerLocked()
+
+	if segSize <= 0 {
+		return errGSOFallback
+	}
+
+	err := u.sendSegmented(payload, addr, segSize)
+	if errors.Is(err, errGSODisabled) {
+		u.l.WithField("addr", addr).Warn("UDP GSO disabled by kernel, falling back to sendto")
+		u.enableGSO = false
+		return u.sendSegmentsIndividually(payload, addr, segSize)
+	}
+
+	return err
+}
+
+func (u *StdConn) sendSegmented(payload []byte, addr netip.AddrPort, segSize int) error {
+	if len(payload) == 0 {
+		return nil
+	}
+
+	control := make([]byte, unix.CmsgSpace(2))
+	hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+	hdr.Level = unix.SOL_UDP
+	hdr.Type = unix.UDP_SEGMENT
+	setCmsgLen(hdr, unix.CmsgLen(2))
+	binary.LittleEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
+
+	var sa unix.Sockaddr
+	if addr.Addr().Is4() {
+		var sa4 unix.SockaddrInet4
+		sa4.Port = int(addr.Port())
+		sa4.Addr = addr.Addr().As4()
+		sa = &sa4
+	} else {
+		var sa6 unix.SockaddrInet6
+		sa6.Port = int(addr.Port())
+		sa6.Addr = addr.Addr().As16()
+		sa = &sa6
+	}
+
+	if _, err := unix.SendmsgN(u.sysFd, payload, control, sa, 0); err != nil {
+		if errno, ok := err.(syscall.Errno); ok && (errno == unix.EINVAL || errno == unix.ENOTSUP || errno == unix.EOPNOTSUPP) {
+			return errGSODisabled
+		}
+		return &net.OpError{Op: "sendmsg", Err: err}
+	}
+	return nil
+}
+
+func (u *StdConn) sendSegmentsIndividually(buf []byte, addr netip.AddrPort, segSize int) error {
+	if segSize <= 0 {
+		return errGSOFallback
+	}
+
+	for offset := 0; offset < len(buf); offset += segSize {
+		end := offset + segSize
+		if end > len(buf) {
+			end = len(buf)
+		}
+		var err error
+		if u.isV4 {
+			err = u.writeTo4(buf[offset:end], addr)
+		} else {
+			err = u.writeTo6(buf[offset:end], addr)
+		}
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (u *StdConn) scheduleGSOFlushLocked() {
+	if u.gsoTimer == nil {
+		u.gsoTimer = time.AfterFunc(u.gsoFlushTimeout, u.gsoFlushTimer)
+		return
+	}
+	u.gsoTimer.Reset(u.gsoFlushTimeout)
+}
+
+func (u *StdConn) stopGSOTimerLocked() {
+	if u.gsoTimer != nil {
+		u.gsoTimer.Stop()
+		u.gsoTimer = nil
+	}
+}
+
+func (u *StdConn) gsoFlushTimer() {
+	u.gsoMu.Lock()
+	defer u.gsoMu.Unlock()
+	_ = u.flushGSOlocked()
+}
+
+func parseGROControl(control []byte) (int, int) {
+	if len(control) == 0 {
+		return 0, 0
+	}
+
+	cmsgs, err := unix.ParseSocketControlMessage(control)
+	if err != nil {
+		return 0, 0
+	}
+
+	for _, c := range cmsgs {
+		if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
+			segSize := int(binary.LittleEndian.Uint16(c.Data[:2]))
+			segCount := 0
+			if len(c.Data) >= 4 {
+				segCount = int(binary.LittleEndian.Uint16(c.Data[2:4]))
+			}
+			return segSize, segCount
+		}
+	}
+
+	return 0, 0
+}
+
+func (u *StdConn) emitGROSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize int) bool {
+	if segSize <= 0 || segSize >= len(payload) {
+		return false
+	}
+
+	for offset := 0; offset < len(payload); offset += segSize {
+		end := offset + segSize
+		if end > len(payload) {
+			end = len(payload)
+		}
+		r(addr, payload[offset:end])
+	}
+	return true
+}
+
+func normalizeGROSegSize(segSize, segCount, total int) int {
+	if segCount > 1 && total > 0 {
+		avg := total / segCount
+		if avg > 0 {
+			if segSize > avg {
+				if segSize-8 == avg {
+					segSize = avg
+				} else if segSize > total {
+					segSize = avg
+				}
+			}
+		}
+	}
+	if segSize > total {
+		segSize = total
+	}
+	return segSize
+}
+
 func (u *StdConn) Close() error {
+	u.disableGSO()
 	return syscall.Close(u.sysFd)
 }
 

+ 18 - 0
udp/udp_linux_32.go

@@ -52,3 +52,21 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
 
 	return msgs, buffers, names
 }
+
+func setRawMessageControl(msg *rawMessage, buf []byte) {
+	if len(buf) == 0 {
+		msg.Hdr.Control = nil
+		msg.Hdr.Controllen = 0
+		return
+	}
+	msg.Hdr.Control = &buf[0]
+	msg.Hdr.Controllen = uint32(len(buf))
+}
+
+func getRawMessageControlLen(msg *rawMessage) int {
+	return int(msg.Hdr.Controllen)
+}
+
+func setCmsgLen(h *unix.Cmsghdr, l int) {
+	h.Len = uint32(l)
+}

+ 18 - 0
udp/udp_linux_64.go

@@ -55,3 +55,21 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
 
 	return msgs, buffers, names
 }
+
+func setRawMessageControl(msg *rawMessage, buf []byte) {
+	if len(buf) == 0 {
+		msg.Hdr.Control = nil
+		msg.Hdr.Controllen = 0
+		return
+	}
+	msg.Hdr.Control = &buf[0]
+	msg.Hdr.Controllen = uint64(len(buf))
+}
+
+func getRawMessageControlLen(msg *rawMessage) int {
+	return int(msg.Hdr.Controllen)
+}
+
+func setCmsgLen(h *unix.Cmsghdr, l int) {
+	h.Len = uint64(l)
+}