Browse Source

Support ANY queries

Ask Bjørn Hansen 13 years ago
parent
commit
586b42ba92
2 changed files with 26 additions and 4 deletions
  1. 21 4
      picker.go
  2. 5 0
      types.go

+ 21 - 4
picker.go

@@ -1,15 +1,32 @@
 package main
 package main
 
 
 import (
 import (
+	"github.com/miekg/dns"
 	"math/rand"
 	"math/rand"
 )
 )
 
 
-func (label *Label) Picker(dnsType uint16, max int) Records {
+func (label *Label) Picker(qtype uint16, max int) Records {
 
 
-	if label_rr := label.Records[dnsType]; label_rr != nil {
+	if qtype == dns.TypeANY {
+		result := make([]Record, 0)
+		for rtype, _ := range label.Records {
+
+			rtype_records := label.Picker(rtype, max)
+
+			tmp_result := make(Records, len(result)+len(rtype_records))
+
+			copy(tmp_result, result)
+			copy(tmp_result[len(result):], rtype_records)
+			result = tmp_result
+		}
+
+		return result
+	}
+
+	if label_rr := label.Records[qtype]; label_rr != nil {
 
 
 		// not "balanced", just return all
 		// not "balanced", just return all
-		if label.Weight[dnsType] == 0 {
+		if label.Weight[qtype] == 0 {
 			return label_rr
 			return label_rr
 		}
 		}
 
 
@@ -21,7 +38,7 @@ func (label *Label) Picker(dnsType uint16, max int) Records {
 		servers := make([]Record, len(label_rr))
 		servers := make([]Record, len(label_rr))
 		copy(servers, label_rr)
 		copy(servers, label_rr)
 		result := make([]Record, max)
 		result := make([]Record, max)
-		sum := label.Weight[dnsType]
+		sum := label.Weight[qtype]
 
 
 		for si := 0; si < max; si++ {
 		for si := 0; si < max; si++ {
 			n := rand.Intn(sum + 1)
 			n := rand.Intn(sum + 1)

+ 5 - 0
types.go

@@ -52,6 +52,11 @@ func (z *Zone) SoaRR() dns.RR {
 
 
 func (z *Zone) findLabels(s, cc string, qtype uint16) *Label {
 func (z *Zone) findLabels(s, cc string, qtype uint16) *Label {
 
 
+	if qtype == dns.TypeANY {
+		// short-circuit mostly to avoid subtle bugs later
+		return z.Labels[s]
+	}
+
 	selectors := []string{}
 	selectors := []string{}
 
 
 	if len(cc) > 0 {
 	if len(cc) > 0 {