Nate Brown hai 3 meses
pai
achega
4848cf051a
Modificáronse 2 ficheiros con 73 adicións e 69 borrados
  1. 0 5
      interface.go
  2. 73 64
      udp/udp_linux.go

+ 0 - 5
interface.go

@@ -6,7 +6,6 @@ import (
 	"fmt"
 	"io"
 	"net/netip"
-	"runtime"
 	"sync"
 	"sync/atomic"
 	"time"
@@ -259,8 +258,6 @@ func (f *Interface) run() (func(), error) {
 }
 
 func (f *Interface) listenOut(i int) {
-	runtime.LockOSThread()
-
 	var li udp.Conn
 	if i > 0 {
 		li = f.writers[i]
@@ -284,8 +281,6 @@ func (f *Interface) listenOut(i int) {
 }
 
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
-	runtime.LockOSThread()
-
 	packet := make([]byte, mtu)
 	out := make([]byte, mtu)
 	fwPacket := &firewall.Packet{}

+ 73 - 64
udp/udp_linux.go

@@ -25,6 +25,13 @@ type StdConn struct {
 	isV4  bool
 	l     *logrus.Logger
 	batch int
+
+	// cached fields for reading packets
+	msgs    []rawMessage
+	buffers [][]byte
+	names   [][]byte
+	n       uintptr
+	err     error
 }
 
 func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
@@ -142,97 +149,99 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
 
 func (u *StdConn) ListenOut(r EncReader) {
 	var ip netip.Addr
-	var n uintptr
-	var err error
-	msgs, buffers, names := u.PrepareRawMessages(u.batch)
+
+	u.msgs, u.buffers, u.names = u.PrepareRawMessages(u.batch)
 	read := u.ReadMulti
 	if u.batch == 1 {
 		read = u.ReadSingle
 	}
 
 	for {
-		read(msgs, &n, &err)
-		if err != nil {
-			u.l.WithError(err).Error("udp socket is closed, exiting read loop")
+		read()
+		if u.err != nil {
+			//TODO: remove logging, return error
+			u.l.WithError(u.err).Error("udp socket is closed, exiting read loop")
 			return
 		}
 
-		for i := 0; i < int(n); i++ {
+		for i := 0; i < int(u.n); i++ {
 			// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
 			if u.isV4 {
-				ip, _ = netip.AddrFromSlice(names[i][4:8])
+				ip, _ = netip.AddrFromSlice(u.names[i][4:8])
 			} else {
-				ip, _ = netip.AddrFromSlice(names[i][8:24])
+				ip, _ = netip.AddrFromSlice(u.names[i][8:24])
 			}
 			//u.l.Error("GOT A PACKET", msgs[i].Len)
-			r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
+			r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(u.names[i][2:4])), u.buffers[i][:u.msgs[i].Len])
 		}
 	}
 }
 
-func (u *StdConn) ReadSingle(msgs []rawMessage, n *uintptr, err *error) {
-	oErr := u.rc.Read(func(fd uintptr) bool {
-		in, _, nErr := unix.Syscall6(
-			unix.SYS_RECVMSG,
-			fd,
-			uintptr(unsafe.Pointer(&(msgs[0].Hdr))),
-			0, 0, 0, 0,
-		)
-
-		if nErr == syscall.EAGAIN || nErr == syscall.EINTR {
-			// Retry read
-			return false
-
-		} else if nErr != 0 {
-			u.l.Errorf("READING FROM UDP SINGLE had an errno %d", nErr)
-			*err = &net.OpError{Op: "recvmsg", Err: nErr}
-			*n = 0
-			return true
-		}
-
-		msgs[0].Len = uint32(in)
-		*n = 1
-		return true
-	})
-
-	if *err == nil && oErr != nil {
-		*err = oErr
-		*n = 0
+func (u *StdConn) ReadSingle() {
+	err := u.rc.Read(u.innerReadSingle)
+	if u.err == nil && err != nil {
+		u.err = err
+		u.n = 0
 		return
 	}
 }
 
-func (u *StdConn) ReadMulti(msgs []rawMessage, n *uintptr, err *error) {
-	oErr := u.rc.Read(func(fd uintptr) bool {
-		var nErr syscall.Errno
-		*n, _, nErr = unix.Syscall6(
-			unix.SYS_RECVMMSG,
-			fd,
-			uintptr(unsafe.Pointer(&(msgs[0].Hdr))),
-			uintptr(len(msgs)),
-			unix.MSG_WAITFORONE,
-			0, 0,
-		)
+func (u *StdConn) innerReadSingle(fd uintptr) bool {
+	in, _, err := unix.Syscall6(
+		unix.SYS_RECVMSG,
+		fd,
+		uintptr(unsafe.Pointer(&(u.msgs[0].Hdr))),
+		0, 0, 0, 0,
+	)
+
+	if err == syscall.EAGAIN || err == syscall.EINTR {
+		// Retry read
+		return false
+
+	} else if err != 0 {
+		u.l.Errorf("READING FROM UDP SINGLE had an errno %d", err)
+		u.err = &net.OpError{Op: "recvmsg", Err: err}
+		u.n = 0
+		return true
+	}
 
-		if nErr == syscall.EAGAIN || nErr == syscall.EINTR {
-			// Retry read
-			return false
+	u.msgs[0].Len = uint32(in)
+	u.n = 1
+	return true
+}
 
-		} else if nErr != 0 {
-			u.l.Errorf("READING FROM UDP MULTI had an errno %d", nErr)
-			*err = &net.OpError{Op: "recvmmsg", Err: nErr}
-			*n = 0
-			return true
-		}
+func (u *StdConn) ReadMulti() {
+	err := u.rc.Read(u.innerReadMulti)
+	if u.err == nil && err != nil {
+		u.err = err
+		u.n = 0
+		return
+	}
+}
 
+func (u *StdConn) innerReadMulti(fd uintptr) bool {
+	var err syscall.Errno
+	u.n, _, err = unix.Syscall6(
+		unix.SYS_RECVMMSG,
+		fd,
+		uintptr(unsafe.Pointer(&u.msgs[0])),
+		uintptr(len(u.msgs)),
+		unix.MSG_WAITFORONE,
+		0, 0,
+	)
+
+	if err == syscall.EAGAIN || err == syscall.EINTR {
+		// Retry read
+		return false
+
+	} else if err != 0 {
+		u.l.Errorf("READING FROM UDP MULTI had an errno %d", err)
+		u.err = &net.OpError{Op: "recvmmsg", Err: err}
+		u.n = 0
 		return true
-	})
-
-	if *err == nil && oErr != nil {
-		*err = oErr
-		*n = 0
-		return
 	}
+
+	return true
 }
 
 func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {