2
0

udp_linux.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. //go:build !android && !e2e_testing
  2. // +build !android,!e2e_testing
  3. package udp
  4. import (
  5. "context"
  6. "encoding/binary"
  7. "fmt"
  8. "net"
  9. "net/netip"
  10. "strconv"
  11. "syscall"
  12. "unsafe"
  13. "github.com/rcrowley/go-metrics"
  14. "github.com/sirupsen/logrus"
  15. "github.com/slackhq/nebula/config"
  16. "golang.org/x/sys/unix"
  17. )
  18. type StdConn struct {
  19. c *net.UDPConn
  20. rc syscall.RawConn
  21. isV4 bool
  22. l *logrus.Logger
  23. batch int
  24. }
  25. func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
  26. lc := net.ListenConfig{
  27. Control: func(network, address string, c syscall.RawConn) error {
  28. if multi {
  29. var err error
  30. oErr := c.Control(func(fd uintptr) {
  31. err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
  32. })
  33. if oErr != nil {
  34. return fmt.Errorf("error while setting SO_REUSEPORT: %w", oErr)
  35. }
  36. if err != nil {
  37. return fmt.Errorf("unable to set SO_REUSEPORT: %w", err)
  38. }
  39. }
  40. return nil
  41. },
  42. }
  43. c, err := lc.ListenPacket(context.Background(), "udp", net.JoinHostPort(ip.String(), strconv.Itoa(port)))
  44. if err != nil {
  45. return nil, fmt.Errorf("unable to open socket: %w", err)
  46. }
  47. uc := c.(*net.UDPConn)
  48. rc, err := uc.SyscallConn()
  49. if err != nil {
  50. _ = c.Close()
  51. return nil, fmt.Errorf("unable to open sysfd: %w", err)
  52. }
  53. return &StdConn{c: uc, rc: rc, isV4: ip.Is4(), l: l, batch: batch}, err
  54. }
  55. func (u *StdConn) Rebind() error {
  56. return nil
  57. }
  58. func (u *StdConn) SetRecvBuffer(n int) error {
  59. var err error
  60. oErr := u.rc.Control(func(fd uintptr) {
  61. err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
  62. })
  63. if oErr != nil {
  64. return oErr
  65. }
  66. return err
  67. }
  68. func (u *StdConn) SetSendBuffer(n int) error {
  69. var err error
  70. oErr := u.rc.Control(func(fd uintptr) {
  71. err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
  72. })
  73. if oErr != nil {
  74. return oErr
  75. }
  76. return err
  77. }
  78. func (u *StdConn) SetSoMark(mark int) error {
  79. var err error
  80. oErr := u.rc.Control(func(fd uintptr) {
  81. err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, mark)
  82. })
  83. if oErr != nil {
  84. return oErr
  85. }
  86. return err
  87. }
  88. func (u *StdConn) GetRecvBuffer() (int, error) {
  89. var err error
  90. var n int
  91. oErr := u.rc.Control(func(fd uintptr) {
  92. n, err = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
  93. })
  94. if oErr != nil {
  95. return n, oErr
  96. }
  97. return n, err
  98. }
  99. func (u *StdConn) GetSendBuffer() (int, error) {
  100. var err error
  101. var n int
  102. oErr := u.rc.Control(func(fd uintptr) {
  103. n, err = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF)
  104. })
  105. if oErr != nil {
  106. return n, oErr
  107. }
  108. return n, err
  109. }
  110. func (u *StdConn) GetSoMark() (int, error) {
  111. var err error
  112. var n int
  113. oErr := u.rc.Control(func(fd uintptr) {
  114. n, err = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK)
  115. })
  116. if oErr != nil {
  117. return n, oErr
  118. }
  119. return n, err
  120. }
  121. func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
  122. sa := u.c.LocalAddr()
  123. return netip.ParseAddrPort(sa.String())
  124. }
  125. func (u *StdConn) ListenOut(r EncReader) {
  126. var ip netip.Addr
  127. var n uintptr
  128. var err error
  129. msgs, buffers, names := u.PrepareRawMessages(u.batch)
  130. read := u.ReadMulti
  131. if u.batch == 1 {
  132. read = u.ReadSingle
  133. }
  134. for {
  135. read(msgs, &n, &err)
  136. if err != nil {
  137. u.l.WithError(err).Error("udp socket is closed, exiting read loop")
  138. return
  139. }
  140. for i := 0; i < int(n); i++ {
  141. // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
  142. if u.isV4 {
  143. ip, _ = netip.AddrFromSlice(names[i][4:8])
  144. } else {
  145. ip, _ = netip.AddrFromSlice(names[i][8:24])
  146. }
  147. //u.l.Error("GOT A PACKET", msgs[i].Len)
  148. r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
  149. }
  150. }
  151. }
  152. func (u *StdConn) ReadSingle(msgs []rawMessage, n *uintptr, err *error) {
  153. oErr := u.rc.Read(func(fd uintptr) bool {
  154. in, _, nErr := unix.Syscall6(
  155. unix.SYS_RECVMSG,
  156. fd,
  157. uintptr(unsafe.Pointer(&(msgs[0].Hdr))),
  158. 0, 0, 0, 0,
  159. )
  160. if nErr == syscall.EAGAIN || nErr == syscall.EINTR {
  161. // Retry read
  162. return false
  163. } else if nErr != 0 {
  164. u.l.Errorf("READING FROM UDP SINGLE had an errno %d", nErr)
  165. *err = &net.OpError{Op: "recvmsg", Err: nErr}
  166. *n = 0
  167. return true
  168. }
  169. msgs[0].Len = uint32(in)
  170. *n = 1
  171. return true
  172. })
  173. if *err == nil && oErr != nil {
  174. *err = oErr
  175. *n = 0
  176. return
  177. }
  178. }
  179. func (u *StdConn) ReadMulti(msgs []rawMessage, n *uintptr, err *error) {
  180. oErr := u.rc.Read(func(fd uintptr) bool {
  181. var nErr syscall.Errno
  182. *n, _, nErr = unix.Syscall6(
  183. unix.SYS_RECVMMSG,
  184. fd,
  185. uintptr(unsafe.Pointer(&(msgs[0].Hdr))),
  186. uintptr(len(msgs)),
  187. unix.MSG_WAITFORONE,
  188. 0, 0,
  189. )
  190. if nErr == syscall.EAGAIN || nErr == syscall.EINTR {
  191. // Retry read
  192. return false
  193. } else if nErr != 0 {
  194. u.l.Errorf("READING FROM UDP MULTI had an errno %d", nErr)
  195. *err = &net.OpError{Op: "recvmmsg", Err: nErr}
  196. *n = 0
  197. return true
  198. }
  199. return true
  200. })
  201. if *err == nil && oErr != nil {
  202. *err = oErr
  203. *n = 0
  204. return
  205. }
  206. }
  207. func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
  208. _, err := u.c.WriteToUDPAddrPort(b, ip)
  209. return err
  210. }
  211. func (u *StdConn) ReloadConfig(c *config.C) {
  212. b := c.GetInt("listen.read_buffer", 0)
  213. if b > 0 {
  214. err := u.SetRecvBuffer(b)
  215. if err == nil {
  216. s, err := u.GetRecvBuffer()
  217. if err == nil {
  218. u.l.WithField("size", s).Info("listen.read_buffer was set")
  219. } else {
  220. u.l.WithError(err).Warn("Failed to get listen.read_buffer")
  221. }
  222. } else {
  223. u.l.WithError(err).Error("Failed to set listen.read_buffer")
  224. }
  225. }
  226. b = c.GetInt("listen.write_buffer", 0)
  227. if b > 0 {
  228. err := u.SetSendBuffer(b)
  229. if err == nil {
  230. s, err := u.GetSendBuffer()
  231. if err == nil {
  232. u.l.WithField("size", s).Info("listen.write_buffer was set")
  233. } else {
  234. u.l.WithError(err).Warn("Failed to get listen.write_buffer")
  235. }
  236. } else {
  237. u.l.WithError(err).Error("Failed to set listen.write_buffer")
  238. }
  239. }
  240. b = c.GetInt("listen.so_mark", 0)
  241. s, err := u.GetSoMark()
  242. if b > 0 || (err == nil && s != 0) {
  243. err := u.SetSoMark(b)
  244. if err == nil {
  245. s, err := u.GetSoMark()
  246. if err == nil {
  247. u.l.WithField("mark", s).Info("listen.so_mark was set")
  248. } else {
  249. u.l.WithError(err).Warn("Failed to get listen.so_mark")
  250. }
  251. } else {
  252. u.l.WithError(err).Error("Failed to set listen.so_mark")
  253. }
  254. }
  255. }
  256. func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
  257. var vallen uint32 = 4 * unix.SK_MEMINFO_VARS
  258. var err error
  259. oErr := u.rc.Control(func(fd uintptr) {
  260. _, _, err = unix.Syscall6(
  261. unix.SYS_GETSOCKOPT,
  262. fd,
  263. uintptr(unix.SOL_SOCKET),
  264. uintptr(unix.SO_MEMINFO),
  265. uintptr(unsafe.Pointer(meminfo)),
  266. uintptr(unsafe.Pointer(&vallen)),
  267. 0,
  268. )
  269. })
  270. if oErr != nil {
  271. return oErr
  272. }
  273. return err
  274. }
  275. func (u *StdConn) Close() error {
  276. err := u.c.Close()
  277. return err
  278. }
  279. func NewUDPStatsEmitter(udpConns []Conn) func() {
  280. // Check if our kernel supports SO_MEMINFO before registering the gauges
  281. var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge
  282. var meminfo [unix.SK_MEMINFO_VARS]uint32
  283. if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil {
  284. udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
  285. for i := range udpConns {
  286. udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
  287. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil),
  288. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil),
  289. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil),
  290. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", i), nil),
  291. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", i), nil),
  292. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", i), nil),
  293. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", i), nil),
  294. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", i), nil),
  295. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", i), nil),
  296. }
  297. }
  298. }
  299. return func() {
  300. for i, gauges := range udpGauges {
  301. if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil {
  302. for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
  303. gauges[j].Update(int64(meminfo[j]))
  304. }
  305. }
  306. }
  307. }
  308. }