Browse Source

relay rework

JackDoan 3 days ago
parent
commit
188b20457e
1 changed files with 84 additions and 202 deletions
  1. 84 202
      outside.go

+ 84 - 202
outside.go

@@ -20,150 +20,86 @@ const (
 	minFwPacketLen = 4
 )
 
-func (f *Interface) readOutsidePacketFromRelay(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
-	//todo this is way too similar to readOutsidePacketsMany, find a way to eliminate
-	err := h.Parse(packet)
+// handleRelayPackets handles relay packets. Returns false if there's nothing left to do, true for continuing to process an unwrapped TerminalType packet
+// scratch must be large enough to contain a packet to be relayed if needed
+func (f *Interface) handleRelayPackets(via *ViaSender, hostinfo *HostInfo, segment *[]byte, scratch []byte, h *header.H, nb []byte) bool {
+	var err error
+	// The entire body is sent as AD, not encrypted.
+	// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
+	// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
+	// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
+	// which will gracefully fail in the DecryptDanger call.
+	seg := *segment
+	signedPayload := seg[:len(*segment)-hostinfo.ConnectionState.dKey.Overhead()]
+	signatureValue := seg[len(*segment)-hostinfo.ConnectionState.dKey.Overhead():]
+	scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(scratch, signedPayload, signatureValue, h.MessageCounter, nb)
 	if err != nil {
-		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
-		if len(packet) > 1 {
-			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
-		}
-		return
-	}
-
-	//l.Error("in packet ", header, packet[HeaderLen:])
-	if !via.IsRelayed {
-		if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
-			if f.l.Level >= logrus.DebugLevel {
-				f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
-			}
-			return
-		}
-	}
-
-	var hostinfo *HostInfo
-	// verify if we've seen this index before, otherwise respond to the handshake initiation
-	if h.Type == header.Message && h.Subtype == header.MessageRelay {
-		hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
-	} else {
-		hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
+		return false
 	}
+	// Successfully validated the thing. Get rid of the Relay header.
+	signedPayload = signedPayload[header.Len:]
+	// Pull the Roaming parts up here, and return in all call paths.
+	f.handleHostRoaming(hostinfo, *via)
+	// Track usage of both the HostInfo and the Relay for the received & authenticated packet
+	f.connectionManager.In(hostinfo)
+	f.connectionManager.RelayUsed(h.RemoteIndex)
 
-	var ci *ConnectionState
-	if hostinfo != nil {
-		ci = hostinfo.ConnectionState
+	relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
+	if !ok {
+		// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
+		// its internal mapping. This should never happen.
+		hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
+		return false
 	}
 
-	switch h.Type {
-	case header.Message:
-		if !f.handleEncrypted(ci, via, h) {
-			return
-		}
-
-		switch h.Subtype {
-		case header.MessageNone:
-			if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) {
-				return
-			}
-		case header.MessageRelay:
-			//this packet already came to us via a relay
-			if f.l.Level >= logrus.DebugLevel {
-				f.l.WithField("from", via).Debug("Refusing to process double relayed packet")
-			}
-			return
-		}
-
-	case header.LightHouse:
-		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, via, h) {
-			return
-		}
-
-		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
+	switch relay.Type {
+	case TerminalType:
+		// If I am the target of this relay, process the unwrapped packet
+		// We need to re-write our variables to ensure this segment is correctly parsed.
+		// We could set up for a recursive call here, but this makes it easier to prove that we'll never stack-overflow
+		*via = ViaSender{
+			UdpAddr:   via.UdpAddr,
+			relayHI:   hostinfo,
+			remoteIdx: relay.RemoteIndex,
+			relay:     relay,
+			IsRelayed: true,
+		}
+		//mirrors the top of readOutsideSegment
+		err = h.Parse(signedPayload)
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("from", via).
-				WithField("packet", packet).
-				Error("Failed to decrypt lighthouse packet")
-			return
-		}
-
-		//TODO: assert via is not relayed
-		lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
-
-		// Fallthrough to the bottom to record incoming traffic
-
-	case header.Test:
-		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, via, h) {
-			return
-		}
-
-		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
+			// Hole punch packets are 0 or 1 byte big, so let's ignore printing those errors
+			if len(signedPayload) > 1 {
+				f.l.WithField("packet", segment).Infof("Error while parsing inbound packet from %s: %s", via, err)
+			}
+			return false
+		}
+		*segment = signedPayload
+		//continue flowing through readOutsideSegment()
+		return true
+	case ForwardingType:
+		// Find the target HostInfo relay object
+		targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("from", via).
-				WithField("packet", packet).
-				Error("Failed to decrypt test packet")
-			return
-		}
-
-		if h.Subtype == header.TestRequest {
-			// This testRequest might be from TryPromoteBest, so we should roam
-			// to the new IP address before responding
-			f.handleHostRoaming(hostinfo, via)
-			f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
+			hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
+			return false
 		}
 
