Browse Source

works properly

Ryan 1 month ago
parent
commit
a0f8cb2098
3 changed files with 143 additions and 29 deletions
  1. 125 25
      udp/udp_linux.go
  2. 9 2
      udp/udp_linux_32.go
  3. 9 2
      udp/udp_linux_64.go

+ 125 - 25
udp/udp_linux.go

@@ -23,9 +23,10 @@ import (
 var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
 var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
 
 
 const (
 const (
-	defaultGSOMaxSegments  = 8
-	defaultGSOFlushTimeout = 150 * time.Microsecond
-	maxGSOBatchBytes       = 0xFFFF
+	defaultGSOMaxSegments    = 8
+	defaultGSOFlushTimeout   = 150 * time.Microsecond
+	defaultGROReadBufferSize = MTU * defaultGSOMaxSegments
+	maxGSOBatchBytes         = 0xFFFF
 )
 )
 
 
 var (
 var (
@@ -51,6 +52,8 @@ type StdConn struct {
 	gsoMaxBytes     int
 	gsoMaxBytes     int
 	gsoFlushTimeout time.Duration
 	gsoFlushTimeout time.Duration
 	gsoTimer        *time.Timer
 	gsoTimer        *time.Timer
+
+	groBufSize int
 }
 }
 
 
 func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
@@ -103,6 +106,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
 		gsoMaxSegments:  defaultGSOMaxSegments,
 		gsoMaxSegments:  defaultGSOMaxSegments,
 		gsoMaxBytes:     MTU * defaultGSOMaxSegments,
 		gsoMaxBytes:     MTU * defaultGSOMaxSegments,
 		gsoFlushTimeout: defaultGSOFlushTimeout,
 		gsoFlushTimeout: defaultGSOFlushTimeout,
+		groBufSize:      MTU,
 	}, err
 	}, err
 }
 }
 
 
@@ -158,13 +162,20 @@ func (u *StdConn) ListenOut(r EncReader) error {
 		controls [][]byte
 		controls [][]byte
 	)
 	)
 
 
