|
@@ -18,7 +18,6 @@ package services
|
|
|
import (
|
|
|
"context"
|
|
|
"fmt"
|
|
|
- "net"
|
|
|
"regexp"
|
|
|
"time"
|
|
|
|
|
@@ -81,7 +80,8 @@ type dnsHandler struct {
|
|
|
cache *lru.Cache
|
|
|
}
|
|
|
|
|
|
-func (d dnsHandler) parseQuery(m *dns.Msg) {
|
|
|
+func (d dnsHandler) parseQuery(m *dns.Msg, forward bool) *dns.Msg {
|
|
|
+ response := m.Copy()
|
|
|
if len(m.Question) > 0 {
|
|
|
q := m.Question[0]
|
|
|
// Resolve the entry to an IP from the blockchain data
|
|
@@ -93,17 +93,20 @@ func (d dnsHandler) parseQuery(m *dns.Msg) {
|
|
|
if val, exists := res[dns.Type(q.Qtype)]; exists {
|
|
|
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, dns.TypeToString[q.Qtype], val))
|
|
|
if err == nil {
|
|
|
- m.Answer = append(m.Answer, rr)
|
|
|
- return
|
|
|
+ response.Answer = append(m.Answer, rr)
|
|
|
+ return response
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- r, err := d.forwardQuery(m)
|
|
|
- if err == nil {
|
|
|
- m.Answer = r.Answer
|
|
|
+ if forward {
|
|
|
+ r, err := d.forwardQuery(m)
|
|
|
+ if err == nil {
|
|
|
+ response.Answer = r.Answer
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
+ return response
|
|
|
}
|
|
|
|
|
|
func (d dnsHandler) handleDNSRequest() func(w dns.ResponseWriter, r *dns.Msg) {
|
|
@@ -111,33 +114,39 @@ func (d dnsHandler) handleDNSRequest() func(w dns.ResponseWriter, r *dns.Msg) {
|
|
|
m := new(dns.Msg)
|
|
|
m.SetReply(r)
|
|
|
m.Compress = false
|
|
|
-
|
|
|
+ var resp *dns.Msg
|
|
|
switch r.Opcode {
|
|
|
case dns.OpcodeQuery:
|
|
|
- d.parseQuery(m)
|
|
|
+ resp = d.parseQuery(r, d.forwarder)
|
|
|
}
|
|
|
|
|
|
- w.WriteMsg(m)
|
|
|
+ w.WriteMsg(resp)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func (d dnsHandler) forwardQuery(dnsMessage *dns.Msg) (*dns.Msg, error) {
|
|
|
- mess := new(dns.Msg)
|
|
|
- mess.Question = dnsMessage.Copy().Question
|
|
|
- if len(mess.Question) > 0 {
|
|
|
- if v, ok := d.cache.Get(mess.Question[0].String()); ok {
|
|
|
+ reqCopy := dnsMessage.Copy()
|
|
|
+ if len(reqCopy.Question) > 0 {
|
|
|
+ if v, ok := d.cache.Get(reqCopy.Question[0].String()); ok {
|
|
|
q := v.(*dns.Msg)
|
|
|
return q, nil
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
for _, server := range d.forward {
|
|
|
- r, err := QueryDNS(d.ctx, mess, server)
|
|
|
+ r, err := QueryDNS(d.ctx, reqCopy, server)
|
|
|
+ if r != nil && len(r.Answer) == 0 && !r.MsgHdr.Truncated {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
- if r == nil || r.Rcode == dns.RcodeNameError || r.Rcode == dns.RcodeSuccess {
|
|
|
- d.cache.Add(mess.Question[0].String(), r)
|
|
|
+
|
|
|
+ if r.Rcode == dns.RcodeSuccess {
|
|
|
+ d.cache.Add(reqCopy.Question[0].String(), r)
|
|
|
+ }
|
|
|
+
|
|
|
+ if r == nil || r.Rcode == dns.RcodeNameError || r.Rcode == dns.RcodeSuccess || err == nil {
|
|
|
return r, err
|
|
|
}
|
|
|
}
|
|
@@ -147,22 +156,10 @@ func (d dnsHandler) forwardQuery(dnsMessage *dns.Msg) (*dns.Msg, error) {
|
|
|
// QueryDNS queries a dns server with a dns message and return the answer
|
|
|
// it is blocking.
|
|
|
func QueryDNS(ctx context.Context, msg *dns.Msg, dnsServer string) (*dns.Msg, error) {
|
|
|
- c := new(dns.Conn)
|
|
|
- cc, _ := (&net.Dialer{Timeout: 35 * time.Second}).DialContext(ctx, "udp", dnsServer)
|
|
|
- c.Conn = cc
|
|
|
- defer c.Close()
|
|
|
-
|
|
|
- err := c.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- err = c.WriteMsg(msg)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- err = c.SetReadDeadline(time.Now().Add(30 * time.Second))
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- return c.ReadMsg()
|
|
|
+ client := &dns.Client{
|
|
|
+ Net: "udp",
|
|
|
+ Timeout: 30 * time.Second,
|
|
|
+ SingleInflight: true}
|
|
|
+ r, _, err := client.Exchange(msg, dnsServer)
|
|
|
+ return r, err
|
|
|
}
|