فهرست منبع

cert-v2 chores (#1466)

Jack Doan 5 روز پیش
والد
کامیت
768325c9b4
5فایلهای تغییر یافته به همراه53 افزوده شده و 74 حذف شده
  1. 1 2
      cert/cert.go
  2. 1 0
      cert/errors.go
  3. 48 69
      lighthouse.go
  4. 0 1
      overlay/tun_linux.go
  5. 3 2
      pki.go

+ 1 - 2
cert/cert.go

@@ -135,8 +135,7 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
 	case Version2:
 		c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
 	default:
-		//TODO: CERT-V2 make a static var
-		return nil, fmt.Errorf("unknown certificate version %d", v)
+		return nil, ErrUnknownVersion
 	}
 
 	if err != nil {

+ 1 - 0
cert/errors.go

@@ -20,6 +20,7 @@ var (
 	ErrPublicPrivateKeyMismatch   = errors.New("public key and private key are not a pair")
 	ErrPrivateKeyEncrypted        = errors.New("private key must be decrypted")
 	ErrCaNotFound                 = errors.New("could not find ca for the certificate")
+	ErrUnknownVersion             = errors.New("certificate version unrecognized")
 
 	ErrInvalidPEMBlock                   = errors.New("input did not contain a valid PEM encoded block")
 	ErrInvalidPEMCertificateBanner       = errors.New("bytes did not contain a proper certificate banner")

+ 48 - 69
lighthouse.go

@@ -24,6 +24,7 @@ import (
 )
 
 var ErrHostNotKnown = errors.New("host not known")
+var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
 
 type LightHouse struct {
 	//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
@@ -710,14 +711,10 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
 }
 
 func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
-	if _, ok := lh.GetLighthouses()[vpnAddr]; ok {
-		return true
-	}
-	return false
+	_, ok := lh.GetLighthouses()[vpnAddr]
+	return ok
 }
 
-// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake
-// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
 func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
 	l := lh.GetLighthouses()
 	for _, a := range vpnAddr {
@@ -1060,17 +1057,8 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
 		return
 	}
 
-	useVersion := cert.Version1
-	var queryVpnAddr netip.Addr
-	if n.Details.OldVpnAddr != 0 {
-		b := [4]byte{}
-		binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-		queryVpnAddr = netip.AddrFrom4(b)
-		useVersion = 1
-	} else if n.Details.VpnAddr != nil {
-		queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
-		useVersion = 2
-	} else {
+	queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
+	if err != nil {
 		if lhh.l.Level >= logrus.DebugLevel {
 			lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
 		}
@@ -1128,8 +1116,9 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
 			if ok {
 				whereToPunch = newDest
 			} else {
-				//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee
-				//choosing to do nothing for now, but maybe we return an error?
+				if lhh.l.Level >= logrus.DebugLevel {
+					lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
+				}
 			}
 		}
 
@@ -1188,19 +1177,17 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
 				if !r.Is4() {
 					continue
 				}
-
 				b = r.As4()
 				n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
 			}
-
 		} else if v == cert.Version2 {
 			for _, r := range c.relay.relay {
 				n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
 			}
-
 		} else {
-			//TODO: CERT-V2 don't panic
-			panic("unsupported version")
+			if lhh.l.Level >= logrus.DebugLevel {
+				lhh.l.WithField("version", v).Debug("unsupported protocol version")
+			}
 		}
 	}
 }
@@ -1210,18 +1197,16 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
 		return
 	}
 
-	lhh.lh.Lock()
-
-	var certVpnAddr netip.Addr
-	if n.Details.OldVpnAddr != 0 {
-		b := [4]byte{}
-		binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-		certVpnAddr = netip.AddrFrom4(b)
-	} else if n.Details.VpnAddr != nil {
-		certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
+	certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
+	if err != nil {
+		if lhh.l.Level >= logrus.DebugLevel {
+			lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
+		}
+		return
 	}
 	relays := n.Details.GetRelays()
 
+	lhh.lh.Lock()
 	am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
 	am.Lock()
 	lhh.lh.Unlock()
@@ -1246,24 +1231,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
 		return
 	}
 