-		// Fallthrough to the bottom to record incoming traffic
-
-		// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
-		// are unauthenticated
-
-	case header.Handshake:
-		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		f.handshakeManager.HandleIncoming(via, packet, h)
-		return
-
-	case header.RecvError:
-		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		f.handleRecvError(via.UdpAddr, h)
-		return
-
-	case header.CloseTunnel:
-		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, via, h) {
-			return
-		}
-
-		hostinfo.logger(f.l).WithField("from", via).
-			Info("Close tunnel received, tearing down.")
-
-		f.closeTunnel(hostinfo)
-		return
-
-	case header.Control:
-		if !f.handleEncrypted(ci, via, h) {
-			return
-		}
-
-		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
-		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("from", via).
-				WithField("packet", packet).
-				Error("Failed to decrypt Control packet")
-			return
+		// If that relay is Established, forward the payload through it
+		if targetRelay.State == Established {
+			switch targetRelay.Type {
+			case ForwardingType:
+				// Forward this packet through the relay tunnel, and find the target HostInfo
+				f.SendVia(targetHI, targetRelay, signedPayload, nb, scratch[:0], false) //todo it would be nice to queue this up and do it later, or at least avoid a memcpy of signedPayload
+			case TerminalType:
+				hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
+			default:
+				hostinfo.logger(f.l).WithField("targetRelay.Type", targetRelay.Type).Error("Unexpected Relay Type")
+			}
+		} else {
+			hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
 		}
-
-		f.relayManager.HandleControlMsg(hostinfo, d, f)
-
-	default:
-		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
-		return
 	}
-
-	f.handleHostRoaming(hostinfo, via)
-
-	f.connectionManager.In(hostinfo)
+	return false
 }
 
 func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
@@ -180,6 +116,11 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
 	// verify if we've seen this index before, otherwise respond to the handshake initiation
 	if h.Type == header.Message && h.Subtype == header.MessageRelay {
 		hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
+		keepGoing := f.handleRelayPackets(&via, hostinfo, &segment, out.Scratch[:0], h, nb)
+		if !keepGoing {
+			return
+		}
+
 	} else {
 		hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
 	}
@@ -198,73 +139,15 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
 		switch h.Subtype {
 		case header.MessageNone:
 			if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out, segment, fwPacket, nb, q, localCache, now) {
+				//todo we've allocated a segment we aren't using.
+				//Unfortunately, we can't un-allocate it.
+				//Saving it for "next time" is also problematic.
+				//todo we need to give the segment back, but we don't want to actually send the packet to the tun. blanking the slice is probably the way to go?
 				return
 			}
 		case header.MessageRelay:
-			// The entire body is sent as AD, not encrypted.
-			// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
-			// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
-			// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
-			// which will gracefully fail in the DecryptDanger call.
-			signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()]
-			signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():]
-			out.Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out.Scratch, signedPayload, signatureValue, h.MessageCounter, nb)
-			if err != nil {
-				return
-			}
-			// Successfully validated the thing. Get rid of the Relay header.
-			signedPayload = signedPayload[header.Len:]
-			// Pull the Roaming parts up here, and return in all call paths.
-			f.handleHostRoaming(hostinfo, via)
-			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
-			f.connectionManager.In(hostinfo)
-			f.connectionManager.RelayUsed(h.RemoteIndex)
-
-			relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
-			if !ok {
-				// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
-				// its internal mapping. This should never happen.
-				hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
-				return
-			}
-
-			switch relay.Type {
-			case TerminalType:
-				// If I am the target of this relay, process the unwrapped packet
-				// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
-				via = ViaSender{
-					UdpAddr:   via.UdpAddr,
-					relayHI:   hostinfo,
-					remoteIdx: relay.RemoteIndex,
-					relay:     relay,
-					IsRelayed: true,
-				}
-				f.readOutsidePacketFromRelay(via, out.Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
-				return
-			case ForwardingType:
-				// Find the target HostInfo relay object
-				targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
-				if err != nil {
-					hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
-					return
-				}
-
-				// If that relay is Established, forward the payload through it
-				if targetRelay.State == Established {
-					switch targetRelay.Type {
-					case ForwardingType:
-						// Forward this packet through the relay tunnel
-						// Find the target HostInfo
-						f.SendVia(targetHI, targetRelay, signedPayload, nb, out.Scratch, false)
-						return
-					case TerminalType:
-						hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
-					}
-				} else {
-					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
-					return
-				}
-			}
+			f.l.Error("relayed messages cannot contain relay messages, dropping packet")
+			return
 		}
 
 	case header.LightHouse:
@@ -376,12 +259,11 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*p
 
 		for segment := range pkt.Segments() {
 			f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, q, localCache, now)
-
-		}
-		_, err := f.readers[q].WriteOne(out[i], false, q)
-		if err != nil {
-			f.l.WithError(err).Error("Failed to write packet")
 		}
+		//_, err := f.readers[q].WriteOne(out[i], false, q)
+		//if err != nil {
+		//	f.l.WithError(err).Error("Failed to write packet")
+		//}
 	}
 }