| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 | //go:build !e2e_testing// +build !e2e_testingpackage udpimport (	"context"	"encoding/binary"	"errors"	"fmt"	"net"	"net/netip"	"syscall"	"unsafe"	"github.com/sirupsen/logrus"	"github.com/slackhq/nebula/config"	"golang.org/x/sys/unix")type StdConn struct {	*net.UDPConn	isV4  bool	sysFd uintptr	l     *logrus.Logger}var _ Conn = &StdConn{}func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {	lc := NewListenConfig(multi)	pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))	if err != nil {		return nil, err	}	if uc, ok := pc.(*net.UDPConn); ok {		c := &StdConn{UDPConn: uc, l: l}		rc, err := uc.SyscallConn()		if err != nil {			return nil, fmt.Errorf("failed to open udp socket: %w", err)		}		err = rc.Control(func(fd uintptr) {			c.sysFd = fd		})		if err != nil {			return nil, fmt.Errorf("failed to get udp fd: %w", err)		}		la, err := c.LocalAddr()		if err != nil {			return nil, err		}		c.isV4 = la.Addr().Is4()		return c, nil	}	return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc)}func NewListenConfig(multi bool) net.ListenConfig {	return net.ListenConfig{		Control: func(network, address string, c syscall.RawConn) error {			if multi {				var controlErr error				err := c.Control(func(fd uintptr) {					if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {						controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)						return					}				})				if err != nil {					return err				}				if controlErr != nil {					return controlErr				}			}			return nil		},	}}//go:linkname sendto golang.org/x/sys/unix.sendto//go:noescapefunc sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error)func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {	var sa unsafe.Pointer	var addrLen int32	if u.isV4 {		if ap.Addr().Is6() {			return ErrInvalidIPv6RemoteForSocket		}		var rsa unix.RawSockaddrInet6		rsa.Family = unix.AF_INET6		rsa.Addr = ap.Addr().As16()		binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())		sa = unsafe.Pointer(&rsa)		addrLen = syscall.SizeofSockaddrInet4	} else {		var rsa unix.RawSockaddrInet6		rsa.Family = unix.AF_INET6		rsa.Addr = ap.Addr().As16()		binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())		sa = unsafe.Pointer(&rsa)		addrLen = syscall.SizeofSockaddrInet6	}	// Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves	// See https://github.com/golang/go/issues/73919	for {		//_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen)		err := sendto(int(u.sysFd), b, 0, sa, addrLen)		if err == nil {			// Written, get out before the error handling			return nil		}		if errors.Is(err, syscall.EINTR) {			// Write was interrupted, retry			continue		}		if errors.Is(err, syscall.EAGAIN) {			return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK}		}		if errors.Is(err, syscall.EBADF) {			return net.ErrClosed		}		return &net.OpError{Op: "sendto", Err: err}	}}func (u *StdConn) LocalAddr() (netip.AddrPort, error) {	a := u.UDPConn.LocalAddr()	switch v := a.(type) {	case *net.UDPAddr:		addr, ok := netip.AddrFromSlice(v.IP)		if !ok {			return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)		}		return netip.AddrPortFrom(addr, uint16(v.Port)), nil	default:		return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)	}}func (u *StdConn) ReloadConfig(c *config.C) {	// TODO}func NewUDPStatsEmitter(udpConns []Conn) func() {	// No UDP stats for non-linux	return func() {}}func (u *StdConn) ListenOut(r EncReader) {	buffer := make([]byte, MTU)	for {		// Just read one packet at a time		n, rua, err := u.ReadFromUDPAddrPort(buffer)		if err != nil {			if errors.Is(err, net.ErrClosed) {				u.l.WithError(err).Debug("udp socket is closed, exiting read loop")				return			}			u.l.WithError(err).Error("unexpected udp socket receive error")		}		r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])	}}func (u *StdConn) Rebind() error {	var err error	if u.isV4 {		err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0)	} else {		err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0)	}	if err != nil {		u.l.WithError(err).Error("Failed to rebind udp socket")	}	return nil}
 |