|
@@ -5,7 +5,6 @@ package udp
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
- "encoding/binary"
|
|
|
"fmt"
|
|
|
"net"
|
|
|
"net/netip"
|
|
@@ -16,11 +15,13 @@ import (
|
|
|
"github.com/rcrowley/go-metrics"
|
|
|
"github.com/sirupsen/logrus"
|
|
|
"github.com/slackhq/nebula/config"
|
|
|
+ "golang.org/x/net/ipv6"
|
|
|
"golang.org/x/sys/unix"
|
|
|
)
|
|
|
|
|
|
type StdConn struct {
|
|
|
- c *net.UDPConn
|
|
|
+ c *ipv6.PacketConn
|
|
|
+ uc *net.UDPConn
|
|
|
rc syscall.RawConn
|
|
|
isV4 bool
|
|
|
l *logrus.Logger
|
|
@@ -65,8 +66,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|
|
_ = c.Close()
|
|
|
return nil, fmt.Errorf("unable to open sysfd: %w", err)
|
|
|
}
|
|
|
-
|
|
|
- return &StdConn{c: uc, rc: rc, isV4: ip.Is4(), l: l, batch: batch}, err
|
|
|
+ return &StdConn{c: ipv6.NewPacketConn(c), rc: rc, uc: uc, isV4: ip.Is4(), l: l, batch: batch}, err
|
|
|
}
|
|
|
|
|
|
func (u *StdConn) Rebind() error {
|
|
@@ -143,36 +143,48 @@ func (u *StdConn) GetSoMark() (int, error) {
|
|
|
}
|
|
|
|
|
|
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|
|
- sa := u.c.LocalAddr()
|
|
|
+ sa := u.uc.LocalAddr()
|
|
|
return netip.ParseAddrPort(sa.String())
|
|
|
}
|
|
|
|
|
|
func (u *StdConn) ListenOut(r EncReader) {
|
|
|
var ip netip.Addr
|
|
|
+ var port int
|
|
|
+
|
|
|
+ //u.msgs, u.buffers, u.names = u.PrepareRawMessages(u.batch)
|
|
|
+ //read := u.ReadMulti
|
|
|
+ //if u.batch == 1 {
|
|
|
+ // read = u.ReadSingle
|
|
|
+ //}
|
|
|
|
|
|
- u.msgs, u.buffers, u.names = u.PrepareRawMessages(u.batch)
|
|
|
- read := u.ReadMulti
|
|
|
- if u.batch == 1 {
|
|
|
- read = u.ReadSingle
|
|
|
+ var err error
|
|
|
+ var n int
|
|
|
+ msgs := make([]ipv6.Message, u.batch)
|
|
|
+ for i := range msgs {
|
|
|
+ msgs[i].Buffers = [][]byte{make([]byte, MTU)}
|
|
|
}
|
|
|
|
|
|
for {
|
|
|
- read()
|
|
|
- if u.err != nil {
|
|
|
+ //read()
|
|
|
+ n, err = u.c.ReadBatch(msgs, 0)
|
|
|
+ if err != nil {
|
|
|
//TODO: remove logging, return error
|
|
|
- u.l.WithError(u.err).Error("udp socket is closed, exiting read loop")
|
|
|
+ u.l.WithError(err).Error("udp socket is closed, exiting read loop")
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- for i := 0; i < int(u.n); i++ {
|
|
|
- // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
|
|
- if u.isV4 {
|
|
|
- ip, _ = netip.AddrFromSlice(u.names[i][4:8])
|
|
|
- } else {
|
|
|
- ip, _ = netip.AddrFromSlice(u.names[i][8:24])
|
|
|
+ for i := 0; i < n; i++ {
|
|
|
+ switch addr := msgs[i].Addr.(type) {
|
|
|
+ case *net.UDPAddr:
|
|
|
+ // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
|
|
+ ip, _ = netip.AddrFromSlice(addr.IP)
|
|
|
+ port = addr.Port
|
|
|
+ default:
|
|
|
+ //TODO: this is an error, return?
|
|
|
}
|
|
|
+
|
|
|
//u.l.Error("GOT A PACKET", msgs[i].Len)
|
|
|
- r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(u.names[i][2:4])), u.buffers[i][:u.msgs[i].Len])
|
|
|
+ r(netip.AddrPortFrom(ip.Unmap(), uint16(port)), msgs[i].Buffers[0][:msgs[i].N])
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -245,7 +257,7 @@ func (u *StdConn) innerReadMulti(fd uintptr) bool {
|
|
|
}
|
|
|
|
|
|
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
|
|
- _, err := u.c.WriteToUDPAddrPort(b, ip)
|
|
|
+ _, err := u.uc.WriteToUDPAddrPort(b, ip)
|
|
|
return err
|
|
|
}
|
|
|
|
|
@@ -318,7 +330,7 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
|
|
}
|
|
|
|
|
|
func (u *StdConn) Close() error {
|
|
|
- err := u.c.Close()
|
|
|
+ err := u.uc.Close()
|
|
|
return err
|
|
|
}
|
|
|
|