| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539 |
- // SPDX-License-Identifier: MIT
- //
- // Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
- package conn
- import (
- "context"
- "errors"
- "net"
- "net/netip"
- "runtime"
- "strconv"
- "sync"
- "syscall"
- "golang.org/x/net/ipv4"
- "golang.org/x/net/ipv6"
- "golang.org/x/sys/unix"
- )
- var (
- _ Bind = (*StdNetBind)(nil)
- )
- // StdNetBind implements Bind for all platforms. While Windows has its own Bind
- // (see bind_windows.go), it may fall back to StdNetBind.
- // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
- // methods for sending and receiving multiple datagrams per-syscall. See the
- // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
- type StdNetBind struct {
- mu sync.Mutex // protects all fields except as specified
- ipv4 *net.UDPConn
- ipv6 *net.UDPConn
- ipv4PC *ipv4.PacketConn // will be nil on non-Linux
- ipv6PC *ipv6.PacketConn // will be nil on non-Linux
- // these three fields are not guarded by mu
- udpAddrPool sync.Pool
- ipv4MsgsPool sync.Pool
- ipv6MsgsPool sync.Pool
- blackhole4 bool
- blackhole6 bool
- listenAddr4 string
- listenAddr6 string
- bindV4 bool
- bindV6 bool
- reusePort bool
- }
- func newStdNetBind() *StdNetBind {
- return &StdNetBind{
- udpAddrPool: sync.Pool{
- New: func() any {
- return &net.UDPAddr{
- IP: make([]byte, 16),
- }
- },
- },
- ipv4MsgsPool: sync.Pool{
- New: func() any {
- msgs := make([]ipv4.Message, IdealBatchSize)
- for i := range msgs {
- msgs[i].Buffers = make(net.Buffers, 1)
- msgs[i].OOB = make([]byte, srcControlSize)
- }
- return &msgs
- },
- },
- ipv6MsgsPool: sync.Pool{
- New: func() any {
- msgs := make([]ipv6.Message, IdealBatchSize)
- for i := range msgs {
- msgs[i].Buffers = make(net.Buffers, 1)
- msgs[i].OOB = make([]byte, srcControlSize)
- }
- return &msgs
- },
- },
- bindV4: true,
- bindV6: true,
- reusePort: false,
- }
- }
- // NewStdNetBind creates a bind that listens on all interfaces.
- func NewStdNetBind() *StdNetBind {
- return newStdNetBind()
- }
- // NewStdNetBindForAddr creates a bind that listens on a specific address.
- // If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
- // IPv6 socket will be created.
- func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
- b := newStdNetBind()
- if addr.IsValid() {
- if addr.IsUnspecified() {
- // keep dual-stack defaults with empty listen addresses
- } else if addr.Is4() {
- b.listenAddr4 = addr.Unmap().String()
- b.bindV4 = true
- b.bindV6 = false
- } else {
- b.listenAddr6 = addr.Unmap().String()
- b.bindV6 = true
- b.bindV4 = false
- }
- }
- b.reusePort = reusePort
- return b
- }
- type StdNetEndpoint struct {
- // AddrPort is the endpoint destination.
- netip.AddrPort
- // src is the current sticky source address and interface index, if supported.
- src struct {
- netip.Addr
- ifidx int32
- }
- }
- var (
- _ Bind = (*StdNetBind)(nil)
- _ Endpoint = &StdNetEndpoint{}
- )
- func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
- e, err := netip.ParseAddrPort(s)
- if err != nil {
- return nil, err
- }
- return &StdNetEndpoint{
- AddrPort: e,
- }, nil
- }
- func (e *StdNetEndpoint) ClearSrc() {
- e.src.ifidx = 0
- e.src.Addr = netip.Addr{}
- }
- func (e *StdNetEndpoint) DstIP() netip.Addr {
- return e.AddrPort.Addr()
- }
- func (e *StdNetEndpoint) SrcIP() netip.Addr {
- return e.src.Addr
- }
- func (e *StdNetEndpoint) SrcIfidx() int32 {
- return e.src.ifidx
- }
- func (e *StdNetEndpoint) DstToBytes() []byte {
- b, _ := e.AddrPort.MarshalBinary()
- return b
- }
- func (e *StdNetEndpoint) DstToString() string {
- return e.AddrPort.String()
- }
- func (e *StdNetEndpoint) SrcToString() string {
- return e.src.Addr.String()
- }
- func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPConn, int, error) {
- lc := listenConfig()
- if s.reusePort {
- base := lc.Control
- lc.Control = func(network, address string, c syscall.RawConn) error {
- if base != nil {
- if err := base(network, address, c); err != nil {
- return err
- }
- }
- return c.Control(func(fd uintptr) {
- _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
- })
- }
- }
- addr := ":" + strconv.Itoa(port)
- if host != "" {
- addr = net.JoinHostPort(host, strconv.Itoa(port))
- }
- conn, err := lc.ListenPacket(context.Background(), network, addr)
- if err != nil {
- return nil, 0, err
- }
- // Retrieve port.
- laddr := conn.LocalAddr()
- uaddr, err := net.ResolveUDPAddr(
- laddr.Network(),
- laddr.String(),
- )
- if err != nil {
- return nil, 0, err
- }
- return conn.(*net.UDPConn), uaddr.Port, nil
- }
- func (s *StdNetBind) openIPv4(port int) (*net.UDPConn, *ipv4.PacketConn, int, error) {
- if !s.bindV4 {
- return nil, nil, port, nil
- }
- host := s.listenAddr4
- conn, actualPort, err := s.listenNet("udp4", host, port)
- if err != nil {
- if errors.Is(err, syscall.EAFNOSUPPORT) {
- return nil, nil, port, nil
- }
- return nil, nil, port, err
- }
- if runtime.GOOS != "linux" {
- return conn, nil, actualPort, nil
- }
- pc := ipv4.NewPacketConn(conn)
- return conn, pc, actualPort, nil
- }
- func (s *StdNetBind) openIPv6(port int) (*net.UDPConn, *ipv6.PacketConn, int, error) {
- if !s.bindV6 {
- return nil, nil, port, nil
- }
- host := s.listenAddr6
- conn, actualPort, err := s.listenNet("udp6", host, port)
- if err != nil {
- if errors.Is(err, syscall.EAFNOSUPPORT) {
- return nil, nil, port, nil
- }
- return nil, nil, port, err
- }
- if runtime.GOOS != "linux" {
- return conn, nil, actualPort, nil
- }
- pc := ipv6.NewPacketConn(conn)
- return conn, pc, actualPort, nil
- }
- func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
- var err error
- var tries int
- if s.ipv4 != nil || s.ipv6 != nil {
- return nil, 0, ErrBindAlreadyOpen
- }
- // Attempt to open ipv4 and ipv6 listeners on the same port.
- // If uport is 0, we can retry on failure.
- again:
- port := int(uport)
- var v4conn *net.UDPConn
- var v6conn *net.UDPConn
- var v4pc *ipv4.PacketConn
- var v6pc *ipv6.PacketConn
- v4conn, v4pc, port, err = s.openIPv4(port)
- if err != nil {
- return nil, 0, err
- }
- // Listen on the same port as we're using for ipv4.
- v6conn, v6pc, port, err = s.openIPv6(port)
- if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- if v4conn != nil {
- v4conn.Close()
- }
- tries++
- goto again
- }
- if err != nil {
- if v4conn != nil {
- v4conn.Close()
- }
- return nil, 0, err
- }
- var fns []ReceiveFunc
- if v4conn != nil {
- s.ipv4 = v4conn
- if v4pc != nil {
- s.ipv4PC = v4pc
- }
- fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
- }
- if v6conn != nil {
- s.ipv6 = v6conn
- if v6pc != nil {
- s.ipv6PC = v6pc
- }
- fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
- }
- if len(fns) == 0 {
- return nil, 0, syscall.EAFNOSUPPORT
- }
- return fns, uint16(port), nil
- }
- func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
- return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
- defer s.ipv4MsgsPool.Put(msgs)
- for i := range bufs {
- (*msgs)[i].Buffers[0] = bufs[i]
- }
- var numMsgs int
- if runtime.GOOS == "linux" && pc != nil {
- numMsgs, err = pc.ReadBatch(*msgs, 0)
- if err != nil {
- return 0, err
- }
- } else {
- msg := &(*msgs)[0]
- msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
- if err != nil {
- return 0, err
- }
- numMsgs = 1
- }
- for i := 0; i < numMsgs; i++ {
- msg := &(*msgs)[i]
- sizes[i] = msg.N
- addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
- ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
- getSrcFromControl(msg.OOB[:msg.NN], ep)
- eps[i] = ep
- }
- return numMsgs, nil
- }
- }
- func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
- return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
- msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
- defer s.ipv6MsgsPool.Put(msgs)
- for i := range bufs {
- (*msgs)[i].Buffers[0] = bufs[i]
- }
- var numMsgs int
- if runtime.GOOS == "linux" && pc != nil {
- numMsgs, err = pc.ReadBatch(*msgs, 0)
- if err != nil {
- return 0, err
- }
- } else {
- msg := &(*msgs)[0]
- msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
- if err != nil {
- return 0, err
- }
- numMsgs = 1
- }
- for i := 0; i < numMsgs; i++ {
- msg := &(*msgs)[i]
- sizes[i] = msg.N
- addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
- ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
- getSrcFromControl(msg.OOB[:msg.NN], ep)
- eps[i] = ep
- }
- return numMsgs, nil
- }
- }
- // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
- // rename the IdealBatchSize constant to BatchSize.
- func (s *StdNetBind) BatchSize() int {
- if runtime.GOOS == "linux" {
- return IdealBatchSize
- }
- return 1
- }
- func (s *StdNetBind) Close() error {
- s.mu.Lock()
- defer s.mu.Unlock()
- var err1, err2 error
- if s.ipv4 != nil {
- err1 = s.ipv4.Close()
- s.ipv4 = nil
- s.ipv4PC = nil
- }
- if s.ipv6 != nil {
- err2 = s.ipv6.Close()
- s.ipv6 = nil
- s.ipv6PC = nil
- }
- s.blackhole4 = false
- s.blackhole6 = false
- if err1 != nil {
- return err1
- }
- return err2
- }
- func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
- s.mu.Lock()
- blackhole := s.blackhole4
- conn := s.ipv4
- var (
- pc4 *ipv4.PacketConn
- pc6 *ipv6.PacketConn
- )
- is6 := false
- if endpoint.DstIP().Is6() {
- blackhole = s.blackhole6
- conn = s.ipv6
- pc6 = s.ipv6PC
- is6 = true
- } else {
- pc4 = s.ipv4PC
- }
- s.mu.Unlock()
- if blackhole {
- return nil
- }
- if conn == nil {
- return syscall.EAFNOSUPPORT
- }
- if is6 {
- return s.send6(conn, pc6, endpoint, bufs)
- } else {
- return s.send4(conn, pc4, endpoint, bufs)
- }
- }
- func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
- ua := s.udpAddrPool.Get().(*net.UDPAddr)
- as4 := ep.DstIP().As4()
- copy(ua.IP, as4[:])
- ua.IP = ua.IP[:4]
- ua.Port = int(ep.(*StdNetEndpoint).Port())
- msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
- for i, buf := range bufs {
- (*msgs)[i].Buffers[0] = buf
- (*msgs)[i].Addr = ua
- setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
- }
- var (
- n int
- err error
- start int
- )
- if runtime.GOOS == "linux" && pc != nil {
- for {
- n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
- if err != nil {
- if errors.Is(err, syscall.EAFNOSUPPORT) {
- for j := start; j < len(bufs); j++ {
- _, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
- if werr != nil {
- err = werr
- break
- }
- }
- }
- break
- }
- if n == len((*msgs)[start:len(bufs)]) {
- break
- }
- start += n
- }
- } else {
- for i, buf := range bufs {
- _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
- if err != nil {
- break
- }
- }
- }
- s.udpAddrPool.Put(ua)
- s.ipv4MsgsPool.Put(msgs)
- return err
- }
- func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
- ua := s.udpAddrPool.Get().(*net.UDPAddr)
- as16 := ep.DstIP().As16()
- copy(ua.IP, as16[:])
- ua.IP = ua.IP[:16]
- ua.Port = int(ep.(*StdNetEndpoint).Port())
- msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
- for i, buf := range bufs {
- (*msgs)[i].Buffers[0] = buf
- (*msgs)[i].Addr = ua
- setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
- }
- var (
- n int
- err error
- start int
- )
- if runtime.GOOS == "linux" && pc != nil {
- for {
- n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
- if err != nil {
- if errors.Is(err, syscall.EAFNOSUPPORT) {
- for j := start; j < len(bufs); j++ {
- _, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
- if werr != nil {
- err = werr
- break
- }
- }
- }
- break
- }
- if n == len((*msgs)[start:len(bufs)]) {
- break
- }
- start += n
- }
- } else {
- for i, buf := range bufs {
- _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
- if err != nil {
- break
- }
- }
- }
- s.udpAddrPool.Put(ua)
- s.ipv6MsgsPool.Put(msgs)
- return err
- }
|