Browse Source

:gear: Add logger to DNS and fixup msg response

Ettore Di Giacinto 3 years ago
parent
commit
2095720230
4 changed files with 19 additions and 14 deletions
  1. 1 1
      cmd/dns.go
  2. 1 1
      cmd/main.go
  3. 15 10
      pkg/services/dns.go
  4. 2 2
      pkg/services/dns_test.go

+ 1 - 1
cmd/dns.go

@@ -61,7 +61,7 @@ func DNS() cli.Command {
 			dns := c.String("listen")
 			dns := c.String("listen")
 			// Adds DNS Server
 			// Adds DNS Server
 			o = append(o,
 			o = append(o,
-				services.DNS(dns,
+				services.DNS(ll, dns,
 					c.Bool("dns-forwarder"),
 					c.Bool("dns-forwarder"),
 					c.StringSlice("dns-forward-server"),
 					c.StringSlice("dns-forward-server"),
 					c.Int("dns-cache-size"),
 					c.Int("dns-cache-size"),

+ 1 - 1
cmd/main.go

@@ -201,7 +201,7 @@ func Main() func(c *cli.Context) error {
 		if dns != "" {
 		if dns != "" {
 			// Adds DNS Server
 			// Adds DNS Server
 			o = append(o,
 			o = append(o,
-				services.DNS(dns,
+				services.DNS(ll, dns,
 					c.Bool("dns-forwarder"),
 					c.Bool("dns-forwarder"),
 					c.StringSlice("dns-forward-server"),
 					c.StringSlice("dns-forward-server"),
 					c.Int("dns-cache-size"),
 					c.Int("dns-cache-size"),

+ 15 - 10
pkg/services/dns.go

@@ -22,6 +22,7 @@ import (
 	"time"
 	"time"
 
 
 	lru "github.com/hashicorp/golang-lru"
 	lru "github.com/hashicorp/golang-lru"
+	"github.com/ipfs/go-log"
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
 	"github.com/mudler/edgevpn/pkg/blockchain"
 	"github.com/mudler/edgevpn/pkg/blockchain"
 	"github.com/mudler/edgevpn/pkg/node"
 	"github.com/mudler/edgevpn/pkg/node"
@@ -30,7 +31,7 @@ import (
 	"github.com/pkg/errors"
 	"github.com/pkg/errors"
 )
 )
 
 
-func DNSNetworkService(listenAddr string, forwarder bool, forward []string, cacheSize int) node.NetworkService {
+func DNSNetworkService(ll log.StandardLogger, listenAddr string, forwarder bool, forward []string, cacheSize int) node.NetworkService {
 	return func(ctx context.Context, c node.Config, n *node.Node, b *blockchain.Ledger) error {
 	return func(ctx context.Context, c node.Config, n *node.Node, b *blockchain.Ledger) error {
 		server := &dns.Server{Addr: listenAddr, Net: "udp"}
 		server := &dns.Server{Addr: listenAddr, Net: "udp"}
 		cache, err := lru.New(cacheSize)
 		cache, err := lru.New(cacheSize)
@@ -38,7 +39,7 @@ func DNSNetworkService(listenAddr string, forwarder bool, forward []string, cach
 			return err
 			return err
 		}
 		}
 		go func() {
 		go func() {
-			dns.HandleFunc(".", dnsHandler{ctx, b, forwarder, forward, cache}.handleDNSRequest())
+			dns.HandleFunc(".", dnsHandler{ctx, b, forwarder, forward, cache, ll}.handleDNSRequest())
 			fmt.Println(server.ListenAndServe())
 			fmt.Println(server.ListenAndServe())
 		}()
 		}()
 
 
@@ -53,9 +54,9 @@ func DNSNetworkService(listenAddr string, forwarder bool, forward []string, cach
 
 
 // DNS returns a network service binding a dns blockchain resolver on listenAddr.
 // DNS returns a network service binding a dns blockchain resolver on listenAddr.
 // Takes an associated name for the addresses in the blockchain
 // Takes an associated name for the addresses in the blockchain
-func DNS(listenAddr string, forwarder bool, forward []string, cacheSize int) []node.Option {
+func DNS(ll log.StandardLogger, listenAddr string, forwarder bool, forward []string, cacheSize int) []node.Option {
 	return []node.Option{
 	return []node.Option{
-		node.WithNetworkService(DNSNetworkService(listenAddr, forwarder, forward, cacheSize)),
+		node.WithNetworkService(DNSNetworkService(ll, listenAddr, forwarder, forward, cacheSize)),
 	}
 	}
 }
 }
 
 
@@ -78,10 +79,12 @@ type dnsHandler struct {
 	forwarder bool
 	forwarder bool
 	forward   []string
 	forward   []string
 	cache     *lru.Cache
 	cache     *lru.Cache
+	ll        log.StandardLogger
 }
 }
 
 
 func (d dnsHandler) parseQuery(m *dns.Msg, forward bool) *dns.Msg {
 func (d dnsHandler) parseQuery(m *dns.Msg, forward bool) *dns.Msg {
 	response := m.Copy()
 	response := m.Copy()
+	d.ll.Debug("Received DNS request", m)
 	if len(m.Question) > 0 {
 	if len(m.Question) > 0 {
 		q := m.Question[0]
 		q := m.Question[0]
 		// Resolve the entry to an IP from the blockchain data
 		// Resolve the entry to an IP from the blockchain data
@@ -94,16 +97,19 @@ func (d dnsHandler) parseQuery(m *dns.Msg, forward bool) *dns.Msg {
 					rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, dns.TypeToString[q.Qtype], val))
 					rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, dns.TypeToString[q.Qtype], val))
 					if err == nil {
 					if err == nil {
 						response.Answer = append(m.Answer, rr)
 						response.Answer = append(m.Answer, rr)
+						d.ll.Debug("Response from blockchain", response)
 						return response
 						return response
 					}
 					}
 				}
 				}
 			}
 			}
 		}
 		}
 		if forward {
 		if forward {
+			d.ll.Debug("Forwarding DNS request", m)
 			r, err := d.forwardQuery(m)
 			r, err := d.forwardQuery(m)
 			if err == nil {
 			if err == nil {
 				response.Answer = r.Answer
 				response.Answer = r.Answer
 			}
 			}
+			d.ll.Debug("Response from fw server", r)
 		}
 		}
 	}
 	}
 	return response
 	return response
@@ -111,15 +117,13 @@ func (d dnsHandler) parseQuery(m *dns.Msg, forward bool) *dns.Msg {
 
 
 func (d dnsHandler) handleDNSRequest() func(w dns.ResponseWriter, r *dns.Msg) {
 func (d dnsHandler) handleDNSRequest() func(w dns.ResponseWriter, r *dns.Msg) {
 	return func(w dns.ResponseWriter, r *dns.Msg) {
 	return func(w dns.ResponseWriter, r *dns.Msg) {
-		m := new(dns.Msg)
-		m.SetReply(r)
-		m.Compress = false
 		var resp *dns.Msg
 		var resp *dns.Msg
 		switch r.Opcode {
 		switch r.Opcode {
 		case dns.OpcodeQuery:
 		case dns.OpcodeQuery:
 			resp = d.parseQuery(r, d.forwarder)
 			resp = d.parseQuery(r, d.forwarder)
 		}
 		}
-
+		resp.SetReply(r)
+		resp.Compress = false
 		w.WriteMsg(resp)
 		w.WriteMsg(resp)
 	}
 	}
 }
 }
@@ -129,6 +133,7 @@ func (d dnsHandler) forwardQuery(dnsMessage *dns.Msg) (*dns.Msg, error) {
 	if len(reqCopy.Question) > 0 {
 	if len(reqCopy.Question) > 0 {
 		if v, ok := d.cache.Get(reqCopy.Question[0].String()); ok {
 		if v, ok := d.cache.Get(reqCopy.Question[0].String()); ok {
 			q := v.(*dns.Msg)
 			q := v.(*dns.Msg)
+			q.Id = reqCopy.Id
 			return q, nil
 			return q, nil
 		}
 		}
 	}
 	}
@@ -139,14 +144,14 @@ func (d dnsHandler) forwardQuery(dnsMessage *dns.Msg) (*dns.Msg, error) {
 		}
 		}
 
 
 		if err != nil {
 		if err != nil {
-			return nil, err
+			continue
 		}
 		}
 
 
 		if r.Rcode == dns.RcodeSuccess {
 		if r.Rcode == dns.RcodeSuccess {
 			d.cache.Add(reqCopy.Question[0].String(), r)
 			d.cache.Add(reqCopy.Question[0].String(), r)
 		}
 		}
 
 
-		if r == nil || r.Rcode == dns.RcodeNameError || r.Rcode == dns.RcodeSuccess || err == nil {
+		if r.Rcode == dns.RcodeNameError || r.Rcode == dns.RcodeSuccess || err == nil {
 			return r, err
 			return r, err
 		}
 		}
 	}
 	}

+ 2 - 2
pkg/services/dns_test.go

@@ -47,9 +47,9 @@ var _ = Describe("DNS service", func() {
 			ctx, cancel := context.WithCancel(context.Background())
 			ctx, cancel := context.WithCancel(context.Background())
 			defer cancel()
 			defer cancel()
 
 
-			opts := DNS("127.0.0.1:19192", true, []string{"8.8.8.8:53"}, 10)
+			opts := DNS(logg, "127.0.0.1:19192", true, []string{"8.8.8.8:53"}, 10)
 			opts = append(opts, node.FromBase64(true, true, token), node.WithStore(&blockchain.MemoryStore{}), l)
 			opts = append(opts, node.FromBase64(true, true, token), node.WithStore(&blockchain.MemoryStore{}), l)
-			e, _  := node.New(opts...)
+			e, _ := node.New(opts...)
 
 
 			e.Start(ctx)
 			e.Start(ctx)
 			e2.Start(ctx)
 			e2.Start(ctx)