bind_std.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. // SPDX-License-Identifier: MIT
  2. //
  3. // Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
  4. package conn
  5. import (
  6. "context"
  7. "errors"
  8. "net"
  9. "net/netip"
  10. "runtime"
  11. "strconv"
  12. "sync"
  13. "syscall"
  14. "golang.org/x/net/ipv4"
  15. "golang.org/x/net/ipv6"
  16. "golang.org/x/sys/unix"
  17. )
  18. var (
  19. _ Bind = (*StdNetBind)(nil)
  20. )
  21. // StdNetBind implements Bind for all platforms. While Windows has its own Bind
  22. // (see bind_windows.go), it may fall back to StdNetBind.
  23. // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
  24. // methods for sending and receiving multiple datagrams per-syscall. See the
  25. // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
  26. type StdNetBind struct {
  27. mu sync.Mutex // protects all fields except as specified
  28. ipv4 *net.UDPConn
  29. ipv6 *net.UDPConn
  30. ipv4PC *ipv4.PacketConn // will be nil on non-Linux
  31. ipv6PC *ipv6.PacketConn // will be nil on non-Linux
  32. // these three fields are not guarded by mu
  33. udpAddrPool sync.Pool
  34. ipv4MsgsPool sync.Pool
  35. ipv6MsgsPool sync.Pool
  36. blackhole4 bool
  37. blackhole6 bool
  38. listenAddr4 string
  39. listenAddr6 string
  40. bindV4 bool
  41. bindV6 bool
  42. reusePort bool
  43. }
  44. func newStdNetBind() *StdNetBind {
  45. return &StdNetBind{
  46. udpAddrPool: sync.Pool{
  47. New: func() any {
  48. return &net.UDPAddr{
  49. IP: make([]byte, 16),
  50. }
  51. },
  52. },
  53. ipv4MsgsPool: sync.Pool{
  54. New: func() any {
  55. msgs := make([]ipv4.Message, IdealBatchSize)
  56. for i := range msgs {
  57. msgs[i].Buffers = make(net.Buffers, 1)
  58. msgs[i].OOB = make([]byte, srcControlSize)
  59. }
  60. return &msgs
  61. },
  62. },
  63. ipv6MsgsPool: sync.Pool{
  64. New: func() any {
  65. msgs := make([]ipv6.Message, IdealBatchSize)
  66. for i := range msgs {
  67. msgs[i].Buffers = make(net.Buffers, 1)
  68. msgs[i].OOB = make([]byte, srcControlSize)
  69. }
  70. return &msgs
  71. },
  72. },
  73. bindV4: true,
  74. bindV6: true,
  75. reusePort: false,
  76. }
  77. }
  78. // NewStdNetBind creates a bind that listens on all interfaces.
  79. func NewStdNetBind() *StdNetBind {
  80. return newStdNetBind()
  81. }
  82. // NewStdNetBindForAddr creates a bind that listens on a specific address.
  83. // If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
  84. // IPv6 socket will be created.
  85. func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
  86. b := newStdNetBind()
  87. if addr.IsValid() {
  88. if addr.IsUnspecified() {
  89. // keep dual-stack defaults with empty listen addresses
  90. } else if addr.Is4() {
  91. b.listenAddr4 = addr.Unmap().String()
  92. b.bindV4 = true
  93. b.bindV6 = false
  94. } else {
  95. b.listenAddr6 = addr.Unmap().String()
  96. b.bindV6 = true
  97. b.bindV4 = false
  98. }
  99. }
  100. b.reusePort = reusePort
  101. return b
  102. }
  103. type StdNetEndpoint struct {
  104. // AddrPort is the endpoint destination.
  105. netip.AddrPort
  106. // src is the current sticky source address and interface index, if supported.
  107. src struct {
  108. netip.Addr
  109. ifidx int32
  110. }
  111. }
  112. var (
  113. _ Bind = (*StdNetBind)(nil)
  114. _ Endpoint = &StdNetEndpoint{}
  115. )
  116. func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
  117. e, err := netip.ParseAddrPort(s)
  118. if err != nil {
  119. return nil, err
  120. }
  121. return &StdNetEndpoint{
  122. AddrPort: e,
  123. }, nil
  124. }
  125. func (e *StdNetEndpoint) ClearSrc() {
  126. e.src.ifidx = 0
  127. e.src.Addr = netip.Addr{}
  128. }
  129. func (e *StdNetEndpoint) DstIP() netip.Addr {
  130. return e.AddrPort.Addr()
  131. }
  132. func (e *StdNetEndpoint) SrcIP() netip.Addr {
  133. return e.src.Addr
  134. }
  135. func (e *StdNetEndpoint) SrcIfidx() int32 {
  136. return e.src.ifidx
  137. }
  138. func (e *StdNetEndpoint) DstToBytes() []byte {
  139. b, _ := e.AddrPort.MarshalBinary()
  140. return b
  141. }
  142. func (e *StdNetEndpoint) DstToString() string {
  143. return e.AddrPort.String()
  144. }
  145. func (e *StdNetEndpoint) SrcToString() string {
  146. return e.src.Addr.String()
  147. }
  148. func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPConn, int, error) {
  149. lc := listenConfig()
  150. if s.reusePort {
  151. base := lc.Control
  152. lc.Control = func(network, address string, c syscall.RawConn) error {
  153. if base != nil {
  154. if err := base(network, address, c); err != nil {
  155. return err
  156. }
  157. }
  158. return c.Control(func(fd uintptr) {
  159. _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
  160. })
  161. }
  162. }
  163. addr := ":" + strconv.Itoa(port)
  164. if host != "" {
  165. addr = net.JoinHostPort(host, strconv.Itoa(port))
  166. }
  167. conn, err := lc.ListenPacket(context.Background(), network, addr)
  168. if err != nil {
  169. return nil, 0, err
  170. }
  171. // Retrieve port.
  172. laddr := conn.LocalAddr()
  173. uaddr, err := net.ResolveUDPAddr(
  174. laddr.Network(),
  175. laddr.String(),
  176. )
  177. if err != nil {
  178. return nil, 0, err
  179. }
  180. return conn.(*net.UDPConn), uaddr.Port, nil
  181. }
  182. func (s *StdNetBind) openIPv4(port int) (*net.UDPConn, *ipv4.PacketConn, int, error) {
  183. if !s.bindV4 {
  184. return nil, nil, port, nil
  185. }
  186. host := s.listenAddr4
  187. conn, actualPort, err := s.listenNet("udp4", host, port)
  188. if err != nil {
  189. if errors.Is(err, syscall.EAFNOSUPPORT) {
  190. return nil, nil, port, nil
  191. }
  192. return nil, nil, port, err
  193. }
  194. if runtime.GOOS != "linux" {
  195. return conn, nil, actualPort, nil
  196. }
  197. pc := ipv4.NewPacketConn(conn)
  198. return conn, pc, actualPort, nil
  199. }
  200. func (s *StdNetBind) openIPv6(port int) (*net.UDPConn, *ipv6.PacketConn, int, error) {
  201. if !s.bindV6 {
  202. return nil, nil, port, nil
  203. }
  204. host := s.listenAddr6
  205. conn, actualPort, err := s.listenNet("udp6", host, port)
  206. if err != nil {
  207. if errors.Is(err, syscall.EAFNOSUPPORT) {
  208. return nil, nil, port, nil
  209. }
  210. return nil, nil, port, err
  211. }
  212. if runtime.GOOS != "linux" {
  213. return conn, nil, actualPort, nil
  214. }
  215. pc := ipv6.NewPacketConn(conn)
  216. return conn, pc, actualPort, nil
  217. }
  218. func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
  219. s.mu.Lock()
  220. defer s.mu.Unlock()
  221. var err error
  222. var tries int
  223. if s.ipv4 != nil || s.ipv6 != nil {
  224. return nil, 0, ErrBindAlreadyOpen
  225. }
  226. // Attempt to open ipv4 and ipv6 listeners on the same port.
  227. // If uport is 0, we can retry on failure.
  228. again:
  229. port := int(uport)
  230. var v4conn *net.UDPConn
  231. var v6conn *net.UDPConn
  232. var v4pc *ipv4.PacketConn
  233. var v6pc *ipv6.PacketConn
  234. v4conn, v4pc, port, err = s.openIPv4(port)
  235. if err != nil {
  236. return nil, 0, err
  237. }
  238. // Listen on the same port as we're using for ipv4.
  239. v6conn, v6pc, port, err = s.openIPv6(port)
  240. if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
  241. if v4conn != nil {
  242. v4conn.Close()
  243. }
  244. tries++
  245. goto again
  246. }
  247. if err != nil {
  248. if v4conn != nil {
  249. v4conn.Close()
  250. }
  251. return nil, 0, err
  252. }
  253. var fns []ReceiveFunc
  254. if v4conn != nil {
  255. s.ipv4 = v4conn
  256. if v4pc != nil {
  257. s.ipv4PC = v4pc
  258. }
  259. fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
  260. }
  261. if v6conn != nil {
  262. s.ipv6 = v6conn
  263. if v6pc != nil {
  264. s.ipv6PC = v6pc
  265. }
  266. fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
  267. }
  268. if len(fns) == 0 {
  269. return nil, 0, syscall.EAFNOSUPPORT
  270. }
  271. return fns, uint16(port), nil
  272. }
  273. func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
  274. return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
  275. msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
  276. defer s.ipv4MsgsPool.Put(msgs)
  277. for i := range bufs {
  278. (*msgs)[i].Buffers[0] = bufs[i]
  279. }
  280. var numMsgs int
  281. if runtime.GOOS == "linux" && pc != nil {
  282. numMsgs, err = pc.ReadBatch(*msgs, 0)
  283. if err != nil {
  284. return 0, err
  285. }
  286. } else {
  287. msg := &(*msgs)[0]
  288. msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
  289. if err != nil {
  290. return 0, err
  291. }
  292. numMsgs = 1
  293. }
  294. for i := 0; i < numMsgs; i++ {
  295. msg := &(*msgs)[i]
  296. sizes[i] = msg.N
  297. addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
  298. ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
  299. getSrcFromControl(msg.OOB[:msg.NN], ep)
  300. eps[i] = ep
  301. }
  302. return numMsgs, nil
  303. }
  304. }
  305. func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
  306. return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
  307. msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
  308. defer s.ipv6MsgsPool.Put(msgs)
  309. for i := range bufs {
  310. (*msgs)[i].Buffers[0] = bufs[i]
  311. }
  312. var numMsgs int
  313. if runtime.GOOS == "linux" && pc != nil {
  314. numMsgs, err = pc.ReadBatch(*msgs, 0)
  315. if err != nil {
  316. return 0, err
  317. }
  318. } else {
  319. msg := &(*msgs)[0]
  320. msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
  321. if err != nil {
  322. return 0, err
  323. }
  324. numMsgs = 1
  325. }
  326. for i := 0; i < numMsgs; i++ {
  327. msg := &(*msgs)[i]
  328. sizes[i] = msg.N
  329. addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
  330. ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
  331. getSrcFromControl(msg.OOB[:msg.NN], ep)
  332. eps[i] = ep
  333. }
  334. return numMsgs, nil
  335. }
  336. }
  337. // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
  338. // rename the IdealBatchSize constant to BatchSize.
  339. func (s *StdNetBind) BatchSize() int {
  340. if runtime.GOOS == "linux" {
  341. return IdealBatchSize
  342. }
  343. return 1
  344. }
  345. func (s *StdNetBind) Close() error {
  346. s.mu.Lock()
  347. defer s.mu.Unlock()
  348. var err1, err2 error
  349. if s.ipv4 != nil {
  350. err1 = s.ipv4.Close()
  351. s.ipv4 = nil
  352. s.ipv4PC = nil
  353. }
  354. if s.ipv6 != nil {
  355. err2 = s.ipv6.Close()
  356. s.ipv6 = nil
  357. s.ipv6PC = nil
  358. }
  359. s.blackhole4 = false
  360. s.blackhole6 = false
  361. if err1 != nil {
  362. return err1
  363. }
  364. return err2
  365. }
  366. func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
  367. s.mu.Lock()
  368. blackhole := s.blackhole4
  369. conn := s.ipv4
  370. var (
  371. pc4 *ipv4.PacketConn
  372. pc6 *ipv6.PacketConn
  373. )
  374. is6 := false
  375. if endpoint.DstIP().Is6() {
  376. blackhole = s.blackhole6
  377. conn = s.ipv6
  378. pc6 = s.ipv6PC
  379. is6 = true
  380. } else {
  381. pc4 = s.ipv4PC
  382. }
  383. s.mu.Unlock()
  384. if blackhole {
  385. return nil
  386. }
  387. if conn == nil {
  388. return syscall.EAFNOSUPPORT
  389. }
  390. if is6 {
  391. return s.send6(conn, pc6, endpoint, bufs)
  392. } else {
  393. return s.send4(conn, pc4, endpoint, bufs)
  394. }
  395. }
  396. func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
  397. ua := s.udpAddrPool.Get().(*net.UDPAddr)
  398. as4 := ep.DstIP().As4()
  399. copy(ua.IP, as4[:])
  400. ua.IP = ua.IP[:4]
  401. ua.Port = int(ep.(*StdNetEndpoint).Port())
  402. msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
  403. for i, buf := range bufs {
  404. (*msgs)[i].Buffers[0] = buf
  405. (*msgs)[i].Addr = ua
  406. setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
  407. }
  408. var (
  409. n int
  410. err error
  411. start int
  412. )
  413. if runtime.GOOS == "linux" && pc != nil {
  414. for {
  415. n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
  416. if err != nil {
  417. if errors.Is(err, syscall.EAFNOSUPPORT) {
  418. for j := start; j < len(bufs); j++ {
  419. _, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
  420. if werr != nil {
  421. err = werr
  422. break
  423. }
  424. }
  425. }
  426. break
  427. }
  428. if n == len((*msgs)[start:len(bufs)]) {
  429. break
  430. }
  431. start += n
  432. }
  433. } else {
  434. for i, buf := range bufs {
  435. _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
  436. if err != nil {
  437. break
  438. }
  439. }
  440. }
  441. s.udpAddrPool.Put(ua)
  442. s.ipv4MsgsPool.Put(msgs)
  443. return err
  444. }
  445. func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
  446. ua := s.udpAddrPool.Get().(*net.UDPAddr)
  447. as16 := ep.DstIP().As16()
  448. copy(ua.IP, as16[:])
  449. ua.IP = ua.IP[:16]
  450. ua.Port = int(ep.(*StdNetEndpoint).Port())
  451. msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
  452. for i, buf := range bufs {
  453. (*msgs)[i].Buffers[0] = buf
  454. (*msgs)[i].Addr = ua
  455. setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
  456. }
  457. var (
  458. n int
  459. err error
  460. start int
  461. )
  462. if runtime.GOOS == "linux" && pc != nil {
  463. for {
  464. n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
  465. if err != nil {
  466. if errors.Is(err, syscall.EAFNOSUPPORT) {
  467. for j := start; j < len(bufs); j++ {
  468. _, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua)
  469. if werr != nil {
  470. err = werr
  471. break
  472. }
  473. }
  474. }
  475. break
  476. }
  477. if n == len((*msgs)[start:len(bufs)]) {
  478. break
  479. }
  480. start += n
  481. }
  482. } else {
  483. for i, buf := range bufs {
  484. _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
  485. if err != nil {
  486. break
  487. }
  488. }
  489. }
  490. s.udpAddrPool.Put(ua)
  491. s.ipv6MsgsPool.Put(msgs)
  492. return err
  493. }