Răsfoiți Sursa

Resolve some todos (#1274)

Nate Brown 8 luni în urmă
părinte
comite
9d310e72c2
7 a modificat fișierele cu 86 adăugiri și 101 ștergeri
  1. 12 11
      connection_manager.go
  2. 5 9
      control.go
  3. 0 1
      firewall.go
  4. 1 1
      interface.go
  5. 3 2
      overlay/tun_windows.go
  6. 0 11
      pki.go
  7. 65 66
      ssh.go

+ 12 - 11
connection_manager.go

@@ -426,17 +426,17 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// Let's sort this out.
 
-	//TODO: current.vpnIp should become an array of vpnIps
+	// Only one side should swap because if both swap then we may never resolve to a single tunnel.
+	// vpn addr is static across all tunnels for this host pair so lets
+	// use that to determine if we should consider swapping.
 	if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
-		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
-		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
-		// The remotes vpn ip is lower than mine. I will not flip.
+		// Their primary vpn addr is less than mine. Do not swap.
 		return false
 	}
 
-	//TODO: we should favor v2 over v1 certificates if configured to send them
-
-	crt := n.intf.pki.getCertificate(current.ConnectionState.myCert.Version())
+	crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
+	// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
+	// settle down.
 	return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
 }
 
@@ -495,13 +495,14 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 }
 
 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
-	crt := n.intf.pki.getCertificate(hostinfo.ConnectionState.myCert.Version())
-	if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), crt.Signature()) {
+	cs := n.intf.pki.getCertState()
+	curCrt := hostinfo.ConnectionState.myCert
+	myCrt := cs.getCertificate(curCrt.Version())
+	if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
+		// The current tunnel is using the latest certificate and version, no need to rehandshake.
 		return
 	}
 
-	//TODO: we should favor v2 over v1 certificates if configured to send them
-
 	n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
 		WithField("reason", "local certificate is not current").
 		Info("Re-handshaking with remote")

+ 5 - 9
control.go

@@ -133,9 +133,9 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
 	_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
 	if found {
-		//TODO: we might have 2 certs....
-		//TODO: this should return our latest version cert
-		return c.f.pki.getDefaultCertificate().Copy()
+		// Only returning the default certificate since its impossible
+		// for any other host but ourselves to have more than 1
+		return c.f.pki.getCertState().GetDefaultCertificate().Copy()
 	}
 	hi := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hi == nil {
@@ -228,13 +228,9 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
 // the int returned is a count of tunnels closed
 func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	//TODO: this is probably better as a function in ConnectionManager or HostMap directly
-	lighthouses := c.f.lightHouse.GetLighthouses()
-
 	shutdown := func(h *HostInfo) {
-		if excludeLighthouses {
-			if _, ok := lighthouses[h.vpnAddrs[0]]; ok {
-				return
-			}
+		if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
+			return
 		}
 		c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 		c.f.closeTunnel(h)

+ 0 - 1
firewall.go

@@ -23,7 +23,6 @@ import (
 )
 
 type FirewallInterface interface {
-	//TODO: name these better addr, localAddr. Are they vpnAddrs?
 	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
 }
 

+ 1 - 1
interface.go

@@ -419,7 +419,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 			f.firewall.EmitStats()
 			f.handshakeManager.EmitStats()
 			udpStats()
-			certExpirationGauge.Update(int64(f.pki.getDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second))
+			certExpirationGauge.Update(int64(f.pki.getCertState().GetDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second))
 			//TODO: we should also report the default certificate version
 		}
 	}

+ 3 - 2
overlay/tun_windows.go

@@ -239,11 +239,12 @@ func (t *winTun) Close() error {
 	luid := winipcfg.LUID(t.tun.LUID())
 	_ = luid.FlushRoutes(windows.AF_INET)
 	_ = luid.FlushIPAddresses(windows.AF_INET)
-	/* We don't support IPV6 yet
+
 	_ = luid.FlushRoutes(windows.AF_INET6)
 	_ = luid.FlushIPAddresses(windows.AF_INET6)
-	*/
+
 	_ = luid.FlushDNS(windows.AF_INET)
+	_ = luid.FlushDNS(windows.AF_INET6)
 
 	return t.tun.Close()
 }

+ 0 - 11
pki.go

@@ -70,16 +70,6 @@ func (p *PKI) getCertState() *CertState {
 	return p.cs.Load()
 }
 
-// TODO: We should remove this
-func (p *PKI) getDefaultCertificate() cert.Certificate {
-	return p.cs.Load().GetDefaultCertificate()
-}
-
-// TODO: We should remove this
-func (p *PKI) getCertificate(v cert.Version) cert.Certificate {
-	return p.cs.Load().getCertificate(v)
-}
-
 func (p *PKI) reload(c *config.C, initial bool) error {
 	err := p.reloadCerts(c, initial)
 	if err != nil {
@@ -300,7 +290,6 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
 		// Load the certificate
 		crt, rawCert, err = loadCertificate(rawCert)
 		if err != nil {
-			//TODO: check error
 			return nil, err
 		}
 

+ 65 - 66
ssh.go

@@ -320,7 +320,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "print-cert",
-		ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip",
+		ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr",
 		Flags: func() (*flag.FlagSet, interface{}) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshPrintCertFlags{}
@@ -336,7 +336,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "print-tunnel",
-		ShortDescription: "Prints json details about a tunnel for the provided vpn ip",
+		ShortDescription: "Prints json details about a tunnel for the provided vpn addr",
 		Flags: func() (*flag.FlagSet, interface{}) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshPrintTunnelFlags{}
@@ -364,7 +364,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "change-remote",
-		ShortDescription: "Changes the remote address used in the tunnel for the provided vpn ip",
+		ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr",
 		Flags: func() (*flag.FlagSet, interface{}) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshChangeRemoteFlags{}
@@ -378,7 +378,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "close-tunnel",
-		ShortDescription: "Closes a tunnel for the provided vpn ip",
+		ShortDescription: "Closes a tunnel for the provided vpn addr",
 		Flags: func() (*flag.FlagSet, interface{}) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshCloseTunnelFlags{}
@@ -392,7 +392,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "create-tunnel",
-		ShortDescription: "Creates a tunnel for the provided vpn ip and address",
+		ShortDescription: "Creates a tunnel for the provided vpn address",
 		Help:             "The lighthouses will be queried for real addresses but you can provide one as well.",
 		Flags: func() (*flag.FlagSet, interface{}) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
@@ -407,8 +407,8 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "query-lighthouse",
-		ShortDescription: "Query the lighthouses for the provided vpn ip",
-		Help:             "This command is asynchronous. Only currently known udp ips will be printed.",
+		ShortDescription: "Query the lighthouses for the provided vpn address",
+		Help:             "This command is asynchronous. Only currently known udp addresses will be printed.",
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
 			return sshQueryLighthouse(f, fs, a, w)
 		},
@@ -465,8 +465,8 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 	}
 
 	type lighthouseInfo struct {
-		VpnIp string    `json:"vpnIp"`
-		Addrs *CacheMap `json:"addrs"`
+		VpnAddr string    `json:"vpnAddr"`
+		Addrs   *CacheMap `json:"addrs"`
 	}
 
 	lightHouse.RLock()
@@ -474,15 +474,15 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 	x := 0
 	for k, v := range lightHouse.addrMap {
 		addrMap[x] = lighthouseInfo{
-			VpnIp: k.String(),
-			Addrs: v.CopyCache(),
+			VpnAddr: k.String(),
+			Addrs:   v.CopyCache(),
 		}
 		x++
 	}
 	lightHouse.RUnlock()
 
 	sort.Slice(addrMap, func(i, j int) bool {
-		return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0
+		return strings.Compare(addrMap[i].VpnAddr, addrMap[j].VpnAddr) < 0
 	})
 
 	if fs.Json || fs.Pretty {
@@ -503,7 +503,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 			if err != nil {
 				return err
 			}
-			err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b)))
+			err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddr, string(b)))
 			if err != nil {
 				return err
 			}
@@ -541,20 +541,20 @@ func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter
 
 func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
 	if len(a) == 0 {
-		return w.WriteLine("No vpn ip was provided")
+		return w.WriteLine("No vpn address was provided")
 	}
 
-	vpnIp, err := netip.ParseAddr(a[0])
+	vpnAddr, err := netip.ParseAddr(a[0])
 	if err != nil {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	if !vpnIp.IsValid() {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+	if !vpnAddr.IsValid() {
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
 	var cm *CacheMap
-	rl := ifce.lightHouse.Query(vpnIp)
+	rl := ifce.lightHouse.Query(vpnAddr)
 	if rl != nil {
 		cm = rl.CopyCache()
 	}
@@ -569,21 +569,21 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 	}
 
 	if len(a) == 0 {
-		return w.WriteLine("No vpn ip was provided")
+		return w.WriteLine("No vpn address was provided")
 	}
 
-	vpnIp, err := netip.ParseAddr(a[0])
+	vpnAddr, err := netip.ParseAddr(a[0])
 	if err != nil {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	if !vpnIp.IsValid() {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+	if !vpnAddr.IsValid() {
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
 	if hostInfo == nil {
-		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
+		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0]))
 	}
 
 	if !flags.LocalOnly {
@@ -610,24 +610,24 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 	}
 
 	if len(a) == 0 {
-		return w.WriteLine("No vpn ip was provided")
+		return w.WriteLine("No vpn address was provided")
 	}
 
-	vpnIp, err := netip.ParseAddr(a[0])
+	vpnAddr, err := netip.ParseAddr(a[0])
 	if err != nil {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	if !vpnIp.IsValid() {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+	if !vpnAddr.IsValid() {
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
 	}
 
-	hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnIp)
+	hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr)
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
 	}
@@ -640,7 +640,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		}
 	}
 
