Просмотр исходного кода

remove awful per-packet scratch buf

JackDoan 4 дней назад
Родитель
Сommit
aeded87e71
4 измененных файлов с 33 добавлено и 28 удалено
  1. 2 4
      interface.go
  2. 10 13
      outside.go
  3. 1 1
      overlay/vhostnet/device.go
  4. 20 10
      packet/outpacket.go

+ 2 - 4
interface.go

@@ -291,16 +291,14 @@ func (f *Interface) listenOut(q int) {
 	h := &header.H{}
 	fwPacket := &firewall.Packet{}
 	nb := make([]byte, 12, 12)
+	scratch := make([]byte, udp.MTU)
 
 	toSend := make([][]byte, batch)
 
 	li.ListenOut(func(pkts []*packet.UDPPacket) {
 		toSend = toSend[:0]
-		for i := range outPackets {
-			outPackets[i].SegCounter = 0
-		}
 
-		f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now())
+		f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, scratch, q, ctCache.Get(f.l), time.Now())
 		//we opportunistically tx, but try to also send stragglers
 		if _, err := f.readers[q].WriteMany(outPackets, q); err != nil {
 			f.l.WithError(err).Error("Failed to send packets")

+ 10 - 13
outside.go

@@ -102,7 +102,7 @@ func (f *Interface) handleRelayPackets(via *ViaSender, hostinfo *HostInfo, segme
 	return false
 }
 
-func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
+func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, scratch []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
 	err := h.Parse(segment)
 	if err != nil {
 		// Hole punch packets are 0 or 1 byte big, so let's ignore printing those errors
@@ -116,7 +116,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
 	// verify if we've seen this index before, otherwise respond to the handshake initiation
 	if h.Type == header.Message && h.Subtype == header.MessageRelay {
 		hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
-		keepGoing := f.handleRelayPackets(&via, hostinfo, &segment, out.Scratch[:0], h, nb)
+		keepGoing := f.handleRelayPackets(&via, hostinfo, &segment, scratch[:0], h, nb)
 		if !keepGoing {
 			return
 		}
@@ -139,10 +139,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
 		switch h.Subtype {
 		case header.MessageNone:
 			if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out, segment, fwPacket, nb, q, localCache, now) {
-				//todo we've allocated a segment we aren't using.
-				//Unfortunately, we can't un-allocate it.
-				//Saving it for "next time" is also problematic.
-				//todo we need to give the segment back, but we don't want to actually send the packet to the tun. blanking the slice is probably the way to go?
+				out.DestroyLastSegment() //prevent a rejected segment from being used
 				return
 			}
 		case header.MessageRelay:
@@ -156,7 +153,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
 			return
 		}
 
-		d, err := f.decrypt(hostinfo, h.MessageCounter, out.Scratch, segment, h, nb)
+		d, err := f.decrypt(hostinfo, h.MessageCounter, scratch, segment, h, nb)
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via.UdpAddr).
 				WithField("packet", segment).
@@ -174,7 +171,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
 			return
 		}
 
-		d, err := f.decrypt(hostinfo, h.MessageCounter, out.Scratch, segment, h, nb)
+		d, err := f.decrypt(hostinfo, h.MessageCounter, scratch, segment, h, nb)
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via).
 				WithField("packet", segment).
@@ -186,7 +183,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
 			// This testRequest might be from TryPromoteBest, so we should roam
 			// to the new IP address before responding
 			f.handleHostRoaming(hostinfo, via)
-			f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out.Scratch)
+			f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, scratch)
 		}
 
 		// Fallthrough to the bottom to record incoming traffic
@@ -221,7 +218,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
 			return
 		}
 
-		d, err := f.decrypt(hostinfo, h.MessageCounter, out.Scratch, segment, h, nb)
+		d, err := f.decrypt(hostinfo, h.MessageCounter, scratch, segment, h, nb)
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via).
 				WithField("packet", segment).
@@ -242,9 +239,9 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
 	f.connectionManager.In(hostinfo)
 }
 
-func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
+func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, scratch []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
 	for i, pkt := range packets {
-		out[i].Scratch = out[i].Scratch[:0]
+		scratch = scratch[:0]
 		via := ViaSender{UdpAddr: pkt.AddrPort()}
 
 		//l.Error("in packet ", header, packet[HeaderLen:])
@@ -258,7 +255,7 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*p
 		}
 
 		for segment := range pkt.Segments() {
-			f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, q, localCache, now)
+			f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, scratch, q, localCache, now)
 		}
 		//_, err := f.readers[q].WriteOne(out[i], false, q)
 		//if err != nil {

+ 1 - 1
overlay/vhostnet/device.go

@@ -122,7 +122,7 @@ func NewDevice(options ...Option) (*Device, error) {
 		return nil, fmt.Errorf("refill receive queue: %w", err)
 	}
 	if err = dev.prefillTxQueue(); err != nil {
-		return nil, fmt.Errorf("refill tx queue: %w", err)
+		return nil, fmt.Errorf("prefill tx queue: %w", err)
 	}
 
 	// Make sure to clean up even when the device gets garbage collected without

+ 20 - 10
packet/outpacket.go

@@ -6,15 +6,14 @@ import (
 )
 
 type OutPacket struct {
-	Segments        [][]byte
+	Segments [][]byte
+	// SegmentHeaders maps to the first virtio.NetHdrSize+14 bytes of Segments[n]
+	SegmentHeaders [][]byte
+	// SegmentPayloads maps to the remaining bytes of Segments[n]
 	SegmentPayloads [][]byte
-	SegmentHeaders  [][]byte
-	SegmentIDs      []uint16
-
-	SegSize    int
-	SegCounter int
-
-	Scratch []byte
+	// SegmentIDs is the list of underlying buffer IDs of Segments.
+	// SegmentIDs, Segments, SegmentHeaders, SegmentPayloads should all have the same length at all times!
+	SegmentIDs []uint16
 }
 
 func NewOut() *OutPacket {
@@ -23,7 +22,6 @@ func NewOut() *OutPacket {
 	out.SegmentHeaders = make([][]byte, 0, 64)
 	out.SegmentPayloads = make([][]byte, 0, 64)
 	out.SegmentIDs = make([]uint16, 0, 64)
-	out.Scratch = make([]byte, Size)
 	return out
 }
 
@@ -32,7 +30,19 @@ func (pkt *OutPacket) Reset() {
 	pkt.SegmentPayloads = pkt.SegmentPayloads[:0]
 	pkt.SegmentHeaders = pkt.SegmentHeaders[:0]
 	pkt.SegmentIDs = pkt.SegmentIDs[:0]
-	pkt.SegSize = 0
+}
+
+// DestroyLastSegment removes the contents of the last segment in the list.
+// Use this to handle firewall drops or similar, but still hand the segment buffer back to the underlying driver.
+// Implementations shall discard zero-length segments internally.
+func (pkt *OutPacket) DestroyLastSegment() {
+	if len(pkt.Segments) == 0 {
+		return
+	}
+	lastSeg := len(pkt.SegmentIDs) - 1
+	pkt.SegmentPayloads[lastSeg] = pkt.SegmentPayloads[lastSeg][:0]
+	pkt.SegmentHeaders[lastSeg] = pkt.SegmentHeaders[lastSeg][:0]
+	pkt.Segments[lastSeg] = pkt.Segments[lastSeg][:0]
 }
 
 func (pkt *OutPacket) UseSegment(segID uint16, seg []byte, isV6 bool) int {