Parcourir la source

Start refactoring targeting options

Ask Bjørn Hansen il y a 12 ans
Parent
commit
d2d7e364b1
3 fichiers modifiés avec 33 ajouts et 29 suppressions
  1. 11 2
      serve.go
  2. 14 19
      zone.go
  3. 8 8
      zone_test.go

+ 11 - 2
serve.go

@@ -3,8 +3,8 @@ package main
 import (
 	"encoding/json"
 	"fmt"
-	"github.com/abh/geodns/countries"
 	"github.com/abh/dns"
+	"github.com/abh/geodns/countries"
 	"log"
 	"net"
 	"os"
@@ -72,11 +72,20 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 		ip = realIp
 	}
 
+	var targets []string
 	var country string
 	var netmask int
 	if geoIP != nil {
 		country, netmask = geoIP.GetCountry(ip)
 		country = strings.ToLower(country)
+		if len(country) > 0 {
+			targets = append(targets, country)
+			continent := countries.CountryContinent[country]
+			if len(continent) > 0 {
+				targets = append(targets, continent)
+			}
+		}
+		targets = append(targets, "@")
 	}
 
 	m := new(dns.Msg)
@@ -97,7 +106,7 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 		}
 	}
 
-	labels, labelQtype := z.findLabels(label, country, qTypes{dns.TypeMF, dns.TypeCNAME, qtype})
+	labels, labelQtype := z.findLabels(label, targets, qTypes{dns.TypeMF, dns.TypeCNAME, qtype})
 	if labelQtype == 0 {
 		labelQtype = qtype
 	}

+ 14 - 19
zone.go

@@ -1,9 +1,8 @@
 package main
 
 import (
-	"github.com/abh/geodns/countries"
-	"github.com/abh/go-metrics"
 	"github.com/abh/dns"
+	"github.com/abh/go-metrics"
 	"strings"
 	"time"
 )
@@ -123,26 +122,22 @@ func (z *Zone) SoaRR() dns.RR {
 // continent and the global label name as needed. Looks for the
 // first available qType at each targeting level. Return a Label
 // and the qtype that was "found"
-func (z *Zone) findLabels(s, cc string, qts qTypes) (*Label, uint16) {
+func (z *Zone) findLabels(s string, targets []string, qts qTypes) (*Label, uint16) {
+
+	for _, target := range targets {
 
-	selectors := []string{}
+		var name string
 
-	if len(cc) > 0 {
-		continent := countries.CountryContinent[cc]
-		var s_cc string
-		if len(s) > 0 {
-			s_cc = s + "." + cc
-			if len(continent) > 0 {
-				continent = s + "." + continent
+		switch target {
+		case "@":
+			name = s
+		default:
+			if len(s) > 0 {
+				name = s + "." + target
+			} else {
+				name = target
 			}
-		} else {
-			s_cc = cc
 		}
-		selectors = append(selectors, s_cc, continent)
-	}
-	selectors = append(selectors, s)
-
-	for _, name := range selectors {
 
 		if label, ok := z.Labels[name]; ok {
 
@@ -158,7 +153,7 @@ func (z *Zone) findLabels(s, cc string, qts qTypes) (*Label, uint16) {
 					if label.Records[dns.TypeMF] != nil {
 						name = label.firstRR(dns.TypeMF).(*dns.MF).Mf
 						// TODO: need to avoid loops here somehow
-						return z.findLabels(name, cc, qts)
+						return z.findLabels(name, targets, qts)
 					}
 				default:
 					// return the label if it has the right record

+ 8 - 8
zone_test.go

@@ -17,40 +17,40 @@ func (s *ConfigSuite) TestExampleComZone(c *C) {
 	c.Check(ex.Labels["weight"].MaxHosts, Equals, 1)
 
 	// Make sure that the empty "no.bar" zone gets skipped and "bar" is used
-	label, qtype := ex.findLabels("bar", "no", qTypes{dns.TypeA})
+	label, qtype := ex.findLabels("bar", []string{"no", "europe", "@"}, qTypes{dns.TypeA})
 	c.Check(label.Records[dns.TypeA], HasLen, 1)
 	c.Check(label.Records[dns.TypeA][0].RR.(*dns.A).A.String(), Equals, "192.168.1.2")
 	c.Check(qtype, Equals, dns.TypeA)
 
-	label, qtype = ex.findLabels("", "", qTypes{dns.TypeMX})
+	label, qtype = ex.findLabels("", []string{"@"}, qTypes{dns.TypeMX})
 	Mxs := label.Records[dns.TypeMX]
 	c.Check(Mxs, HasLen, 2)
 	c.Check(Mxs[0].RR.(*dns.MX).Mx, Equals, "mx.example.net.")
 	c.Check(Mxs[1].RR.(*dns.MX).Mx, Equals, "mx2.example.net.")
 
-	label, qtype = ex.findLabels("", "dk", qTypes{dns.TypeMX})
+	label, qtype = ex.findLabels("", []string{"dk", "europe", "@"}, qTypes{dns.TypeMX})
 	Mxs = label.Records[dns.TypeMX]
 	c.Check(Mxs, HasLen, 1)
 	c.Check(Mxs[0].RR.(*dns.MX).Mx, Equals, "mx-eu.example.net.")
 	c.Check(qtype, Equals, dns.TypeMX)
 
 	// look for multiple record types
-	label, qtype = ex.findLabels("www", "", qTypes{dns.TypeCNAME, dns.TypeA})
+	label, qtype = ex.findLabels("www", []string{"@"}, qTypes{dns.TypeCNAME, dns.TypeA})
 	c.Check(label.Records[dns.TypeCNAME], HasLen, 1)
 	c.Check(qtype, Equals, dns.TypeCNAME)
 
-	label, qtype = ex.findLabels("", "", qTypes{dns.TypeNS})
+	label, qtype = ex.findLabels("", []string{"@"}, qTypes{dns.TypeNS})
 	Ns := label.Records[dns.TypeNS]
 	c.Check(Ns, HasLen, 2)
 	c.Check(Ns[0].RR.(*dns.NS).Ns, Equals, "ns1.example.net.")
 	c.Check(Ns[1].RR.(*dns.NS).Ns, Equals, "ns2.example.net.")
 
-	label, qtype = ex.findLabels("foo", "", qTypes{dns.TypeTXT})
+	label, qtype = ex.findLabels("foo", []string{"@"}, qTypes{dns.TypeTXT})
 	Txt := label.Records[dns.TypeTXT]
 	c.Check(Txt, HasLen, 1)
 	c.Check(Txt[0].RR.(*dns.TXT).Txt[0], Equals, "this is foo")
 
-	label, qtype = ex.findLabels("weight", "", qTypes{dns.TypeTXT})
+	label, qtype = ex.findLabels("weight", []string{"@"}, qTypes{dns.TypeTXT})
 	Txt = label.Records[dns.TypeTXT]
 	c.Check(Txt, HasLen, 2)
 	c.Check(Txt[0].RR.(*dns.TXT).Txt[0], Equals, "w1000")
@@ -63,7 +63,7 @@ func (s *ConfigSuite) TestExampleOrgZone(c *C) {
 	// test.example.org was loaded
 	c.Assert(ex.Labels, NotNil)
 
-	label, qtype := ex.findLabels("sub", "", qTypes{dns.TypeNS})
+	label, qtype := ex.findLabels("sub", []string{"@"}, qTypes{dns.TypeNS})
 	c.Assert(qtype, Equals, dns.TypeNS)
 
 	Ns := label.Records[dns.TypeNS]