Преглед на файлове

Basic AAAA support, move serve function to separate file, start geoip support

Ask Bjørn Hansen преди 13 години
родител
ревизия
9cbea63af7
променени са 3 файла, в които са добавени 183 реда и са изтрити 112 реда
  1. 79 112
      geodns.go
  2. 17 0
      geoip.go
  3. 87 0
      serve.go

+ 79 - 112
geodns.go

@@ -6,10 +6,7 @@ import (
 	"flag"
 	"fmt"
 	"io/ioutil"
-	"log"
 	"net"
-	"os"
-	"os/signal"
 	"strconv"
 )
 
@@ -43,50 +40,6 @@ var (
 	flagrun = flag.Bool("run", false, "run server")
 )
 
-func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone, opt *Options) {
-	logPrintf("[zone %s] incoming %s %s %d from %s\n", z.Origin, req.Question[0].Name, dns.Rr_str[req.Question[0].Qtype], req.MsgHdr.Id, w.RemoteAddr())
-
-	fmt.Println("Got request", req)
-
-	m := new(dns.Msg)
-	m.SetReply(req)
-	m.MsgHdr.Authoritative = true
-
-	// TODO: Function to find appropriate label with records
-	if region, ok := z.Labels[""]; ok {
-		if region_rr := region.Records[req.Question[0].Qtype]; region_rr != nil {
-			//fmt.Printf("REGION_RR %T %v\n", region_rr, region_rr)
-			max := len(region_rr)
-			if max > 4 {
-				max = 4
-			}
-			servers := region_rr[0:max]
-			var rrs []dns.RR
-			for _, record := range servers {
-				rr := record.RR
-				rr.Header().Name = req.Question[0].Name
-				fmt.Println(rr)
-				rrs = append(rrs, rr)
-			}
-			m.Answer = rrs
-		}
-	}
-
-	ednsFromRequest(req, m)
-	w.Write(m)
-	return
-}
-
-func ednsFromRequest(req, m *dns.Msg) {
-	for _, r := range req.Extra {
-		if r.Header().Rrtype == dns.TypeOPT {
-			m.SetEdns0(4096, r.(*dns.RR_OPT).Do())
-			return
-		}
-	}
-	return
-}
-
 func main() {
 
 	flag.Usage = func() {
@@ -115,6 +68,8 @@ func main() {
 		}
 		//fmt.Println(objmap)
 
+		var data map[string]interface{}
+
 		for k, v := range objmap {
 			fmt.Printf("k: %s v: %#v, T: %T\n", k, v, v)
 
@@ -129,99 +84,111 @@ func main() {
 				continue
 
 			case "data":
+				data = v.(map[string]interface{})
+			}
+		}
 
-				// fmt.Println("V", v)
+		setupZoneData(data, Zone, Options)
 
-				var data map[string]interface{}
-				data = v.(map[string]interface{})
-				//fmt.Println("DATA", data)
+	}
+
+	//fmt.Printf("ZO T: %T %s\n", Zones["0.us"], Zones["0.us"])
 
-				for dk, dv := range data {
+	//fmt.Println("IP", string(Zone.Regions["0.us"].IPv4[0].ip))
 
-					fmt.Printf("K %s V %s TYPE-V %T\n", dk, dv, dv)
+	runServe(Zone, Options)
+}
 
-					Zone.Labels[dk] = new(Label)
-					label := Zone.Labels[dk]
-					//make([]Server, len(Records))
+func setupZoneData(data map[string]interface{}, Zone *Zone, Options *Options) {
 
-					var a = dv.(map[string]interface{})["a"]
+	var recordTypes = map[string]uint16{
+		"a":    dns.TypeA,
+		"aaaa": dns.TypeAAAA,
+	}
 
-					if a == nil {
-						fmt.Println("No A records, continue..")
-						continue
-					}
+	for dk, dv := range data {
 
-					//					fmt.Println("A", a)
-					fmt.Printf("A %s TYPE-A %T\n", a, a)
+		fmt.Printf("K %s V %s TYPE-V %T\n", dk, dv, dv)
 
-					Records := make(map[string][]interface{})
+		Zone.Labels[dk] = new(Label)
+		label := Zone.Labels[dk]
+		//make([]Server, len(Records))
 
-					Records["a"] = a.([]interface{})
+		for rType, dnsType := range recordTypes {
+			fmt.Println(rType, dnsType)
 
-					//fmt.Printf("RECORDS %s TYPE-REC %T\n", Records, Records)
+			var rdata = dv.(map[string]interface{})[rType]
 
-					if label.Records == nil {
-						label.Records = make(map[uint16][]Record)
-					}
+			if rdata == nil {
+				fmt.Printf("No %s records for label %s", rType, dk)
+				continue
+			}
 
-					label.Records[dns.TypeA] = make([]Record, len(Records["a"]))
+			fmt.Printf("rdata %s TYPE-R %T\n", rdata, rdata)
 
-					for i := 0; i < len(Records["a"]); i++ {
-						foo := Records["a"][i].([]interface{})
-						//fmt.Printf("FOO TYPE %T %s\n", foo, foo)
-						record := new(Record)
-						ip := foo[0].(string)
+			Records := make(map[string][]interface{})
 
-						record.Weight, err = strconv.Atoi(foo[1].(string))
+			Records[rType] = rdata.([]interface{})
 
-						var h dns.RR_Header
-						h.Ttl = uint32(Options.Ttl)
-						h.Class = dns.ClassINET
+			//fmt.Printf("RECORDS %s TYPE-REC %T\n", Records, Records)
 
-						h.Rrtype = dns.TypeA
+			if label.Records == nil {
+				label.Records = make(map[uint16][]Record)
+			}
 
-						rr := new(dns.RR_A)
-						rr.Hdr = h
-						rr.A = net.ParseIP(ip)
-						if rr.A == nil {
-							panic("Bad A record")
-						}
-						record.RR = rr
-						//fmt.Println(rr)
+			label.Records[dnsType] = make([]Record, len(Records[rType]))
 
-						label.Records[dns.TypeA][i] = *record
-					}
+			for i := 0; i < len(Records[rType]); i++ {
+				foo := Records[rType][i].([]interface{})
+				//fmt.Printf("FOO TYPE %T %s\n", foo, foo)
+				record := new(Record)
+				ip := foo[0].(string)
+
+				var err error
+				record.Weight, err = strconv.Atoi(foo[1].(string))
+				if err != nil {
+					panic("Error converting weight to integer")
 				}
-				//fmt.Println(Zones[k])
-			}
-		}
 
-	}
+				var h dns.RR_Header
+				fmt.Println("TTL OPTIONS", Options.Ttl)
+				h.Ttl = uint32(Options.Ttl)
+				h.Class = dns.ClassINET
 
-	//fmt.Printf("ZO T: %T %s\n", Zones["0.us"], Zones["0.us"])
+				h.Rrtype = dnsType
 
-	//fmt.Println("IP", string(Zone.Regions["0.us"].IPv4[0].ip))
+				fmt.Println("H", h)
 
-	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { serve(w, r, Zone, Options) })
-	// Only listen on UDP
-	go func() {
-		if err := dns.ListenAndServe(*listen, "udp", nil); err != nil {
-			log.Fatalf("geodns: failed to setup %s %s", *listen, "udp")
-		}
-	}()
+				switch dnsType {
+				case dns.TypeA:
+					rr := new(dns.RR_A)
+					rr.Hdr = h
+					rr.A = net.ParseIP(ip)
+					if rr.A == nil {
+						panic("Bad A record")
+					}
+					record.RR = rr
+				case dns.TypeAAAA:
+					rr := new(dns.RR_AAAA)
+					rr.Hdr = h
+					rr.AAAA = net.ParseIP(ip)
+					if rr.AAAA == nil {
+						panic("Bad AAAA record")
+					}
+					record.RR = rr
+				default:
+					fmt.Println("type:", rType)
+					panic("Don't know how to handle this type")
 
-	if *flagrun {
+				}
 
-		sig := make(chan os.Signal)
-		signal.Notify(sig, os.Interrupt)
+				if record.RR == nil {
+					panic("record.RR is nil")
+				}
 
-	forever:
-		for {
-			select {
-			case <-sig:
-				log.Printf("geodns: signal received, stopping")
-				break forever
+				label.Records[dnsType][i] = *record
 			}
 		}
 	}
+	//fmt.Println(Zones[k])
 }

