bind_std.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. /* SPDX-License-Identifier: MIT
  2. *
  3. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
  4. */
  5. package conn
  6. import (
  7. "context"
  8. "errors"
  9. "fmt"
  10. "net"
  11. "net/netip"
  12. "runtime"
  13. "strconv"
  14. "sync"
  15. "syscall"
  16. "golang.org/x/net/ipv4"
  17. "golang.org/x/net/ipv6"
  18. )
  19. var (
  20. _ Bind = (*StdNetBind)(nil)
  21. )
  22. // StdNetBind implements Bind for all platforms. While Windows has its own Bind
  23. // (see bind_windows.go), it may fall back to StdNetBind.
  24. // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
  25. // methods for sending and receiving multiple datagrams per-syscall. See the
  26. // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
  27. type StdNetBind struct {
  28. mu sync.Mutex // protects all fields except as specified
  29. ipv4 *net.UDPConn
  30. ipv6 *net.UDPConn
  31. ipv4PC *ipv4.PacketConn // will be nil on non-Linux
  32. ipv6PC *ipv6.PacketConn // will be nil on non-Linux
  33. ipv4TxOffload bool
  34. ipv4RxOffload bool
  35. ipv6TxOffload bool
  36. ipv6RxOffload bool
  37. // these two fields are not guarded by mu
  38. udpAddrPool sync.Pool
  39. msgsPool sync.Pool
  40. blackhole4 bool
  41. blackhole6 bool
  42. q int
  43. }
  44. // NewStdNetBind creates a bind that listens on all interfaces.
  45. func NewStdNetBind() *StdNetBind {
  46. return newStdNetBind().(*StdNetBind)
  47. }
  48. // NewStdNetBindForAddr creates a bind that listens on a specific address.
  49. // If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
  50. // IPv6 socket will be created.
  51. func NewStdNetBindForAddr(addr netip.Addr, reusePort bool, q int) *StdNetBind {
  52. b := NewStdNetBind()
  53. b.q = q
  54. //if addr.IsValid() {
  55. // if addr.IsUnspecified() {
  56. // // keep dual-stack defaults with empty listen addresses
  57. // } else if addr.Is4() {
  58. // b.listenAddr4 = addr.Unmap().String()
  59. // b.bindV4 = true
  60. // b.bindV6 = false
  61. // } else {
  62. // b.listenAddr6 = addr.Unmap().String()
  63. // b.bindV6 = true
  64. // b.bindV4 = false
  65. // }
  66. //}
  67. //b.reusePort = reusePort
  68. return b
  69. }
  70. func newStdNetBind() Bind {
  71. return &StdNetBind{
  72. udpAddrPool: sync.Pool{
  73. New: func() any {
  74. return &net.UDPAddr{
  75. IP: make([]byte, 16),
  76. }
  77. },
  78. },
  79. msgsPool: sync.Pool{
  80. New: func() any {
  81. // ipv6.Message and ipv4.Message are interchangeable as they are
  82. // both aliases for x/net/internal/socket.Message.
  83. msgs := make([]ipv6.Message, IdealBatchSize)
  84. for i := range msgs {
  85. msgs[i].Buffers = make(net.Buffers, 1)
  86. msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
  87. }
  88. return &msgs
  89. },
  90. },
  91. }
  92. }
  93. type StdNetEndpoint struct {
  94. // AddrPort is the endpoint destination.
  95. netip.AddrPort
  96. // src is the current sticky source address and interface index, if
  97. // supported. Typically this is a PKTINFO structure from/for control
  98. // messages, see unix.PKTINFO for an example.
  99. src []byte
  100. }
  101. var (
  102. _ Bind = (*StdNetBind)(nil)
  103. _ Endpoint = &StdNetEndpoint{}
  104. )
  105. func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
  106. e, err := netip.ParseAddrPort(s)
  107. if err != nil {
  108. return nil, err
  109. }
  110. return &StdNetEndpoint{
  111. AddrPort: e,
  112. }, nil
  113. }
  114. func (e *StdNetEndpoint) ClearSrc() {
  115. if e.src != nil {
  116. // Truncate src, no need to reallocate.
  117. e.src = e.src[:0]
  118. }
  119. }
  120. func (e *StdNetEndpoint) DstIP() netip.Addr {
  121. return e.AddrPort.Addr()
  122. }
  123. // See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
  124. func (e *StdNetEndpoint) DstToBytes() []byte {
  125. b, _ := e.AddrPort.MarshalBinary()
  126. return b
  127. }
  128. func (e *StdNetEndpoint) DstToString() string {
  129. return e.AddrPort.String()
  130. }
  131. func listenNet(network string, port int, q int) (*net.UDPConn, int, error) {
  132. lc := listenConfig(q)
  133. conn, err := lc.ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
  134. if err != nil {
  135. return nil, 0, err
  136. }
  137. if q == 0 {
  138. if EvilFdZero == 0 {
  139. panic("fuck")
  140. }
  141. err = reusePortHax(EvilFdZero)
  142. if err != nil {
  143. return nil, 0, fmt.Errorf("reuse port hax: %v", err)
  144. }
  145. }
  146. // Retrieve port.
  147. laddr := conn.LocalAddr()
  148. uaddr, err := net.ResolveUDPAddr(
  149. laddr.Network(),
  150. laddr.String(),
  151. )
  152. if err != nil {
  153. return nil, 0, err
  154. }
  155. return conn.(*net.UDPConn), uaddr.Port, nil
  156. }
  157. func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
  158. s.mu.Lock()
  159. defer s.mu.Unlock()
  160. var err error
  161. var tries int
  162. if s.ipv4 != nil || s.ipv6 != nil {
  163. return nil, 0, ErrBindAlreadyOpen
  164. }
  165. // Attempt to open ipv4 and ipv6 listeners on the same port.
  166. // If uport is 0, we can retry on failure.
  167. again:
  168. port := int(uport)
  169. var v4conn, v6conn *net.UDPConn
  170. var v4pc *ipv4.PacketConn
  171. var v6pc *ipv6.PacketConn
  172. v4conn, port, err = listenNet("udp4", port, s.q)
  173. if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
  174. return nil, 0, err
  175. }
  176. // Listen on the same port as we're using for ipv4.
  177. v6conn, port, err = listenNet("udp6", port, s.q)
  178. if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
  179. v4conn.Close()
  180. tries++
  181. goto again
  182. }
  183. if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
  184. v4conn.Close()
  185. return nil, 0, err
  186. }
  187. var fns []ReceiveFunc
  188. if v4conn != nil {
  189. s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
  190. if runtime.GOOS == "linux" || runtime.GOOS == "android" {
  191. v4pc = ipv4.NewPacketConn(v4conn)
  192. s.ipv4PC = v4pc
  193. }
  194. fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
  195. s.ipv4 = v4conn
  196. }
  197. if v6conn != nil {
  198. s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
  199. if runtime.GOOS == "linux" || runtime.GOOS == "android" {
  200. v6pc = ipv6.NewPacketConn(v6conn)
  201. s.ipv6PC = v6pc
  202. }
  203. fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
  204. s.ipv6 = v6conn
  205. }
  206. if len(fns) == 0 {
  207. return nil, 0, syscall.EAFNOSUPPORT
  208. }
  209. return fns, uint16(port), nil
  210. }
  211. func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
  212. for i := range *msgs {
  213. (*msgs)[i].OOB = (*msgs)[i].OOB[:0]
  214. (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
  215. }
  216. s.msgsPool.Put(msgs)
  217. }
  218. func (s *StdNetBind) getMessages() *[]ipv6.Message {
  219. return s.msgsPool.Get().(*[]ipv6.Message)
  220. }
  221. var (
  222. // If compilation fails here these are no longer the same underlying type.
  223. _ ipv6.Message = ipv4.Message{}
  224. )
  225. type batchReader interface {
  226. ReadBatch([]ipv6.Message, int) (int, error)
  227. }
  228. type batchWriter interface {
  229. WriteBatch([]ipv6.Message, int) (int, error)
  230. }
  231. func (s *StdNetBind) receiveIP(
  232. br batchReader,
  233. conn *net.UDPConn,
  234. rxOffload bool,
  235. bufs [][]byte,
  236. sizes []int,
  237. eps []Endpoint,
  238. ) (n int, err error) {
  239. msgs := s.getMessages()
  240. for i := range bufs {
  241. (*msgs)[i].Buffers[0] = bufs[i]
  242. (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
  243. }
  244. defer s.putMessages(msgs)
  245. var numMsgs int
  246. if runtime.GOOS == "linux" || runtime.GOOS == "android" {
  247. if rxOffload {
  248. readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
  249. numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
  250. if err != nil {
  251. return 0, err
  252. }
  253. numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
  254. if err != nil {
  255. return 0, err
  256. }
  257. } else {
  258. numMsgs, err = br.ReadBatch(*msgs, 0)
  259. if err != nil {
  260. return 0, err
  261. }
  262. }
  263. } else {
  264. msg := &(*msgs)[0]
  265. msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
  266. if err != nil {
  267. return 0, err
  268. }
  269. numMsgs = 1
  270. }
  271. for i := 0; i < numMsgs; i++ {
  272. msg := &(*msgs)[i]
  273. sizes[i] = msg.N
  274. if sizes[i] == 0 {
  275. continue
  276. }
  277. addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
  278. ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
  279. getSrcFromControl(msg.OOB[:msg.NN], ep)
  280. eps[i] = ep
  281. }
  282. return numMsgs, nil
  283. }
  284. func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
  285. return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
  286. return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
  287. }
  288. }
  289. func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
  290. return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
  291. return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
  292. }
  293. }
  294. // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
  295. // rename the IdealBatchSize constant to BatchSize.
  296. func (s *StdNetBind) BatchSize() int {
  297. if runtime.GOOS == "linux" || runtime.GOOS == "android" {
  298. return IdealBatchSize
  299. }
  300. return 1
  301. }
  302. func (s *StdNetBind) Close() error {
  303. s.mu.Lock()
  304. defer s.mu.Unlock()
  305. var err1, err2 error
  306. if s.ipv4 != nil {
  307. err1 = s.ipv4.Close()
  308. s.ipv4 = nil
  309. s.ipv4PC = nil
  310. }
  311. if s.ipv6 != nil {
  312. err2 = s.ipv6.Close()
  313. s.ipv6 = nil
  314. s.ipv6PC = nil
  315. }
  316. s.blackhole4 = false
  317. s.blackhole6 = false
  318. s.ipv4TxOffload = false
  319. s.ipv4RxOffload = false
  320. s.ipv6TxOffload = false
  321. s.ipv6RxOffload = false
  322. if err1 != nil {
  323. return err1
  324. }
  325. return err2
  326. }
  327. type ErrUDPGSODisabled struct {
  328. onLaddr string
  329. RetryErr error
  330. }
  331. func (e ErrUDPGSODisabled) Error() string {
  332. return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
  333. }
  334. func (e ErrUDPGSODisabled) Unwrap() error {
  335. return e.RetryErr
  336. }
  337. func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
  338. s.mu.Lock()
  339. blackhole := s.blackhole4
  340. conn := s.ipv4
  341. offload := s.ipv4TxOffload
  342. br := batchWriter(s.ipv4PC)
  343. is6 := false
  344. if endpoint.DstIP().Is6() {
  345. blackhole = s.blackhole6
  346. conn = s.ipv6
  347. br = s.ipv6PC
  348. is6 = true
  349. offload = s.ipv6TxOffload
  350. }
  351. s.mu.Unlock()
  352. if blackhole {
  353. return nil
  354. }
  355. if conn == nil {
  356. return syscall.EAFNOSUPPORT
  357. }
  358. msgs := s.getMessages()
  359. defer s.putMessages(msgs)
  360. ua := s.udpAddrPool.Get().(*net.UDPAddr)
  361. defer s.udpAddrPool.Put(ua)
  362. if is6 {
  363. as16 := endpoint.DstIP().As16()
  364. copy(ua.IP, as16[:])
  365. ua.IP = ua.IP[:16]
  366. } else {
  367. as4 := endpoint.DstIP().As4()
  368. copy(ua.IP, as4[:])
  369. ua.IP = ua.IP[:4]
  370. }
  371. ua.Port = int(endpoint.(*StdNetEndpoint).Port())
  372. var (
  373. retried bool
  374. err error
  375. )
  376. retry:
  377. if offload {
  378. n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
  379. err = s.send(conn, br, (*msgs)[:n])
  380. if err != nil && offload && errShouldDisableUDPGSO(err) {
  381. offload = false
  382. s.mu.Lock()
  383. if is6 {
  384. s.ipv6TxOffload = false
  385. } else {
  386. s.ipv4TxOffload = false
  387. }
  388. s.mu.Unlock()
  389. retried = true
  390. goto retry
  391. }
  392. } else {
  393. for i := range bufs {
  394. (*msgs)[i].Addr = ua
  395. (*msgs)[i].Buffers[0] = bufs[i]
  396. setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
  397. }
  398. err = s.send(conn, br, (*msgs)[:len(bufs)])
  399. }
  400. if retried {
  401. return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
  402. }
  403. return err
  404. }
  405. func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
  406. var (
  407. n int
  408. err error
  409. start int
  410. )
  411. if runtime.GOOS == "linux" || runtime.GOOS == "android" {
  412. for {
  413. n, err = pc.WriteBatch(msgs[start:], 0)
  414. if err != nil || n == len(msgs[start:]) {
  415. break
  416. }
  417. start += n
  418. }
  419. } else {
  420. for _, msg := range msgs {
  421. _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
  422. if err != nil {
  423. break
  424. }
  425. }
  426. }
  427. return err
  428. }
  429. const (
  430. // Exceeding these values results in EMSGSIZE. They account for layer3 and
  431. // layer4 headers. IPv6 does not need to account for itself as the payload
  432. // length field is self excluding.
  433. maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
  434. maxIPv6PayloadLen = 1<<16 - 1 - 8
  435. // This is a hard limit imposed by the kernel.
  436. udpSegmentMaxDatagrams = 64
  437. )
  438. type setGSOFunc func(control *[]byte, gsoSize uint16)
  439. func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
  440. var (
  441. base = -1 // index of msg we are currently coalescing into
  442. gsoSize int // segmentation size of msgs[base]
  443. dgramCnt int // number of dgrams coalesced into msgs[base]
  444. endBatch bool // tracking flag to start a new batch on next iteration of bufs
  445. )
  446. maxPayloadLen := maxIPv4PayloadLen
  447. if ep.DstIP().Is6() {
  448. maxPayloadLen = maxIPv6PayloadLen
  449. }
  450. for i, buf := range bufs {
  451. if i > 0 {
  452. msgLen := len(buf)
  453. baseLenBefore := len(msgs[base].Buffers[0])
  454. freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
  455. if msgLen+baseLenBefore <= maxPayloadLen &&
  456. msgLen <= gsoSize &&
  457. msgLen <= freeBaseCap &&
  458. dgramCnt < udpSegmentMaxDatagrams &&
  459. !endBatch {
  460. msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
  461. if i == len(bufs)-1 {
  462. setGSO(&msgs[base].OOB, uint16(gsoSize))
  463. }
  464. dgramCnt++
  465. if msgLen < gsoSize {
  466. // A smaller than gsoSize packet on the tail is legal, but
  467. // it must end the batch.
  468. endBatch = true
  469. }
  470. continue
  471. }
  472. }
  473. if dgramCnt > 1 {
  474. setGSO(&msgs[base].OOB, uint16(gsoSize))
  475. }
  476. // Reset prior to incrementing base since we are preparing to start a
  477. // new potential batch.
  478. endBatch = false
  479. base++
  480. gsoSize = len(buf)
  481. setSrcControl(&msgs[base].OOB, ep)
  482. msgs[base].Buffers[0] = buf
  483. msgs[base].Addr = addr
  484. dgramCnt = 1
  485. }
  486. return base + 1
  487. }
  488. type getGSOFunc func(control []byte) (int, error)
  489. func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
  490. for i := firstMsgAt; i < len(msgs); i++ {
  491. msg := &msgs[i]
  492. if msg.N == 0 {
  493. return n, err
  494. }
  495. var (
  496. gsoSize int
  497. start int
  498. end = msg.N
  499. numToSplit = 1
  500. )
  501. gsoSize, err = getGSO(msg.OOB[:msg.NN])
  502. if err != nil {
  503. return n, err
  504. }
  505. if gsoSize > 0 {
  506. numToSplit = (msg.N + gsoSize - 1) / gsoSize
  507. end = gsoSize
  508. }
  509. for j := 0; j < numToSplit; j++ {
  510. if n > i {
  511. return n, errors.New("splitting coalesced packet resulted in overflow")
  512. }
  513. copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
  514. msgs[n].N = copied
  515. msgs[n].Addr = msg.Addr
  516. start = end
  517. end += gsoSize
  518. if end > msg.N {
  519. end = msg.N
  520. }
  521. n++
  522. }
  523. if i != n-1 {
  524. // It is legal for bytes to move within msg.Buffers[0] as a result
  525. // of splitting, so we only zero the source msg len when it is not
  526. // the destination of the last split operation above.
  527. msg.N = 0
  528. }
  529. }
  530. return n, nil
  531. }