JackDoan hai 4 días
pai
achega
188b20457e
Modificáronse 1 ficheiros con 84 adicións e 202 borrados
  1. 84 202
      outside.go

+ 84 - 202
outside.go

@@ -20,150 +20,86 @@ const (
 	minFwPacketLen = 4
 	minFwPacketLen = 4
 )
 )
 
 
-func (f *Interface) readOutsidePacketFromRelay(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
-	//todo this is way too similar to readOutsidePacketsMany, find a way to eliminate
-	err := h.Parse(packet)
+// handleRelayPackets handles relay packets. Returns false if there's nothing left to do, true for continuing to process an unwrapped TerminalType packet
+// scratch must be large enough to contain a packet to be relayed if needed
+func (f *Interface) handleRelayPackets(via *ViaSender, hostinfo *HostInfo, segment *[]byte, scratch []byte, h *header.H, nb []byte) bool {
+	var err error
+	// The entire body is sent as AD, not encrypted.
+	// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
+	// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
+	// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
+	// which will gracefully fail in the DecryptDanger call.
+	seg := *segment
+	signedPayload := seg[:len(*segment)-hostinfo.ConnectionState.dKey.Overhead()]
+	signatureValue := seg[len(*segment)-hostinfo.ConnectionState.dKey.Overhead():]
+	scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(scratch, signedPayload, signatureValue, h.MessageCounter, nb)
 	if err != nil {
 	if err != nil {
-		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
-		if len(packet) > 1 {
-			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
-		}
-		return
-	}
-
-	//l.Error("in packet ", header, packet[HeaderLen:])
-	if !via.IsRelayed {
-		if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
-			if f.l.Level >= logrus.DebugLevel {
-				f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
-			}
-			return
-		}
-	}
-
-	var hostinfo *HostInfo
-	// verify if we've seen this index before, otherwise respond to the handshake initiation
-	if h.Type == header.Message && h.Subtype == header.MessageRelay {
-		hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
-	} else {
-		hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
+		return false
 	}
 	}
+	// Successfully validated the thing. Get rid of the Relay header.
+	signedPayload = signedPayload[header.Len:]
+	// Pull the Roaming parts up here, and return in all call paths.
+	f.handleHostRoaming(hostinfo, *via)
+	// Track usage of both the HostInfo and the Relay for the received & authenticated packet
+	f.connectionManager.In(hostinfo)
+	f.connectionManager.RelayUsed(h.RemoteIndex)
 
 
-	var ci *ConnectionState
-	if hostinfo != nil {
-		ci = hostinfo.ConnectionState
+	relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
+	if !ok {
+		// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
+		// its internal mapping. This should never happen.
+		hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
+		return false
 	}
 	}
 
 
-	switch h.Type {
-	case header.Message:
-		if !f.handleEncrypted(ci, via, h) {
-			return
-		}
-
-		switch h.Subtype {
-		case header.MessageNone:
-			if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) {
-				return
-			}
-		case header.MessageRelay:
-			//this packet already came to us via a relay
-			if f.l.Level >= logrus.DebugLevel {
-				f.l.WithField("from", via).Debug("Refusing to process double relayed packet")
-			}
-			return
-		}
-
-	case header.LightHouse:
-		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, via, h) {
-			return
-		}
-
-		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
+	switch relay.Type {
+	case TerminalType:
+		// If I am the target of this relay, process the unwrapped packet
+		// We need to re-write our variables to ensure this segment is correctly parsed.
+		// We could set up for a recursive call here, but this makes it easier to prove that we'll never stack-overflow
+		*via = ViaSender{
+			UdpAddr:   via.UdpAddr,
+			relayHI:   hostinfo,
+			remoteIdx: relay.RemoteIndex,
+			relay:     relay,
+			IsRelayed: true,
+		}
+		//mirrors the top of readOutsideSegment
+		err = h.Parse(signedPayload)
 		if err != nil {
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("from", via).
-				WithField("packet", packet).
-				Error("Failed to decrypt lighthouse packet")
-			return
-		}
-
-		//TODO: assert via is not relayed
-		lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
-
-		// Fallthrough to the bottom to record incoming traffic
-
-	case header.Test:
-		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, via, h) {
-			return
-		}
-
-		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
+			// Hole punch packets are 0 or 1 byte big, so let's ignore printing those errors
+			if len(signedPayload) > 1 {
+				f.l.WithField("packet", segment).Infof("Error while parsing inbound packet from %s: %s", via, err)
+			}
+			return false
+		}
+		*segment = signedPayload
+		//continue flowing through readOutsideSegment()
+		return true
+	case ForwardingType:
+		// Find the target HostInfo relay object
+		targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
 		if err != nil {
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("from", via).
-				WithField("packet", packet).
-				Error("Failed to decrypt test packet")
-			return
-		}
-
-		if h.Subtype == header.TestRequest {
-			// This testRequest might be from TryPromoteBest, so we should roam
-			// to the new IP address before responding
-			f.handleHostRoaming(hostinfo, via)
-			f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
+			hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
+			return false
 		}
 		}
 
 
