JackDoan 1 hónapja
szülő
commit
e7423d39f9

+ 1 - 0
go.mod

@@ -6,6 +6,7 @@ require (
 	dario.cat/mergo v1.0.2
 	github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
 	github.com/armon/go-radix v1.0.0
+	github.com/cilium/ebpf v0.12.3
 	github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
 	github.com/flynn/noise v1.1.0
 	github.com/gaissmai/bart v0.25.0

+ 6 - 1
go.sum

@@ -17,6 +17,8 @@ github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6r
 github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
 github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4=
+github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM=
 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps=
 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -24,6 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
 github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
+github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA=
+github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
 github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM=
 github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
@@ -78,8 +82,9 @@ github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfn
 github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
 github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
 github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
-github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
 github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
+github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
 github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=

+ 3 - 2
main.go

@@ -179,7 +179,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 		useWGDefault := runtime.GOOS == "linux"
 		useWG := c.GetBool("listen.use_wireguard_stack", useWGDefault)
-		var mkListener func(*logrus.Logger, netip.Addr, int, bool, int) (udp.Conn, error)
+		var mkListener func(*logrus.Logger, netip.Addr, int, bool, int, int) (udp.Conn, error)
 		if useWG {
 			mkListener = udp.NewWireguardListener
 		} else {
@@ -188,10 +188,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 		for i := 0; i < routines; i++ {
 			l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
-			udpServer, err := mkListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
+			udpServer, err := mkListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64), i)
 			if err != nil {
 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 			}
+			//todo set bpf on zeroth socket
 			udpServer.ReloadConfig(c)
 			if cfg, ok := udpServer.(interface {
 				ConfigureOffload(bool, bool, int)

+ 1 - 1
udp/udp_linux.go

@@ -32,7 +32,7 @@ func maybeIPV4(ip net.IP) (net.IP, bool) {
 	return ip, false
 }
 
-func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int, q int) (Conn, error) {
 	af := unix.AF_INET6
 	if ip.Is4() {
 		af = unix.AF_INET

+ 5 - 4
udp/wireguard_conn_linux.go

@@ -27,13 +27,13 @@ type WGConn struct {
 	enableGRO bool
 	gsoMaxSeg int
 	closed    atomic.Bool
-
+	q         int
 	closeOnce sync.Once
 }
 
 // NewWireguardListener creates a UDP listener backed by WireGuard's StdNetBind.
-func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
-	bind := wgconn.NewStdNetBindForAddr(ip, multi)
+func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int, q int) (Conn, error) {
+	bind := wgconn.NewStdNetBindForAddr(ip, multi, q)
 	recvers, actualPort, err := bind.Open(uint16(port))
 	if err != nil {
 		return nil, err
@@ -51,6 +51,7 @@ func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool,
 		reqBatch:  batch,
 		localIP:   ip,
 		localPort: actualPort,
+		q:         q,
 	}, nil
 }
 
@@ -71,7 +72,7 @@ func (c *WGConn) listen(fn wgconn.ReceiveFunc, r EncReader) {
 	batchSize := c.batch
 	packets := make([][]byte, batchSize)
 	for i := range packets {
-		packets[i] = make([]byte, MTU)
+		packets[i] = make([]byte, 0xffff)
 	}
 	sizes := make([]int, batchSize)
 	endpoints := make([]wgconn.Endpoint, batchSize)

+ 19 - 5
wgstack/conn/bind_std.go

@@ -46,6 +46,7 @@ type StdNetBind struct {
 
 	blackhole4 bool
 	blackhole6 bool
+	q          int
 }
 
 // NewStdNetBind creates a bind that listens on all interfaces.
@@ -56,8 +57,9 @@ func NewStdNetBind() *StdNetBind {
 // NewStdNetBindForAddr creates a bind that listens on a specific address.
 // If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
 // IPv6 socket will be created.
-func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
+func NewStdNetBindForAddr(addr netip.Addr, reusePort bool, q int) *StdNetBind {
 	b := NewStdNetBind()
+	b.q = q
 	//if addr.IsValid() {
 	//	if addr.IsUnspecified() {
 	//		// keep dual-stack defaults with empty listen addresses
@@ -147,12 +149,24 @@ func (e *StdNetEndpoint) DstToString() string {
 	return e.AddrPort.String()
 }
 
-func listenNet(network string, port int) (*net.UDPConn, int, error) {
-	conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
+func listenNet(network string, port int, q int) (*net.UDPConn, int, error) {
+	lc := listenConfig(q)
+
+	conn, err := lc.ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
 	if err != nil {
 		return nil, 0, err
 	}
 
+	if q == 0 {
+		if EvilFdZero == 0 {
+			panic("fuck")
+		}
+		err = reusePortHax(EvilFdZero)
+		if err != nil {
+			return nil, 0, fmt.Errorf("reuse port hax: %v", err)
+		}
+	}
+
 	// Retrieve port.
 	laddr := conn.LocalAddr()
 	uaddr, err := net.ResolveUDPAddr(
@@ -185,13 +199,13 @@ again:
 	var v4pc *ipv4.PacketConn
 	var v6pc *ipv6.PacketConn
 
-	v4conn, port, err = listenNet("udp4", port)
+	v4conn, port, err = listenNet("udp4", port, s.q)
 	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
 		return nil, 0, err
 	}
 
 	// Listen on the same port as we're using for ipv4.
-	v6conn, port, err = listenNet("udp6", port)
+	v6conn, port, err = listenNet("udp6", port, s.q)
 	if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
 		v4conn.Close()
 		tries++

+ 181 - 1
wgstack/conn/controlfns.go

@@ -5,8 +5,12 @@
 package conn
 
 import (
+	"fmt"
 	"net"
 	"syscall"
+
+	"github.com/cilium/ebpf"
+	"github.com/cilium/ebpf/asm"
 )
 
 // UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
@@ -25,10 +29,169 @@ type controlFn func(network, address string, c syscall.RawConn) error
 // that can apply socket options.
 var controlFns = []controlFn{}
 
+const SO_ATTACH_REUSEPORT_EBPF = 52
+
+//Create eBPF program that returns a hash to distribute packets
+
+func createReuseportProgram() (*ebpf.Program, error) {
+	// This program uses the packet's hash and returns it modulo number of sockets
+	// Simple version: just return a counter-based distribution
+	//instructions := asm.Instructions{
+	//	// Load the skb->hash value (already computed by kernel)
+	//	asm.LoadMem(asm.R0, asm.R1, int16(unsafe.Offsetof(unix.XDPMd{}.RxQueueIndex)), asm.Word),
+	//	asm.Return(),
+	//}
+	//
+	//// Alternative: simpler round-robin approach
+	//// This returns the CPU number, effectively round-robin
+	//instructions := asm.Instructions{
+	//	asm.Mov.Reg(asm.R0, asm.R1),              // Move ctx to R0
+	//	asm.LoadMem(asm.R0, asm.R1, 0, asm.Word), // Load some field
+	//	asm.Return(),
+	//}
+
+	// Better: Use BPF helper to get random/hash value
+	//instructions := asm.Instructions{
+	//	// Call get_prandom_u32() to get random value for distribution
+	//	asm.Mov.Imm(asm.R0, 0),
+	//	asm.Call.Label("get_prandom_u32"),
+	//	asm.Return(),
+	//}
+	//
+	//prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
+	//	Type:         ebpf.SocketFilter,
+	//	Instructions: instructions,
+	//	License:      "GPL",
+	//})
+
+	//instructions := asm.Instructions{
+	//	// R1 contains pointer to skb
+	//	// Load skb->hash at offset 0x20 (may vary by kernel, but 0x20 is common)
+	//	asm.LoadMem(asm.R0, asm.R1, 0x20, asm.Word),
+	//
+	//	// If hash is 0, use rxhash instead (fallback)
+	//	asm.JEq.Imm(asm.R0, 0, "use_rxhash"),
+	//	asm.Return().Sym("return"),
+	//
+	//	// Fallback: load rxhash
+	//	asm.LoadMem(asm.R0, asm.R1, 0x24, asm.Word).Sym("use_rxhash"),
+	//	asm.Return(),
+	//}
+	//
+	//prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
+	//	Type:         ebpf.SkReuseport,
+	//	Instructions: instructions,
+	//	License:      "GPL",
+	//})
+
+	//instructions := asm.Instructions{
+	//	// R1 = ctx (sk_reuseport_md)
+	//	// R2 = sk_reuseport map (we'll use NULL/0 for default behavior)
+	//	// R3 = key (select socket index)
+	//	// R4 = flags
+	//
+	//	// Simple approach: use the hash field from sk_reuseport_md
+	//	// struct sk_reuseport_md { ... __u32 hash; ... } at offset 24
+	//	asm.Mov.Reg(asm.R6, asm.R1), // Save ctx
+	//
+	//	// Load the hash value at offset 24
+	//	asm.LoadMem(asm.R2, asm.R6, 24, asm.Word),
+	//
+	//	// Call bpf_sk_select_reuseport(ctx, map, key, flags)
+	//	asm.Mov.Reg(asm.R1, asm.R6), // ctx
+	//	asm.Mov.Imm(asm.R2, 0),      // map (NULL = use default)
+	//	asm.Mov.Reg(asm.R3, asm.R2), // key = hash we loaded (in R2)
+	//	asm.Mov.Imm(asm.R4, 0),      // flags
+	//	asm.Call.Label("sk_select_reuseport"),
+	//
+	//	// Return 0
+	//	asm.Mov.Imm(asm.R0, 0),
+	//	asm.Return(),
+	//}
+	//
+	//prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
+	//	Type:         ebpf.SkReuseport,
+	//	Instructions: instructions,
+	//	License:      "GPL",
+	//})
+
+	instructions := asm.Instructions{
+		// R1 = ctx (sk_reuseport_md pointer)
+		// Load hash from sk_reuseport_md at offset 24
+		//asm.LoadMem(asm.R0, asm.R1, 20, asm.Word),
+
+		// R1 = ctx (save it)
+		asm.Mov.Reg(asm.R6, asm.R1),
+
+		// Prepare string on stack: "BPF called!\n"
+		// We need to build the format string on the stack
+		asm.Mov.Reg(asm.R1, asm.R10), // R1 = frame pointer
+		asm.Add.Imm(asm.R1, -16),     // R1 = stack location for string
+
+		// Write "BPF called!\n" to stack (we'll use a simpler version)
+		// Store immediate 64-bit values
+		asm.StoreImm(asm.R1, 0, 0x2066706220, asm.DWord), // "bpf "
+		asm.StoreImm(asm.R1, 8, 0x0a21, asm.DWord),       // "!\n"
+
+		// Call bpf_trace_printk(fmt, fmt_size)
+		// R1 already points to format string
+		asm.Mov.Imm(asm.R2, 16), // R2 = format size
+		asm.Call.Label("bpf_printk"),
+
+		// Return 0 (send to socket 0 for testing)
+		asm.Mov.Imm(asm.R0, 0),
+		asm.Return(),
+
+		//asm.Mov.Imm(asm.R0, 0),
+		//// Just return the hash directly
+		//// The kernel will automatically modulo by number of sockets
+		//asm.Return(),
+	}
+
+	prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
+		Type:         ebpf.SkReuseport,
+		Instructions: instructions,
+		License:      "GPL",
+	})
+
+	return prog, err
+}
+
+//func createReuseportProgram() (*ebpf.Program, error) {
+//	// Try offset 20 (common in newer kernels)
+//	instructions := asm.Instructions{
+//		asm.LoadMem(asm.R0, asm.R1, 20, asm.Word),
+//		asm.Return(),
+//	}
+//
+//	prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{
+//		Type:         ebpf.SkReuseport,
+//		Instructions: instructions,
+//		License:      "GPL",
+//	})
+//
+//	return prog, err
+//}
+
+func reusePortHax(fd uintptr) error {
+	prog, err := createReuseportProgram()
+	if err != nil {
+		return fmt.Errorf("failed to create eBPF program: %w", err)
+	}
+	//defer prog.Close()
+	sockErr := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, prog.FD())
+	if sockErr != nil {
+		return sockErr
+	}
+	return nil
+}
+
+var EvilFdZero uintptr
+
 // listenConfig returns a net.ListenConfig that applies the controlFns to the
 // socket prior to bind. This is used to apply socket buffer sizing and packet
 // information OOB configuration for sticky sockets.
-func listenConfig() *net.ListenConfig {
+func listenConfig(q int) *net.ListenConfig {
 	return &net.ListenConfig{
 		Control: func(network, address string, c syscall.RawConn) error {
 			for _, fn := range controlFns {
@@ -36,6 +199,23 @@ func listenConfig() *net.ListenConfig {
 					return err
 				}
 			}
+
+			if q == 0 {
+				c.Control(func(fd uintptr) {
+					EvilFdZero = fd
+				})
+				//	var e error
+				//	err := c.Control(func(fd uintptr) {
+				//		e = reusePortHax(fd)
+				//	})
+				//	if err != nil {
+				//		return err
+				//	}
+				//	if e != nil {
+				//		return e
+				//	}
+			}
+
 			return nil
 		},
 	}

+ 1 - 0
wgstack/conn/controlfns_linux.go

@@ -30,6 +30,7 @@ func init() {
 				_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
 				_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
 				_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)  //todo!!!
+				_ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)      //todo!!!
 				_ = unix.SetsockoptInt(int(fd), unix.SOL_UDP, unix.UDP_SEGMENT, 0xffff) //todo!!!
 				//print(err.Error())
 			})

+ 6 - 2
wgstack/conn/features_linux.go

@@ -6,6 +6,7 @@
 package conn
 
 import (
+	"fmt"
 	"net"
 
 	"golang.org/x/sys/unix"
@@ -16,12 +17,15 @@ func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
 	if err != nil {
 		return
 	}
+	a := 0
 	err = rc.Control(func(fd uintptr) {
-		_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
-		txOffload = errSyscall == nil
+		a, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
+
+		txOffload = err == nil
 		opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
 		rxOffload = errSyscall == nil && opt == 1
 	})
+	fmt.Printf("%d", a)
 	if err != nil {
 		return false, false
 	}