| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- //go:build linux && !android && !e2e_testing
- package udp
- import (
- "errors"
- "net"
- "net/netip"
- "sync"
- "sync/atomic"
- "github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/config"
- wgconn "github.com/slackhq/nebula/wgstack/conn"
- )
- // WGConn adapts WireGuard's batched UDP bind implementation to Nebula's udp.Conn interface.
- type WGConn struct {
- l *logrus.Logger
- bind *wgconn.StdNetBind
- recvers []wgconn.ReceiveFunc
- batch int
- reqBatch int
- localIP netip.Addr
- localPort uint16
- enableGSO bool
- enableGRO bool
- gsoMaxSeg int
- closed atomic.Bool
- q int
- closeOnce sync.Once
- }
- // NewWireguardListener creates a UDP listener backed by WireGuard's StdNetBind.
- func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int, q int) (Conn, error) {
- bind := wgconn.NewStdNetBindForAddr(ip, multi, q)
- recvers, actualPort, err := bind.Open(uint16(port))
- if err != nil {
- return nil, err
- }
- if batch <= 0 {
- batch = bind.BatchSize()
- } else if batch > bind.BatchSize() {
- batch = bind.BatchSize()
- }
- return &WGConn{
- l: l,
- bind: bind,
- recvers: recvers,
- batch: batch,
- reqBatch: batch,
- localIP: ip,
- localPort: actualPort,
- q: q,
- }, nil
- }
- func (c *WGConn) Rebind() error {
- // WireGuard's bind does not support rebinding in place.
- return nil
- }
- func (c *WGConn) LocalAddr() (netip.AddrPort, error) {
- if !c.localIP.IsValid() || c.localIP.IsUnspecified() {
- // Fallback to wildcard IPv4 for display purposes.
- return netip.AddrPortFrom(netip.IPv4Unspecified(), c.localPort), nil
- }
- return netip.AddrPortFrom(c.localIP, c.localPort), nil
- }
- func (c *WGConn) listen(fn wgconn.ReceiveFunc, r EncReader) {
- batchSize := c.batch
- packets := make([][]byte, batchSize)
- for i := range packets {
- packets[i] = make([]byte, 0xffff)
- }
- sizes := make([]int, batchSize)
- endpoints := make([]wgconn.Endpoint, batchSize)
- for {
- if c.closed.Load() {
- return
- }
- n, err := fn(packets, sizes, endpoints)
- if err != nil {
- if errors.Is(err, net.ErrClosed) {
- return
- }
- if c.l != nil {
- c.l.WithError(err).Debug("wireguard UDP listener receive error")
- }
- continue
- }
- for i := 0; i < n; i++ {
- if sizes[i] == 0 {
- continue
- }
- stdEp, ok := endpoints[i].(*wgconn.StdNetEndpoint)
- if !ok {
- if c.l != nil {
- c.l.Warn("wireguard UDP listener received unexpected endpoint type")
- }
- continue
- }
- addr := stdEp.AddrPort
- r(addr, packets[i][:sizes[i]])
- endpoints[i] = nil
- }
- }
- }
- func (c *WGConn) ListenOut(r EncReader) {
- for _, fn := range c.recvers {
- go c.listen(fn, r)
- }
- }
- func (c *WGConn) WriteTo(b []byte, addr netip.AddrPort) error {
- if len(b) == 0 {
- return nil
- }
- if c.closed.Load() {
- return net.ErrClosed
- }
- ep := &wgconn.StdNetEndpoint{AddrPort: addr}
- return c.bind.Send([][]byte{b}, ep)
- }
- func (c *WGConn) WriteBatch(datagrams []Datagram) error {
- if len(datagrams) == 0 {
- return nil
- }
- if c.closed.Load() {
- return net.ErrClosed
- }
- max := c.batch
- if max <= 0 {
- max = len(datagrams)
- if max == 0 {
- max = 1
- }
- }
- bufs := make([][]byte, 0, max)
- var (
- current netip.AddrPort
- endpoint *wgconn.StdNetEndpoint
- haveAddr bool
- )
- flush := func() error {
- if len(bufs) == 0 || endpoint == nil {
- bufs = bufs[:0]
- return nil
- }
- err := c.bind.Send(bufs, endpoint)
- bufs = bufs[:0]
- return err
- }
- for _, d := range datagrams {
- if len(d.Payload) == 0 || !d.Addr.IsValid() {
- continue
- }
- if !haveAddr || d.Addr != current {
- if err := flush(); err != nil {
- return err
- }
- current = d.Addr
- endpoint = &wgconn.StdNetEndpoint{AddrPort: current}
- haveAddr = true
- }
- bufs = append(bufs, d.Payload)
- if len(bufs) >= max {
- if err := flush(); err != nil {
- return err
- }
- }
- }
- return flush()
- }
- func (c *WGConn) ConfigureOffload(enableGSO, enableGRO bool, maxSegments int) {
- c.enableGSO = enableGSO
- c.enableGRO = enableGRO
- if maxSegments <= 0 {
- maxSegments = 1
- } else if maxSegments > wgconn.IdealBatchSize {
- maxSegments = wgconn.IdealBatchSize
- }
- c.gsoMaxSeg = maxSegments
- effectiveBatch := c.reqBatch
- if enableGSO && c.bind != nil {
- bindBatch := c.bind.BatchSize()
- if effectiveBatch < bindBatch {
- if c.l != nil {
- c.l.WithFields(logrus.Fields{
- "requested": c.reqBatch,
- "effective": bindBatch,
- }).Warn("listen.batch below wireguard minimum; using bind batch size for UDP GSO support")
- }
- effectiveBatch = bindBatch
- }
- }
- c.batch = effectiveBatch
- if c.l != nil {
- c.l.WithFields(logrus.Fields{
- "enableGSO": enableGSO,
- "enableGRO": enableGRO,
- "gsoMaxSegments": maxSegments,
- }).Debug("configured wireguard UDP offload")
- }
- }
- func (c *WGConn) ReloadConfig(*config.C) {
- // WireGuard bind currently does not expose runtime configuration knobs.
- }
- func (c *WGConn) Close() error {
- var err error
- c.closeOnce.Do(func() {
- c.closed.Store(true)
- err = c.bind.Close()
- })
- return err
- }
|