Pārlūkot izejas kodu

add a little context to dns

Jay Wren 3 mēneši atpakaļ
vecāks
revīzija
5ceac2b078
2 mainītis faili ar 14 papildinājumiem un 13 dzēšanām
  1. 13 12
      dns_server.go
  2. 1 1
      main.go

+ 13 - 12
dns_server.go

@@ -1,6 +1,7 @@
 package nebula
 
 import (
+	"context"
 	"fmt"
 	"net"
 	"net/netip"
@@ -39,7 +40,7 @@ func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecord
 	}
 }
 
-func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
+func (d *dnsRecords) query(q uint16, data string) netip.Addr {
 	data = strings.ToLower(data)
 	d.RLock()
 	defer d.RUnlock()
@@ -57,7 +58,7 @@ func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
 	return netip.Addr{}
 }
 
-func (d *dnsRecords) QueryCert(data string) string {
+func (d *dnsRecords) queryCert(data string) string {
 	ip, err := netip.ParseAddr(data[:len(data)-1])
 	if err != nil {
 		return ""
@@ -122,7 +123,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
 		case dns.TypeA, dns.TypeAAAA:
 			qType := dns.TypeToString[q.Qtype]
 			d.l.Debugf("Query for %s %s", qType, q.Name)
-			ip := d.Query(q.Qtype, q.Name)
+			ip := d.query(q.Qtype, q.Name)
 			if ip.IsValid() {
 				rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
 				if err == nil {
@@ -135,7 +136,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
 				return
 			}
 			d.l.Debugf("Query for TXT %s", q.Name)
-			ip := d.QueryCert(q.Name)
+			ip := d.queryCert(q.Name)
 			if ip != "" {
 				rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
 				if err == nil {
@@ -163,18 +164,18 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
 	w.WriteMsg(m)
 }
 
-func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
+func dnsMain(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
 	dnsR = newDnsRecords(l, cs, hostMap)
 
 	// attach request handler func
 	dns.HandleFunc(".", dnsR.handleDnsRequest)
 
 	c.RegisterReloadCallback(func(c *config.C) {
-		reloadDns(l, c)
+		reloadDns(ctx, l, c)
 	})
 
 	return func() {
-		startDns(l, c)
+		startDns(ctx, l, c)
 	}
 }
 
@@ -187,24 +188,24 @@ func getDnsServerAddr(c *config.C) string {
 	return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
 }
 
-func startDns(l *logrus.Logger, c *config.C) {
+func startDns(ctx context.Context, l *logrus.Logger, c *config.C) {
 	dnsAddr = getDnsServerAddr(c)
 	dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
 	l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
 	err := dnsServer.ListenAndServe()
-	defer dnsServer.Shutdown()
+	defer dnsServer.ShutdownContext(ctx)
 	if err != nil {
 		l.Errorf("Failed to start server: %s\n ", err.Error())
 	}
 }
 
-func reloadDns(l *logrus.Logger, c *config.C) {
+func reloadDns(ctx context.Context, l *logrus.Logger, c *config.C) {
 	if dnsAddr == getDnsServerAddr(c) {
 		l.Debug("No DNS server config change detected")
 		return
 	}
 
 	l.Debug("Restarting DNS server")
-	dnsServer.Shutdown()
-	go startDns(l, c)
+	dnsServer.ShutdownContext(ctx)
+	go startDns(ctx, l, c)
 }

+ 1 - 1
main.go

@@ -284,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	var dnsStart func()
 	if lightHouse.amLighthouse && serveDns {
 		l.Debugln("Starting dns server")
-		dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
+		dnsStart = dnsMain(ctx, l, pki.getCertState(), hostMap, c)
 	}
 
 	return &Control{