-		// Fallthrough to the bottom to record incoming traffic
-
-		// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
-		// are unauthenticated
-
-	case header.Handshake:
-		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		f.handshakeManager.HandleIncoming(via, packet, h)
-		return
-
-	case header.RecvError:
-		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		f.handleRecvError(via.UdpAddr, h)
-		return
-
-	case header.CloseTunnel:
-		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, via, h) {
-			return
-		}
-
-		hostinfo.logger(f.l).WithField("from", via).
-			Info("Close tunnel received, tearing down.")
-
-		f.closeTunnel(hostinfo)
-		return
-
-	case header.Control:
-		if !f.handleEncrypted(ci, via, h) {
-			return
-		}
-
-		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
-		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("from", via).
-				WithField("packet", packet).
-				Error("Failed to decrypt Control packet")
-			return
+		// If that relay is Established, forward the payload through it
+		if targetRelay.State == Established {
+			switch targetRelay.Type {
+			case ForwardingType:
+				// Forward this packet through the relay tunnel, and find the target HostInfo
+				f.SendVia(targetHI, targetRelay, signedPayload, nb, scratch[:0], false) //todo it would be nice to queue this up and do it later, or at least avoid a memcpy of signedPayload
+			case TerminalType:
+				hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
+			default:
+				hostinfo.logger(f.l).WithField("targetRelay.Type", targetRelay.Type).Error("Unexpected Relay Type")
+			}
+		} else {
+			hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
 		}
 		}
-
-		f.relayManager.HandleControlMsg(hostinfo, d, f)
-
-	default:
-		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
-		return
 	}
 	}
-
-	f.handleHostRoaming(hostinfo, via)
-
-	f.connectionManager.In(hostinfo)
+	return false
 }
 }
 
 
 func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
 func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
@@ -180,6 +116,11 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
 	// verify if we've seen this index before, otherwise respond to the handshake initiation
 	// verify if we've seen this index before, otherwise respond to the handshake initiation
 	if h.Type == header.Message && h.Subtype == header.MessageRelay {
 	if h.Type == header.Message && h.Subtype == header.MessageRelay {
 		hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
 		hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
+		keepGoing := f.handleRelayPackets(&via, hostinfo, &segment, out.Scratch[:0], h, nb)
+		if !keepGoing {
+			return
+		}
+
 	} else {
 	} else {
 		hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
 		hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
 	}
 	}
@@ -198,73 +139,15 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
 		switch h.Subtype {
 		switch h.Subtype {
 		case header.MessageNone:
 		case header.MessageNone:
 			if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out, segment, fwPacket, nb, q, localCache, now) {
 			if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out, segment, fwPacket, nb, q, localCache, now) {
+				//todo we've allocated a segment we aren't using.
+				//Unfortunately, we can't un-allocate it.
+				//Saving it for "next time" is also problematic.
+				//todo we need to give the segment back, but we don't want to actually send the packet to the tun. blanking the slice is probably the way to go?
 				return
 				return
 			}
 			}
 		case header.MessageRelay:
 		case header.MessageRelay:
-			// The entire body is sent as AD, not encrypted.
-			// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
-			// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
-			// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
-			// which will gracefully fail in the DecryptDanger call.
-			signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()]
-			signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():]
-			out.Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out.Scratch, signedPayload, signatureValue, h.MessageCounter, nb)
-			if err != nil {
-				return
-			}
-			// Successfully validated the thing. Get rid of the Relay header.
-			signedPayload = signedPayload[header.Len:]
-			// Pull the Roaming parts up here, and return in all call paths.
-			f.handleHostRoaming(hostinfo, via)
-			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
-			f.connectionManager.In(hostinfo)
-			f.connectionManager.RelayUsed(h.RemoteIndex)
-
-			relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
-			if !ok {
-				// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
-				// its internal mapping. This should never happen.
-				hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
-				return
-			}
-
-			switch relay.Type {
-			case TerminalType:
-				// If I am the target of this relay, process the unwrapped packet
-				// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
-				via = ViaSender{
-					UdpAddr:   via.UdpAddr,
-					relayHI:   hostinfo,
-					remoteIdx: relay.RemoteIndex,
-					relay:     relay,
-					IsRelayed: true,
-				}
-				f.readOutsidePacketFromRelay(via, out.Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
-				return
-			case ForwardingType:
-				// Find the target HostInfo relay object
-				targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
-				if err != nil {
-					hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
-					return
-				}
-
-				// If that relay is Established, forward the payload through it
-				if targetRelay.State == Established {
-					switch targetRelay.Type {
-					case ForwardingType:
-						// Forward this packet through the relay tunnel
-						// Find the target HostInfo
-						f.SendVia(targetHI, targetRelay, signedPayload, nb, out.Scratch, false)
-						return
-					case TerminalType:
-						hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
-					}
-				} else {
-					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
-					return
-				}
-			}
+			f.l.Error("relayed messages cannot contain relay messages, dropping packet")
+			return
 		}
 		}
 
 
 	case header.LightHouse:
 	case header.LightHouse:
@@ -376,12 +259,11 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*p
 
 
 		for segment := range pkt.Segments() {
 		for segment := range pkt.Segments() {
 			f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, q, localCache, now)
 			f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, q, localCache, now)
-
-		}
-		_, err := f.readers[q].WriteOne(out[i], false, q)
-		if err != nil {
-			f.l.WithError(err).Error("Failed to write packet")
 		}
 		}
+		//_, err := f.readers[q].WriteOne(out[i], false, q)
+		//if err != nil {
+		//	f.l.WithError(err).Error("Failed to write packet")
+		//}
 	}
 	}
 }
 }