+ 17 - 0
geoip.go

@@ -0,0 +1,17 @@
+package main
+
+import (
+	"fmt"
+	"github.com/abh/geoip"
+)
+
+func setupGeoIP() *geoip.GeoIP {
+	file := "/opt/local/share/GeoIP/GeoIP.dat"
+
+	gi := geoip.GeoIP_Open(file)
+	if gi == nil {
+		fmt.Printf("Could not open GeoIP database\n")
+		return nil
+	}
+	return gi
+}

+ 87 - 0
serve.go

@@ -0,0 +1,87 @@
+package main
+
+import (
+	"dns"
+	"fmt"
+	"log"
+	"os"
+	"os/signal"
+)
+
+func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone, opt *Options) {
+	logPrintf("[zone %s] incoming %s %s %d from %s\n", z.Origin, req.Question[0].Name, dns.Rr_str[req.Question[0].Qtype], req.MsgHdr.Id, w.RemoteAddr())
+
+	fmt.Println("Got request", req)
+
+	raddr := w.RemoteAddr()
+
+	gi := setupGeoIP()
+	country := gi.GetCountry(raddr.String())
+	fmt.Println("Country:", country)
+
+	m := new(dns.Msg)
+	m.SetReply(req)
+	m.MsgHdr.Authoritative = true
+
+	// TODO: Function to find appropriate label with records
+	if region, ok := z.Labels["2"]; ok {
+		if region_rr := region.Records[req.Question[0].Qtype]; region_rr != nil {
+			//fmt.Printf("REGION_RR %T %v\n", region_rr, region_rr)
+			max := len(region_rr)
+			if max > 4 {
+				max = 4
+			}
+			servers := region_rr[0:max]
+			var rrs []dns.RR
+			for _, record := range servers {
+				rr := record.RR
+				fmt.Println("RR", rr)
+				rr.Header().Name = req.Question[0].Name
+				fmt.Println(rr)
+				rrs = append(rrs, rr)
+			}
+			m.Answer = rrs
+		}
+	}
+
+	ednsFromRequest(req, m)
+	w.Write(m)
+	return
+}
+
+func runServe(Zone *Zone, Options *Options) {
+
+	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { serve(w, r, Zone, Options) })
+	// Only listen on UDP
+	go func() {
+		if err := dns.ListenAndServe(*listen, "udp", nil); err != nil {
+			log.Fatalf("geodns: failed to setup %s %s", *listen, "udp")
+		}
+	}()
+
+	if *flagrun {
+
+		sig := make(chan os.Signal)
+		signal.Notify(sig, os.Interrupt)
+
+	forever:
+		for {
+			select {
+			case <-sig:
+				log.Printf("geodns: signal received, stopping")
+				break forever
+			}
+		}
+	}
+
+}
+
+func ednsFromRequest(req, m *dns.Msg) {
+	for _, r := range req.Extra {
+		if r.Header().Rrtype == dns.TypeOPT {
+			m.SetEdns0(4096, r.(*dns.RR_OPT).Do())
+			return
+		}
+	}
+	return
+}