2
0

udp_linux.go 13 KB


  1. //go:build !android && !e2e_testing
  2. // +build !android,!e2e_testing
  3. package udp
  4. import (
  5. "encoding/binary"
  6. "fmt"
  7. "net"
  8. "net/netip"
  9. "syscall"
  10. "unsafe"
  11. "github.com/rcrowley/go-metrics"
  12. "github.com/sirupsen/logrus"
  13. "github.com/slackhq/nebula/config"
  14. "github.com/slackhq/nebula/packet"
  15. "golang.org/x/sys/unix"
  16. )
  17. const iovMax = 128 //1024 //no unix constant for this? from limits.h
  18. //todo I'd like this to be 1024 but we seem to hit errors around ~130?
  19. type StdConn struct {
  20. sysFd int
  21. isV4 bool
  22. l *logrus.Logger
  23. batch int
  24. enableGRO bool
  25. msgs []rawMessage
  26. iovs [][]iovec
  27. }
  28. func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
  29. af := unix.AF_INET6
  30. if ip.Is4() {
  31. af = unix.AF_INET
  32. }
  33. syscall.ForkLock.RLock()
  34. fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
  35. if err == nil {
  36. unix.CloseOnExec(fd)
  37. }
  38. syscall.ForkLock.RUnlock()
  39. if err != nil {
  40. unix.Close(fd)
  41. return nil, fmt.Errorf("unable to open socket: %s", err)
  42. }
  43. if multi {
  44. if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
  45. return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err)
  46. }
  47. }
  48. var sa unix.Sockaddr
  49. if ip.Is4() {
  50. sa4 := &unix.SockaddrInet4{Port: port}
  51. sa4.Addr = ip.As4()
  52. sa = sa4
  53. } else {
  54. sa6 := &unix.SockaddrInet6{Port: port}
  55. sa6.Addr = ip.As16()
  56. sa = sa6
  57. }
  58. if err = unix.Bind(fd, sa); err != nil {
  59. return nil, fmt.Errorf("unable to bind to socket: %s", err)
  60. }
  61. const batchSize = 8192
  62. msgs := make([]rawMessage, 0, batchSize) //todo configure
  63. iovs := make([][]iovec, batchSize)
  64. for i := range iovs {
  65. iovs[i] = make([]iovec, iovMax)
  66. }
  67. return &StdConn{
  68. sysFd: fd,
  69. isV4: ip.Is4(),
  70. l: l,
  71. batch: batch,
  72. msgs: msgs,
  73. iovs: iovs,
  74. }, err
  75. }
  76. func (u *StdConn) SupportsMultipleReaders() bool {
  77. return true
  78. }
  79. func (u *StdConn) Rebind() error {
  80. return nil
  81. }
  82. func (u *StdConn) SetRecvBuffer(n int) error {
  83. return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
  84. }
  85. func (u *StdConn) SetSendBuffer(n int) error {
  86. return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
  87. }
  88. func (u *StdConn) SetSoMark(mark int) error {
  89. return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark)
  90. }
  91. func (u *StdConn) GetRecvBuffer() (int, error) {
  92. return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
  93. }
  94. func (u *StdConn) GetSendBuffer() (int, error) {
  95. return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
  96. }
  97. func (u *StdConn) GetSoMark() (int, error) {
  98. return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK)
  99. }
  100. func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
  101. sa, err := unix.Getsockname(u.sysFd)
  102. if err != nil {
  103. return netip.AddrPort{}, err
  104. }
  105. switch sa := sa.(type) {
  106. case *unix.SockaddrInet4:
  107. return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil
  108. case *unix.SockaddrInet6:
  109. return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil
  110. default:
  111. return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa)
  112. }
  113. }
  114. func (u *StdConn) ListenOut(r EncReader) {
  115. msgs, packets := u.PrepareRawMessages(u.batch, u.isV4)
  116. read := u.ReadMulti
  117. if u.batch == 1 {
  118. read = u.ReadSingle
  119. }
  120. for {
  121. n, err := read(msgs)
  122. if err != nil {
  123. u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
  124. return
  125. }
  126. for i := 0; i < n; i++ {
  127. packets[i].Payload = packets[i].Payload[:msgs[i].Len]
  128. packets[i].Update(getRawMessageControlLen(&msgs[i]))
  129. }
  130. r(packets[:n])
  131. for i := 0; i < n; i++ { //todo reset this in prev loop, but this makes debug ez
  132. msgs[i].Hdr.Controllen = uint64(unix.CmsgSpace(2))
  133. }
  134. }
  135. }
  136. func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
  137. for {
  138. n, _, err := unix.Syscall6(
  139. unix.SYS_RECVMSG,
  140. uintptr(u.sysFd),
  141. uintptr(unsafe.Pointer(&(msgs[0].Hdr))),
  142. 0,
  143. 0,
  144. 0,
  145. 0,
  146. )
  147. if err != 0 {
  148. return 0, &net.OpError{Op: "recvmsg", Err: err}
  149. }
  150. msgs[0].Len = uint32(n)
  151. return 1, nil
  152. }
  153. }
  154. func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
  155. for {
  156. n, _, err := unix.Syscall6(
  157. unix.SYS_RECVMMSG,
  158. uintptr(u.sysFd),
  159. uintptr(unsafe.Pointer(&msgs[0])),
  160. uintptr(len(msgs)),
  161. unix.MSG_WAITFORONE,
  162. 0,
  163. 0,
  164. )
  165. if err != 0 {
  166. return 0, &net.OpError{Op: "recvmmsg", Err: err}
  167. }
  168. return int(n), nil
  169. }
  170. }
  171. func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
  172. if u.isV4 {
  173. return u.writeTo4(b, ip)
  174. }
  175. return u.writeTo6(b, ip)
  176. }
  177. func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error {
  178. if u.isV4 {
  179. return u.writeTo4(b, ip)
  180. }
  181. return u.writeTo6(b, ip)
  182. }
  183. func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
  184. nl, err := u.encodeSockaddr(pkt.Name, addr)
  185. if err != nil {
  186. return err
  187. }
  188. pkt.Name = pkt.Name[:nl]
  189. pkt.OutLen = len(pkt.Payload)
  190. return nil
  191. }
  192. func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
  193. if len(pkts) == 0 {
  194. return 0, nil
  195. }
  196. u.msgs = u.msgs[:0]
  197. //u.iovs = u.iovs[:0]
  198. sent := 0
  199. var mostRecentPkt *packet.Packet
  200. mostRecentPktSize := 0
  201. //segmenting := false
  202. idx := 0
  203. for _, pkt := range pkts {
  204. if len(pkt.Payload) == 0 || pkt.OutLen == -1 {
  205. sent++
  206. continue
  207. }
  208. lastIdx := idx - 1
  209. if mostRecentPkt != nil && pkt.CompatibleForSegmentationWith(mostRecentPkt, mostRecentPktSize) && u.msgs[lastIdx].Hdr.Iovlen < iovMax {
  210. u.msgs[lastIdx].Hdr.Controllen = uint64(len(mostRecentPkt.Control))
  211. u.msgs[lastIdx].Hdr.Control = &mostRecentPkt.Control[0]
  212. u.iovs[lastIdx][u.msgs[lastIdx].Hdr.Iovlen].Base = &pkt.Payload[0]
  213. u.iovs[lastIdx][u.msgs[lastIdx].Hdr.Iovlen].Len = uint64(len(pkt.Payload))
  214. u.msgs[lastIdx].Hdr.Iovlen++
  215. mostRecentPktSize += len(pkt.Payload)
  216. mostRecentPkt.SetSegSizeForTX()
  217. } else {
  218. u.msgs = append(u.msgs, rawMessage{})
  219. u.iovs[idx][0] = iovec{
  220. Base: &pkt.Payload[0],
  221. Len: uint64(len(pkt.Payload)),
  222. }
  223. msg := &u.msgs[idx]
  224. iov := &u.iovs[idx][0]
  225. idx++
  226. msg.Hdr.Iov = iov
  227. msg.Hdr.Iovlen = 1
  228. setRawMessageControl(msg, nil)
  229. msg.Hdr.Flags = 0
  230. msg.Hdr.Name = &pkt.Name[0]
  231. msg.Hdr.Namelen = uint32(len(pkt.Name))
  232. mostRecentPkt = pkt
  233. mostRecentPktSize = len(pkt.Payload)
  234. }
  235. }
  236. if len(u.msgs) == 0 {
  237. return sent, nil
  238. }
  239. offset := 0
  240. for offset < len(u.msgs) {
  241. n, _, errno := unix.Syscall6(
  242. unix.SYS_SENDMMSG,
  243. uintptr(u.sysFd),
  244. uintptr(unsafe.Pointer(&u.msgs[offset])),
  245. uintptr(len(u.msgs)-offset),
  246. 0,
  247. 0,
  248. 0,
  249. )
  250. if errno != 0 {
  251. if errno == unix.EINTR {
  252. continue
  253. }
  254. //for i := 0; i < len(u.msgs); i++ {
  255. // for j := 0; j < int(u.msgs[i].Hdr.Iovlen); j++ {
  256. // u.l.WithFields(logrus.Fields{
  257. // "msg_index": i,
  258. // "iov idx": j,
  259. // "iov": fmt.Sprintf("%+v", u.iovs[i][j]),
  260. // }).Warn("failed to send message")
  261. // }
  262. //
  263. //}
  264. u.l.WithFields(logrus.Fields{
  265. "errno": errno,
  266. "idx": idx,
  267. "len": len(u.msgs),
  268. "deets": fmt.Sprintf("%+v", u.msgs),
  269. "lastIOV": fmt.Sprintf("%+v", u.iovs[len(u.msgs)-1][u.msgs[len(u.msgs)-1].Hdr.Iovlen-1]),
  270. }).Error("failed to send message")
  271. return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
  272. }
  273. if n == 0 {
  274. break
  275. }
  276. offset += int(n)
  277. }
  278. return sent + len(u.msgs), nil
  279. }
  280. func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
  281. if u.isV4 {
  282. if !addr.Addr().Is4() {
  283. return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
  284. }
  285. var sa unix.RawSockaddrInet4
  286. sa.Family = unix.AF_INET
  287. sa.Addr = addr.Addr().As4()
  288. binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
  289. size := unix.SizeofSockaddrInet4
  290. copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
  291. return uint32(size), nil
  292. }
  293. var sa unix.RawSockaddrInet6
  294. sa.Family = unix.AF_INET6
  295. sa.Addr = addr.Addr().As16()
  296. binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
  297. size := unix.SizeofSockaddrInet6
  298. copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
  299. return uint32(size), nil
  300. }
  301. func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
  302. var rsa unix.RawSockaddrInet6
  303. rsa.Family = unix.AF_INET6
  304. rsa.Addr = ip.Addr().As16()
  305. binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port())
  306. for {
  307. _, _, err := unix.Syscall6(
  308. unix.SYS_SENDTO,
  309. uintptr(u.sysFd),
  310. uintptr(unsafe.Pointer(&b[0])),
  311. uintptr(len(b)),
  312. uintptr(0),
  313. uintptr(unsafe.Pointer(&rsa)),
  314. uintptr(unix.SizeofSockaddrInet6),
  315. )
  316. if err != 0 {
  317. return &net.OpError{Op: "sendto", Err: err}
  318. }
  319. return nil
  320. }
  321. }
  322. func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
  323. if !ip.Addr().Is4() {
  324. return ErrInvalidIPv6RemoteForSocket
  325. }
  326. var rsa unix.RawSockaddrInet4
  327. rsa.Family = unix.AF_INET
  328. rsa.Addr = ip.Addr().As4()
  329. binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port())
  330. for {
  331. _, _, err := unix.Syscall6(
  332. unix.SYS_SENDTO,
  333. uintptr(u.sysFd),
  334. uintptr(unsafe.Pointer(&b[0])),
  335. uintptr(len(b)),
  336. uintptr(0),
  337. uintptr(unsafe.Pointer(&rsa)),
  338. uintptr(unix.SizeofSockaddrInet4),
  339. )
  340. if err != 0 {
  341. return &net.OpError{Op: "sendto", Err: err}
  342. }
  343. return nil
  344. }
  345. }
  346. func (u *StdConn) ReloadConfig(c *config.C) {
  347. b := c.GetInt("listen.read_buffer", 0)
  348. if b > 0 {
  349. err := u.SetRecvBuffer(b)
  350. if err == nil {
  351. s, err := u.GetRecvBuffer()
  352. if err == nil {
  353. u.l.WithField("size", s).Info("listen.read_buffer was set")
  354. } else {
  355. u.l.WithError(err).Warn("Failed to get listen.read_buffer")
  356. }
  357. } else {
  358. u.l.WithError(err).Error("Failed to set listen.read_buffer")
  359. }
  360. }
  361. b = c.GetInt("listen.write_buffer", 0)
  362. if b > 0 {
  363. err := u.SetSendBuffer(b)
  364. if err == nil {
  365. s, err := u.GetSendBuffer()
  366. if err == nil {
  367. u.l.WithField("size", s).Info("listen.write_buffer was set")
  368. } else {
  369. u.l.WithError(err).Warn("Failed to get listen.write_buffer")
  370. }
  371. } else {
  372. u.l.WithError(err).Error("Failed to set listen.write_buffer")
  373. }
  374. }
  375. b = c.GetInt("listen.so_mark", 0)
  376. s, err := u.GetSoMark()
  377. if b > 0 || (err == nil && s != 0) {
  378. err := u.SetSoMark(b)
  379. if err == nil {
  380. s, err := u.GetSoMark()
  381. if err == nil {
  382. u.l.WithField("mark", s).Info("listen.so_mark was set")
  383. } else {
  384. u.l.WithError(err).Warn("Failed to get listen.so_mark")
  385. }
  386. } else {
  387. u.l.WithError(err).Error("Failed to set listen.so_mark")
  388. }
  389. }
  390. u.configureGRO(true)
  391. }
  392. func (u *StdConn) configureGRO(enable bool) {
  393. if enable == u.enableGRO {
  394. return
  395. }
  396. if enable {
  397. if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
  398. u.l.WithError(err).Warn("Failed to enable UDP GRO")
  399. return
  400. }
  401. u.enableGRO = true
  402. u.l.Info("UDP GRO enabled")
  403. } else {
  404. if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
  405. u.l.WithError(err).Warn("Failed to disable UDP GRO")
  406. }
  407. u.enableGRO = false
  408. }
  409. }
  410. func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
  411. var vallen uint32 = 4 * unix.SK_MEMINFO_VARS
  412. _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
  413. if err != 0 {
  414. return err
  415. }
  416. return nil
  417. }
  418. func (u *StdConn) Close() error {
  419. return syscall.Close(u.sysFd)
  420. }
  421. func NewUDPStatsEmitter(udpConns []Conn) func() {
  422. // Check if our kernel supports SO_MEMINFO before registering the gauges
  423. var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge
  424. var meminfo [unix.SK_MEMINFO_VARS]uint32
  425. if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil {
  426. udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
  427. for i := range udpConns {
  428. udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
  429. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil),
  430. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil),
  431. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil),
  432. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", i), nil),
  433. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", i), nil),
  434. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", i), nil),
  435. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", i), nil),
  436. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", i), nil),
  437. metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", i), nil),
  438. }
  439. }
  440. }
  441. return func() {
  442. for i, gauges := range udpGauges {
  443. if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil {
  444. for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
  445. gauges[j].Update(int64(meminfo[j]))
  446. }
  447. }
  448. }
  449. }
  450. }