فهرست منبع

:gear: Simplify DNS forwarding, respect user-options

The forward flag was ignored before. Simplifies overall DNS Querying
removing redundant client code. Also fix a bug while the message
response was parsed for forwarding requests
Ettore Di Giacinto 3 سال پیش
والد
کامیت
9ed2716b17
1فایلهای تغییر یافته به همراه33 افزوده شده و 36 حذف شده
  1. 33 36
      pkg/services/dns.go

+ 33 - 36
pkg/services/dns.go

@@ -18,7 +18,6 @@ package services
 import (
 	"context"
 	"fmt"
-	"net"
 	"regexp"
 	"time"
 
@@ -81,7 +80,8 @@ type dnsHandler struct {
 	cache     *lru.Cache
 }
 
-func (d dnsHandler) parseQuery(m *dns.Msg) {
+func (d dnsHandler) parseQuery(m *dns.Msg, forward bool) *dns.Msg {
+	response := m.Copy()
 	if len(m.Question) > 0 {
 		q := m.Question[0]
 		// Resolve the entry to an IP from the blockchain data
@@ -93,17 +93,20 @@ func (d dnsHandler) parseQuery(m *dns.Msg) {
 				if val, exists := res[dns.Type(q.Qtype)]; exists {
 					rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, dns.TypeToString[q.Qtype], val))
 					if err == nil {
-						m.Answer = append(m.Answer, rr)
-						return
+						response.Answer = append(m.Answer, rr)
+						return response
 					}
 				}
 			}
 		}
-		r, err := d.forwardQuery(m)
-		if err == nil {
-			m.Answer = r.Answer
+		if forward {
+			r, err := d.forwardQuery(m)
+			if err == nil {
+				response.Answer = r.Answer
+			}
 		}
 	}
+	return response
 }
 
 func (d dnsHandler) handleDNSRequest() func(w dns.ResponseWriter, r *dns.Msg) {
@@ -111,33 +114,39 @@ func (d dnsHandler) handleDNSRequest() func(w dns.ResponseWriter, r *dns.Msg) {
 		m := new(dns.Msg)
 		m.SetReply(r)
 		m.Compress = false
-
+		var resp *dns.Msg
 		switch r.Opcode {
 		case dns.OpcodeQuery:
-			d.parseQuery(m)
+			resp = d.parseQuery(r, d.forwarder)
 		}
 
-		w.WriteMsg(m)
+		w.WriteMsg(resp)
 	}
 }
 
 func (d dnsHandler) forwardQuery(dnsMessage *dns.Msg) (*dns.Msg, error) {
-	mess := new(dns.Msg)
-	mess.Question = dnsMessage.Copy().Question
-	if len(mess.Question) > 0 {
-		if v, ok := d.cache.Get(mess.Question[0].String()); ok {
+	reqCopy := dnsMessage.Copy()
+	if len(reqCopy.Question) > 0 {
+		if v, ok := d.cache.Get(reqCopy.Question[0].String()); ok {
 			q := v.(*dns.Msg)
 			return q, nil
 		}
 	}
-
 	for _, server := range d.forward {
-		r, err := QueryDNS(d.ctx, mess, server)
+		r, err := QueryDNS(d.ctx, reqCopy, server)
+		if r != nil && len(r.Answer) == 0 && !r.MsgHdr.Truncated {
+			continue
+		}
+
 		if err != nil {
 			return nil, err
 		}
-		if r == nil || r.Rcode == dns.RcodeNameError || r.Rcode == dns.RcodeSuccess {
-			d.cache.Add(mess.Question[0].String(), r)
+
+		if r.Rcode == dns.RcodeSuccess {
+			d.cache.Add(reqCopy.Question[0].String(), r)
+		}
+
+		if r == nil || r.Rcode == dns.RcodeNameError || r.Rcode == dns.RcodeSuccess || err == nil {
 			return r, err
 		}
 	}
@@ -147,22 +156,10 @@ func (d dnsHandler) forwardQuery(dnsMessage *dns.Msg) (*dns.Msg, error) {
 // QueryDNS queries a dns server with a dns message and return the answer
 // it is blocking.
 func QueryDNS(ctx context.Context, msg *dns.Msg, dnsServer string) (*dns.Msg, error) {
-	c := new(dns.Conn)
-	cc, _ := (&net.Dialer{Timeout: 35 * time.Second}).DialContext(ctx, "udp", dnsServer)
-	c.Conn = cc
-	defer c.Close()
-
-	err := c.SetWriteDeadline(time.Now().Add(30 * time.Second))
-	if err != nil {
-		return nil, err
-	}
-	err = c.WriteMsg(msg)
-	if err != nil {
-		return nil, err
-	}
-	err = c.SetReadDeadline(time.Now().Add(30 * time.Second))
-	if err != nil {
-		return nil, err
-	}
-	return c.ReadMsg()
+	client := &dns.Client{
+		Net:            "udp",
+		Timeout:        30 * time.Second,
+		SingleInflight: true}
+	r, _, err := client.Exchange(msg, dnsServer)
+	return r, err
 }