-	hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
+	hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil)
 	if addr.IsValid() {
 		hostInfo.SetRemote(addr)
 	}
@@ -656,7 +656,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 	}
 
 	if len(a) == 0 {
-		return w.WriteLine("No vpn ip was provided")
+		return w.WriteLine("No vpn address was provided")
 	}
 
 	if flags.Address == "" {
@@ -668,18 +668,18 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine("Address could not be parsed")
 	}
 
-	vpnIp, err := netip.ParseAddr(a[0])
+	vpnAddr, err := netip.ParseAddr(a[0])
 	if err != nil {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	if !vpnIp.IsValid() {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+	if !vpnAddr.IsValid() {
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
 	if hostInfo == nil {
-		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
+		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0]))
 	}
 
 	hostInfo.SetRemote(addr)
@@ -785,21 +785,20 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 		return nil
 	}
 
-	//TODO: This should return both certs
-	cert := ifce.pki.getDefaultCertificate()
+	cert := ifce.pki.getCertState().GetDefaultCertificate()
 	if len(a) > 0 {
-		vpnIp, err := netip.ParseAddr(a[0])
+		vpnAddr, err := netip.ParseAddr(a[0])
 		if err != nil {
-			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+			return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
 		}
 
-		if !vpnIp.IsValid() {
-			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+		if !vpnAddr.IsValid() {
+			return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
 		}
 
-		hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
+		hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
 		if hostInfo == nil {
-			return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
+			return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0]))
 		}
 
 		cert = hostInfo.GetCert().Certificate
@@ -857,15 +856,15 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		Error          error
 		Type           string
 		State          string
-		PeerIp         netip.Addr
+		PeerAddr       netip.Addr
 		LocalIndex     uint32
 		RemoteIndex    uint32
 		RelayedThrough []netip.Addr
 	}
 
 	type RelayOutput struct {
-		NebulaIp    netip.Addr
-		RelayForIps []RelayFor
+		NebulaAddr    netip.Addr
+		RelayForAddrs []RelayFor
 	}
 
 	type CmdOutput struct {
@@ -881,16 +880,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 	}
 
 	for k, v := range relays {
-		ro := RelayOutput{NebulaIp: v.vpnAddrs[0]}
+		ro := RelayOutput{NebulaAddr: v.vpnAddrs[0]}
 		co.Relays = append(co.Relays, &ro)
 		relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0])
 		if relayHI == nil {
-			ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
+			ro.RelayForAddrs = append(ro.RelayForAddrs, RelayFor{Error: errors.New("could not find hostinfo")})
 			continue
 		}
-		for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
+		for _, vpnAddr := range relayHI.relayState.CopyRelayForIps() {
 			rf := RelayFor{Error: nil}
-			r, ok := relayHI.relayState.GetRelayForByAddr(vpnIp)
+			r, ok := relayHI.relayState.GetRelayForByAddr(vpnAddr)
 			if ok {
 				t := ""
 				switch r.Type {
@@ -914,19 +913,19 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 
 				rf.LocalIndex = r.LocalIndex
 				rf.RemoteIndex = r.RemoteIndex
-				rf.PeerIp = r.PeerAddr
+				rf.PeerAddr = r.PeerAddr
 				rf.Type = t
 				rf.State = s
 				if rf.LocalIndex != k {
 					rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
 				}
 			}
-			relayedHI := ifce.hostMap.QueryVpnAddr(vpnIp)
+			relayedHI := ifce.hostMap.QueryVpnAddr(vpnAddr)
 			if relayedHI != nil {
 				rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
 			}
 
-			ro.RelayForIps = append(ro.RelayForIps, rf)
+			ro.RelayForAddrs = append(ro.RelayForAddrs, rf)
 		}
 	}
 	err := enc.Encode(co)
@@ -944,21 +943,21 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 	}
 
 	if len(a) == 0 {
-		return w.WriteLine("No vpn ip was provided")
+		return w.WriteLine("No vpn address was provided")
 	}
 
-	vpnIp, err := netip.ParseAddr(a[0])
+	vpnAddr, err := netip.ParseAddr(a[0])
 	if err != nil {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+		return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
 	}
 
-	if !vpnIp.IsValid() {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+	if !vpnAddr.IsValid() {
+		return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
 	}
 
-	hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
 	if hostInfo == nil {
-		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
+		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0]))
 	}
 
 	enc := json.NewEncoder(w.GetWriter())