wireguard_conn_linux.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. //go:build linux && !android && !e2e_testing
  2. package udp
  3. import (
  4. "errors"
  5. "net"
  6. "net/netip"
  7. "sync"
  8. "sync/atomic"
  9. "github.com/sirupsen/logrus"
  10. "github.com/slackhq/nebula/config"
  11. wgconn "github.com/slackhq/nebula/wgstack/conn"
  12. )
  13. // WGConn adapts WireGuard's batched UDP bind implementation to Nebula's udp.Conn interface.
  14. type WGConn struct {
  15. l *logrus.Logger
  16. bind *wgconn.StdNetBind
  17. recvers []wgconn.ReceiveFunc
  18. batch int
  19. reqBatch int
  20. localIP netip.Addr
  21. localPort uint16
  22. enableGSO bool
  23. enableGRO bool
  24. gsoMaxSeg int
  25. closed atomic.Bool
  26. q int
  27. closeOnce sync.Once
  28. }
  29. // NewWireguardListener creates a UDP listener backed by WireGuard's StdNetBind.
  30. func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int, q int) (Conn, error) {
  31. bind := wgconn.NewStdNetBindForAddr(ip, multi, q)
  32. recvers, actualPort, err := bind.Open(uint16(port))
  33. if err != nil {
  34. return nil, err
  35. }
  36. if batch <= 0 {
  37. batch = bind.BatchSize()
  38. } else if batch > bind.BatchSize() {
  39. batch = bind.BatchSize()
  40. }
  41. return &WGConn{
  42. l: l,
  43. bind: bind,
  44. recvers: recvers,
  45. batch: batch,
  46. reqBatch: batch,
  47. localIP: ip,
  48. localPort: actualPort,
  49. q: q,
  50. }, nil
  51. }
  52. func (c *WGConn) Rebind() error {
  53. // WireGuard's bind does not support rebinding in place.
  54. return nil
  55. }
  56. func (c *WGConn) LocalAddr() (netip.AddrPort, error) {
  57. if !c.localIP.IsValid() || c.localIP.IsUnspecified() {
  58. // Fallback to wildcard IPv4 for display purposes.
  59. return netip.AddrPortFrom(netip.IPv4Unspecified(), c.localPort), nil
  60. }
  61. return netip.AddrPortFrom(c.localIP, c.localPort), nil
  62. }
  63. func (c *WGConn) listen(fn wgconn.ReceiveFunc, r EncReader) {
  64. batchSize := c.batch
  65. packets := make([][]byte, batchSize)
  66. for i := range packets {
  67. packets[i] = make([]byte, 0xffff)
  68. }
  69. sizes := make([]int, batchSize)
  70. endpoints := make([]wgconn.Endpoint, batchSize)
  71. for {
  72. if c.closed.Load() {
  73. return
  74. }
  75. n, err := fn(packets, sizes, endpoints)
  76. if err != nil {
  77. if errors.Is(err, net.ErrClosed) {
  78. return
  79. }
  80. if c.l != nil {
  81. c.l.WithError(err).Debug("wireguard UDP listener receive error")
  82. }
  83. continue
  84. }
  85. for i := 0; i < n; i++ {
  86. if sizes[i] == 0 {
  87. continue
  88. }
  89. stdEp, ok := endpoints[i].(*wgconn.StdNetEndpoint)
  90. if !ok {
  91. if c.l != nil {
  92. c.l.Warn("wireguard UDP listener received unexpected endpoint type")
  93. }
  94. continue
  95. }
  96. addr := stdEp.AddrPort
  97. r(addr, packets[i][:sizes[i]])
  98. endpoints[i] = nil
  99. }
  100. }
  101. }
  102. func (c *WGConn) ListenOut(r EncReader) {
  103. for _, fn := range c.recvers {
  104. go c.listen(fn, r)
  105. }
  106. }
  107. func (c *WGConn) WriteTo(b []byte, addr netip.AddrPort) error {
  108. if len(b) == 0 {
  109. return nil
  110. }
  111. if c.closed.Load() {
  112. return net.ErrClosed
  113. }
  114. ep := &wgconn.StdNetEndpoint{AddrPort: addr}
  115. return c.bind.Send([][]byte{b}, ep)
  116. }
  117. func (c *WGConn) WriteBatch(datagrams []Datagram) error {
  118. if len(datagrams) == 0 {
  119. return nil
  120. }
  121. if c.closed.Load() {
  122. return net.ErrClosed
  123. }
  124. max := c.batch
  125. if max <= 0 {
  126. max = len(datagrams)
  127. if max == 0 {
  128. max = 1
  129. }
  130. }
  131. bufs := make([][]byte, 0, max)
  132. var (
  133. current netip.AddrPort
  134. endpoint *wgconn.StdNetEndpoint
  135. haveAddr bool
  136. )
  137. flush := func() error {
  138. if len(bufs) == 0 || endpoint == nil {
  139. bufs = bufs[:0]
  140. return nil
  141. }
  142. err := c.bind.Send(bufs, endpoint)
  143. bufs = bufs[:0]
  144. return err
  145. }
  146. for _, d := range datagrams {
  147. if len(d.Payload) == 0 || !d.Addr.IsValid() {
  148. continue
  149. }
  150. if !haveAddr || d.Addr != current {
  151. if err := flush(); err != nil {
  152. return err
  153. }
  154. current = d.Addr
  155. endpoint = &wgconn.StdNetEndpoint{AddrPort: current}
  156. haveAddr = true
  157. }
  158. bufs = append(bufs, d.Payload)
  159. if len(bufs) >= max {
  160. if err := flush(); err != nil {
  161. return err
  162. }
  163. }
  164. }
  165. return flush()
  166. }
  167. func (c *WGConn) ConfigureOffload(enableGSO, enableGRO bool, maxSegments int) {
  168. c.enableGSO = enableGSO
  169. c.enableGRO = enableGRO
  170. if maxSegments <= 0 {
  171. maxSegments = 1
  172. } else if maxSegments > wgconn.IdealBatchSize {
  173. maxSegments = wgconn.IdealBatchSize
  174. }
  175. c.gsoMaxSeg = maxSegments
  176. effectiveBatch := c.reqBatch
  177. if enableGSO && c.bind != nil {
  178. bindBatch := c.bind.BatchSize()
  179. if effectiveBatch < bindBatch {
  180. if c.l != nil {
  181. c.l.WithFields(logrus.Fields{
  182. "requested": c.reqBatch,
  183. "effective": bindBatch,
  184. }).Warn("listen.batch below wireguard minimum; using bind batch size for UDP GSO support")
  185. }
  186. effectiveBatch = bindBatch
  187. }
  188. }
  189. c.batch = effectiveBatch
  190. if c.l != nil {
  191. c.l.WithFields(logrus.Fields{
  192. "enableGSO": enableGSO,
  193. "enableGRO": enableGRO,
  194. "gsoMaxSegments": maxSegments,
  195. }).Debug("configured wireguard UDP offload")
  196. }
  197. }
  198. func (c *WGConn) ReloadConfig(*config.C) {
  199. // WireGuard bind currently does not expose runtime configuration knobs.
  200. }
  201. func (c *WGConn) Close() error {
  202. var err error
  203. c.closeOnce.Do(func() {
  204. c.closed.Store(true)
  205. err = c.bind.Close()
  206. })
  207. return err
  208. }