123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- //go:build !e2e_testing
- // +build !e2e_testing
- package udp
- import (
- "context"
- "encoding/binary"
- "errors"
- "fmt"
- "net"
- "net/netip"
- "syscall"
- "unsafe"
- "github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/config"
- "golang.org/x/sys/unix"
- )
- type StdConn struct {
- *net.UDPConn
- isV4 bool
- sysFd uintptr
- l *logrus.Logger
- }
- var _ Conn = &StdConn{}
- func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
- lc := NewListenConfig(multi)
- pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
- if err != nil {
- return nil, err
- }
- if uc, ok := pc.(*net.UDPConn); ok {
- c := &StdConn{UDPConn: uc, l: l}
- rc, err := uc.SyscallConn()
- if err != nil {
- return nil, fmt.Errorf("failed to open udp socket: %w", err)
- }
- err = rc.Control(func(fd uintptr) {
- c.sysFd = fd
- })
- if err != nil {
- return nil, fmt.Errorf("failed to get udp fd: %w", err)
- }
- la, err := c.LocalAddr()
- if err != nil {
- return nil, err
- }
- c.isV4 = la.Addr().Is4()
- return c, nil
- }
- return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc)
- }
- func NewListenConfig(multi bool) net.ListenConfig {
- return net.ListenConfig{
- Control: func(network, address string, c syscall.RawConn) error {
- if multi {
- var controlErr error
- err := c.Control(func(fd uintptr) {
- if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
- controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
- return
- }
- })
- if err != nil {
- return err
- }
- if controlErr != nil {
- return controlErr
- }
- }
- return nil
- },
- }
- }
- //go:linkname sendto golang.org/x/sys/unix.sendto
- //go:noescape
- func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error)
- func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
- var sa unsafe.Pointer
- var addrLen int32
- if u.isV4 {
- if ap.Addr().Is6() {
- return ErrInvalidIPv6RemoteForSocket
- }
- var rsa unix.RawSockaddrInet6
- rsa.Family = unix.AF_INET6
- rsa.Addr = ap.Addr().As16()
- binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
- sa = unsafe.Pointer(&rsa)
- addrLen = syscall.SizeofSockaddrInet4
- } else {
- var rsa unix.RawSockaddrInet6
- rsa.Family = unix.AF_INET6
- rsa.Addr = ap.Addr().As16()
- binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
- sa = unsafe.Pointer(&rsa)
- addrLen = syscall.SizeofSockaddrInet6
- }
- // Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves
- // See https://github.com/golang/go/issues/73919
- for {
- //_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen)
- err := sendto(int(u.sysFd), b, 0, sa, addrLen)
- if err == nil {
- // Written, get out before the error handling
- return nil
- }
- if errors.Is(err, syscall.EINTR) {
- // Write was interrupted, retry
- continue
- }
- if errors.Is(err, syscall.EAGAIN) {
- return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK}
- }
- if errors.Is(err, syscall.EBADF) {
- return net.ErrClosed
- }
- return &net.OpError{Op: "sendto", Err: err}
- }
- }
- func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
- a := u.UDPConn.LocalAddr()
- switch v := a.(type) {
- case *net.UDPAddr:
- addr, ok := netip.AddrFromSlice(v.IP)
- if !ok {
- return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
- }
- return netip.AddrPortFrom(addr, uint16(v.Port)), nil
- default:
- return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
- }
- }
- func (u *StdConn) ReloadConfig(c *config.C) {
- // TODO
- }
- func NewUDPStatsEmitter(udpConns []Conn) func() {
- // No UDP stats for non-linux
- return func() {}
- }
- func (u *StdConn) ListenOut(r EncReader) {
- buffer := make([]byte, MTU)
- for {
- // Just read one packet at a time
- n, rua, err := u.ReadFromUDPAddrPort(buffer)
- if err != nil {
- if errors.Is(err, net.ErrClosed) {
- u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
- return
- }
- u.l.WithError(err).Error("unexpected udp socket receive error")
- }
- r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
- }
- }
- func (u *StdConn) Rebind() error {
- var err error
- if u.isV4 {
- err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0)
- } else {
- err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0)
- }
- if err != nil {
- u.l.WithError(err).Error("Failed to rebind udp socket")
- }
- return nil
- }
|