|
@@ -0,0 +1,403 @@
|
|
|
+//go:build !e2e_testing
|
|
|
+// +build !e2e_testing
|
|
|
+
|
|
|
+// Inspired by https://git.zx2c4.com/wireguard-go/tree/conn/bind_windows.go
|
|
|
+
|
|
|
+package udp
|
|
|
+
|
|
|
+import (
|
|
|
+ "errors"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "net"
|
|
|
+ "sync"
|
|
|
+ "sync/atomic"
|
|
|
+ "syscall"
|
|
|
+ "unsafe"
|
|
|
+
|
|
|
+ "github.com/sirupsen/logrus"
|
|
|
+ "github.com/slackhq/nebula/config"
|
|
|
+ "github.com/slackhq/nebula/firewall"
|
|
|
+ "github.com/slackhq/nebula/header"
|
|
|
+
|
|
|
+ "golang.org/x/sys/windows"
|
|
|
+ "golang.zx2c4.com/wireguard/conn/winrio"
|
|
|
+)
|
|
|
+
|
|
|
+// Assert we meet the standard conn interface
|
|
|
+var _ Conn = &RIOConn{}
|
|
|
+
|
|
|
+//go:linkname procyield runtime.procyield
|
|
|
+func procyield(cycles uint32)
|
|
|
+
|
|
|
+const (
|
|
|
+ packetsPerRing = 1024
|
|
|
+ bytesPerPacket = 2048 - 32
|
|
|
+ receiveSpins = 15
|
|
|
+)
|
|
|
+
|
|
|
+type ringPacket struct {
|
|
|
+ addr windows.RawSockaddrInet6
|
|
|
+ data [bytesPerPacket]byte
|
|
|
+}
|
|
|
+
|
|
|
+type ringBuffer struct {
|
|
|
+ packets uintptr
|
|
|
+ head, tail uint32
|
|
|
+ id winrio.BufferId
|
|
|
+ iocp windows.Handle
|
|
|
+ isFull bool
|
|
|
+ cq winrio.Cq
|
|
|
+ mu sync.Mutex
|
|
|
+ overlapped windows.Overlapped
|
|
|
+}
|
|
|
+
|
|
|
+type RIOConn struct {
|
|
|
+ isOpen atomic.Bool
|
|
|
+ l *logrus.Logger
|
|
|
+ sock windows.Handle
|
|
|
+ rx, tx ringBuffer
|
|
|
+ rq winrio.Rq
|
|
|
+ results [packetsPerRing]winrio.Result
|
|
|
+}
|
|
|
+
|
|
|
+func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) {
|
|
|
+ if !winrio.Initialize() {
|
|
|
+ return nil, errors.New("could not initialize winrio")
|
|
|
+ }
|
|
|
+
|
|
|
+ u := &RIOConn{l: l}
|
|
|
+
|
|
|
+ addr := [16]byte{}
|
|
|
+ copy(addr[:], ip.To16())
|
|
|
+ err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port})
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("bind: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ for i := 0; i < packetsPerRing; i++ {
|
|
|
+ err = u.insertReceiveRequest()
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("init rx ring: %w", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ u.isOpen.Store(true)
|
|
|
+ return u, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (u *RIOConn) bind(sa windows.Sockaddr) error {
|
|
|
+ var err error
|
|
|
+ u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ // Enable v4 for this socket
|
|
|
+ syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
|
|
|
+
|
|
|
+ err = u.rx.Open()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ err = u.tx.Open()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ u.rq, err = winrio.CreateRequestQueue(u.sock, packetsPerRing, 1, packetsPerRing, 1, u.rx.cq, u.tx.cq, 0)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ err = windows.Bind(u.sock, sa)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
|
|
|
+ plaintext := make([]byte, MTU)
|
|
|
+ buffer := make([]byte, MTU)
|
|
|
+ h := &header.H{}
|
|
|
+ fwPacket := &firewall.Packet{}
|
|
|
+ udpAddr := &Addr{IP: make([]byte, 16)}
|
|
|
+ nb := make([]byte, 12, 12)
|
|
|
+
|
|
|
+ for {
|
|
|
+ // Just read one packet at a time
|
|
|
+ n, rua, err := u.receive(buffer)
|
|
|
+ if err != nil {
|
|
|
+ u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ udpAddr.IP = rua.Addr[:]
|
|
|
+ p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port))
|
|
|
+ p[0] = byte(rua.Port >> 8)
|
|
|
+ p[1] = byte(rua.Port)
|
|
|
+ r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (u *RIOConn) insertReceiveRequest() error {
|
|
|
+ packet := u.rx.Push()
|
|
|
+ dataBuffer := &winrio.Buffer{
|
|
|
+ Id: u.rx.id,
|
|
|
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.rx.packets),
|
|
|
+ Length: uint32(len(packet.data)),
|
|
|
+ }
|
|
|
+ addressBuffer := &winrio.Buffer{
|
|
|
+ Id: u.rx.id,
|
|
|
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.rx.packets),
|
|
|
+ Length: uint32(unsafe.Sizeof(packet.addr)),
|
|
|
+ }
|
|
|
+
|
|
|
+ return winrio.ReceiveEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
|
|
|
+}
|
|
|
+
|
|
|
+func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) {
|
|
|
+ if !u.isOpen.Load() {
|
|
|
+ return 0, windows.RawSockaddrInet6{}, net.ErrClosed
|
|
|
+ }
|
|
|
+
|
|
|
+ u.rx.mu.Lock()
|
|
|
+ defer u.rx.mu.Unlock()
|
|
|
+
|
|
|
+ var err error
|
|
|
+ var count uint32
|
|
|
+ var results [1]winrio.Result
|
|
|
+
|
|
|
+retry:
|
|
|
+ count = 0
|
|
|
+ for tries := 0; count == 0 && tries < receiveSpins; tries++ {
|
|
|
+ if tries > 0 {
|
|
|
+ if !u.isOpen.Load() {
|
|
|
+ return 0, windows.RawSockaddrInet6{}, net.ErrClosed
|
|
|
+ }
|
|
|
+ procyield(1)
|
|
|
+ }
|
|
|
+
|
|
|
+ count = winrio.DequeueCompletion(u.rx.cq, results[:])
|
|
|
+ }
|
|
|
+
|
|
|
+ if count == 0 {
|
|
|
+ err = winrio.Notify(u.rx.cq)
|
|
|
+ if err != nil {
|
|
|
+ return 0, windows.RawSockaddrInet6{}, err
|
|
|
+ }
|
|
|
+ var bytes uint32
|
|
|
+ var key uintptr
|
|
|
+ var overlapped *windows.Overlapped
|
|
|
+ err = windows.GetQueuedCompletionStatus(u.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
|
|
|
+ if err != nil {
|
|
|
+ return 0, windows.RawSockaddrInet6{}, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if !u.isOpen.Load() {
|
|
|
+ return 0, windows.RawSockaddrInet6{}, net.ErrClosed
|
|
|
+ }
|
|
|
+
|
|
|
+ count = winrio.DequeueCompletion(u.rx.cq, results[:])
|
|
|
+ if count == 0 {
|
|
|
+ return 0, windows.RawSockaddrInet6{}, io.ErrNoProgress
|
|
|
+
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ u.rx.Return(1)
|
|
|
+ err = u.insertReceiveRequest()
|
|
|
+ if err != nil {
|
|
|
+ return 0, windows.RawSockaddrInet6{}, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
|
|
|
+ // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
|
|
|
+ // attacker bandwidth, just like the rest of the receive path.
|
|
|
+ if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
|
|
|
+ goto retry
|
|
|
+ }
|
|
|
+
|
|
|
+ if results[0].Status != 0 {
|
|
|
+ return 0, windows.RawSockaddrInet6{}, windows.Errno(results[0].Status)
|
|
|
+ }
|
|
|
+
|
|
|
+ packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
|
|
|
+ ep := packet.addr
|
|
|
+ n := copy(buf, packet.data[:results[0].BytesTransferred])
|
|
|
+ return n, ep, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
|
|
|
+ if !u.isOpen.Load() {
|
|
|
+ return net.ErrClosed
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(buf) > bytesPerPacket {
|
|
|
+ return io.ErrShortBuffer
|
|
|
+ }
|
|
|
+
|
|
|
+ u.tx.mu.Lock()
|
|
|
+ defer u.tx.mu.Unlock()
|
|
|
+
|
|
|
+ count := winrio.DequeueCompletion(u.tx.cq, u.results[:])
|
|
|
+ if count == 0 && u.tx.isFull {
|
|
|
+ err := winrio.Notify(u.tx.cq)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ var bytes uint32
|
|
|
+ var key uintptr
|
|
|
+ var overlapped *windows.Overlapped
|
|
|
+ err = windows.GetQueuedCompletionStatus(u.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ if !u.isOpen.Load() {
|
|
|
+ return net.ErrClosed
|
|
|
+ }
|
|
|
+
|
|
|
+ count = winrio.DequeueCompletion(u.tx.cq, u.results[:])
|
|
|
+ if count == 0 {
|
|
|
+ return io.ErrNoProgress
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if count > 0 {
|
|
|
+ u.tx.Return(count)
|
|
|
+ }
|
|
|
+
|
|
|
+ packet := u.tx.Push()
|
|
|
+ packet.addr.Family = windows.AF_INET6
|
|
|
+ p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port))
|
|
|
+ p[0] = byte(addr.Port >> 8)
|
|
|
+ p[1] = byte(addr.Port)
|
|
|
+ copy(packet.addr.Addr[:], addr.IP.To16())
|
|
|
+ copy(packet.data[:], buf)
|
|
|
+
|
|
|
+ dataBuffer := &winrio.Buffer{
|
|
|
+ Id: u.tx.id,
|
|
|
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.tx.packets),
|
|
|
+ Length: uint32(len(buf)),
|
|
|
+ }
|
|
|
+
|
|
|
+ addressBuffer := &winrio.Buffer{
|
|
|
+ Id: u.tx.id,
|
|
|
+ Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.tx.packets),
|
|
|
+ Length: uint32(unsafe.Sizeof(packet.addr)),
|
|
|
+ }
|
|
|
+
|
|
|
+ return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
|
|
+}
|
|
|
+
|
|
|
+func (u *RIOConn) LocalAddr() (*Addr, error) {
|
|
|
+ sa, err := windows.Getsockname(u.sock)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ v6 := sa.(*windows.SockaddrInet6)
|
|
|
+ return &Addr{
|
|
|
+ IP: v6.Addr[:],
|
|
|
+ Port: uint16(v6.Port),
|
|
|
+ }, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (u *RIOConn) Rebind() error {
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (u *RIOConn) ReloadConfig(*config.C) {}
|
|
|
+
|
|
|
+func (u *RIOConn) Close() error {
|
|
|
+ if !u.isOpen.CompareAndSwap(true, false) {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil)
|
|
|
+ windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil)
|
|
|
+
|
|
|
+ u.rx.CloseAndZero()
|
|
|
+ u.tx.CloseAndZero()
|
|
|
+ if u.sock != 0 {
|
|
|
+ windows.CloseHandle(u.sock)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (ring *ringBuffer) Push() *ringPacket {
|
|
|
+ for ring.isFull {
|
|
|
+ panic("ring is full")
|
|
|
+ }
|
|
|
+ ret := (*ringPacket)(unsafe.Pointer(ring.packets + (uintptr(ring.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
|
|
|
+ ring.tail += 1
|
|
|
+ if ring.tail%packetsPerRing == ring.head%packetsPerRing {
|
|
|
+ ring.isFull = true
|
|
|
+ }
|
|
|
+ return ret
|
|
|
+}
|
|
|
+
|
|
|
+func (ring *ringBuffer) Return(count uint32) {
|
|
|
+ if ring.head%packetsPerRing == ring.tail%packetsPerRing && !ring.isFull {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ ring.head += count
|
|
|
+ ring.isFull = false
|
|
|
+}
|
|
|
+
|
|
|
+func (ring *ringBuffer) CloseAndZero() {
|
|
|
+ if ring.cq != 0 {
|
|
|
+ winrio.CloseCompletionQueue(ring.cq)
|
|
|
+ ring.cq = 0
|
|
|
+ }
|
|
|
+
|
|
|
+ if ring.iocp != 0 {
|
|
|
+ windows.CloseHandle(ring.iocp)
|
|
|
+ ring.iocp = 0
|
|
|
+ }
|
|
|
+
|
|
|
+ if ring.id != 0 {
|
|
|
+ winrio.DeregisterBuffer(ring.id)
|
|
|
+ ring.id = 0
|
|
|
+ }
|
|
|
+
|
|
|
+ if ring.packets != 0 {
|
|
|
+ windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
|
|
|
+ ring.packets = 0
|
|
|
+ }
|
|
|
+
|
|
|
+ ring.head = 0
|
|
|
+ ring.tail = 0
|
|
|
+ ring.isFull = false
|
|
|
+}
|
|
|
+
|
|
|
+func (ring *ringBuffer) Open() error {
|
|
|
+ var err error
|
|
|
+ packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
|
|
|
+ ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|