|
|
@@ -1,12 +1,14 @@
|
|
|
-// SPDX-License-Identifier: MIT
|
|
|
-//
|
|
|
-// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
|
+/* SPDX-License-Identifier: MIT
|
|
|
+ *
|
|
|
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
|
+ */
|
|
|
|
|
|
package conn
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"errors"
|
|
|
+ "fmt"
|
|
|
"net"
|
|
|
"net/netip"
|
|
|
"runtime"
|
|
|
@@ -16,7 +18,6 @@ import (
|
|
|
|
|
|
"golang.org/x/net/ipv4"
|
|
|
"golang.org/x/net/ipv6"
|
|
|
- "golang.org/x/sys/unix"
|
|
|
)
|
|
|
|
|
|
var (
|
|
|
@@ -29,28 +30,53 @@ var (
|
|
|
// methods for sending and receiving multiple datagrams per-syscall. See the
|
|
|
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
|
|
type StdNetBind struct {
|
|
|
- mu sync.Mutex // protects all fields except as specified
|
|
|
- ipv4 *net.UDPConn
|
|
|
- ipv6 *net.UDPConn
|
|
|
- ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
|
|
- ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
|
|
-
|
|
|
- // these three fields are not guarded by mu
|
|
|
- udpAddrPool sync.Pool
|
|
|
- ipv4MsgsPool sync.Pool
|
|
|
- ipv6MsgsPool sync.Pool
|
|
|
+ mu sync.Mutex // protects all fields except as specified
|
|
|
+ ipv4 *net.UDPConn
|
|
|
+ ipv6 *net.UDPConn
|
|
|
+ ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
|
|
+ ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
|
|
+ ipv4TxOffload bool
|
|
|
+ ipv4RxOffload bool
|
|
|
+ ipv6TxOffload bool
|
|
|
+ ipv6RxOffload bool
|
|
|
+
|
|
|
+ // these two fields are not guarded by mu
|
|
|
+ udpAddrPool sync.Pool
|
|
|
+ msgsPool sync.Pool
|
|
|
|
|
|
blackhole4 bool
|
|
|
blackhole6 bool
|
|
|
+}
|
|
|
+
|
|
|
+// NewStdNetBind creates a bind that listens on all interfaces.
|
|
|
+func NewStdNetBind() *StdNetBind {
|
|
|
+ return newStdNetBind().(*StdNetBind)
|
|
|
+}
|
|
|
+
|
|
|
+// NewStdNetBindForAddr creates a bind that listens on a specific address.
|
|
|
+// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
|
|
|
+// IPv6 socket will be created.
|
|
|
+func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
|
|
|
+ b := NewStdNetBind()
|
|
|
+ //if addr.IsValid() {
|
|
|
+ // if addr.IsUnspecified() {
|
|
|
+ // // keep dual-stack defaults with empty listen addresses
|
|
|
+ // } else if addr.Is4() {
|
|
|
+ // b.listenAddr4 = addr.Unmap().String()
|
|
|
+ // b.bindV4 = true
|
|
|
+ // b.bindV6 = false
|
|
|
+ // } else {
|
|
|
+ // b.listenAddr6 = addr.Unmap().String()
|
|
|
+ // b.bindV6 = true
|
|
|
+ // b.bindV4 = false
|
|
|
+ // }
|
|
|
+ //}
|
|
|
+ //b.reusePort = reusePort
|
|
|
|
|
|
- listenAddr4 string
|
|
|
- listenAddr6 string
|
|
|
- bindV4 bool
|
|
|
- bindV6 bool
|
|
|
- reusePort bool
|
|
|
+ return b
|
|
|
}
|
|
|
|
|
|
-func newStdNetBind() *StdNetBind {
|
|
|
+func newStdNetBind() Bind {
|
|
|
return &StdNetBind{
|
|
|
udpAddrPool: sync.Pool{
|
|
|
New: func() any {
|
|
|
@@ -60,68 +86,28 @@ func newStdNetBind() *StdNetBind {
|
|
|
},
|
|
|
},
|
|
|
|
|
|
- ipv4MsgsPool: sync.Pool{
|
|
|
- New: func() any {
|
|
|
- msgs := make([]ipv4.Message, IdealBatchSize)
|
|
|
- for i := range msgs {
|
|
|
- msgs[i].Buffers = make(net.Buffers, 1)
|
|
|
- msgs[i].OOB = make([]byte, srcControlSize)
|
|
|
- }
|
|
|
- return &msgs
|
|
|
- },
|
|
|
- },
|
|
|
-
|
|
|
- ipv6MsgsPool: sync.Pool{
|
|
|
+ msgsPool: sync.Pool{
|
|
|
New: func() any {
|
|
|
+ // ipv6.Message and ipv4.Message are interchangeable as they are
|
|
|
+ // both aliases for x/net/internal/socket.Message.
|
|
|
msgs := make([]ipv6.Message, IdealBatchSize)
|
|
|
for i := range msgs {
|
|
|
msgs[i].Buffers = make(net.Buffers, 1)
|
|
|
- msgs[i].OOB = make([]byte, srcControlSize)
|
|
|
+ msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
|
|
|
}
|
|
|
return &msgs
|
|
|
},
|
|
|
},
|
|
|
- bindV4: true,
|
|
|
- bindV6: true,
|
|
|
- reusePort: false,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// NewStdNetBind creates a bind that listens on all interfaces.
|
|
|
-func NewStdNetBind() *StdNetBind {
|
|
|
- return newStdNetBind()
|
|
|
-}
|
|
|
-
|
|
|
-// NewStdNetBindForAddr creates a bind that listens on a specific address.
|
|
|
-// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
|
|
|
-// IPv6 socket will be created.
|
|
|
-func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
|
|
|
- b := newStdNetBind()
|
|
|
- if addr.IsValid() {
|
|
|
- if addr.IsUnspecified() {
|
|
|
- // keep dual-stack defaults with empty listen addresses
|
|
|
- } else if addr.Is4() {
|
|
|
- b.listenAddr4 = addr.Unmap().String()
|
|
|
- b.bindV4 = true
|
|
|
- b.bindV6 = false
|
|
|
- } else {
|
|
|
- b.listenAddr6 = addr.Unmap().String()
|
|
|
- b.bindV6 = true
|
|
|
- b.bindV4 = false
|
|
|
- }
|
|
|
- }
|
|
|
- b.reusePort = reusePort
|
|
|
- return b
|
|
|
-}
|
|
|
-
|
|
|
type StdNetEndpoint struct {
|
|
|
// AddrPort is the endpoint destination.
|
|
|
netip.AddrPort
|
|
|
- // src is the current sticky source address and interface index, if supported.
|
|
|
- src struct {
|
|
|
- netip.Addr
|
|
|
- ifidx int32
|
|
|
- }
|
|
|
+ // src is the current sticky source address and interface index, if
|
|
|
+ // supported. Typically this is a PKTINFO structure from/for control
|
|
|
+ // messages, see unix.PKTINFO for an example.
|
|
|
+ src []byte
|
|
|
}
|
|
|
|
|
|
var (
|
|
|
@@ -140,21 +126,17 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
|
|
}
|
|
|
|
|
|
func (e *StdNetEndpoint) ClearSrc() {
|
|
|
- e.src.ifidx = 0
|
|
|
- e.src.Addr = netip.Addr{}
|
|
|
+ if e.src != nil {
|
|
|
+ // Truncate src, no need to reallocate.
|
|
|
+ e.src = e.src[:0]
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
|
|
return e.AddrPort.Addr()
|
|
|
}
|
|
|
|
|
|
-func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
|
|
- return e.src.Addr
|
|
|
-}
|
|
|
-
|
|
|
-func (e *StdNetEndpoint) SrcIfidx() int32 {
|
|
|
- return e.src.ifidx
|
|
|
-}
|
|
|
+// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
|
|
|
|
|
|
func (e *StdNetEndpoint) DstToBytes() []byte {
|
|
|
b, _ := e.AddrPort.MarshalBinary()
|
|
|
@@ -165,32 +147,8 @@ func (e *StdNetEndpoint) DstToString() string {
|
|
|
return e.AddrPort.String()
|
|
|
}
|
|
|
|
|
|
-func (e *StdNetEndpoint) SrcToString() string {
|
|
|
- return e.src.Addr.String()
|
|
|
-}
|
|
|
-
|
|
|
-func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPConn, int, error) {
|
|
|
- lc := listenConfig()
|
|
|
- if s.reusePort {
|
|
|
- base := lc.Control
|
|
|
- lc.Control = func(network, address string, c syscall.RawConn) error {
|
|
|
- if base != nil {
|
|
|
- if err := base(network, address, c); err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- }
|
|
|
- return c.Control(func(fd uintptr) {
|
|
|
- _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
|
|
|
- })
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- addr := ":" + strconv.Itoa(port)
|
|
|
- if host != "" {
|
|
|
- addr = net.JoinHostPort(host, strconv.Itoa(port))
|
|
|
- }
|
|
|
-
|
|
|
- conn, err := lc.ListenPacket(context.Background(), network, addr)
|
|
|
+func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
|
|
+ conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
|
|
|
if err != nil {
|
|
|
return nil, 0, err
|
|
|
}
|
|
|
@@ -204,45 +162,8 @@ func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPC
|
|
|
if err != nil {
|
|
|
return nil, 0, err
|
|
|
}
|
|
|
- return conn.(*net.UDPConn), uaddr.Port, nil
|
|
|
-}
|
|
|
-
|
|
|
-func (s *StdNetBind) openIPv4(port int) (*net.UDPConn, *ipv4.PacketConn, int, error) {
|
|
|
- if !s.bindV4 {
|
|
|
- return nil, nil, port, nil
|
|
|
- }
|
|
|
- host := s.listenAddr4
|
|
|
- conn, actualPort, err := s.listenNet("udp4", host, port)
|
|
|
- if err != nil {
|
|
|
- if errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
|
- return nil, nil, port, nil
|
|
|
- }
|
|
|
- return nil, nil, port, err
|
|
|
- }
|
|
|
- if runtime.GOOS != "linux" {
|
|
|
- return conn, nil, actualPort, nil
|
|
|
- }
|
|
|
- pc := ipv4.NewPacketConn(conn)
|
|
|
- return conn, pc, actualPort, nil
|
|
|
-}
|
|
|
|
|
|
-func (s *StdNetBind) openIPv6(port int) (*net.UDPConn, *ipv6.PacketConn, int, error) {
|
|
|
- if !s.bindV6 {
|
|
|
- return nil, nil, port, nil
|
|
|
- }
|
|
|
- host := s.listenAddr6
|
|
|
- conn, actualPort, err := s.listenNet("udp6", host, port)
|
|
|
- if err != nil {
|
|
|
- if errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
|
- return nil, nil, port, nil
|
|
|
- }
|
|
|
- return nil, nil, port, err
|
|
|
- }
|
|
|
- if runtime.GOOS != "linux" {
|
|
|
- return conn, nil, actualPort, nil
|
|
|
- }
|
|
|
- pc := ipv6.NewPacketConn(conn)
|
|
|
- return conn, pc, actualPort, nil
|
|
|
+ return conn.(*net.UDPConn), uaddr.Port, nil
|
|
|
}
|
|
|
|
|
|
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
|
|
@@ -260,46 +181,44 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
|
|
// If uport is 0, we can retry on failure.
|
|
|
again:
|
|
|
port := int(uport)
|
|
|
- var v4conn *net.UDPConn
|
|
|
- var v6conn *net.UDPConn
|
|
|
+ var v4conn, v6conn *net.UDPConn
|
|
|
var v4pc *ipv4.PacketConn
|
|
|
var v6pc *ipv6.PacketConn
|
|
|
|
|
|
- v4conn, v4pc, port, err = s.openIPv4(port)
|
|
|
- if err != nil {
|
|
|
+ v4conn, port, err = listenNet("udp4", port)
|
|
|
+ if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
|
return nil, 0, err
|
|
|
}
|
|
|
|
|
|
// Listen on the same port as we're using for ipv4.
|
|
|
- v6conn, v6pc, port, err = s.openIPv6(port)
|
|
|
+ v6conn, port, err = listenNet("udp6", port)
|
|
|
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
|
|
- if v4conn != nil {
|
|
|
- v4conn.Close()
|
|
|
- }
|
|
|
+ v4conn.Close()
|
|
|
tries++
|
|
|
goto again
|
|
|
}
|
|
|
- if err != nil {
|
|
|
- if v4conn != nil {
|
|
|
- v4conn.Close()
|
|
|
- }
|
|
|
+ if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
|
+ v4conn.Close()
|
|
|
return nil, 0, err
|
|
|
}
|
|
|
-
|
|
|
var fns []ReceiveFunc
|
|
|
if v4conn != nil {
|
|
|
- s.ipv4 = v4conn
|
|
|
- if v4pc != nil {
|
|
|
+ s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
|
|
|
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
|
|
+ v4pc = ipv4.NewPacketConn(v4conn)
|
|
|
s.ipv4PC = v4pc
|
|
|
}
|
|
|
- fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
|
|
|
+ fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
|
|
|
+ s.ipv4 = v4conn
|
|
|
}
|
|
|
if v6conn != nil {
|
|
|
- s.ipv6 = v6conn
|
|
|
- if v6pc != nil {
|
|
|
+ s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
|
|
|
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
|
|
+ v6pc = ipv6.NewPacketConn(v6conn)
|
|
|
s.ipv6PC = v6pc
|
|
|
}
|
|
|
- fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
|
|
|
+ fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
|
|
|
+ s.ipv6 = v6conn
|
|
|
}
|
|
|
if len(fns) == 0 {
|
|
|
return nil, 0, syscall.EAFNOSUPPORT
|
|
|
@@ -308,76 +227,101 @@ again:
|
|
|
return fns, uint16(port), nil
|
|
|
}
|
|
|
|
|
|
-func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
|
|
|
- return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
|
|
- msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
|
|
- defer s.ipv4MsgsPool.Put(msgs)
|
|
|
- for i := range bufs {
|
|
|
- (*msgs)[i].Buffers[0] = bufs[i]
|
|
|
- }
|
|
|
- var numMsgs int
|
|
|
- if runtime.GOOS == "linux" && pc != nil {
|
|
|
- numMsgs, err = pc.ReadBatch(*msgs, 0)
|
|
|
+func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
|
|
|
+ for i := range *msgs {
|
|
|
+ (*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
|
|
+ (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
|
|
+ }
|
|
|
+ s.msgsPool.Put(msgs)
|
|
|
+}
|
|
|
+
|
|
|
+func (s *StdNetBind) getMessages() *[]ipv6.Message {
|
|
|
+ return s.msgsPool.Get().(*[]ipv6.Message)
|
|
|
+}
|
|
|
+
|
|
|
+var (
|
|
|
+ // If compilation fails here these are no longer the same underlying type.
|
|
|
+ _ ipv6.Message = ipv4.Message{}
|
|
|
+)
|
|
|
+
|
|
|
+type batchReader interface {
|
|
|
+ ReadBatch([]ipv6.Message, int) (int, error)
|
|
|
+}
|
|
|
+
|
|
|
+type batchWriter interface {
|
|
|
+ WriteBatch([]ipv6.Message, int) (int, error)
|
|
|
+}
|
|
|
+
|
|
|
+func (s *StdNetBind) receiveIP(
|
|
|
+ br batchReader,
|
|
|
+ conn *net.UDPConn,
|
|
|
+ rxOffload bool,
|
|
|
+ bufs [][]byte,
|
|
|
+ sizes []int,
|
|
|
+ eps []Endpoint,
|
|
|
+) (n int, err error) {
|
|
|
+ msgs := s.getMessages()
|
|
|
+ for i := range bufs {
|
|
|
+ (*msgs)[i].Buffers[0] = bufs[i]
|
|
|
+ (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
|
|
+ }
|
|
|
+ defer s.putMessages(msgs)
|
|
|
+ var numMsgs int
|
|
|
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
|
|
+ if rxOffload {
|
|
|
+ readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
|
|
|
+ numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
} else {
|
|
|
- msg := &(*msgs)[0]
|
|
|
- msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
|
|
+ numMsgs, err = br.ReadBatch(*msgs, 0)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
- numMsgs = 1
|
|
|
}
|
|
|
- for i := 0; i < numMsgs; i++ {
|
|
|
- msg := &(*msgs)[i]
|
|
|
- sizes[i] = msg.N
|
|
|
- addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
|
|
- ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
|
|
- getSrcFromControl(msg.OOB[:msg.NN], ep)
|
|
|
- eps[i] = ep
|
|
|
+ } else {
|
|
|
+ msg := &(*msgs)[0]
|
|
|
+ msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
}
|
|
|
- return numMsgs, nil
|
|
|
+ numMsgs = 1
|
|
|
}
|
|
|
+ for i := 0; i < numMsgs; i++ {
|
|
|
+ msg := &(*msgs)[i]
|
|
|
+ sizes[i] = msg.N
|
|
|
+ if sizes[i] == 0 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
|
|
+ ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
|
|
+ getSrcFromControl(msg.OOB[:msg.NN], ep)
|
|
|
+ eps[i] = ep
|
|
|
+ }
|
|
|
+ return numMsgs, nil
|
|
|
}
|
|
|
|
|
|
-func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
|
|
|
+func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
|
|
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
|
|
- msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
|
|
- defer s.ipv6MsgsPool.Put(msgs)
|
|
|
- for i := range bufs {
|
|
|
- (*msgs)[i].Buffers[0] = bufs[i]
|
|
|
- }
|
|
|
- var numMsgs int
|
|
|
- if runtime.GOOS == "linux" && pc != nil {
|
|
|
- numMsgs, err = pc.ReadBatch(*msgs, 0)
|
|
|
- if err != nil {
|
|
|
- return 0, err
|
|
|
- }
|
|
|
- } else {
|
|
|
- msg := &(*msgs)[0]
|
|
|
- msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
|
|
- if err != nil {
|
|
|
- return 0, err
|
|
|
- }
|
|
|
- numMsgs = 1
|
|
|
- }
|
|
|
- for i := 0; i < numMsgs; i++ {
|
|
|
- msg := &(*msgs)[i]
|
|
|
- sizes[i] = msg.N
|
|
|
- addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
|
|
- ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
|
|
- getSrcFromControl(msg.OOB[:msg.NN], ep)
|
|
|
- eps[i] = ep
|
|
|
- }
|
|
|
- return numMsgs, nil
|
|
|
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
|
|
+ return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
|
|
+ return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
|
|
// rename the IdealBatchSize constant to BatchSize.
|
|
|
func (s *StdNetBind) BatchSize() int {
|
|
|
- if runtime.GOOS == "linux" {
|
|
|
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
|
|
return IdealBatchSize
|
|
|
}
|
|
|
return 1
|
|
|
@@ -400,28 +344,42 @@ func (s *StdNetBind) Close() error {
|
|
|
}
|
|
|
s.blackhole4 = false
|
|
|
s.blackhole6 = false
|
|
|
+ s.ipv4TxOffload = false
|
|
|
+ s.ipv4RxOffload = false
|
|
|
+ s.ipv6TxOffload = false
|
|
|
+ s.ipv6RxOffload = false
|
|
|
if err1 != nil {
|
|
|
return err1
|
|
|
}
|
|
|
return err2
|
|
|
}
|
|
|
|
|
|
+type ErrUDPGSODisabled struct {
|
|
|
+ onLaddr string
|
|
|
+ RetryErr error
|
|
|
+}
|
|
|
+
|
|
|
+func (e ErrUDPGSODisabled) Error() string {
|
|
|
+ return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
|
|
|
+}
|
|
|
+
|
|
|
+func (e ErrUDPGSODisabled) Unwrap() error {
|
|
|
+ return e.RetryErr
|
|
|
+}
|
|
|
+
|
|
|
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
|
|
s.mu.Lock()
|
|
|
blackhole := s.blackhole4
|
|
|
conn := s.ipv4
|
|
|
- var (
|
|
|
- pc4 *ipv4.PacketConn
|
|
|
- pc6 *ipv6.PacketConn
|
|
|
- )
|
|
|
+ offload := s.ipv4TxOffload
|
|
|
+ br := batchWriter(s.ipv4PC)
|
|
|
is6 := false
|
|
|
if endpoint.DstIP().Is6() {
|
|
|
blackhole = s.blackhole6
|
|
|
conn = s.ipv6
|
|
|
- pc6 = s.ipv6PC
|
|
|
+ br = s.ipv6PC
|
|
|
is6 = true
|
|
|
- } else {
|
|
|
- pc4 = s.ipv4PC
|
|
|
+ offload = s.ipv6TxOffload
|
|
|
}
|
|
|
s.mu.Unlock()
|
|
|
|
|
|
@@ -431,109 +389,185 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
|
|
if conn == nil {
|
|
|
return syscall.EAFNOSUPPORT
|
|
|
}
|
|
|
+
|
|
|
+ msgs := s.getMessages()
|
|
|
+ defer s.putMessages(msgs)
|
|
|
+ ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
|
|
+ defer s.udpAddrPool.Put(ua)
|
|
|
if is6 {
|
|
|
- return s.send6(conn, pc6, endpoint, bufs)
|
|
|
+ as16 := endpoint.DstIP().As16()
|
|
|
+ copy(ua.IP, as16[:])
|
|
|
+ ua.IP = ua.IP[:16]
|
|
|
} else {
|
|
|
- return s.send4(conn, pc4, endpoint, bufs)
|
|
|
+ as4 := endpoint.DstIP().As4()
|
|
|
+ copy(ua.IP, as4[:])
|
|
|
+ ua.IP = ua.IP[:4]
|
|
|
}
|
|
|
+ ua.Port = int(endpoint.(*StdNetEndpoint).Port())
|
|
|
+ var (
|
|
|
+ retried bool
|
|
|
+ err error
|
|
|
+ )
|
|
|
+retry:
|
|
|
+ if offload {
|
|
|
+ n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
|
|
|
+ err = s.send(conn, br, (*msgs)[:n])
|
|
|
+ if err != nil && offload && errShouldDisableUDPGSO(err) {
|
|
|
+ offload = false
|
|
|
+ s.mu.Lock()
|
|
|
+ if is6 {
|
|
|
+ s.ipv6TxOffload = false
|
|
|
+ } else {
|
|
|
+ s.ipv4TxOffload = false
|
|
|
+ }
|
|
|
+ s.mu.Unlock()
|
|
|
+ retried = true
|
|
|
+ goto retry
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for i := range bufs {
|
|
|
+ (*msgs)[i].Addr = ua
|
|
|
+ (*msgs)[i].Buffers[0] = bufs[i]
|
|
|
+ setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
|
|
|
+ }
|
|
|
+ err = s.send(conn, br, (*msgs)[:len(bufs)])
|
|
|
+ }
|
|
|
+ if retried {
|
|
|
+ return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
|
|
|
+ }
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
-func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
|
|
|
- ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
|
|
- as4 := ep.DstIP().As4()
|
|
|
- copy(ua.IP, as4[:])
|
|
|
- ua.IP = ua.IP[:4]
|
|
|
- ua.Port = int(ep.(*StdNetEndpoint).Port())
|
|
|
- msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
|
|
- for i, buf := range bufs {
|
|
|
- (*msgs)[i].Buffers[0] = buf
|
|
|
- (*msgs)[i].Addr = ua
|
|
|
- setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
|
|
- }
|
|
|
+func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
|
|
|
var (
|
|
|
n int
|
|
|
err error
|
|
|
start int
|
|
|
)
|
|
|
- if runtime.GOOS == "linux" && pc != nil {
|
|
|
+ if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
|
|
for {
|
|
|
- n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
|
|
|
- if err != nil {
|
|
|
- if errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
|
- for j := start; j < len(bufs); j++ {
|
|
|
- _, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
|
|
|
- if werr != nil {
|
|
|
- err = werr
|
|
|
- break
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- break
|
|
|
- }
|
|
|
- if n == len((*msgs)[start:len(bufs)]) {
|
|
|
+ n, err = pc.WriteBatch(msgs[start:], 0)
|
|
|
+ if err != nil || n == len(msgs[start:]) {
|
|
|
break
|
|
|
}
|
|
|
start += n
|
|
|
}
|
|
|
} else {
|
|
|
- for i, buf := range bufs {
|
|
|
- _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
|
|
|
+ for _, msg := range msgs {
|
|
|
+ _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
|
|
|
if err != nil {
|
|
|
break
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- s.udpAddrPool.Put(ua)
|
|
|
- s.ipv4MsgsPool.Put(msgs)
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
-func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
|
|
|
- ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
|
|
- as16 := ep.DstIP().As16()
|
|
|
- copy(ua.IP, as16[:])
|
|
|
- ua.IP = ua.IP[:16]
|
|
|
- ua.Port = int(ep.(*StdNetEndpoint).Port())
|
|
|
- msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
|
|
- for i, buf := range bufs {
|
|
|
- (*msgs)[i].Buffers[0] = buf
|
|
|
- (*msgs)[i].Addr = ua
|
|
|
- setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
|
|
- }
|
|
|
+const (
|
|
|
+ // Exceeding these values results in EMSGSIZE. They account for layer3 and
|
|
|
+ // layer4 headers. IPv6 does not need to account for itself as the payload
|
|
|
+ // length field is self excluding.
|
|
|
+ maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
|
|
|
+ maxIPv6PayloadLen = 1<<16 - 1 - 8
|
|
|
+
|
|
|
+ // This is a hard limit imposed by the kernel.
|
|
|
+ udpSegmentMaxDatagrams = 64
|
|
|
+)
|
|
|
+
|
|
|
+type setGSOFunc func(control *[]byte, gsoSize uint16)
|
|
|
+
|
|
|
+func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
|
|
|
var (
|
|
|
- n int
|
|
|
- err error
|
|
|
- start int
|
|
|
+ base = -1 // index of msg we are currently coalescing into
|
|
|
+ gsoSize int // segmentation size of msgs[base]
|
|
|
+ dgramCnt int // number of dgrams coalesced into msgs[base]
|
|
|
+ endBatch bool // tracking flag to start a new batch on next iteration of bufs
|
|
|
)
|
|
|
- if runtime.GOOS == "linux" && pc != nil {
|
|
|
- for {
|
|
|
- n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
|
|
|
- if err != nil {
|
|
|
- if errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
|
- for j := start; j < len(bufs); j++ {
|
|
|
- _, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
|
|
|
- if werr != nil {
|
|
|
- err = werr
|
|
|
- break
|
|
|
- }
|
|
|
- }
|
|
|
+ maxPayloadLen := maxIPv4PayloadLen
|
|
|
+ if ep.DstIP().Is6() {
|
|
|
+ maxPayloadLen = maxIPv6PayloadLen
|
|
|
+ }
|
|
|
+ for i, buf := range bufs {
|
|
|
+ if i > 0 {
|
|
|
+ msgLen := len(buf)
|
|
|
+ baseLenBefore := len(msgs[base].Buffers[0])
|
|
|
+ freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
|
|
|
+ if msgLen+baseLenBefore <= maxPayloadLen &&
|
|
|
+ msgLen <= gsoSize &&
|
|
|
+ msgLen <= freeBaseCap &&
|
|
|
+ dgramCnt < udpSegmentMaxDatagrams &&
|
|
|
+ !endBatch {
|
|
|
+ msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
|
|
|
+ if i == len(bufs)-1 {
|
|
|
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
|
|
|
}
|
|
|
- break
|
|
|
- }
|
|
|
- if n == len((*msgs)[start:len(bufs)]) {
|
|
|
- break
|
|
|
+ dgramCnt++
|
|
|
+ if msgLen < gsoSize {
|
|
|
+ // A smaller than gsoSize packet on the tail is legal, but
|
|
|
+ // it must end the batch.
|
|
|
+ endBatch = true
|
|
|
+ }
|
|
|
+ continue
|
|
|
}
|
|
|
- start += n
|
|
|
}
|
|
|
- } else {
|
|
|
- for i, buf := range bufs {
|
|
|
- _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
|
|
|
- if err != nil {
|
|
|
- break
|
|
|
+ if dgramCnt > 1 {
|
|
|
+ setGSO(&msgs[base].OOB, uint16(gsoSize))
|
|
|
+ }
|
|
|
+ // Reset prior to incrementing base since we are preparing to start a
|
|
|
+ // new potential batch.
|
|
|
+ endBatch = false
|
|
|
+ base++
|
|
|
+ gsoSize = len(buf)
|
|
|
+ setSrcControl(&msgs[base].OOB, ep)
|
|
|
+ msgs[base].Buffers[0] = buf
|
|
|
+ msgs[base].Addr = addr
|
|
|
+ dgramCnt = 1
|
|
|
+ }
|
|
|
+ return base + 1
|
|
|
+}
|
|
|
+
|
|
|
+type getGSOFunc func(control []byte) (int, error)
|
|
|
+
|
|
|
+func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
|
|
|
+ for i := firstMsgAt; i < len(msgs); i++ {
|
|
|
+ msg := &msgs[i]
|
|
|
+ if msg.N == 0 {
|
|
|
+ return n, err
|
|
|
+ }
|
|
|
+ var (
|
|
|
+ gsoSize int
|
|
|
+ start int
|
|
|
+ end = msg.N
|
|
|
+ numToSplit = 1
|
|
|
+ )
|
|
|
+ gsoSize, err = getGSO(msg.OOB[:msg.NN])
|
|
|
+ if err != nil {
|
|
|
+ return n, err
|
|
|
+ }
|
|
|
+ if gsoSize > 0 {
|
|
|
+ numToSplit = (msg.N + gsoSize - 1) / gsoSize
|
|
|
+ end = gsoSize
|
|
|
+ }
|
|
|
+ for j := 0; j < numToSplit; j++ {
|
|
|
+ if n > i {
|
|
|
+ return n, errors.New("splitting coalesced packet resulted in overflow")
|
|
|
+ }
|
|
|
+ copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
|
|
|
+ msgs[n].N = copied
|
|
|
+ msgs[n].Addr = msg.Addr
|
|
|
+ start = end
|
|
|
+ end += gsoSize
|
|
|
+ if end > msg.N {
|
|
|
+ end = msg.N
|
|
|
}
|
|
|
+ n++
|
|
|
+ }
|
|
|
+ if i != n-1 {
|
|
|
+ // It is legal for bytes to move within msg.Buffers[0] as a result
|
|
|
+ // of splitting, so we only zero the source msg len when it is not
|
|
|
+ // the destination of the last split operation above.
|
|
|
+ msg.N = 0
|
|
|
}
|
|
|
}
|
|
|
- s.udpAddrPool.Put(ua)
|
|
|
- s.ipv6MsgsPool.Put(msgs)
|
|
|
- return err
|
|
|
+ return n, nil
|
|
|
}
|