JackDoan 1 month ago
parent
commit
befba57366
3 changed files with 380 additions and 354 deletions
  1. 325 291
      wgstack/conn/bind_std.go
  2. 3 0
      wgstack/conn/controlfns_linux.go
  3. 52 63
      wgstack/conn/sticky_linux.go

+ 325 - 291
wgstack/conn/bind_std.go

@@ -1,12 +1,14 @@
-// SPDX-License-Identifier: MIT
-//
-// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
 
 package conn
 
 import (
 	"context"
 	"errors"
+	"fmt"
 	"net"
 	"net/netip"
 	"runtime"
@@ -16,7 +18,6 @@ import (
 
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv6"
-	"golang.org/x/sys/unix"
 )
 
 var (
@@ -29,28 +30,53 @@ var (
 // methods for sending and receiving multiple datagrams per-syscall. See the
 // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
 type StdNetBind struct {
-	mu     sync.Mutex // protects all fields except as specified
-	ipv4   *net.UDPConn
-	ipv6   *net.UDPConn
-	ipv4PC *ipv4.PacketConn // will be nil on non-Linux
-	ipv6PC *ipv6.PacketConn // will be nil on non-Linux
-
-	// these three fields are not guarded by mu
-	udpAddrPool  sync.Pool
-	ipv4MsgsPool sync.Pool
-	ipv6MsgsPool sync.Pool
+	mu            sync.Mutex // protects all fields except as specified
+	ipv4          *net.UDPConn
+	ipv6          *net.UDPConn
+	ipv4PC        *ipv4.PacketConn // will be nil on non-Linux
+	ipv6PC        *ipv6.PacketConn // will be nil on non-Linux
+	ipv4TxOffload bool
+	ipv4RxOffload bool
+	ipv6TxOffload bool
+	ipv6RxOffload bool
+
+	// these two fields are not guarded by mu
+	udpAddrPool sync.Pool
+	msgsPool    sync.Pool
 
 	blackhole4 bool
 	blackhole6 bool
+}
+
+// NewStdNetBind creates a bind that listens on all interfaces.
+func NewStdNetBind() *StdNetBind {
+	return newStdNetBind().(*StdNetBind)
+}
+
+// NewStdNetBindForAddr creates a bind that listens on a specific address.
+// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
+// IPv6 socket will be created.
+func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
+	b := NewStdNetBind()
+	//if addr.IsValid() {
+	//	if addr.IsUnspecified() {
+	//		// keep dual-stack defaults with empty listen addresses
+	//	} else if addr.Is4() {
+	//		b.listenAddr4 = addr.Unmap().String()
+	//		b.bindV4 = true
+	//		b.bindV6 = false
+	//	} else {
+	//		b.listenAddr6 = addr.Unmap().String()
+	//		b.bindV6 = true
+	//		b.bindV4 = false
+	//	}
+	//}
+	//b.reusePort = reusePort
 
-	listenAddr4 string
-	listenAddr6 string
-	bindV4      bool
-	bindV6      bool
-	reusePort   bool
+	return b
 }
 
-func newStdNetBind() *StdNetBind {
+func newStdNetBind() Bind {
 	return &StdNetBind{
 		udpAddrPool: sync.Pool{
 			New: func() any {
@@ -60,68 +86,28 @@ func newStdNetBind() *StdNetBind {
 			},
 		},
 
-		ipv4MsgsPool: sync.Pool{
-			New: func() any {
-				msgs := make([]ipv4.Message, IdealBatchSize)
-				for i := range msgs {
-					msgs[i].Buffers = make(net.Buffers, 1)
-					msgs[i].OOB = make([]byte, srcControlSize)
-				}
-				return &msgs
-			},
-		},
-
-		ipv6MsgsPool: sync.Pool{
+		msgsPool: sync.Pool{
 			New: func() any {
+				// ipv6.Message and ipv4.Message are interchangeable as they are
+				// both aliases for x/net/internal/socket.Message.
 				msgs := make([]ipv6.Message, IdealBatchSize)
 				for i := range msgs {
 					msgs[i].Buffers = make(net.Buffers, 1)
-					msgs[i].OOB = make([]byte, srcControlSize)
+					msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
 				}
 				return &msgs
 			},
 		},
-		bindV4:    true,
-		bindV6:    true,
-		reusePort: false,
 	}
 }
 
-// NewStdNetBind creates a bind that listens on all interfaces.
-func NewStdNetBind() *StdNetBind {
-	return newStdNetBind()
-}
-
-// NewStdNetBindForAddr creates a bind that listens on a specific address.
-// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
-// IPv6 socket will be created.
-func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
-	b := newStdNetBind()
-	if addr.IsValid() {
-		if addr.IsUnspecified() {
-			// keep dual-stack defaults with empty listen addresses
-		} else if addr.Is4() {
-			b.listenAddr4 = addr.Unmap().String()
-			b.bindV4 = true
-			b.bindV6 = false
-		} else {
-			b.listenAddr6 = addr.Unmap().String()
-			b.bindV6 = true
-			b.bindV4 = false
-		}
-	}
-	b.reusePort = reusePort
-	return b
-}
-
 type StdNetEndpoint struct {
 	// AddrPort is the endpoint destination.
 	netip.AddrPort
-	// src is the current sticky source address and interface index, if supported.
-	src struct {
-		netip.Addr
-		ifidx int32
-	}
+	// src is the current sticky source address and interface index, if
+	// supported. Typically this is a PKTINFO structure from/for control
+	// messages, see unix.PKTINFO for an example.
+	src []byte
 }
 
 var (
@@ -140,21 +126,17 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
 }
 
 func (e *StdNetEndpoint) ClearSrc() {
-	e.src.ifidx = 0
-	e.src.Addr = netip.Addr{}
+	if e.src != nil {
+		// Truncate src, no need to reallocate.
+		e.src = e.src[:0]
+	}
 }
 
 func (e *StdNetEndpoint) DstIP() netip.Addr {
 	return e.AddrPort.Addr()
 }
 
-func (e *StdNetEndpoint) SrcIP() netip.Addr {
-	return e.src.Addr
-}
-
-func (e *StdNetEndpoint) SrcIfidx() int32 {
-	return e.src.ifidx
-}
+// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
 
 func (e *StdNetEndpoint) DstToBytes() []byte {
 	b, _ := e.AddrPort.MarshalBinary()
@@ -165,32 +147,8 @@ func (e *StdNetEndpoint) DstToString() string {
 	return e.AddrPort.String()
 }
 
-func (e *StdNetEndpoint) SrcToString() string {
-	return e.src.Addr.String()
-}
-
-func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPConn, int, error) {
-	lc := listenConfig()
-	if s.reusePort {
-		base := lc.Control
-		lc.Control = func(network, address string, c syscall.RawConn) error {
-			if base != nil {
-				if err := base(network, address, c); err != nil {
-					return err
-				}
-			}
-			return c.Control(func(fd uintptr) {
-				_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
-			})
-		}
-	}
-
-	addr := ":" + strconv.Itoa(port)
-	if host != "" {
-		addr = net.JoinHostPort(host, strconv.Itoa(port))
-	}
-
-	conn, err := lc.ListenPacket(context.Background(), network, addr)
+func listenNet(network string, port int) (*net.UDPConn, int, error) {
+	conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
 	if err != nil {
 		return nil, 0, err
 	}
@@ -204,45 +162,8 @@ func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPC
 	if err != nil {
 		return nil, 0, err
 	}
-	return conn.(*net.UDPConn), uaddr.Port, nil
-}
-
-func (s *StdNetBind) openIPv4(port int) (*net.UDPConn, *ipv4.PacketConn, int, error) {
-	if !s.bindV4 {
-		return nil, nil, port, nil
-	}
-	host := s.listenAddr4
-	conn, actualPort, err := s.listenNet("udp4", host, port)
-	if err != nil {
-		if errors.Is(err, syscall.EAFNOSUPPORT) {
-			return nil, nil, port, nil
-		}
-		return nil, nil, port, err
-	}
-	if runtime.GOOS != "linux" {
-		return conn, nil, actualPort, nil
-	}
-	pc := ipv4.NewPacketConn(conn)
-	return conn, pc, actualPort, nil
-}
 
-func (s *StdNetBind) openIPv6(port int) (*net.UDPConn, *ipv6.PacketConn, int, error) {
-	if !s.bindV6 {
-		return nil, nil, port, nil
-	}
-	host := s.listenAddr6
-	conn, actualPort, err := s.listenNet("udp6", host, port)
-	if err != nil {
-		if errors.Is(err, syscall.EAFNOSUPPORT) {
-			return nil, nil, port, nil
-		}
-		return nil, nil, port, err
-	}
-	if runtime.GOOS != "linux" {
-		return conn, nil, actualPort, nil
-	}
-	pc := ipv6.NewPacketConn(conn)
-	return conn, pc, actualPort, nil
+	return conn.(*net.UDPConn), uaddr.Port, nil
 }
 
 func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
@@ -260,46 +181,44 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
 	// If uport is 0, we can retry on failure.
 again:
 	port := int(uport)
-	var v4conn *net.UDPConn
-	var v6conn *net.UDPConn
+	var v4conn, v6conn *net.UDPConn
 	var v4pc *ipv4.PacketConn
 	var v6pc *ipv6.PacketConn
 
-	v4conn, v4pc, port, err = s.openIPv4(port)
-	if err != nil {
+	v4conn, port, err = listenNet("udp4", port)
+	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
 		return nil, 0, err
 	}
 
 	// Listen on the same port as we're using for ipv4.
-	v6conn, v6pc, port, err = s.openIPv6(port)
+	v6conn, port, err = listenNet("udp6", port)
 	if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
-		if v4conn != nil {
-			v4conn.Close()
-		}
+		v4conn.Close()
 		tries++
 		goto again
 	}
-	if err != nil {
-		if v4conn != nil {
-			v4conn.Close()
-		}
+	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
+		v4conn.Close()
 		return nil, 0, err
 	}
-
 	var fns []ReceiveFunc
 	if v4conn != nil {
-		s.ipv4 = v4conn
-		if v4pc != nil {
+		s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
+		if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+			v4pc = ipv4.NewPacketConn(v4conn)
 			s.ipv4PC = v4pc
 		}
-		fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
+		fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
+		s.ipv4 = v4conn
 	}
 	if v6conn != nil {
-		s.ipv6 = v6conn
-		if v6pc != nil {
+		s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
+		if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+			v6pc = ipv6.NewPacketConn(v6conn)
 			s.ipv6PC = v6pc
 		}
-		fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
+		fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
+		s.ipv6 = v6conn
 	}
 	if len(fns) == 0 {
 		return nil, 0, syscall.EAFNOSUPPORT
@@ -308,76 +227,101 @@ again:
 	return fns, uint16(port), nil
 }
 
-func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
-	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
-		msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
-		defer s.ipv4MsgsPool.Put(msgs)
-		for i := range bufs {
-			(*msgs)[i].Buffers[0] = bufs[i]
-		}
-		var numMsgs int
-		if runtime.GOOS == "linux" && pc != nil {
-			numMsgs, err = pc.ReadBatch(*msgs, 0)
+func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
+	for i := range *msgs {
+		(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
+		(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
+	}
+	s.msgsPool.Put(msgs)
+}
+
+func (s *StdNetBind) getMessages() *[]ipv6.Message {
+	return s.msgsPool.Get().(*[]ipv6.Message)
+}
+
+var (
+	// If compilation fails here these are no longer the same underlying type.
+	_ ipv6.Message = ipv4.Message{}
+)
+
+type batchReader interface {
+	ReadBatch([]ipv6.Message, int) (int, error)
+}
+
+type batchWriter interface {
+	WriteBatch([]ipv6.Message, int) (int, error)
+}
+
+func (s *StdNetBind) receiveIP(
+	br batchReader,
+	conn *net.UDPConn,
+	rxOffload bool,
+	bufs [][]byte,
+	sizes []int,
+	eps []Endpoint,
+) (n int, err error) {
+	msgs := s.getMessages()
+	for i := range bufs {
+		(*msgs)[i].Buffers[0] = bufs[i]
+		(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
+	}
+	defer s.putMessages(msgs)
+	var numMsgs int
+	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
+		if rxOffload {
+			readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
+			numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
+			if err != nil {
+				return 0, err
+			}
+			numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
 			if err != nil {
 				return 0, err
 			}
 		} else {
-			msg := &(*msgs)[0]
-			msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+			numMsgs, err = br.ReadBatch(*msgs, 0)
 			if err != nil {
 				return 0, err
 			}
-			numMsgs = 1
 		}
-		for i := 0; i < numMsgs; i++ {
-			msg := &(*msgs)[i]
-			sizes[i] = msg.N
-			addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
-			ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
-			getSrcFromControl(msg.OOB[:msg.NN], ep)
-			eps[i] = ep
+	} else {
+		msg := &(*msgs)[0]
+		msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+		if err != nil {
+			return 0, err
 		}
-		return numMsgs, nil
+		numMsgs = 1
 	}
+	for i := 0; i < numMsgs; i++ {
+		msg := &(*msgs)[i]
+		sizes[i] = msg.N
+		if sizes[i] == 0 {
+			continue
+		}
+		addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+		ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
+		getSrcFromControl(msg.OOB[:msg.NN], ep)
+		eps[i] = ep
+	}
+	return numMsgs, nil
 }
 
-func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
+func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
 	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
-		msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
-		defer s.ipv6MsgsPool.Put(msgs)
-		for i := range bufs {
-			(*msgs)[i].Buffers[0] = bufs[i]
-		}
-		var numMsgs int
-		if runtime.GOOS == "linux" && pc != nil {
-			numMsgs, err = pc.ReadBatch(*msgs, 0)
-			if err != nil {
-				return 0, err
-			}
-		} else {
-			msg := &(*msgs)[0]
-			msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
-			if err != nil {
-				return 0, err
-			}
-			numMsgs = 1
-		}
-		for i := 0; i < numMsgs; i++ {
-			msg := &(*msgs)[i]
-			sizes[i] = msg.N
-			addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
-			ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
-			getSrcFromControl(msg.OOB[:msg.NN], ep)
-			eps[i] = ep
-		}
-		return numMsgs, nil
+		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
+	}
+}
+
+func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
+	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
 	}
 }
 
 // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
 // rename the IdealBatchSize constant to BatchSize.
 func (s *StdNetBind) BatchSize() int {
-	if runtime.GOOS == "linux" {
+	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
 		return IdealBatchSize
 	}
 	return 1
@@ -400,28 +344,42 @@ func (s *StdNetBind) Close() error {
 	}
 	s.blackhole4 = false
 	s.blackhole6 = false
+	s.ipv4TxOffload = false
+	s.ipv4RxOffload = false
+	s.ipv6TxOffload = false
+	s.ipv6RxOffload = false
 	if err1 != nil {
 		return err1
 	}
 	return err2
 }
 
+type ErrUDPGSODisabled struct {
+	onLaddr  string
+	RetryErr error
+}
+
+func (e ErrUDPGSODisabled) Error() string {
+	return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
+}
+
+func (e ErrUDPGSODisabled) Unwrap() error {
+	return e.RetryErr
+}
+
 func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
 	s.mu.Lock()
 	blackhole := s.blackhole4
 	conn := s.ipv4
-	var (
-		pc4 *ipv4.PacketConn
-		pc6 *ipv6.PacketConn
-	)
+	offload := s.ipv4TxOffload
+	br := batchWriter(s.ipv4PC)
 	is6 := false
 	if endpoint.DstIP().Is6() {
 		blackhole = s.blackhole6
 		conn = s.ipv6
-		pc6 = s.ipv6PC
+		br = s.ipv6PC
 		is6 = true
-	} else {
-		pc4 = s.ipv4PC
+		offload = s.ipv6TxOffload
 	}
 	s.mu.Unlock()
 
@@ -431,109 +389,185 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
 	if conn == nil {
 		return syscall.EAFNOSUPPORT
 	}
+
+	msgs := s.getMessages()
+	defer s.putMessages(msgs)
+	ua := s.udpAddrPool.Get().(*net.UDPAddr)
+	defer s.udpAddrPool.Put(ua)
 	if is6 {
-		return s.send6(conn, pc6, endpoint, bufs)
+		as16 := endpoint.DstIP().As16()
+		copy(ua.IP, as16[:])
+		ua.IP = ua.IP[:16]
 	} else {
-		return s.send4(conn, pc4, endpoint, bufs)
+		as4 := endpoint.DstIP().As4()
+		copy(ua.IP, as4[:])
+		ua.IP = ua.IP[:4]
 	}
+	ua.Port = int(endpoint.(*StdNetEndpoint).Port())
+	var (
+		retried bool
+		err     error
+	)
+retry:
+	if offload {
+		n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
+		err = s.send(conn, br, (*msgs)[:n])
+		if err != nil && offload && errShouldDisableUDPGSO(err) {
+			offload = false
+			s.mu.Lock()
+			if is6 {
+				s.ipv6TxOffload = false
+			} else {
+				s.ipv4TxOffload = false
+			}
+			s.mu.Unlock()
+			retried = true
+			goto retry
+		}
+	} else {
+		for i := range bufs {
+			(*msgs)[i].Addr = ua
+			(*msgs)[i].Buffers[0] = bufs[i]
+			setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
+		}
+		err = s.send(conn, br, (*msgs)[:len(bufs)])
+	}
+	if retried {
+		return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
+	}
+	return err
 }
 
-func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
-	ua := s.udpAddrPool.Get().(*net.UDPAddr)
-	as4 := ep.DstIP().As4()
-	copy(ua.IP, as4[:])
-	ua.IP = ua.IP[:4]
-	ua.Port = int(ep.(*StdNetEndpoint).Port())
-	msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
-	for i, buf := range bufs {
-		(*msgs)[i].Buffers[0] = buf
-		(*msgs)[i].Addr = ua
-		setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
-	}
+func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
 	var (
 		n     int
 		err   error
 		start int
 	)
-	if runtime.GOOS == "linux" && pc != nil {
+	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
 		for {
-			n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
-			if err != nil {
-				if errors.Is(err, syscall.EAFNOSUPPORT) {
-					for j := start; j < len(bufs); j++ {
-						_, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
-						if werr != nil {
-							err = werr
-							break
-						}
-					}
-				}
-				break
-			}
-			if n == len((*msgs)[start:len(bufs)]) {
+			n, err = pc.WriteBatch(msgs[start:], 0)
+			if err != nil || n == len(msgs[start:]) {
 				break
 			}
 			start += n
 		}
 	} else {
-		for i, buf := range bufs {
-			_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
+		for _, msg := range msgs {
+			_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
 			if err != nil {
 				break
 			}
 		}
 	}
-	s.udpAddrPool.Put(ua)
-	s.ipv4MsgsPool.Put(msgs)
 	return err
 }
 
-func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
-	ua := s.udpAddrPool.Get().(*net.UDPAddr)
-	as16 := ep.DstIP().As16()
-	copy(ua.IP, as16[:])
-	ua.IP = ua.IP[:16]
-	ua.Port = int(ep.(*StdNetEndpoint).Port())
-	msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
-	for i, buf := range bufs {
-		(*msgs)[i].Buffers[0] = buf
-		(*msgs)[i].Addr = ua
-		setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
-	}
+const (
+	// Exceeding these values results in EMSGSIZE. They account for layer3 and
+	// layer4 headers. IPv6 does not need to account for itself as the payload
+	// length field is self excluding.
+	maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
+	maxIPv6PayloadLen = 1<<16 - 1 - 8
+
+	// This is a hard limit imposed by the kernel.
+	udpSegmentMaxDatagrams = 64
+)
+
+type setGSOFunc func(control *[]byte, gsoSize uint16)
+
+func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
 	var (
-		n     int
-		err   error
-		start int
+		base     = -1 // index of msg we are currently coalescing into
+		gsoSize  int  // segmentation size of msgs[base]
+		dgramCnt int  // number of dgrams coalesced into msgs[base]
+		endBatch bool // tracking flag to start a new batch on next iteration of bufs
 	)
-	if runtime.GOOS == "linux" && pc != nil {
-		for {
-			n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
-			if err != nil {
-				if errors.Is(err, syscall.EAFNOSUPPORT) {
-					for j := start; j < len(bufs); j++ {
-						_, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
-						if werr != nil {
-							err = werr
-							break
-						}
-					}
+	maxPayloadLen := maxIPv4PayloadLen
+	if ep.DstIP().Is6() {
+		maxPayloadLen = maxIPv6PayloadLen
+	}
+	for i, buf := range bufs {
+		if i > 0 {
+			msgLen := len(buf)
+			baseLenBefore := len(msgs[base].Buffers[0])
+			freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
+			if msgLen+baseLenBefore <= maxPayloadLen &&
+				msgLen <= gsoSize &&
+				msgLen <= freeBaseCap &&
+				dgramCnt < udpSegmentMaxDatagrams &&
+				!endBatch {
+				msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
+				if i == len(bufs)-1 {
+					setGSO(&msgs[base].OOB, uint16(gsoSize))
 				}
-				break
-			}
-			if n == len((*msgs)[start:len(bufs)]) {
-				break
+				dgramCnt++
+				if msgLen < gsoSize {
+					// A smaller than gsoSize packet on the tail is legal, but
+					// it must end the batch.
+					endBatch = true
+				}
+				continue
 			}
-			start += n
 		}
-	} else {
-		for i, buf := range bufs {
-			_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
-			if err != nil {
-				break
+		if dgramCnt > 1 {
+			setGSO(&msgs[base].OOB, uint16(gsoSize))
+		}
+		// Reset prior to incrementing base since we are preparing to start a
+		// new potential batch.
+		endBatch = false
+		base++
+		gsoSize = len(buf)
+		setSrcControl(&msgs[base].OOB, ep)
+		msgs[base].Buffers[0] = buf
+		msgs[base].Addr = addr
+		dgramCnt = 1
+	}
+	return base + 1
+}
+
+type getGSOFunc func(control []byte) (int, error)
+
+func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
+	for i := firstMsgAt; i < len(msgs); i++ {
+		msg := &msgs[i]
+		if msg.N == 0 {
+			return n, err
+		}
+		var (
+			gsoSize    int
+			start      int
+			end        = msg.N
+			numToSplit = 1
+		)
+		gsoSize, err = getGSO(msg.OOB[:msg.NN])
+		if err != nil {
+			return n, err
+		}
+		if gsoSize > 0 {
+			numToSplit = (msg.N + gsoSize - 1) / gsoSize
+			end = gsoSize
+		}
+		for j := 0; j < numToSplit; j++ {
+			if n > i {
+				return n, errors.New("splitting coalesced packet resulted in overflow")
+			}
+			copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
+			msgs[n].N = copied
+			msgs[n].Addr = msg.Addr
+			start = end
+			end += gsoSize
+			if end > msg.N {
+				end = msg.N
 			}
+			n++
+		}
+		if i != n-1 {
+			// It is legal for bytes to move within msg.Buffers[0] as a result
+			// of splitting, so we only zero the source msg len when it is not
+			// the destination of the last split operation above.
+			msg.N = 0
 		}
 	}
-	s.udpAddrPool.Put(ua)
-	s.ipv6MsgsPool.Put(msgs)
-	return err
+	return n, nil
 }

+ 3 - 0
wgstack/conn/controlfns_linux.go

@@ -29,6 +29,9 @@ func init() {
 				// Set beyond *mem_max if CAP_NET_ADMIN
 				_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
 				_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
+				_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)  //todo!!!
+				_ = unix.SetsockoptInt(int(fd), unix.SOL_UDP, unix.UDP_SEGMENT, 0xffff) //todo!!!
+				//print(err.Error())
 			})
 		},
 

+ 52 - 63
wgstack/conn/sticky_linux.go

@@ -1,9 +1,3 @@
-//go:build linux && !android
-
-// SPDX-License-Identifier: MIT
-//
-// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
-
 package conn
 
 import (
@@ -13,6 +7,37 @@ import (
 	"golang.org/x/sys/unix"
 )
 
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+	switch len(e.src) {
+	case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+		info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+		return netip.AddrFrom4(info.Spec_dst)
+	case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+		info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+		// TODO: set zone. in order to do so we need to check if the address is
+		// link local, and if it is perform a syscall to turn the ifindex into a
+		// zone string because netip uses string zones.
+		return netip.AddrFrom16(info.Addr)
+	}
+	return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+	switch len(e.src) {
+	case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+		info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+		return info.Ifindex
+	case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+		info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+		return int32(info.Ifindex)
+	}
+	return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+	return e.SrcIP().String()
+}
+
 // getSrcFromControl parses the control for PKTINFO and if found updates ep with
 // the source information found.
 func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
@@ -34,83 +59,47 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
 		if hdr.Level == unix.IPPROTO_IP &&
 			hdr.Type == unix.IP_PKTINFO {
 
-			info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
-			ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
-			ep.src.ifidx = info.Ifindex
+			if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
+				ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+			}
+			ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
 
+			hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+			copy(ep.src, hdrBuf)
+			copy(ep.src[unix.CmsgLen(0):], data)
 			return
 		}
 
 		if hdr.Level == unix.IPPROTO_IPV6 &&
 			hdr.Type == unix.IPV6_PKTINFO {
 
-			info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
-			ep.src.Addr = netip.AddrFrom16(info.Addr)
-			ep.src.ifidx = int32(info.Ifindex)
+			if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
+				ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+			}
 
+			ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
+
+			hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+			copy(ep.src, hdrBuf)
+			copy(ep.src[unix.CmsgLen(0):], data)
 			return
 		}
 	}
 }
 
-// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
-// panics if buf is of insufficient size.
-func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
-	size := int(unsafe.Sizeof(t))
-	if len(buf) < size {
-		panic("pktInfoFromBuf: buffer too small")
-	}
-	copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
-	return t
-}
-
 // setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
 // and source ifindex found in ep. control's len will be set to 0 in the event
 // that ep is a default value.
 func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
-	*control = (*control)[:cap(*control)]
-	if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
-		*control = (*control)[:0]
-		return
-	}
-
-	if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
-		*control = (*control)[:0]
-		return
-	}
-
-	if len(*control) < srcControlSize {
-		*control = (*control)[:0]
+	if cap(*control) < len(ep.src) {
 		return
 	}
-
-	hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
-	if ep.SrcIP().Is4() {
-		hdr.Level = unix.IPPROTO_IP
-		hdr.Type = unix.IP_PKTINFO
-		hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
-
-		info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
-		info.Ifindex = ep.src.ifidx
-		if ep.SrcIP().IsValid() {
-			info.Spec_dst = ep.SrcIP().As4()
-		}
-		*control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
-	} else {
-		hdr.Level = unix.IPPROTO_IPV6
-		hdr.Type = unix.IPV6_PKTINFO
-		hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
-
-		info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
-		info.Ifindex = uint32(ep.src.ifidx)
-		if ep.SrcIP().IsValid() {
-			info.Addr = ep.SrcIP().As16()
-		}
-		*control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
-	}
-
+	*control = (*control)[:0]
+	*control = append(*control, ep.src...)
 }
 
-var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
+// stickyControlSize returns the recommended buffer size for pooling sticky
+// offloading control data.
+var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
 
 const StdNetSupportsStickySockets = true