wireguard_conn_linux.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. //go:build linux && !android && !e2e_testing
  2. package udp
  3. import (
  4. "errors"
  5. "net"
  6. "net/netip"
  7. "sync"
  8. "sync/atomic"
  9. "github.com/sirupsen/logrus"
  10. "github.com/slackhq/nebula/config"
  11. wgconn "github.com/slackhq/nebula/wgstack/conn"
  12. )
  13. // WGConn adapts WireGuard's batched UDP bind implementation to Nebula's udp.Conn interface.
  14. type WGConn struct {
  15. l *logrus.Logger
  16. bind *wgconn.StdNetBind
  17. recvers []wgconn.ReceiveFunc
  18. batch int
  19. reqBatch int
  20. localIP netip.Addr
  21. localPort uint16
  22. enableGSO bool
  23. enableGRO bool
  24. gsoMaxSeg int
  25. closed atomic.Bool
  26. closeOnce sync.Once
  27. }
  28. // NewWireguardListener creates a UDP listener backed by WireGuard's StdNetBind.
  29. func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
  30. bind := wgconn.NewStdNetBindForAddr(ip, multi)
  31. recvers, actualPort, err := bind.Open(uint16(port))
  32. if err != nil {
  33. return nil, err
  34. }
  35. if batch <= 0 {
  36. batch = bind.BatchSize()
  37. } else if batch > bind.BatchSize() {
  38. batch = bind.BatchSize()
  39. }
  40. return &WGConn{
  41. l: l,
  42. bind: bind,
  43. recvers: recvers,
  44. batch: batch,
  45. reqBatch: batch,
  46. localIP: ip,
  47. localPort: actualPort,
  48. }, nil
  49. }
  50. func (c *WGConn) Rebind() error {
  51. // WireGuard's bind does not support rebinding in place.
  52. return nil
  53. }
  54. func (c *WGConn) LocalAddr() (netip.AddrPort, error) {
  55. if !c.localIP.IsValid() || c.localIP.IsUnspecified() {
  56. // Fallback to wildcard IPv4 for display purposes.
  57. return netip.AddrPortFrom(netip.IPv4Unspecified(), c.localPort), nil
  58. }
  59. return netip.AddrPortFrom(c.localIP, c.localPort), nil
  60. }
  61. func (c *WGConn) listen(fn wgconn.ReceiveFunc, r EncReader) {
  62. batchSize := c.batch
  63. packets := make([][]byte, batchSize)
  64. for i := range packets {
  65. packets[i] = make([]byte, MTU)
  66. }
  67. sizes := make([]int, batchSize)
  68. endpoints := make([]wgconn.Endpoint, batchSize)
  69. for {
  70. if c.closed.Load() {
  71. return
  72. }
  73. n, err := fn(packets, sizes, endpoints)
  74. if err != nil {
  75. if errors.Is(err, net.ErrClosed) {
  76. return
  77. }
  78. if c.l != nil {
  79. c.l.WithError(err).Debug("wireguard UDP listener receive error")
  80. }
  81. continue
  82. }
  83. for i := 0; i < n; i++ {
  84. if sizes[i] == 0 {
  85. continue
  86. }
  87. stdEp, ok := endpoints[i].(*wgconn.StdNetEndpoint)
  88. if !ok {
  89. if c.l != nil {
  90. c.l.Warn("wireguard UDP listener received unexpected endpoint type")
  91. }
  92. continue
  93. }
  94. addr := stdEp.AddrPort
  95. r(addr, packets[i][:sizes[i]])
  96. endpoints[i] = nil
  97. }
  98. }
  99. }
  100. func (c *WGConn) ListenOut(r EncReader) {
  101. for _, fn := range c.recvers {
  102. go c.listen(fn, r)
  103. }
  104. }
  105. func (c *WGConn) WriteTo(b []byte, addr netip.AddrPort) error {
  106. if len(b) == 0 {
  107. return nil
  108. }
  109. if c.closed.Load() {
  110. return net.ErrClosed
  111. }
  112. ep := &wgconn.StdNetEndpoint{AddrPort: addr}
  113. return c.bind.Send([][]byte{b}, ep)
  114. }
  115. func (c *WGConn) WriteBatch(datagrams []Datagram) error {
  116. if len(datagrams) == 0 {
  117. return nil
  118. }
  119. if c.closed.Load() {
  120. return net.ErrClosed
  121. }
  122. max := c.batch
  123. if max <= 0 {
  124. max = len(datagrams)
  125. if max == 0 {
  126. max = 1
  127. }
  128. }
  129. bufs := make([][]byte, 0, max)
  130. var (
  131. current netip.AddrPort
  132. endpoint *wgconn.StdNetEndpoint
  133. haveAddr bool
  134. )
  135. flush := func() error {
  136. if len(bufs) == 0 || endpoint == nil {
  137. bufs = bufs[:0]
  138. return nil
  139. }
  140. err := c.bind.Send(bufs, endpoint)
  141. bufs = bufs[:0]
  142. return err
  143. }
  144. for _, d := range datagrams {
  145. if len(d.Payload) == 0 || !d.Addr.IsValid() {
  146. continue
  147. }
  148. if !haveAddr || d.Addr != current {
  149. if err := flush(); err != nil {
  150. return err
  151. }
  152. current = d.Addr
  153. endpoint = &wgconn.StdNetEndpoint{AddrPort: current}
  154. haveAddr = true
  155. }
  156. bufs = append(bufs, d.Payload)
  157. if len(bufs) >= max {
  158. if err := flush(); err != nil {
  159. return err
  160. }
  161. }
  162. }
  163. return flush()
  164. }
  165. func (c *WGConn) ConfigureOffload(enableGSO, enableGRO bool, maxSegments int) {
  166. c.enableGSO = enableGSO
  167. c.enableGRO = enableGRO
  168. if maxSegments <= 0 {
  169. maxSegments = 1
  170. } else if maxSegments > wgconn.IdealBatchSize {
  171. maxSegments = wgconn.IdealBatchSize
  172. }
  173. c.gsoMaxSeg = maxSegments
  174. effectiveBatch := c.reqBatch
  175. if enableGSO && c.bind != nil {
  176. bindBatch := c.bind.BatchSize()
  177. if effectiveBatch < bindBatch {
  178. if c.l != nil {
  179. c.l.WithFields(logrus.Fields{
  180. "requested": c.reqBatch,
  181. "effective": bindBatch,
  182. }).Warn("listen.batch below wireguard minimum; using bind batch size for UDP GSO support")
  183. }
  184. effectiveBatch = bindBatch
  185. }
  186. }
  187. c.batch = effectiveBatch
  188. if c.l != nil {
  189. c.l.WithFields(logrus.Fields{
  190. "enableGSO": enableGSO,
  191. "enableGRO": enableGRO,
  192. "gsoMaxSegments": maxSegments,
  193. }).Debug("configured wireguard UDP offload")
  194. }
  195. }
  196. func (c *WGConn) ReloadConfig(*config.C) {
  197. // WireGuard bind currently does not expose runtime configuration knobs.
  198. }
  199. func (c *WGConn) Close() error {
  200. var err error
  201. c.closeOnce.Do(func() {
  202. c.closed.Store(true)
  203. err = c.bind.Close()
  204. })
  205. return err
  206. }