-	var detailsVpnAddr netip.Addr
-	useVersion := cert.Version1
-	if n.Details.OldVpnAddr != 0 {
-		b := [4]byte{}
-		binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-		detailsVpnAddr = netip.AddrFrom4(b)
-		useVersion = cert.Version1
-	} else if n.Details.VpnAddr != nil {
-		detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
-		useVersion = cert.Version2
-	} else {
+	detailsVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
+	if err != nil {
 		if lhh.l.Level >= logrus.DebugLevel {
-			lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification")
+			lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostUpdateNotification")
 		}
-		return
 	}
 
-	//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
 	//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
 	//Simple check that the host sent this not someone else
 	if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
@@ -1320,8 +1294,16 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
 		return
 	}
 
+	detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
+	if err != nil {
+		if lhh.l.Level >= logrus.DebugLevel {
+			lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
+		}
+		return
+	}
+
 	empty := []byte{0}
-	punch := func(vpnPeer netip.AddrPort) {
+	punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
 		if !vpnPeer.IsValid() {
 			return
 		}
@@ -1333,48 +1315,31 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
 		}()
 
 		if lhh.l.Level >= logrus.DebugLevel {
-			var logVpnAddr netip.Addr
-			if n.Details.OldVpnAddr != 0 {
-				b := [4]byte{}
-				binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-				logVpnAddr = netip.AddrFrom4(b)
-			} else if n.Details.VpnAddr != nil {
-				logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
-			}
 			lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
 		}
 	}
 
 	for _, a := range n.Details.V4AddrPorts {
-		punch(protoV4AddrPortToNetAddrPort(a))
+		punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
 	}
 
 	for _, a := range n.Details.V6AddrPorts {
-		punch(protoV6AddrPortToNetAddrPort(a))
+		punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
 	}
 
 	// This sends a nebula test packet to the host trying to contact us. In the case
 	// of a double nat or other difficult scenario, this may help establish
 	// a tunnel.
 	if lhh.lh.punchy.GetRespond() {
-		var queryVpnAddr netip.Addr
-		if n.Details.OldVpnAddr != 0 {
-			b := [4]byte{}
-			binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-			queryVpnAddr = netip.AddrFrom4(b)
-		} else if n.Details.VpnAddr != nil {
-			queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
-		}
-
 		go func() {
 			time.Sleep(lhh.lh.punchy.GetRespondDelay())
 			if lhh.l.Level >= logrus.DebugLevel {
-				lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr)
+				lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
 			}
 			//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
 			// for each punchBack packet. We should move this into a timerwheel or a single goroutine
 			// managed by a channel.
-			w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 		}()
 	}
 }
@@ -1453,3 +1418,17 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
 	}
 	return netip.Addr{}, false
 }
+
+func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
+	if d.OldVpnAddr != 0 {
+		b := [4]byte{}
+		binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
+		detailsVpnAddr := netip.AddrFrom4(b)
+		return detailsVpnAddr, cert.Version1, nil
+	} else if d.VpnAddr != nil {
+		detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
+		return detailsVpnAddr, cert.Version2, nil
+	} else {
+		return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
+	}
+}

+ 0 - 1
overlay/tun_linux.go

@@ -293,7 +293,6 @@ func (t *tun) addIPs(link netlink.Link) error {
 
 	//add all new addresses
 	for i := range newAddrs {
-		//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
 		//AddrReplace still adds new IPs, but if their properties change it will change them as well
 		if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
 			return err

+ 3 - 2
pki.go

@@ -173,7 +173,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
 
 	p.cs.Store(newState)
 
-	//TODO: CERT-V2 newState needs a stringer that does json
 	if initial {
 		p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
 	} else {
@@ -359,7 +358,9 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
 			return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
 		}
 
-		//TODO: CERT-V2 make sure v2 has v1s address
+		if v1.Networks()[0] != v2.Networks()[0] {
+			return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
+		}
 
 		cs.initiatingVersion = dv
 	}