浏览代码

Return NXDOMAIN for unknown labels

Ask Bjørn Hansen 13 年之前
父节点
当前提交
f1168cfb6c
共有 2 个文件被更改,包括 50 次插入21 次删除
  1. 15 3
      geodns.go
  2. 35 18
      serve.go

+ 15 - 3
geodns.go

@@ -8,6 +8,7 @@ import (
 	"io/ioutil"
 	"net"
 	"strconv"
+	//"strings"
 )
 
 type Options struct {
@@ -30,8 +31,9 @@ type Label struct {
 type labels map[string]*Label
 
 type Zone struct {
-	Origin string
-	Labels labels
+	Origin    string
+	Labels    labels
+	LenLabels int
 }
 
 var (
@@ -40,6 +42,13 @@ var (
 	flagrun = flag.Bool("run", false, "run server")
 )
 
+func (z *Zone) findLabels(s string) *Label {
+	if label, ok := z.Labels[s]; ok {
+		return label
+	}
+	return nil
+}
+
 func main() {
 
 	flag.Usage = func() {
@@ -50,7 +59,10 @@ func main() {
 	Zone := new(Zone)
 	Zone.Labels = make(labels)
 
-	Zone.Origin = "ntppool.org" // TODO, read multiple files etc
+	// BUG(ask) Doesn't read multiple .json zone files yet
+	Zone.Origin = "ntppool.org"
+	Zone.LenLabels = dns.LenLabels(Zone.Origin)
+
 	Options := new(Options)
 
 	//var objmap map[string]json.RawMessage

+ 35 - 18
serve.go

@@ -6,13 +6,23 @@ import (
 	"log"
 	"os"
 	"os/signal"
+	"strings"
 )
 
+func getQuestionName(z *Zone, req *dns.Msg) string {
+	lx := dns.SplitLabels(req.Question[0].Name)
+	ql := lx[0 : len(lx)-z.LenLabels-1]
+	//fmt.Println("LX:", ql, lx, z.LenLabels)
+	return strings.Join(ql, ".")
+}
+
 func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone, opt *Options) {
 	logPrintf("[zone %s] incoming %s %s %d from %s\n", z.Origin, req.Question[0].Name, dns.Rr_str[req.Question[0].Qtype], req.MsgHdr.Id, w.RemoteAddr())
 
 	fmt.Println("Got request", req)
 
+	label := getQuestionName(z, req)
+
 	raddr := w.RemoteAddr()
 
 	gi := setupGeoIP()
@@ -24,25 +34,32 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone, opt *Options) {
 	m.MsgHdr.Authoritative = true
 
 	// TODO: Function to find appropriate label with records
-	if region, ok := z.Labels[""]; ok {
-		fmt.Println("REG", region)
-		if region_rr := region.Records[req.Question[0].Qtype]; region_rr != nil {
-			//fmt.Printf("REGION_RR %T %v\n", region_rr, region_rr)
-			max := len(region_rr)
-			if max > 4 {
-				max = 4
-			}
-			servers := region_rr[0:max]
-			var rrs []dns.RR
-			for _, record := range servers {
-				rr := record.RR
-				fmt.Println("RR", rr)
-				rr.Header().Name = req.Question[0].Name
-				fmt.Println(rr)
-				rrs = append(rrs, rr)
-			}
-			m.Answer = rrs
+	labels := z.findLabels(label)
+	if labels == nil {
+		// return NXDOMAIN
+		m.SetRcode(req, dns.RcodeNameError)
+		ednsFromRequest(req, m)
+		w.Write(m)
+		return
+	}
+
+	//fmt.Println("REG", region)
+	if region_rr := labels.Records[req.Question[0].Qtype]; region_rr != nil {
+		//fmt.Printf("REGION_RR %T %v\n", region_rr, region_rr)
+		max := len(region_rr)
+		if max > 4 {
+			max = 4
+		}
+		servers := region_rr[0:max]
+		var rrs []dns.RR
+		for _, record := range servers {
+			rr := record.RR
+			fmt.Println("RR", rr)
+			rr.Header().Name = req.Question[0].Name
+			fmt.Println(rr)
+			rrs = append(rrs, rr)
 		}
+		m.Answer = rrs
 	}
 
 	ednsFromRequest(req, m)