Browse Source

:gear: dns: add LRU cache

Ettore Di Giacinto 3 years ago
parent
commit
a00abda3f4
3 changed files with 40 additions and 38 deletions
  1. 1 0
      go.mod
  2. 18 5
      pkg/services/dns.go
  3. 21 33
      pkg/services/dns_test.go

+ 1 - 0
go.mod

@@ -10,6 +10,7 @@ require (
 	github.com/google/btree v1.0.1 // indirect
 	github.com/gookit/color v1.5.0 // indirect
 	github.com/hashicorp/errwrap v1.1.0 // indirect
+	github.com/hashicorp/golang-lru v0.5.4 // indirect
 	github.com/ipfs/go-cid v0.1.0 // indirect
 	github.com/ipfs/go-datastore v0.5.1 // indirect
 	github.com/ipfs/go-log v1.0.5

+ 18 - 5
pkg/services/dns.go

@@ -21,6 +21,7 @@ import (
 	"net"
 	"time"
 
+	lru "github.com/hashicorp/golang-lru"
 	"github.com/miekg/dns"
 	"github.com/mudler/edgevpn/pkg/blockchain"
 	"github.com/mudler/edgevpn/pkg/node"
@@ -33,15 +34,18 @@ const (
 
 // DNS returns a network service binding a dns blockchain resolver on listenAddr.
 // Takes an associated name for the addresses in the blockchain
-func DNS(listenAddr string, forwarder bool, forward []string) []node.Option {
+func DNS(listenAddr string, forwarder bool, forward []string, cacheSize int) []node.Option {
 	return []node.Option{
 		node.WithNetworkService(
 			func(ctx context.Context, c node.Config, n *node.Node, b *blockchain.Ledger) error {
 
 				server := &dns.Server{Addr: listenAddr, Net: "udp"}
-
+				cache, err := lru.New(cacheSize)
+				if err != nil {
+					return err
+				}
 				go func() {
-					dns.HandleFunc(".", dnsHandler{ctx, b, forwarder, forward}.handleDNSRequest())
+					dns.HandleFunc(".", dnsHandler{ctx, b, forwarder, forward, cache}.handleDNSRequest())
 					fmt.Println(server.ListenAndServe())
 				}()
 
@@ -67,6 +71,7 @@ type dnsHandler struct {
 	b         *blockchain.Ledger
 	forwarder bool
 	forward   []string
+	cache     *lru.Cache
 }
 
 func (d dnsHandler) parseQuery(m *dns.Msg) {
@@ -107,14 +112,22 @@ func (d dnsHandler) handleDNSRequest() func(w dns.ResponseWriter, r *dns.Msg) {
 }
 
 func (d dnsHandler) forwardQuery(dnsMessage *dns.Msg) (*dns.Msg, error) {
+	mess := new(dns.Msg)
+	mess.Question = dnsMessage.Copy().Question
+	if len(mess.Question) > 0 {
+		if v, ok := d.cache.Get(mess.Question[0].String()); ok {
+			q := v.(*dns.Msg)
+			return q, nil
+		}
+	}
+
 	for _, server := range d.forward {
-		mess := new(dns.Msg)
-		mess.Question = dnsMessage.Copy().Question
 		r, err := QueryDNS(d.ctx, mess, server)
 		if err != nil {
 			return nil, err
 		}
 		if r == nil || r.Rcode == dns.RcodeNameError || r.Rcode == dns.RcodeSuccess {
+			d.cache.Add(mess.Question[0].String(), r)
 			return r, err
 		}
 	}

+ 21 - 33
pkg/services/dns_test.go

@@ -46,7 +46,7 @@ var _ = Describe("DNS service", func() {
 			ctx, cancel := context.WithCancel(context.Background())
 			defer cancel()
 
-			opts := DNS("127.0.0.1:19192", true, []string{"8.8.8.8:53"})
+			opts := DNS("127.0.0.1:19192", true, []string{"8.8.8.8:53"}, 10)
 			opts = append(opts, node.FromBase64(true, true, token), node.WithStore(&blockchain.MemoryStore{}), l)
 			e := node.New(opts...)
 
@@ -57,43 +57,31 @@ var _ = Describe("DNS service", func() {
 
 			AnnounceDomain(ctx, ll, 15*time.Second, 10*time.Second, "test.foo", "2.2.2.2")
 
-			Eventually(func() string {
-				var s string
-				dnsMessage := new(dns.Msg)
-				dnsMessage.SetQuestion("google.com.", dns.TypeA)
+			searchDomain := func(d string) func() string {
+				return func() string {
+					var s string
+					dnsMessage := new(dns.Msg)
+					dnsMessage.SetQuestion(fmt.Sprintf("%s.", d), dns.TypeA)
 
-				r, err := QueryDNS(ctx, dnsMessage, "127.0.0.1:19192")
-				if r != nil {
-					answers := r.Answer
-					for _, a := range answers {
+					r, err := QueryDNS(ctx, dnsMessage, "127.0.0.1:19192")
+					if r != nil {
+						answers := r.Answer
+						for _, a := range answers {
 
-						s = a.String() + s
+							s = a.String() + s
+						}
 					}
-				}
-				if err != nil {
-					fmt.Println(err)
-				}
-				return s
-			}, 230*time.Second, 1*time.Second).Should(ContainSubstring("A"))
-
-			Eventually(func() string {
-				var s string
-				dnsMessage := new(dns.Msg)
-				dnsMessage.SetQuestion("test.foo.", dns.TypeA)
-
-				r, err := QueryDNS(ctx, dnsMessage, "127.0.0.1:19192")
-				if r != nil {
-					answers := r.Answer
-					for _, a := range answers {
-						s = a.String() + s
+					if err != nil {
+						fmt.Println(err)
 					}
+					return s
 				}
-				if err != nil {
-					fmt.Println(err)
-				}
-				//	r.Answer
-				return s
-			}, 230*time.Second, 1*time.Second).Should(ContainSubstring("2.2.2.2"))
+			}
+
+			Eventually(searchDomain("google.com"), 230*time.Second, 1*time.Second).Should(ContainSubstring("A"))
+			// We hit the same record again, this time it's faster as there is a cache
+			Eventually(searchDomain("google.com"), 1*time.Second, 1*time.Second).Should(ContainSubstring("A"))
+			Eventually(searchDomain("test.foo"), 230*time.Second, 1*time.Second).Should(ContainSubstring("2.2.2.2"))
 		})
 	})
 })