2
0
Эх сурвалжийг харах

Use registered io on Windows when possible (#905)

Nate Brown 2 жил өмнө
parent
commit
a3e59a38ef

+ 3 - 1
Makefile

@@ -12,6 +12,8 @@ ifeq ($(OS),Windows_NT)
 	GOISMIN := $(shell IF "$(GOVERSION)" GEQ "$(GOMINVERSION)" ECHO 1)
 	NEBULA_CMD_SUFFIX = .exe
 	NULL_FILE = nul
+	# RIO on windows does pointer stuff that makes go vet angry
+	VET_FLAGS = -unsafeptr=false
 else
 	GOVERSION := $(shell go version | awk '{print substr($$3, 3)}')
 	GOISMIN := $(shell expr "$(GOVERSION)" ">=" "$(GOMINVERSION)")
@@ -143,7 +145,7 @@ build/nebula-%.zip: build/%/nebula.exe build/%/nebula-cert.exe
 	cd build/$* && zip ../nebula-$*.zip nebula.exe nebula-cert.exe
 
 vet:
-	go vet -v ./...
+	go vet $(VET_FLAGS) -v ./...
 
 test:
 	go test -v ./...

+ 1 - 0
go.mod

@@ -26,6 +26,7 @@ require (
 	golang.org/x/sys v0.8.0
 	golang.org/x/term v0.8.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
+	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
 	golang.zx2c4.com/wireguard/windows v0.5.3
 	google.golang.org/protobuf v1.30.0
 	gopkg.in/yaml.v2 v2.4.0

+ 2 - 0
go.sum

@@ -219,6 +219,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
 golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
+golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
+golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
 golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
 golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
 google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=

+ 7 - 0
interface.go

@@ -413,6 +413,13 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 func (f *Interface) Close() error {
 	f.closed.Store(true)
 
+	for _, u := range f.writers {
+		err := u.Close()
+		if err != nil {
+			f.l.WithError(err).Error("Error while closing udp socket")
+		}
+	}
+
 	// Release the tun device
 	return f.inside.Close()
 }

+ 4 - 0
udp/conn.go

@@ -26,6 +26,7 @@ type Conn interface {
 	ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
 	WriteTo(b []byte, addr *Addr) error
 	ReloadConfig(c *config.C)
+	Close() error
 }
 
 type NoopConn struct{}
@@ -45,3 +46,6 @@ func (NoopConn) WriteTo(_ []byte, _ *Addr) error {
 func (NoopConn) ReloadConfig(_ *config.C) {
 	return
 }
+func (NoopConn) Close() error {
+	return nil
+}

+ 5 - 0
udp/udp_android.go

@@ -8,9 +8,14 @@ import (
 	"net"
 	"syscall"
 
+	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 )
 
+func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+	return NewGenericListener(l, ip, port, multi, batch)
+}
+
 func NewListenConfig(multi bool) net.ListenConfig {
 	return net.ListenConfig{
 		Control: func(network, address string, c syscall.RawConn) error {

+ 5 - 0
udp/udp_darwin.go

@@ -10,9 +10,14 @@ import (
 	"net"
 	"syscall"
 
+	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 )
 
+func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+	return NewGenericListener(l, ip, port, multi, batch)
+}
+
 func NewListenConfig(multi bool) net.ListenConfig {
 	return net.ListenConfig{
 		Control: func(network, address string, c syscall.RawConn) error {

+ 5 - 0
udp/udp_freebsd.go

@@ -10,9 +10,14 @@ import (
 	"net"
 	"syscall"
 
+	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 )
 
+func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+	return NewGenericListener(l, ip, port, multi, batch)
+}
+
 func NewListenConfig(multi bool) net.ListenConfig {
 	return net.ListenConfig{
 		Control: func(network, address string, c syscall.RawConn) error {

+ 5 - 3
udp/udp_generic.go

@@ -23,7 +23,9 @@ type GenericConn struct {
 	l *logrus.Logger
 }
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+var _ Conn = &GenericConn{}
+
+func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
 	lc := NewListenConfig(multi)
 	pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
 	if err != nil {
@@ -80,8 +82,8 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
 		// Just read one packet at a time
 		n, rua, err := u.ReadFromUDP(buffer)
 		if err != nil {
-			u.l.WithError(err).Error("Failed to read packets")
-			continue
+			u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
+			return
 		}
 
 		udpAddr.IP = rua.IP

+ 7 - 2
udp/udp_linux.go

@@ -137,8 +137,8 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
 	for {
 		n, err := read(msgs)
 		if err != nil {
-			u.l.WithError(err).Error("Failed to read packets")
-			continue
+			u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
+			return
 		}
 
 		//metric.Update(int64(n))
@@ -262,6 +262,11 @@ func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error {
 	return nil
 }
 
+func (u *StdConn) Close() error {
+	//TODO: this will not interrupt the read loop
+	return syscall.Close(u.sysFd)
+}
+
 func NewUDPStatsEmitter(udpConns []Conn) func() {
 	// Check if our kernel supports SO_MEMINFO before registering the gauges
 	var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge

+ 403 - 0
udp/udp_rio_windows.go

@@ -0,0 +1,403 @@
+//go:build !e2e_testing
+// +build !e2e_testing
+
+// Inspired by https://git.zx2c4.com/wireguard-go/tree/conn/bind_windows.go
+
+package udp
+
+import (
+	"errors"
+	"fmt"
+	"io"
+	"net"
+	"sync"
+	"sync/atomic"
+	"syscall"
+	"unsafe"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
+
+	"golang.org/x/sys/windows"
+	"golang.zx2c4.com/wireguard/conn/winrio"
+)
+
+// Assert we meet the standard conn interface
+var _ Conn = &RIOConn{}
+
+//go:linkname procyield runtime.procyield
+func procyield(cycles uint32)
+
+const (
+	packetsPerRing = 1024
+	bytesPerPacket = 2048 - 32
+	receiveSpins   = 15
+)
+
+type ringPacket struct {
+	addr windows.RawSockaddrInet6
+	data [bytesPerPacket]byte
+}
+
+type ringBuffer struct {
+	packets    uintptr
+	head, tail uint32
+	id         winrio.BufferId
+	iocp       windows.Handle
+	isFull     bool
+	cq         winrio.Cq
+	mu         sync.Mutex
+	overlapped windows.Overlapped
+}
+
+type RIOConn struct {
+	isOpen  atomic.Bool
+	l       *logrus.Logger
+	sock    windows.Handle
+	rx, tx  ringBuffer
+	rq      winrio.Rq
+	results [packetsPerRing]winrio.Result
+}
+
+func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) {
+	if !winrio.Initialize() {
+		return nil, errors.New("could not initialize winrio")
+	}
+
+	u := &RIOConn{l: l}
+
+	addr := [16]byte{}
+	copy(addr[:], ip.To16())
+	err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port})
+	if err != nil {
+		return nil, fmt.Errorf("bind: %w", err)
+	}
+
+	for i := 0; i < packetsPerRing; i++ {
+		err = u.insertReceiveRequest()
+		if err != nil {
+			return nil, fmt.Errorf("init rx ring: %w", err)
+		}
+	}
+
+	u.isOpen.Store(true)
+	return u, nil
+}
+
+func (u *RIOConn) bind(sa windows.Sockaddr) error {
+	var err error
+	u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
+	if err != nil {
+		return err
+	}
+
+	// Enable v4 for this socket
+	syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
+
+	err = u.rx.Open()
+	if err != nil {
+		return err
+	}
+
+	err = u.tx.Open()
+	if err != nil {
+		return err
+	}
+
+	u.rq, err = winrio.CreateRequestQueue(u.sock, packetsPerRing, 1, packetsPerRing, 1, u.rx.cq, u.tx.cq, 0)
+	if err != nil {
+		return err
+	}
+
+	err = windows.Bind(u.sock, sa)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
+	plaintext := make([]byte, MTU)
+	buffer := make([]byte, MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	udpAddr := &Addr{IP: make([]byte, 16)}
+	nb := make([]byte, 12, 12)
+
+	for {
+		// Just read one packet at a time
+		n, rua, err := u.receive(buffer)
+		if err != nil {
+			u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
+			return
+		}
+
+		udpAddr.IP = rua.Addr[:]
+		p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port))
+		p[0] = byte(rua.Port >> 8)
+		p[1] = byte(rua.Port)
+		r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+	}
+}
+
+func (u *RIOConn) insertReceiveRequest() error {
+	packet := u.rx.Push()
+	dataBuffer := &winrio.Buffer{
+		Id:     u.rx.id,
+		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.rx.packets),
+		Length: uint32(len(packet.data)),
+	}
+	addressBuffer := &winrio.Buffer{
+		Id:     u.rx.id,
+		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.rx.packets),
+		Length: uint32(unsafe.Sizeof(packet.addr)),
+	}
+
+	return winrio.ReceiveEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
+}
+
+func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) {
+	if !u.isOpen.Load() {
+		return 0, windows.RawSockaddrInet6{}, net.ErrClosed
+	}
+
+	u.rx.mu.Lock()
+	defer u.rx.mu.Unlock()
+
+	var err error
+	var count uint32
+	var results [1]winrio.Result
+
+retry:
+	count = 0
+	for tries := 0; count == 0 && tries < receiveSpins; tries++ {
+		if tries > 0 {
+			if !u.isOpen.Load() {
+				return 0, windows.RawSockaddrInet6{}, net.ErrClosed
+			}
+			procyield(1)
+		}
+
+		count = winrio.DequeueCompletion(u.rx.cq, results[:])
+	}
+
+	if count == 0 {
+		err = winrio.Notify(u.rx.cq)
+		if err != nil {
+			return 0, windows.RawSockaddrInet6{}, err
+		}
+		var bytes uint32
+		var key uintptr
+		var overlapped *windows.Overlapped
+		err = windows.GetQueuedCompletionStatus(u.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
+		if err != nil {
+			return 0, windows.RawSockaddrInet6{}, err
+		}
+
+		if !u.isOpen.Load() {
+			return 0, windows.RawSockaddrInet6{}, net.ErrClosed
+		}
+
+		count = winrio.DequeueCompletion(u.rx.cq, results[:])
+		if count == 0 {
+			return 0, windows.RawSockaddrInet6{}, io.ErrNoProgress
+
+		}
+	}
+
+	u.rx.Return(1)
+	err = u.insertReceiveRequest()
+	if err != nil {
+		return 0, windows.RawSockaddrInet6{}, err
+	}
+
+	// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
+	// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
+	// attacker bandwidth, just like the rest of the receive path.
+	if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
+		goto retry
+	}
+
+	if results[0].Status != 0 {
+		return 0, windows.RawSockaddrInet6{}, windows.Errno(results[0].Status)
+	}
+
+	packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
+	ep := packet.addr
+	n := copy(buf, packet.data[:results[0].BytesTransferred])
+	return n, ep, nil
+}
+
+func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
+	if !u.isOpen.Load() {
+		return net.ErrClosed
+	}
+
+	if len(buf) > bytesPerPacket {
+		return io.ErrShortBuffer
+	}
+
+	u.tx.mu.Lock()
+	defer u.tx.mu.Unlock()
+
+	count := winrio.DequeueCompletion(u.tx.cq, u.results[:])
+	if count == 0 && u.tx.isFull {
+		err := winrio.Notify(u.tx.cq)
+		if err != nil {
+			return err
+		}
+
+		var bytes uint32
+		var key uintptr
+		var overlapped *windows.Overlapped
+		err = windows.GetQueuedCompletionStatus(u.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
+		if err != nil {
+			return err
+		}
+
+		if !u.isOpen.Load() {
+			return net.ErrClosed
+		}
+
+		count = winrio.DequeueCompletion(u.tx.cq, u.results[:])
+		if count == 0 {
+			return io.ErrNoProgress
+		}
+	}
+
+	if count > 0 {
+		u.tx.Return(count)
+	}
+
+	packet := u.tx.Push()
+	packet.addr.Family = windows.AF_INET6
+	p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port))
+	p[0] = byte(addr.Port >> 8)
+	p[1] = byte(addr.Port)
+	copy(packet.addr.Addr[:], addr.IP.To16())
+	copy(packet.data[:], buf)
+
+	dataBuffer := &winrio.Buffer{
+		Id:     u.tx.id,
+		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.tx.packets),
+		Length: uint32(len(buf)),
+	}
+
+	addressBuffer := &winrio.Buffer{
+		Id:     u.tx.id,
+		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.tx.packets),
+		Length: uint32(unsafe.Sizeof(packet.addr)),
+	}
+
+	return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
+}
+
+func (u *RIOConn) LocalAddr() (*Addr, error) {
+	sa, err := windows.Getsockname(u.sock)
+	if err != nil {
+		return nil, err
+	}
+
+	v6 := sa.(*windows.SockaddrInet6)
+	return &Addr{
+		IP:   v6.Addr[:],
+		Port: uint16(v6.Port),
+	}, nil
+}
+
+func (u *RIOConn) Rebind() error {
+	return nil
+}
+
+func (u *RIOConn) ReloadConfig(*config.C) {}
+
+func (u *RIOConn) Close() error {
+	if !u.isOpen.CompareAndSwap(true, false) {
+		return nil
+	}
+
+	windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil)
+	windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil)
+
+	u.rx.CloseAndZero()
+	u.tx.CloseAndZero()
+	if u.sock != 0 {
+		windows.CloseHandle(u.sock)
+	}
+	return nil
+}
+
+func (ring *ringBuffer) Push() *ringPacket {
+	for ring.isFull {
+		panic("ring is full")
+	}
+	ret := (*ringPacket)(unsafe.Pointer(ring.packets + (uintptr(ring.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
+	ring.tail += 1
+	if ring.tail%packetsPerRing == ring.head%packetsPerRing {
+		ring.isFull = true
+	}
+	return ret
+}
+
+func (ring *ringBuffer) Return(count uint32) {
+	if ring.head%packetsPerRing == ring.tail%packetsPerRing && !ring.isFull {
+		return
+	}
+	ring.head += count
+	ring.isFull = false
+}
+
+func (ring *ringBuffer) CloseAndZero() {
+	if ring.cq != 0 {
+		winrio.CloseCompletionQueue(ring.cq)
+		ring.cq = 0
+	}
+
+	if ring.iocp != 0 {
+		windows.CloseHandle(ring.iocp)
+		ring.iocp = 0
+	}
+
+	if ring.id != 0 {
+		winrio.DeregisterBuffer(ring.id)
+		ring.id = 0
+	}
+
+	if ring.packets != 0 {
+		windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
+		ring.packets = 0
+	}
+
+	ring.head = 0
+	ring.tail = 0
+	ring.isFull = false
+}
+
+func (ring *ringBuffer) Open() error {
+	var err error
+	packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
+	ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
+	if err != nil {
+		return err
+	}
+
+	ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
+	if err != nil {
+		return err
+	}
+
+	ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
+	if err != nil {
+		return err
+	}
+
+	ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}

+ 6 - 0
udp/udp_tester.go

@@ -140,3 +140,9 @@ func (u *TesterConn) LocalAddr() (*Addr, error) {
 func (u *TesterConn) Rebind() error {
 	return nil
 }
+
+func (u *TesterConn) Close() error {
+	close(u.RxPackets)
+	close(u.TxPackets)
+	return nil
+}

+ 19 - 2
udp/udp_windows.go

@@ -3,14 +3,31 @@
 
 package udp
 
-// Windows support is primarily implemented in udp_generic, besides NewListenConfig
-
 import (
 	"fmt"
 	"net"
 	"syscall"
+
+	"github.com/sirupsen/logrus"
 )
 
+func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+	if multi {
+		//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
+		// The udp stack would need to be reworked to hide away the implementation differences between
+		// Windows and Linux
+		return nil, fmt.Errorf("multiple udp listeners not supported on windows")
+	}
+
+	rc, err := NewRIOListener(l, ip, port)
+	if err == nil {
+		return rc, nil
+	}
+
+	l.WithError(err).Error("Falling back to standard udp sockets")
+	return NewGenericListener(l, ip, port, multi, batch)
+}
+
 func NewListenConfig(multi bool) net.ListenConfig {
 	return net.ListenConfig{
 		Control: func(network, address string, c syscall.RawConn) error {