Pārlūkot izejas kodu

Serve all .json zones from the dns/ directory

Ask Bjørn Hansen 13 gadi atpakaļ
vecāks
revīzija
98d61de8a0
2 mainītis faili ar 58 papildinājumiem un 19 dzēšanām
  1. 37 11
      geodns.go
  2. 21 8
      serve.go

+ 37 - 11
geodns.go

@@ -7,8 +7,9 @@ import (
 	"fmt"
 	"io/ioutil"
 	"net"
+	"path"
 	"strconv"
-	//"strings"
+	"strings"
 )
 
 type Options struct {
@@ -30,6 +31,8 @@ type Label struct {
 
 type labels map[string]*Label
 
+type Zones map[string]*Zone
+
 type Zone struct {
 	Origin    string
 	Labels    labels
@@ -57,18 +60,43 @@ func main() {
 	}
 	flag.Parse()
 
-	Zone := new(Zone)
-	Zone.Labels = make(labels)
+	dirName := "dns"
 
-	// BUG(ask) Doesn't read multiple .json zone files yet
-	Zone.Origin = "ntppool.org"
-	Zone.LenLabels = dns.LenLabels(Zone.Origin)
+	dir, err := ioutil.ReadDir(dirName)
+	if err != nil {
+		panic(err)
+	}
+
+	Zones := make(Zones)
+
+	for i, file := range dir {
+		fileName := file.Name()
+		if !strings.HasSuffix(strings.ToLower(fileName), ".json") {
+			continue
+		}
+		zoneName := fileName[0:strings.LastIndex(fileName, ".")]
+		fmt.Println("FILE:", i, file, zoneName)
+		config := readZoneFile(path.Join(dirName, fileName))
+		config.Origin = zoneName
+		Zones[zoneName] = config
+	}
+
+	fmt.Println("ZONES", Zones)
+
+	runServe(&Zones)
+}
+
+func readZoneFile(fileName string) *Zone {
 
-	b, err := ioutil.ReadFile("ntppool.org.json")
+	b, err := ioutil.ReadFile(fileName)
 	if err != nil {
 		panic(err)
 	}
 
+	Zone := new(Zone)
+	Zone.Labels = make(labels)
+	Zone.LenLabels = dns.LenLabels(Zone.Origin)
+
 	if err == nil {
 		var objmap map[string]interface{}
 		err := json.Unmarshal(b, &objmap)
@@ -105,7 +133,7 @@ func main() {
 
 	//fmt.Println("IP", string(Zone.Regions["0.us"].IPv4[0].ip))
 
-	runServe(Zone)
+	return Zone
 }
 
 func setupZoneData(data map[string]interface{}, Zone *Zone) {
@@ -127,12 +155,10 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 
 		for rType, dnsType := range recordTypes {
 
-			fmt.Println(rType, dnsType)
-
 			var rdata = dv.(map[string]interface{})[rType]
 
 			if rdata == nil {
-				fmt.Printf("No %s records for label %s\n", rType, dk)
+				//fmt.Printf("No %s records for label %s\n", rType, dk)
 				continue
 			}
 

+ 21 - 8
serve.go

@@ -12,12 +12,14 @@ import (
 func getQuestionName(z *Zone, req *dns.Msg) string {
 	lx := dns.SplitLabels(req.Question[0].Name)
 	ql := lx[0 : len(lx)-z.LenLabels-1]
-	//fmt.Println("LX:", ql, lx, z.LenLabels)
+	fmt.Println("LX:", ql, lx, z.LenLabels)
 	return strings.Join(ql, ".")
 }
 
-func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
-	logPrintf("[zone %s] incoming %s %s %d from %s\n", z.Origin, req.Question[0].Name, dns.Rr_str[req.Question[0].Qtype], req.MsgHdr.Id, w.RemoteAddr())
+func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone, zoneName string) {
+	logPrintf("[zone %s/%s] incoming %s %s %d from %s\n", zoneName, z.Origin, req.Question[0].Name, dns.Rr_str[req.Question[0].Qtype], req.MsgHdr.Id, w.RemoteAddr())
+
+	fmt.Printf("ZONE DATA  %#v\n", z)
 
 	fmt.Println("Got request", req)
 
@@ -33,7 +35,7 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 	m.SetReply(req)
 	m.MsgHdr.Authoritative = true
 
-	// TODO: Function to find appropriate label with records
+	// TODO(ask): Function to find appropriate label with records based on the country/continent	
 	labels := z.findLabels(label)
 	if labels == nil {
 		// return NXDOMAIN
@@ -43,13 +45,15 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 		return
 	}
 
-	//fmt.Println("REG", region)
+	fmt.Println("Has the label, looking for records")
+
 	if region_rr := labels.Records[req.Question[0].Qtype]; region_rr != nil {
 		//fmt.Printf("REGION_RR %T %v\n", region_rr, region_rr)
 		max := len(region_rr)
 		if max > 4 {
 			max = 4
 		}
+		// TODO(ask) Pick random servers based on weight, not just the first 'max' entries
 		servers := region_rr[0:max]
 		var rrs []dns.RR
 		for _, record := range servers {
@@ -62,15 +66,24 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 		m.Answer = rrs
 	}
 
+	fmt.Println("Writing reply")
+
 	ednsFromRequest(req, m)
 	w.Write(m)
 	return
 }
 
-func runServe(Zone *Zone) {
+func runServe(Zones *Zones) {
 
-	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { serve(w, r, Zone) })
-	// Only listen on UDP
+	for zoneName, Zone := range *Zones {
+		// BUG(ask) For some reason the closure here gets setup so only the second zone gets used
+		fmt.Printf("Configuring zoneName %s %#v\n", zoneName, Zone)
+		dns.HandleFunc(zoneName, func(w dns.ResponseWriter, r *dns.Msg) {
+			fmt.Println("Going to call serve with", zoneName)
+			serve(w, r, Zone, zoneName)
+		})
+	}
+	// Only listen on UDP for now
 	go func() {
 		if err := dns.ListenAndServe(*listen, "udp", nil); err != nil {
 			log.Fatalf("geodns: failed to setup %s %s", *listen, "udp")