|
@@ -1,6 +1,7 @@
|
|
package nebula
|
|
package nebula
|
|
|
|
|
|
import (
|
|
import (
|
|
|
|
+ "context"
|
|
"fmt"
|
|
"fmt"
|
|
"net"
|
|
"net"
|
|
"net/netip"
|
|
"net/netip"
|
|
@@ -39,7 +40,7 @@ func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecord
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
|
|
|
|
|
+func (d *dnsRecords) query(q uint16, data string) netip.Addr {
|
|
data = strings.ToLower(data)
|
|
data = strings.ToLower(data)
|
|
d.RLock()
|
|
d.RLock()
|
|
defer d.RUnlock()
|
|
defer d.RUnlock()
|
|
@@ -57,7 +58,7 @@ func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
|
return netip.Addr{}
|
|
return netip.Addr{}
|
|
}
|
|
}
|
|
|
|
|
|
-func (d *dnsRecords) QueryCert(data string) string {
|
|
|
|
|
|
+func (d *dnsRecords) queryCert(data string) string {
|
|
ip, err := netip.ParseAddr(data[:len(data)-1])
|
|
ip, err := netip.ParseAddr(data[:len(data)-1])
|
|
if err != nil {
|
|
if err != nil {
|
|
return ""
|
|
return ""
|
|
@@ -122,7 +123,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
|
case dns.TypeA, dns.TypeAAAA:
|
|
case dns.TypeA, dns.TypeAAAA:
|
|
qType := dns.TypeToString[q.Qtype]
|
|
qType := dns.TypeToString[q.Qtype]
|
|
d.l.Debugf("Query for %s %s", qType, q.Name)
|
|
d.l.Debugf("Query for %s %s", qType, q.Name)
|
|
- ip := d.Query(q.Qtype, q.Name)
|
|
|
|
|
|
+ ip := d.query(q.Qtype, q.Name)
|
|
if ip.IsValid() {
|
|
if ip.IsValid() {
|
|
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
|
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
|
if err == nil {
|
|
if err == nil {
|
|
@@ -135,7 +136,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
d.l.Debugf("Query for TXT %s", q.Name)
|
|
d.l.Debugf("Query for TXT %s", q.Name)
|
|
- ip := d.QueryCert(q.Name)
|
|
|
|
|
|
+ ip := d.queryCert(q.Name)
|
|
if ip != "" {
|
|
if ip != "" {
|
|
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
|
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
|
if err == nil {
|
|
if err == nil {
|
|
@@ -163,18 +164,18 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|
w.WriteMsg(m)
|
|
w.WriteMsg(m)
|
|
}
|
|
}
|
|
|
|
|
|
-func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
|
|
|
|
|
+func dnsMain(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
|
dnsR = newDnsRecords(l, cs, hostMap)
|
|
dnsR = newDnsRecords(l, cs, hostMap)
|
|
|
|
|
|
// attach request handler func
|
|
// attach request handler func
|
|
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
|
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
|
|
|
|
|
c.RegisterReloadCallback(func(c *config.C) {
|
|
c.RegisterReloadCallback(func(c *config.C) {
|
|
- reloadDns(l, c)
|
|
|
|
|
|
+ reloadDns(ctx, l, c)
|
|
})
|
|
})
|
|
|
|
|
|
return func() {
|
|
return func() {
|
|
- startDns(l, c)
|
|
|
|
|
|
+ startDns(ctx, l, c)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -187,24 +188,24 @@ func getDnsServerAddr(c *config.C) string {
|
|
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
|
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
|
}
|
|
}
|
|
|
|
|
|
-func startDns(l *logrus.Logger, c *config.C) {
|
|
|
|
|
|
+func startDns(ctx context.Context, l *logrus.Logger, c *config.C) {
|
|
dnsAddr = getDnsServerAddr(c)
|
|
dnsAddr = getDnsServerAddr(c)
|
|
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
|
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
|
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
|
|
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
|
|
err := dnsServer.ListenAndServe()
|
|
err := dnsServer.ListenAndServe()
|
|
- defer dnsServer.Shutdown()
|
|
|
|
|
|
+ defer dnsServer.ShutdownContext(ctx)
|
|
if err != nil {
|
|
if err != nil {
|
|
l.Errorf("Failed to start server: %s\n ", err.Error())
|
|
l.Errorf("Failed to start server: %s\n ", err.Error())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-func reloadDns(l *logrus.Logger, c *config.C) {
|
|
|
|
|
|
+func reloadDns(ctx context.Context, l *logrus.Logger, c *config.C) {
|
|
if dnsAddr == getDnsServerAddr(c) {
|
|
if dnsAddr == getDnsServerAddr(c) {
|
|
l.Debug("No DNS server config change detected")
|
|
l.Debug("No DNS server config change detected")
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
l.Debug("Restarting DNS server")
|
|
l.Debug("Restarting DNS server")
|
|
- dnsServer.Shutdown()
|
|
|
|
- go startDns(l, c)
|
|
|
|
|
|
+ dnsServer.ShutdownContext(ctx)
|
|
|
|
+ go startDns(ctx, l, c)
|
|
}
|
|
}
|