| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587 |
- /* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
- */
- package conn
- import (
- "context"
- "errors"
- "fmt"
- "net"
- "net/netip"
- "runtime"
- "strconv"
- "sync"
- "syscall"
- "golang.org/x/net/ipv4"
- "golang.org/x/net/ipv6"
- )
- var (
- _ Bind = (*StdNetBind)(nil)
- )
- // StdNetBind implements Bind for all platforms. While Windows has its own Bind
- // (see bind_windows.go), it may fall back to StdNetBind.
- // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
- // methods for sending and receiving multiple datagrams per-syscall. See the
- // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
- type StdNetBind struct {
- mu sync.Mutex // protects all fields except as specified
- ipv4 *net.UDPConn
- ipv6 *net.UDPConn
- ipv4PC *ipv4.PacketConn // will be nil on non-Linux
- ipv6PC *ipv6.PacketConn // will be nil on non-Linux
- ipv4TxOffload bool
- ipv4RxOffload bool
- ipv6TxOffload bool
- ipv6RxOffload bool
- // these two fields are not guarded by mu
- udpAddrPool sync.Pool
- msgsPool sync.Pool
- blackhole4 bool
- blackhole6 bool
- q int
- }
- // NewStdNetBind creates a bind that listens on all interfaces.
- func NewStdNetBind() *StdNetBind {
- return newStdNetBind().(*StdNetBind)
- }
- // NewStdNetBindForAddr creates a bind that listens on a specific address.
- // If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
- // IPv6 socket will be created.
- func NewStdNetBindForAddr(addr netip.Addr, reusePort bool, q int) *StdNetBind {
- b := NewStdNetBind()
- b.q = q
- //if addr.IsValid() {
- // if addr.IsUnspecified() {
- // // keep dual-stack defaults with empty listen addresses
- // } else if addr.Is4() {
- // b.listenAddr4 = addr.Unmap().String()
- // b.bindV4 = true
- // b.bindV6 = false
- // } else {
- // b.listenAddr6 = addr.Unmap().String()
- // b.bindV6 = true
- // b.bindV4 = false
- // }
- //}
- //b.reusePort = reusePort
- return b
- }
- func newStdNetBind() Bind {
- return &StdNetBind{
- udpAddrPool: sync.Pool{
- New: func() any {
- return &net.UDPAddr{
- IP: make([]byte, 16),
- }
- },
- },
- msgsPool: sync.Pool{
- New: func() any {
- // ipv6.Message and ipv4.Message are interchangeable as they are
- // both aliases for x/net/internal/socket.Message.
- msgs := make([]ipv6.Message, IdealBatchSize)
- for i := range msgs {
- msgs[i].Buffers = make(net.Buffers, 1)
- msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
- }
- return &msgs
- },
- },
- }
- }
- type StdNetEndpoint struct {
- // AddrPort is the endpoint destination.
- netip.AddrPort
- // src is the current sticky source address and interface index, if
- // supported. Typically this is a PKTINFO structure from/for control
- // messages, see unix.PKTINFO for an example.
- src []byte
- }
- var (
- _ Bind = (*StdNetBind)(nil)
- _ Endpoint = &StdNetEndpoint{}
- )
- func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
- e, err := netip.ParseAddrPort(s)
- if err != nil {
- return nil, err
- }
- return &StdNetEndpoint{
- AddrPort: e,
- }, nil
- }
- func (e *StdNetEndpoint) ClearSrc() {
- if e.src != nil {
- // Truncate src, no need to reallocate.
- e.src = e.src[:0]
- }
- }
- func (e *StdNetEndpoint) DstIP() netip.Addr {
- return e.AddrPort.Addr()
- }
- // See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
- func (e *StdNetEndpoint) DstToBytes() []byte {
- b, _ := e.AddrPort.MarshalBinary()
- return b
- }
- func (e *StdNetEndpoint) DstToString() string {
- return e.AddrPort.String()
- }
- func listenNet(network string, port int, q int) (*net.UDPConn, int, error) {
- lc := listenConfig(q)
- conn, err := lc.ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
- if err != nil {
- return nil, 0, err
- }
- if q == 0 {
- if EvilFdZero == 0 {
- panic("fuck")
- }
- err = reusePortHax(EvilFdZero)
- if err != nil {
- return nil, 0, fmt.Errorf("reuse port hax: %v", err)
- }
- }
- // Retrieve port.
- laddr := conn.LocalAddr()
- uaddr, err := net.ResolveUDPAddr(
- laddr.Network(),
- laddr.String(),
- )
- if err != nil {
- return nil, 0, err
- }
- return conn.(*net.UDPConn), uaddr.Port, nil
- }
- func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
- var err error
- var tries int
- if s.ipv4 != nil || s.ipv6 != nil {
- return nil, 0, ErrBindAlreadyOpen
- }
- // Attempt to open ipv4 and ipv6 listeners on the same port.
- // If uport is 0, we can retry on failure.
- again:
- port := int(uport)
- var v4conn, v6conn *net.UDPConn
- var v4pc *ipv4.PacketConn
- var v6pc *ipv6.PacketConn
- v4conn, port, err = listenNet("udp4", port, s.q)
- if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
- return nil, 0, err
- }
- // Listen on the same port as we're using for ipv4.
- v6conn, port, err = listenNet("udp6", port, s.q)
- if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- v4conn.Close()
- tries++
- goto again
- }
- if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
- v4conn.Close()
- return nil, 0, err
- }
- var fns []ReceiveFunc
- if v4conn != nil {
- s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
- if runtime.GOOS == "linux" || runtime.GOOS == "android" {
- v4pc = ipv4.NewPacketConn(v4conn)
- s.ipv4PC = v4pc
- }
- fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
- s.ipv4 = v4conn
- }
- if v6conn != nil {
- s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
- if runtime.GOOS == "linux" || runtime.GOOS == "android" {
- v6pc = ipv6.NewPacketConn(v6conn)
- s.ipv6PC = v6pc
- }
- fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
- s.ipv6 = v6conn
- }
- if len(fns) == 0 {
- return nil, 0, syscall.EAFNOSUPPORT
- }
- return fns, uint16(port), nil
- }
- func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
- for i := range *msgs {
- (*msgs)[i].OOB = (*msgs)[i].OOB[:0]
- (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
- }
- s.msgsPool.Put(msgs)
- }
- func (s *StdNetBind) getMessages() *[]ipv6.Message {
- return s.msgsPool.Get().(*[]ipv6.Message)
- }
- var (
- // If compilation fails here these are no longer the same underlying type.
- _ ipv6.Message = ipv4.Message{}
- )
- type batchReader interface {
- ReadBatch([]ipv6.Message, int) (int, error)
- }
- type batchWriter interface {
- WriteBatch([]ipv6.Message, int) (int, error)
- }
- func (s *StdNetBind) receiveIP(
- br batchReader,
- conn *net.UDPConn,
- rxOffload bool,
- bufs [][]byte,
- sizes []int,
- eps []Endpoint,
- ) (n int, err error) {
- msgs := s.getMessages()
- for i := range bufs {
- (*msgs)[i].Buffers[0] = bufs[i]
- (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
- }
- defer s.putMessages(msgs)
- var numMsgs int
- if runtime.GOOS == "linux" || runtime.GOOS == "android" {
- if rxOffload {
- readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
- numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
- if err != nil {
- return 0, err
- }
- numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
- if err != nil {
- return 0, err
- }
- } else {
- numMsgs, err = br.ReadBatch(*msgs, 0)
- if err != nil {
- return 0, err
- }
- }
- } else {
- msg := &(*msgs)[0]
- msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
- if err != nil {
- return 0, err
- }
- numMsgs = 1
- }
- for i := 0; i < numMsgs; i++ {
- msg := &(*msgs)[i]
- sizes[i] = msg.N
- if sizes[i] == 0 {
- continue
- }
- addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
- ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
- getSrcFromControl(msg.OOB[:msg.NN], ep)
- eps[i] = ep
- }
- return numMsgs, nil
- }
- func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
- return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
- }
- }
- func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
- return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
- }
- }
- // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
- // rename the IdealBatchSize constant to BatchSize.
- func (s *StdNetBind) BatchSize() int {
- if runtime.GOOS == "linux" || runtime.GOOS == "android" {
- return IdealBatchSize
- }
- return 1
- }
- func (s *StdNetBind) Close() error {
- s.mu.Lock()
- defer s.mu.Unlock()
- var err1, err2 error
- if s.ipv4 != nil {
- err1 = s.ipv4.Close()
- s.ipv4 = nil
- s.ipv4PC = nil
- }
- if s.ipv6 != nil {
- err2 = s.ipv6.Close()
- s.ipv6 = nil
- s.ipv6PC = nil
- }
- s.blackhole4 = false
- s.blackhole6 = false
- s.ipv4TxOffload = false
- s.ipv4RxOffload = false
- s.ipv6TxOffload = false
- s.ipv6RxOffload = false
- if err1 != nil {
- return err1
- }
- return err2
- }
- type ErrUDPGSODisabled struct {
- onLaddr string
- RetryErr error
- }
- func (e ErrUDPGSODisabled) Error() string {
- return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
- }
- func (e ErrUDPGSODisabled) Unwrap() error {
- return e.RetryErr
- }
- func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
- s.mu.Lock()
- blackhole := s.blackhole4
- conn := s.ipv4
- offload := s.ipv4TxOffload
- br := batchWriter(s.ipv4PC)
- is6 := false
- if endpoint.DstIP().Is6() {
- blackhole = s.blackhole6
- conn = s.ipv6
- br = s.ipv6PC
- is6 = true
- offload = s.ipv6TxOffload
- }
- s.mu.Unlock()
- if blackhole {
- return nil
- }
- if conn == nil {
- return syscall.EAFNOSUPPORT
- }
- msgs := s.getMessages()
- defer s.putMessages(msgs)
- ua := s.udpAddrPool.Get().(*net.UDPAddr)
- defer s.udpAddrPool.Put(ua)
- if is6 {
- as16 := endpoint.DstIP().As16()
- copy(ua.IP, as16[:])
- ua.IP = ua.IP[:16]
- } else {
- as4 := endpoint.DstIP().As4()
- copy(ua.IP, as4[:])
- ua.IP = ua.IP[:4]
- }
- ua.Port = int(endpoint.(*StdNetEndpoint).Port())
- var (
- retried bool
- err error
- )
- retry:
- if offload {
- n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
- err = s.send(conn, br, (*msgs)[:n])
- if err != nil && offload && errShouldDisableUDPGSO(err) {
- offload = false
- s.mu.Lock()
- if is6 {
- s.ipv6TxOffload = false
- } else {
- s.ipv4TxOffload = false
- }
- s.mu.Unlock()
- retried = true
- goto retry
- }
- } else {
- for i := range bufs {
- (*msgs)[i].Addr = ua
- (*msgs)[i].Buffers[0] = bufs[i]
- setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
- }
- err = s.send(conn, br, (*msgs)[:len(bufs)])
- }
- if retried {
- return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
- }
- return err
- }
- func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
- var (
- n int
- err error
- start int
- )
- if runtime.GOOS == "linux" || runtime.GOOS == "android" {
- for {
- n, err = pc.WriteBatch(msgs[start:], 0)
- if err != nil || n == len(msgs[start:]) {
- break
- }
- start += n
- }
- } else {
- for _, msg := range msgs {
- _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
- if err != nil {
- break
- }
- }
- }
- return err
- }
- const (
- // Exceeding these values results in EMSGSIZE. They account for layer3 and
- // layer4 headers. IPv6 does not need to account for itself as the payload
- // length field is self excluding.
- maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
- maxIPv6PayloadLen = 1<<16 - 1 - 8
- // This is a hard limit imposed by the kernel.
- udpSegmentMaxDatagrams = 64
- )
- type setGSOFunc func(control *[]byte, gsoSize uint16)
- func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
- var (
- base = -1 // index of msg we are currently coalescing into
- gsoSize int // segmentation size of msgs[base]
- dgramCnt int // number of dgrams coalesced into msgs[base]
- endBatch bool // tracking flag to start a new batch on next iteration of bufs
- )
- maxPayloadLen := maxIPv4PayloadLen
- if ep.DstIP().Is6() {
- maxPayloadLen = maxIPv6PayloadLen
- }
- for i, buf := range bufs {
- if i > 0 {
- msgLen := len(buf)
- baseLenBefore := len(msgs[base].Buffers[0])
- freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
- if msgLen+baseLenBefore <= maxPayloadLen &&
- msgLen <= gsoSize &&
- msgLen <= freeBaseCap &&
- dgramCnt < udpSegmentMaxDatagrams &&
- !endBatch {
- msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
- if i == len(bufs)-1 {
- setGSO(&msgs[base].OOB, uint16(gsoSize))
- }
- dgramCnt++
- if msgLen < gsoSize {
- // A smaller than gsoSize packet on the tail is legal, but
- // it must end the batch.
- endBatch = true
- }
- continue
- }
- }
- if dgramCnt > 1 {
- setGSO(&msgs[base].OOB, uint16(gsoSize))
- }
- // Reset prior to incrementing base since we are preparing to start a
- // new potential batch.
- endBatch = false
- base++
- gsoSize = len(buf)
- setSrcControl(&msgs[base].OOB, ep)
- msgs[base].Buffers[0] = buf
- msgs[base].Addr = addr
- dgramCnt = 1
- }
- return base + 1
- }
- type getGSOFunc func(control []byte) (int, error)
- func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
- for i := firstMsgAt; i < len(msgs); i++ {
- msg := &msgs[i]
- if msg.N == 0 {
- return n, err
- }
- var (
- gsoSize int
- start int
- end = msg.N
- numToSplit = 1
- )
- gsoSize, err = getGSO(msg.OOB[:msg.NN])
- if err != nil {
- return n, err
- }
- if gsoSize > 0 {
- numToSplit = (msg.N + gsoSize - 1) / gsoSize
- end = gsoSize
- }
- for j := 0; j < numToSplit; j++ {
- if n > i {
- return n, errors.New("splitting coalesced packet resulted in overflow")
- }
- copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
- msgs[n].N = copied
- msgs[n].Addr = msg.Addr
- start = end
- end += gsoSize
- if end > msg.N {
- end = msg.N
- }
- n++
- }
- if i != n-1 {
- // It is legal for bytes to move within msg.Buffers[0] as a result
- // of splitting, so we only zero the source msg len when it is not
- // the destination of the last split operation above.
- msg.N = 0
- }
- }
- return n, nil
- }
|