| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- //go:build linux && !android && !e2e_testing
- package overlay
- import (
- "fmt"
- "sync"
- wgtun "github.com/slackhq/nebula/wgstack/tun"
- )
- type wireguardTunIO struct {
- dev wgtun.Device
- mtu int
- batchSize int
- readMu sync.Mutex
- readBuffers [][]byte
- readLens []int
- legacyBuf []byte
- writeMu sync.Mutex
- writeBuf []byte
- writeWrap [][]byte
- writeBuffers [][]byte
- }
- func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
- batch := dev.BatchSize()
- if batch <= 0 {
- batch = 1
- }
- if mtu <= 0 {
- mtu = DefaultMTU
- }
- return &wireguardTunIO{
- dev: dev,
- mtu: mtu,
- batchSize: batch,
- readLens: make([]int, batch),
- legacyBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
- writeBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
- writeWrap: make([][]byte, 1),
- }
- }
- func (w *wireguardTunIO) Read(p []byte) (int, error) {
- w.readMu.Lock()
- defer w.readMu.Unlock()
- bufs := w.readBuffers
- if len(bufs) == 0 {
- bufs = [][]byte{w.legacyBuf}
- w.readBuffers = bufs
- }
- n, err := w.dev.Read(bufs[:1], w.readLens[:1], wgtun.VirtioNetHdrLen)
- if err != nil {
- return 0, err
- }
- if n == 0 {
- return 0, nil
- }
- length := w.readLens[0]
- copy(p, w.legacyBuf[wgtun.VirtioNetHdrLen:wgtun.VirtioNetHdrLen+length])
- return length, nil
- }
- func (w *wireguardTunIO) Write(p []byte) (int, error) {
- if len(p) > w.mtu {
- return 0, fmt.Errorf("wireguard tun: payload exceeds MTU (%d > %d)", len(p), w.mtu)
- }
- w.writeMu.Lock()
- defer w.writeMu.Unlock()
- buf := w.writeBuf[:wgtun.VirtioNetHdrLen+len(p)]
- for i := 0; i < wgtun.VirtioNetHdrLen; i++ {
- buf[i] = 0
- }
- copy(buf[wgtun.VirtioNetHdrLen:], p)
- w.writeWrap[0] = buf
- n, err := w.dev.Write(w.writeWrap, wgtun.VirtioNetHdrLen)
- if err != nil {
- return n, err
- }
- return len(p), nil
- }
- func (w *wireguardTunIO) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
- if pool == nil {
- return nil, fmt.Errorf("wireguard tun: packet pool is nil")
- }
- w.readMu.Lock()
- defer w.readMu.Unlock()
- if len(w.readBuffers) < w.batchSize {
- w.readBuffers = make([][]byte, w.batchSize)
- }
- if len(w.readLens) < w.batchSize {
- w.readLens = make([]int, w.batchSize)
- }
- packets := make([]*Packet, w.batchSize)
- requiredHeadroom := w.BatchHeadroom()
- requiredPayload := w.BatchPayloadCap()
- headroom := 0
- for i := 0; i < w.batchSize; i++ {
- pkt := pool.Get()
- if pkt == nil {
- releasePackets(packets[:i])
- return nil, fmt.Errorf("wireguard tun: packet pool returned nil packet")
- }
- if pkt.Capacity() < requiredPayload {
- pkt.Release()
- releasePackets(packets[:i])
- return nil, fmt.Errorf("wireguard tun: packet capacity %d below required %d", pkt.Capacity(), requiredPayload)
- }
- if i == 0 {
- headroom = pkt.Offset
- if headroom < requiredHeadroom {
- pkt.Release()
- releasePackets(packets[:i])
- return nil, fmt.Errorf("wireguard tun: packet headroom %d below virtio requirement %d", headroom, requiredHeadroom)
- }
- } else if pkt.Offset != headroom {
- pkt.Release()
- releasePackets(packets[:i])
- return nil, fmt.Errorf("wireguard tun: inconsistent packet headroom (%d != %d)", pkt.Offset, headroom)
- }
- packets[i] = pkt
- w.readBuffers[i] = pkt.Buf
- }
- n, err := w.dev.Read(w.readBuffers[:w.batchSize], w.readLens[:w.batchSize], headroom)
- if err != nil {
- releasePackets(packets)
- return nil, err
- }
- if n == 0 {
- releasePackets(packets)
- return nil, nil
- }
- for i := 0; i < n; i++ {
- packets[i].Len = w.readLens[i]
- }
- for i := n; i < w.batchSize; i++ {
- packets[i].Release()
- packets[i] = nil
- }
- return packets[:n], nil
- }
- func (w *wireguardTunIO) WriteBatch(packets []*Packet) (int, error) {
- if len(packets) == 0 {
- return 0, nil
- }
- requiredHeadroom := w.BatchHeadroom()
- offset := packets[0].Offset
- if offset < requiredHeadroom {
- releasePackets(packets)
- return 0, fmt.Errorf("wireguard tun: packet offset %d smaller than required headroom %d", offset, requiredHeadroom)
- }
- for _, pkt := range packets {
- if pkt == nil {
- continue
- }
- if pkt.Offset != offset {
- releasePackets(packets)
- return 0, fmt.Errorf("wireguard tun: mixed packet offsets not supported")
- }
- limit := pkt.Offset + pkt.Len
- if limit > len(pkt.Buf) {
- releasePackets(packets)
- return 0, fmt.Errorf("wireguard tun: packet length %d exceeds buffer capacity %d", pkt.Len, len(pkt.Buf)-pkt.Offset)
- }
- }
- w.writeMu.Lock()
- defer w.writeMu.Unlock()
- if len(w.writeBuffers) < len(packets) {
- w.writeBuffers = make([][]byte, len(packets))
- }
- for i, pkt := range packets {
- if pkt == nil {
- w.writeBuffers[i] = nil
- continue
- }
- limit := pkt.Offset + pkt.Len
- w.writeBuffers[i] = pkt.Buf[:limit]
- }
- n, err := w.dev.Write(w.writeBuffers[:len(packets)], offset)
- if err != nil {
- return n, err
- }
- releasePackets(packets)
- return n, nil
- }
- func (w *wireguardTunIO) BatchHeadroom() int {
- return wgtun.VirtioNetHdrLen
- }
- func (w *wireguardTunIO) BatchPayloadCap() int {
- return w.mtu
- }
- func (w *wireguardTunIO) BatchSize() int {
- return w.batchSize
- }
- func (w *wireguardTunIO) Close() error {
- return nil
- }
- func releasePackets(pkts []*Packet) {
- for _, pkt := range pkts {
- if pkt != nil {
- pkt.Release()
- }
- }
- }
|