Browse Source

adding checks for DNS to ensure connectivity

afeiszli 3 years ago
parent
commit
cbf709166c
2 changed files with 53 additions and 3 deletions
  1. 18 3
      netclient/functions/daemon.go
  2. 35 0
      netclient/local/dns.go

+ 18 - 3
netclient/functions/daemon.go

@@ -280,13 +280,16 @@ func NodeUpdate(client mqtt.Client, msg mqtt.Message) {
 			}
 			}
 			if newNode.DNSOn == "yes" {
 			if newNode.DNSOn == "yes" {
 				ncutils.Log("setting up DNS")
 				ncutils.Log("setting up DNS")
-				if err = local.UpdateDNS(cfg.Node.Interface, cfg.Network, cfg.Server.CoreDNSAddr); err != nil {
-					ncutils.Log("error applying dns" + err.Error())
+				for _, server := range cfg.Node.NetworkSettings.DefaultServerAddrs {
+					if server.IsLeader {
+						go setDNS(cfg.Node.Interface, cfg.Network, server.Address)
+						break
+					}
 				}
 				}
 			}
 			}
 		}
 		}
 		//deal with DNS
 		//deal with DNS
-		if newNode.DNSOn != "yes" && shouldDNSChange {
+		if newNode.DNSOn != "yes" && shouldDNSChange && cfg.Node.Interface != "" {
 			ncutils.Log("settng DNS off")
 			ncutils.Log("settng DNS off")
 			_, err := ncutils.RunCmd("/usr/bin/resolvectl revert "+cfg.Node.Interface, true)
 			_, err := ncutils.RunCmd("/usr/bin/resolvectl revert "+cfg.Node.Interface, true)
 			if err != nil {
 			if err != nil {
@@ -295,6 +298,18 @@ func NodeUpdate(client mqtt.Client, msg mqtt.Message) {
 		}
 		}
 	}()
 	}()
 }
 }
+func setDNS(iface, network, address string) {
+	var reachable bool
+	for counter := 0; !reachable && counter < 5; counter++ {
+		reachable = local.IsDNSReachable(address)
+		time.Sleep(time.Second << 1)
+	}
+	if !reachable {
+		ncutils.Log("not setting dns, server unreachable: " + address)
+	} else if err := local.UpdateDNS(iface, network, address); err != nil {
+		ncutils.Log("error applying dns" + err.Error())
+	}
+}
 
 
 // UpdatePeers -- mqtt message handler for peers/<Network>/<NodeID> topic
 // UpdatePeers -- mqtt message handler for peers/<Network>/<NodeID> topic
 func UpdatePeers(client mqtt.Client, msg mqtt.Message) {
 func UpdatePeers(client mqtt.Client, msg mqtt.Message) {

+ 35 - 0
netclient/local/dns.go

@@ -1,8 +1,11 @@
 package local
 package local
 
 
 import (
 import (
+	"fmt"
+	"net"
 	"os"
 	"os"
 	"strings"
 	"strings"
+	"time"
 
 
 	//"github.com/davecgh/go-spew/spew"
 	//"github.com/davecgh/go-spew/spew"
 	"log"
 	"log"
@@ -11,6 +14,8 @@ import (
 	"github.com/gravitl/netmaker/netclient/ncutils"
 	"github.com/gravitl/netmaker/netclient/ncutils"
 )
 )
 
 
+const DNS_UNREACHABLE_ERROR = "nameserver unreachable"
+
 // SetDNS - sets the DNS of a local machine
 // SetDNS - sets the DNS of a local machine
 func SetDNS(nameserver string) error {
 func SetDNS(nameserver string) error {
 	bytes, err := os.ReadFile("/etc/resolv.conf")
 	bytes, err := os.ReadFile("/etc/resolv.conf")
@@ -35,9 +40,21 @@ func SetDNS(nameserver string) error {
 
 
 // UpdateDNS - updates local DNS of client
 // UpdateDNS - updates local DNS of client
 func UpdateDNS(ifacename string, network string, nameserver string) error {
 func UpdateDNS(ifacename string, network string, nameserver string) error {
+	if ifacename == "" {
+		return fmt.Errorf("cannot set dns: interface name is blank")
+	}
+	if network == "" {
+		return fmt.Errorf("cannot set dns: network name is blank")
+	}
+	if nameserver == "" {
+		return fmt.Errorf("cannot set dns: nameserver is blank")
+	}
 	if ncutils.IsWindows() {
 	if ncutils.IsWindows() {
 		return nil
 		return nil
 	}
 	}
+	if !IsDNSReachable(nameserver) {
+		return fmt.Errorf(DNS_UNREACHABLE_ERROR + " : " + nameserver + ":53")
+	}
 	_, err := exec.LookPath("resolvectl")
 	_, err := exec.LookPath("resolvectl")
 	if err != nil {
 	if err != nil {
 		log.Println(err)
 		log.Println(err)
@@ -60,3 +77,21 @@ func UpdateDNS(ifacename string, network string, nameserver string) error {
 	}
 	}
 	return err
 	return err
 }
 }
+
+func IsDNSReachable(nameserver string) bool {
+	port := "53"
+	protocols := [2]string{"tcp", "udp"}
+	for _, proto := range protocols {
+		timeout := time.Second
+		conn, err := net.DialTimeout(proto, net.JoinHostPort(nameserver, port), timeout)
+		if err != nil {
+			return false
+		}
+		if conn != nil {
+			defer conn.Close()
+		} else {
+			return false
+		}
+	}
+	return true
+}