-	msgs, buffers, names := u.PrepareRawMessages(u.batch)
+	bufSize := u.readBufferSize()
+	msgs, buffers, names := u.PrepareRawMessages(u.batch, bufSize)
 	read := u.ReadMulti
 	read := u.ReadMulti
 	if u.batch == 1 {
 	if u.batch == 1 {
 		read = u.ReadSingle
 		read = u.ReadSingle
 	}
 	}
 
 
 	for {
 	for {
+		desired := u.readBufferSize()
+		if len(buffers) == 0 || cap(buffers[0]) < desired {
+			msgs, buffers, names = u.PrepareRawMessages(u.batch, desired)
+			controls = nil
+		}
+
 		if u.enableGRO {
 		if u.enableGRO {
 			if controls == nil {
 			if controls == nil {
 				controls = make([][]byte, len(msgs))
 				controls = make([][]byte, len(msgs))
@@ -197,9 +208,44 @@ func (u *StdConn) ListenOut(r EncReader) error {
 			addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
 			addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
 			payload := buffers[i][:msgs[i].Len]
 			payload := buffers[i][:msgs[i].Len]
 
 
+			if u.enableGRO && u.l.IsLevelEnabled(logrus.DebugLevel) {
+				ctrlLen := getRawMessageControlLen(&msgs[i])
+				msgFlags := getRawMessageFlags(&msgs[i])
+				u.l.WithFields(logrus.Fields{
+					"tag":         "gro-debug",
+					"stage":       "recv",
+					"payload_len": len(payload),
+					"ctrl_len":    ctrlLen,
+					"msg_flags":   msgFlags,
+				}).Debug("gro batch data")
+				if controls != nil && ctrlLen > 0 {
+					maxDump := ctrlLen
+					if maxDump > 16 {
+						maxDump = 16
+					}
+					u.l.WithFields(logrus.Fields{
+						"tag":         "gro-debug",
+						"stage":       "control-bytes",
+						"control_hex": fmt.Sprintf("%x", controls[i][:maxDump]),
+						"datalen":     ctrlLen,
+					}).Debug("gro control dump")
+				}
+			}
+
+			sawControl := false
 			if controls != nil {
 			if controls != nil {
 				if ctrlLen := getRawMessageControlLen(&msgs[i]); ctrlLen > 0 {
 				if ctrlLen := getRawMessageControlLen(&msgs[i]); ctrlLen > 0 {
-					if segSize, segCount := parseGROControl(controls[i][:ctrlLen]); segCount > 1 && segSize > 0 {
+					if segSize, segCount := parseGROControl(controls[i][:ctrlLen]); segSize > 0 {
+						sawControl = true
+						if u.l.IsLevelEnabled(logrus.DebugLevel) {
+							u.l.WithFields(logrus.Fields{
+								"tag":        "gro-debug",
+								"stage":      "control",
+								"seg_size":   segSize,
+								"seg_count":  segCount,
+								"payloadLen": len(payload),
+							}).Debug("gro control parsed")
+						}
 						segSize = normalizeGROSegSize(segSize, segCount, len(payload))
 						segSize = normalizeGROSegSize(segSize, segCount, len(payload))
 						if segSize > 0 && segSize < len(payload) {
 						if segSize > 0 && segSize < len(payload) {
 							if u.emitGROSegments(r, addr, payload, segSize) {
 							if u.emitGROSegments(r, addr, payload, segSize) {
@@ -210,11 +256,31 @@ func (u *StdConn) ListenOut(r EncReader) error {
 				}
 				}
 			}
 			}
 
 
+			if u.enableGRO && len(payload) > MTU {
+				if !sawControl && u.l.IsLevelEnabled(logrus.DebugLevel) {
+					u.l.WithFields(logrus.Fields{
+						"tag":         "gro-debug",
+						"stage":       "fallback",
+						"payload_len": len(payload),
+					}).Debug("gro control missing; splitting payload by MTU")
+				}
+				if u.emitGROSegments(r, addr, payload, MTU) {
+					continue
+				}
+			}
+
 			r(addr, payload)
 			r(addr, payload)
 		}
 		}
 	}
 	}
 }
 }
 
 
+func (u *StdConn) readBufferSize() int {
+	if u.enableGRO && u.groBufSize > MTU {
+		return u.groBufSize
+	}
+	return MTU
+}
+
 func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
 func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
 	for {
 	for {
 		n, _, err := unix.Syscall6(
 		n, _, err := unix.Syscall6(
@@ -378,12 +444,22 @@ func (u *StdConn) ReloadConfig(c *config.C) {
 		}
 		}
 	}
 	}
 
 
-	u.configureGRO(c.GetBool("listen.enable_gro", false))
+	u.configureGRO(c)
 	u.configureGSO(c)
 	u.configureGSO(c)
 }
 }
 
 
-func (u *StdConn) configureGRO(enable bool) {
+func (u *StdConn) configureGRO(c *config.C) {
+	if c == nil {
+		return
+	}
+
+	enable := c.GetBool("listen.enable_gro", false)
 	if enable == u.enableGRO {
 	if enable == u.enableGRO {
+		if enable {
+			if size := c.GetInt("listen.gro_read_buffer", 0); size > 0 {
+				u.setGROBufferSize(size)
+			}
+		}
 		return
 		return
 	}
 	}
 
 
@@ -393,7 +469,8 @@ func (u *StdConn) configureGRO(enable bool) {
 			return
 			return
 		}
 		}
 		u.enableGRO = true
 		u.enableGRO = true
-		u.l.Info("UDP GRO enabled")
+		u.setGROBufferSize(c.GetInt("listen.gro_read_buffer", defaultGROReadBufferSize))
+		u.l.WithField("buffer_size", u.groBufSize).Info("UDP GRO enabled")
 		return
 		return
 	}
 	}
 
 
@@ -401,6 +478,7 @@ func (u *StdConn) configureGRO(enable bool) {
 		u.l.WithError(err).Warn("Failed to disable UDP GRO")
 		u.l.WithError(err).Warn("Failed to disable UDP GRO")
 	}
 	}
 	u.enableGRO = false
 	u.enableGRO = false
+	u.groBufSize = MTU
 }
 }
 
 
 func (u *StdConn) configureGSO(c *config.C) {
 func (u *StdConn) configureGSO(c *config.C) {
@@ -434,6 +512,16 @@ func (u *StdConn) configureGSO(c *config.C) {
 	u.gsoFlushTimeout = timeout
 	u.gsoFlushTimeout = timeout
 }
 }
 
 
+func (u *StdConn) setGROBufferSize(size int) {
+	if size < MTU {
+		size = defaultGROReadBufferSize
+	}
+	if size > maxGSOBatchBytes {
+		size = maxGSOBatchBytes
+	}
+	u.groBufSize = size
+}
+
 func (u *StdConn) disableGSO() {
 func (u *StdConn) disableGSO() {
 	u.gsoMu.Lock()
 	u.gsoMu.Lock()
 	defer u.gsoMu.Unlock()
 	defer u.gsoMu.Unlock()
@@ -547,7 +635,7 @@ func (u *StdConn) sendSegmented(payload []byte, addr netip.AddrPort, segSize int
 	hdr.Level = unix.SOL_UDP
 	hdr.Level = unix.SOL_UDP
 	hdr.Type = unix.UDP_SEGMENT
 	hdr.Type = unix.UDP_SEGMENT
 	setCmsgLen(hdr, unix.CmsgLen(2))
 	setCmsgLen(hdr, unix.CmsgLen(2))
-	binary.LittleEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
+	binary.NativeEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
 
 
 	var sa unix.Sockaddr
 	var sa unix.Sockaddr
 	if addr.Addr().Is4() {
 	if addr.Addr().Is4() {
@@ -627,10 +715,10 @@ func parseGROControl(control []byte) (int, int) {
 
 
 	for _, c := range cmsgs {
 	for _, c := range cmsgs {
 		if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
 		if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
-			segSize := int(binary.LittleEndian.Uint16(c.Data[:2]))
+			segSize := int(binary.NativeEndian.Uint16(c.Data[:2]))
 			segCount := 0
 			segCount := 0
 			if len(c.Data) >= 4 {
 			if len(c.Data) >= 4 {
-				segCount = int(binary.LittleEndian.Uint16(c.Data[2:4]))
+				segCount = int(binary.NativeEndian.Uint16(c.Data[2:4]))
 			}
 			}
 			return segSize, segCount
 			return segSize, segCount
 		}
 		}
@@ -640,7 +728,7 @@ func parseGROControl(control []byte) (int, int) {
 }
 }
 
 
 func (u *StdConn) emitGROSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize int) bool {
 func (u *StdConn) emitGROSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize int) bool {
-	if segSize <= 0 || segSize >= len(payload) {
+	if segSize <= 0 {
 		return false
 		return false
 	}
 	}
 
 
@@ -649,27 +737,39 @@ func (u *StdConn) emitGROSegments(r EncReader, addr netip.AddrPort, payload []by
 		if end > len(payload) {
 		if end > len(payload) {
 			end = len(payload)
 			end = len(payload)
 		}
 		}
-		r(addr, payload[offset:end])
+		segment := make([]byte, end-offset)
+		copy(segment, payload[offset:end])
+		r(addr, segment)
 	}
 	}
 	return true
 	return true
 }
 }
 
 
 func normalizeGROSegSize(segSize, segCount, total int) int {
 func normalizeGROSegSize(segSize, segCount, total int) int {
-	if segCount > 1 && total > 0 {
-		avg := total / segCount
-		if avg > 0 {
-			if segSize > avg {
-				if segSize-8 == avg {
-					segSize = avg
-				} else if segSize > total {
-					segSize = avg
-				}
-			}
+	if segSize <= 0 || total <= 0 {
+		return segSize
+	}
+
+	if segSize > total && segCount > 0 {
+		segSize = total / segCount
+		if segSize == 0 {
+			segSize = total
 		}
 		}
 	}
 	}
-	if segSize > total {
-		segSize = total
+
+	if segCount <= 1 && segSize > 0 && total > segSize {
+		calculated := total / segSize
+		if calculated <= 1 {
+			calculated = (total + segSize - 1) / segSize
+		}
+		if calculated > 1 {
+			segCount = calculated
+		}
+	}
+
+	if segSize > MTU {
+		return MTU
 	}
 	}
+
 	return segSize
 	return segSize
 }
 }
 
 

+ 9 - 2
udp/udp_linux_32.go

@@ -30,13 +30,16 @@ type rawMessage struct {
 	Len uint32
 	Len uint32
 }
 }
 
 
-func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
+func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
+	if bufSize <= 0 {
+		bufSize = MTU
+	}
 	msgs := make([]rawMessage, n)
 	msgs := make([]rawMessage, n)
 	buffers := make([][]byte, n)
 	buffers := make([][]byte, n)
 	names := make([][]byte, n)
 	names := make([][]byte, n)
 
 
 	for i := range msgs {
 	for i := range msgs {
-		buffers[i] = make([]byte, MTU)
+		buffers[i] = make([]byte, bufSize)
 		names[i] = make([]byte, unix.SizeofSockaddrInet6)
 		names[i] = make([]byte, unix.SizeofSockaddrInet6)
 
 
 		vs := []iovec{
 		vs := []iovec{
@@ -67,6 +70,10 @@ func getRawMessageControlLen(msg *rawMessage) int {
 	return int(msg.Hdr.Controllen)
 	return int(msg.Hdr.Controllen)
 }
 }
 
 
+func getRawMessageFlags(msg *rawMessage) int {
+	return int(msg.Hdr.Flags)
+}
+
 func setCmsgLen(h *unix.Cmsghdr, l int) {
 func setCmsgLen(h *unix.Cmsghdr, l int) {
 	h.Len = uint32(l)
 	h.Len = uint32(l)
 }
 }

+ 9 - 2
udp/udp_linux_64.go

@@ -33,13 +33,16 @@ type rawMessage struct {
 	Pad0 [4]byte
 	Pad0 [4]byte
 }
 }
 
 
-func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
+func (u *StdConn) PrepareRawMessages(n int, bufSize int) ([]rawMessage, [][]byte, [][]byte) {
+	if bufSize <= 0 {
+		bufSize = MTU
+	}
 	msgs := make([]rawMessage, n)
 	msgs := make([]rawMessage, n)
 	buffers := make([][]byte, n)
 	buffers := make([][]byte, n)
 	names := make([][]byte, n)
 	names := make([][]byte, n)
 
 
 	for i := range msgs {
 	for i := range msgs {
-		buffers[i] = make([]byte, MTU)
+		buffers[i] = make([]byte, bufSize)
 		names[i] = make([]byte, unix.SizeofSockaddrInet6)
 		names[i] = make([]byte, unix.SizeofSockaddrInet6)
 
 
 		vs := []iovec{
 		vs := []iovec{
@@ -70,6 +73,10 @@ func getRawMessageControlLen(msg *rawMessage) int {
 	return int(msg.Hdr.Controllen)
 	return int(msg.Hdr.Controllen)
 }
 }
 
 
+func getRawMessageFlags(msg *rawMessage) int {
+	return int(msg.Hdr.Flags)
+}
+
 func setCmsgLen(h *unix.Cmsghdr, l int) {
 func setCmsgLen(h *unix.Cmsghdr, l int) {
 	h.Len = uint64(l)
 	h.Len = uint64(l)
 }
 }