Browse Source

Add more targeting options (state, region) (Merge branch 'targeting')

Ask Bjørn Hansen 12 năm trước cách đây
mục cha
commit
7388d9cc1c
12 tập tin đã thay đổi với 396 bổ sung57 xóa
  1. 9 1
      .travis.yml
  2. 3 0
      config.go
  3. 74 0
      countries/regiongroups.go
  4. 4 0
      dns/geodns.conf.sample
  5. 1 1
      geodns.go
  6. 90 5
      geoip.go
  7. 16 19
      serve.go
  8. 97 0
      targeting.go
  9. 61 0
      targeting_test.go
  10. 20 23
      zone.go
  11. 8 8
      zone_test.go
  12. 13 0
      zones.go

+ 9 - 1
.travis.yml

@@ -1,11 +1,19 @@
 language: go
+go:
+  - 1.1
+  - tip
 before_install:
   - sudo apt-get install libgeoip-dev bzr
 install:
-  - go get github.com/miekg/dns
+  - mkdir -p $TRAVIS_BUILD_DIR/db
+  - curl -s http://geodns.bitnames.com/geoip/GeoLiteCity.dat.gz  | gzip -cd > $TRAVIS_BUILD_DIR/db/GeoIPCity.dat
+  - go get github.com/abh/dns
   - go get github.com/abh/geoip
   - go get launchpad.net/gocheck
   - go get -v
   - go build -v
   - go install
 
+script:
+  - cd $TRAVIS_BUILD_DIR && go test -gocheck.v
+  - go test -gocheck.v -gocheck.b -gocheck.btime=2s

+ 3 - 0
config.go

@@ -16,6 +16,9 @@ type AppConfig struct {
 	Flags struct {
 		HasStatHat bool
 	}
+	GeoIP struct {
+		Directory string
+	}
 }
 
 var Config = new(AppConfig)

+ 74 - 0
countries/regiongroups.go

@@ -0,0 +1,74 @@
+package countries
+
+import (
+	"log"
+)
+
+func CountryRegionGroup(country, region string) string {
+
+	if country != "us" {
+		return ""
+	}
+
+	regions := map[string]string{
+		"us-ak": "us-west",
+		"us-az": "us-west",
+		"us-ca": "us-west",
+		"us-co": "us-west",
+		"us-hi": "us-west",
+		"us-id": "us-west",
+		"us-mt": "us-west",
+		"us-nm": "us-west",
+		"us-nv": "us-west",
+		"us-or": "us-west",
+		"us-ut": "us-west",
+		"us-wa": "us-west",
+		"us-wy": "us-west",
+
+		"us-ar": "us-central",
+		"us-ia": "us-central",
+		"us-in": "us-central",
+		"us-ks": "us-central",
+		"us-la": "us-central",
+		"us-mn": "us-central",
+		"us-mo": "us-central",
+		"us-nd": "us-central",
+		"us-ne": "us-central",
+		"us-ok": "us-central",
+		"us-sd": "us-central",
+		"us-tx": "us-central",
+		"us-wi": "us-central",
+
+		"us-al": "us-east",
+		"us-ct": "us-east",
+		"us-dc": "us-east",
+		"us-de": "us-east",
+		"us-fl": "us-east",
+		"us-ga": "us-east",
+		"us-ky": "us-east",
+		"us-ma": "us-east",
+		"us-md": "us-east",
+		"us-me": "us-east",
+		"us-mi": "us-east",
+		"us-ms": "us-east",
+		"us-nc": "us-east",
+		"us-nh": "us-east",
+		"us-nj": "us-east",
+		"us-ny": "us-east",
+		"us-oh": "us-east",
+		"us-pa": "us-east",
+		"us-ri": "us-east",
+		"us-sc": "us-east",
+		"us-tn": "us-east",
+		"us-va": "us-east",
+		"us-vt": "us-east",
+		"us-wv": "us-east",
+	}
+
+	if group, ok := regions[region]; ok {
+		return group
+	}
+
+	log.Printf("Did not find a region group for '%s'/'%s'", country, region)
+	return ""
+}

+ 4 - 0
dns/geodns.conf.sample

@@ -3,6 +3,10 @@
 ; It is recommended to distribute the configuration file globally
 ; with your .json zone files.
 
+[geoip]
+;; Directory containing the GeoIP .dat database files
+;directory=/usr/local/share/GeoIP/
+
 [stathat]
 ;; Add an API key to send query counts and other metrics to stathat
 ;apikey=abc123

+ 1 - 1
geodns.go

@@ -31,7 +31,7 @@ import (
 )
 
 // VERSION is the current version of GeoDNS
-var VERSION string = "2.3.0"
+var VERSION string = "2.4.0"
 var buildTime string
 var gitVersion string
 

+ 90 - 5
geoip.go

@@ -1,16 +1,101 @@
 package main
 
 import (
+	"github.com/abh/geodns/countries"
 	"github.com/abh/geoip"
 	"log"
+	"net"
+	"strings"
+	"time"
 )
 
-func setupGeoIP() *geoip.GeoIP {
+type GeoIP struct {
+	country         *geoip.GeoIP
+	hasCountry      bool
+	countryLastLoad time.Time
 
-	gi, err := geoip.Open()
+	city         *geoip.GeoIP
+	cityLastLoad time.Time
+	hasCity      bool
+}
+
+var geoIP = new(GeoIP)
+
+func (g *GeoIP) GetCountry(ip net.IP) (country, continent string, netmask int) {
+	if g.country == nil {
+		return "", "", 0
+	}
+
+	country, netmask = geoIP.country.GetCountry(ip.String())
+	if len(country) > 0 {
+		country = strings.ToLower(country)
+		continent = countries.CountryContinent[country]
+	}
+	return
+}
+
+func (g *GeoIP) GetCountryRegion(ip net.IP) (country, continent, regionGroup, region string, netmask int) {
+	if g.city == nil {
+		log.Println("No city database available")
+		country, continent, netmask = g.GetCountry(ip)
+		return
+	}
+
+	record := geoIP.city.GetRecord(ip.String())
+
+	country = record.CountryCode
+	region = record.Region
+	if len(country) > 0 {
+		country = strings.ToLower(country)
+		continent = countries.CountryContinent[country]
+
+		if len(region) > 0 {
+			region = country + "-" + strings.ToLower(region)
+			regionGroup = countries.CountryRegionGroup(country, region)
+		}
+
+	}
+	return
+}
+
+func (g *GeoIP) setDirectory() {
+	if len(Config.GeoIP.Directory) > 0 {
+		geoip.SetCustomDirectory(Config.GeoIP.Directory)
+	}
+}
+
+func (g *GeoIP) setupGeoIPCountry() {
+	if g.country != nil {
+		return
+	}
+
+	g.setDirectory()
+
+	gi, err := geoip.OpenType(geoip.GEOIP_COUNTRY_EDITION)
 	if gi == nil || err != nil {
-		log.Printf("Could not open GeoIP database: %s\n", err)
-		return nil
+		log.Printf("Could not open country GeoIP database: %s\n", err)
+		return
 	}
-	return gi
+	g.countryLastLoad = time.Now()
+	g.hasCity = true
+	g.country = gi
+
+}
+
+func (g *GeoIP) setupGeoIPCity() {
+	if g.city != nil {
+		return
+	}
+
+	g.setDirectory()
+
+	gi, err := geoip.OpenType(geoip.GEOIP_CITY_EDITION_REV1)
+	if gi == nil || err != nil {
+		log.Printf("Could not open city GeoIP database: %s\n", err)
+		return
+	}
+	g.countryLastLoad = time.Now()
+	g.hasCity = true
+	g.city = gi
+
 }

+ 16 - 19
serve.go

@@ -3,7 +3,6 @@ package main
 import (
 	"encoding/json"
 	"fmt"
-	"github.com/abh/geodns/countries"
 	"github.com/abh/dns"
 	"log"
 	"net"
@@ -19,8 +18,6 @@ func getQuestionName(z *Zone, req *dns.Msg) string {
 	return strings.ToLower(strings.Join(ql, "."))
 }
 
-var geoIP = setupGeoIP()
-
 func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 
 	qtype := req.Question[0].Qtype
@@ -41,9 +38,10 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 	z.Metrics.LabelStats.Add(label)
 
 	realIp, _, _ := net.SplitHostPort(w.RemoteAddr().String())
+
 	z.Metrics.ClientStats.Add(realIp)
 
-	var ip string // EDNS or real IP
+	var ip net.IP // EDNS or real IP
 	var edns *dns.EDNS0_SUBNET
 	var opt_rr *dns.OPT
 
@@ -61,7 +59,7 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 					logPrintln("Got edns", e.Address, e.Family, e.SourceNetmask, e.SourceScope)
 					if e.Address != nil {
 						edns = e
-						ip = e.Address.String()
+						ip = e.Address
 					}
 				}
 			}
@@ -69,15 +67,10 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 	}
 
 	if len(ip) == 0 { // no edns subnet
-		ip = realIp
+		ip = net.ParseIP(realIp)
 	}
 
-	var country string
-	var netmask int
-	if geoIP != nil {
-		country, netmask = geoIP.GetCountry(ip)
-		country = strings.ToLower(country)
-	}
+	targets, netmask := z.Options.Targeting.GetTargets(ip)
 
 	m := new(dns.Msg)
 	m.SetReply(req)
@@ -97,7 +90,7 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 		}
 	}
 
-	labels, labelQtype := z.findLabels(label, country, qTypes{dns.TypeMF, dns.TypeCNAME, qtype})
+	labels, labelQtype := z.findLabels(label, targets, qTypes{dns.TypeMF, dns.TypeCNAME, qtype})
 	if labelQtype == 0 {
 		labelQtype = qtype
 	}
@@ -122,13 +115,17 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 				h := dns.RR_Header{Ttl: 1, Class: dns.ClassINET, Rrtype: dns.TypeTXT}
 				h.Name = label + "." + z.Origin + "."
 
+				txt := []string{
+					w.RemoteAddr().String(),
+					ip.String(),
+				}
+
+				targets, netmask := z.Options.Targeting.GetTargets(ip)
+				txt = append(txt, strings.Join(targets, " "))
+				txt = append(txt, fmt.Sprintf("/%d", netmask))
+
 				m.Answer = []dns.RR{&dns.TXT{Hdr: h,
-					Txt: []string{
-						w.RemoteAddr().String(),
-						ip,
-						string(country),
-						string(countries.CountryContinent[country]),
-					},
+					Txt: txt,
 				}}
 			} else {
 				m.Ns = append(m.Ns, z.SoaRR())

+ 97 - 0
targeting.go

@@ -0,0 +1,97 @@
+package main
+
+import (
+	"fmt"
+	"net"
+	"strings"
+)
+
+type TargetOptions int
+
+const (
+	TargetGlobal = 1 << iota
+	TargetContinent
+	TargetCountry
+	TargetRegionGroup
+	TargetRegion
+)
+
+func (t TargetOptions) GetTargets(ip net.IP) ([]string, int) {
+
+	targets := make([]string, 0)
+
+	var country, continent string
+	var netmask int
+
+	switch {
+	case t >= TargetRegionGroup:
+		var region, regionGroup string
+		country, continent, regionGroup, region, netmask = geoIP.GetCountryRegion(ip)
+		if t&TargetRegion > 0 && len(region) > 0 {
+			targets = append(targets, region)
+		}
+		if t&TargetRegionGroup > 0 && len(regionGroup) > 0 {
+			targets = append(targets, regionGroup)
+		}
+
+	case t >= TargetContinent:
+		country, continent, netmask = geoIP.GetCountry(ip)
+	}
+
+	if len(country) > 0 {
+		if t&TargetCountry > 0 {
+			targets = append(targets, country)
+		}
+		if t&TargetContinent > 0 && len(continent) > 0 {
+			targets = append(targets, continent)
+		}
+	}
+
+	if t&TargetGlobal > 0 {
+		targets = append(targets, "@")
+	}
+	return targets, netmask
+}
+
+func (t TargetOptions) String() string {
+	targets := make([]string, 0)
+	if t&TargetGlobal > 0 {
+		targets = append(targets, "@")
+	}
+	if t&TargetContinent > 0 {
+		targets = append(targets, "continent")
+	}
+	if t&TargetCountry > 0 {
+		targets = append(targets, "country")
+	}
+	if t&TargetRegionGroup > 0 {
+		targets = append(targets, "regiongroup")
+	}
+	if t&TargetRegion > 0 {
+		targets = append(targets, "region")
+	}
+	return strings.Join(targets, " ")
+}
+
+func parseTargets(v string) (tgt TargetOptions, err error) {
+	targets := strings.Split(v, " ")
+	for _, t := range targets {
+		var x TargetOptions
+		switch t {
+		case "@":
+			x = TargetGlobal
+		case "country":
+			x = TargetCountry
+		case "continent":
+			x = TargetContinent
+		case "regiongroup":
+			x = TargetRegionGroup
+		case "region":
+			x = TargetRegion
+		default:
+			err = fmt.Errorf("Unknown targeting option '%s'", t)
+		}
+		tgt = tgt | x
+	}
+	return
+}

+ 61 - 0
targeting_test.go

@@ -0,0 +1,61 @@
+package main
+
+import (
+	. "launchpad.net/gocheck"
+	"net"
+)
+
+type TargetingSuite struct {
+}
+
+var _ = Suite(&TargetingSuite{})
+
+func (s *TargetingSuite) SetUpSuite(c *C) {
+	Config.GeoIP.Directory = "db"
+}
+
+func (s *TargetingSuite) TestTargetString(c *C) {
+	var tgt TargetOptions
+	tgt = TargetGlobal + TargetCountry + TargetContinent
+
+	str := tgt.String()
+	c.Check(str, Equals, "@ continent country")
+}
+
+func (s *TargetingSuite) TestTargetParse(c *C) {
+
+	tgt, err := parseTargets("@ foo country")
+	str := tgt.String()
+	c.Check(str, Equals, "@ country")
+	c.Check(err.Error(), Equals, "Unknown targeting option 'foo'")
+
+	tgt, err = parseTargets("@ continent country")
+	c.Assert(err, IsNil)
+	str = tgt.String()
+	c.Check(str, Equals, "@ continent country")
+}
+func (s *TargetingSuite) TestGetTargets(c *C) {
+
+	ip := net.ParseIP("207.171.7.51")
+
+	geoIP.setupGeoIPCity()
+	geoIP.setupGeoIPCountry()
+
+	tgt, _ := parseTargets("@ continent country")
+	targets, _ := tgt.GetTargets(ip)
+	c.Check(targets, DeepEquals, []string{"us", "north-america", "@"})
+
+	if geoIP.city == nil {
+		c.Log("City GeoIP database requred for these tests")
+		return
+	}
+
+	tgt, _ = parseTargets("@ continent country region ")
+	targets, _ = tgt.GetTargets(ip)
+	c.Check(targets, DeepEquals, []string{"us-ca", "us", "north-america", "@"})
+
+	tgt, _ = parseTargets("@ continent regiongroup country region ")
+	targets, _ = tgt.GetTargets(ip)
+	c.Check(targets, DeepEquals, []string{"us-ca", "us-west", "us", "north-america", "@"})
+
+}

+ 20 - 23
zone.go

@@ -1,18 +1,18 @@
 package main
 
 import (
-	"github.com/abh/geodns/countries"
-	"github.com/abh/go-metrics"
 	"github.com/abh/dns"
+	"github.com/abh/go-metrics"
 	"strings"
 	"time"
 )
 
 type ZoneOptions struct {
-	Serial   int
-	Ttl      int
-	MaxHosts int
-	Contact  string
+	Serial    int
+	Ttl       int
+	MaxHosts  int
+	Contact   string
+	Targeting TargetOptions
 }
 
 type ZoneLogging struct {
@@ -73,6 +73,7 @@ func NewZone(name string) *Zone {
 	zone.Options.Ttl = 120
 	zone.Options.MaxHosts = 2
 	zone.Options.Contact = "support.bitnames.com"
+	zone.Options.Targeting = TargetGlobal + TargetCountry + TargetContinent
 
 	return zone
 }
@@ -123,26 +124,22 @@ func (z *Zone) SoaRR() dns.RR {
 // continent and the global label name as needed. Looks for the
 // first available qType at each targeting level. Return a Label
 // and the qtype that was "found"
-func (z *Zone) findLabels(s, cc string, qts qTypes) (*Label, uint16) {
+func (z *Zone) findLabels(s string, targets []string, qts qTypes) (*Label, uint16) {
 
-	selectors := []string{}
+	for _, target := range targets {
 
-	if len(cc) > 0 {
-		continent := countries.CountryContinent[cc]
-		var s_cc string
-		if len(s) > 0 {
-			s_cc = s + "." + cc
-			if len(continent) > 0 {
-				continent = s + "." + continent
+		var name string
+
+		switch target {
+		case "@":
+			name = s
+		default:
+			if len(s) > 0 {
+				name = s + "." + target
+			} else {
+				name = target
 			}
-		} else {
-			s_cc = cc
 		}
-		selectors = append(selectors, s_cc, continent)
-	}
-	selectors = append(selectors, s)
-
-	for _, name := range selectors {
 
 		if label, ok := z.Labels[name]; ok {
 
@@ -158,7 +155,7 @@ func (z *Zone) findLabels(s, cc string, qts qTypes) (*Label, uint16) {
 					if label.Records[dns.TypeMF] != nil {
 						name = label.firstRR(dns.TypeMF).(*dns.MF).Mf
 						// TODO: need to avoid loops here somehow
-						return z.findLabels(name, cc, qts)
+						return z.findLabels(name, targets, qts)
 					}
 				default:
 					// return the label if it has the right record

+ 8 - 8
zone_test.go

@@ -17,40 +17,40 @@ func (s *ConfigSuite) TestExampleComZone(c *C) {
 	c.Check(ex.Labels["weight"].MaxHosts, Equals, 1)
 
 	// Make sure that the empty "no.bar" zone gets skipped and "bar" is used
-	label, qtype := ex.findLabels("bar", "no", qTypes{dns.TypeA})
+	label, qtype := ex.findLabels("bar", []string{"no", "europe", "@"}, qTypes{dns.TypeA})
 	c.Check(label.Records[dns.TypeA], HasLen, 1)
 	c.Check(label.Records[dns.TypeA][0].RR.(*dns.A).A.String(), Equals, "192.168.1.2")
 	c.Check(qtype, Equals, dns.TypeA)
 
-	label, qtype = ex.findLabels("", "", qTypes{dns.TypeMX})
+	label, qtype = ex.findLabels("", []string{"@"}, qTypes{dns.TypeMX})
 	Mxs := label.Records[dns.TypeMX]
 	c.Check(Mxs, HasLen, 2)
 	c.Check(Mxs[0].RR.(*dns.MX).Mx, Equals, "mx.example.net.")
 	c.Check(Mxs[1].RR.(*dns.MX).Mx, Equals, "mx2.example.net.")
 
-	label, qtype = ex.findLabels("", "dk", qTypes{dns.TypeMX})
+	label, qtype = ex.findLabels("", []string{"dk", "europe", "@"}, qTypes{dns.TypeMX})
 	Mxs = label.Records[dns.TypeMX]
 	c.Check(Mxs, HasLen, 1)
 	c.Check(Mxs[0].RR.(*dns.MX).Mx, Equals, "mx-eu.example.net.")
 	c.Check(qtype, Equals, dns.TypeMX)
 
 	// look for multiple record types
-	label, qtype = ex.findLabels("www", "", qTypes{dns.TypeCNAME, dns.TypeA})
+	label, qtype = ex.findLabels("www", []string{"@"}, qTypes{dns.TypeCNAME, dns.TypeA})
 	c.Check(label.Records[dns.TypeCNAME], HasLen, 1)
 	c.Check(qtype, Equals, dns.TypeCNAME)
 
-	label, qtype = ex.findLabels("", "", qTypes{dns.TypeNS})
+	label, qtype = ex.findLabels("", []string{"@"}, qTypes{dns.TypeNS})
 	Ns := label.Records[dns.TypeNS]
 	c.Check(Ns, HasLen, 2)
 	c.Check(Ns[0].RR.(*dns.NS).Ns, Equals, "ns1.example.net.")
 	c.Check(Ns[1].RR.(*dns.NS).Ns, Equals, "ns2.example.net.")
 
-	label, qtype = ex.findLabels("foo", "", qTypes{dns.TypeTXT})
+	label, qtype = ex.findLabels("foo", []string{"@"}, qTypes{dns.TypeTXT})
 	Txt := label.Records[dns.TypeTXT]
 	c.Check(Txt, HasLen, 1)
 	c.Check(Txt[0].RR.(*dns.TXT).Txt[0], Equals, "this is foo")
 
-	label, qtype = ex.findLabels("weight", "", qTypes{dns.TypeTXT})
+	label, qtype = ex.findLabels("weight", []string{"@"}, qTypes{dns.TypeTXT})
 	Txt = label.Records[dns.TypeTXT]
 	c.Check(Txt, HasLen, 2)
 	c.Check(Txt[0].RR.(*dns.TXT).Txt[0], Equals, "w1000")
@@ -63,7 +63,7 @@ func (s *ConfigSuite) TestExampleOrgZone(c *C) {
 	// test.example.org was loaded
 	c.Assert(ex.Labels, NotNil)
 
-	label, qtype := ex.findLabels("sub", "", qTypes{dns.TypeNS})
+	label, qtype := ex.findLabels("sub", []string{"@"}, qTypes{dns.TypeNS})
 	c.Assert(qtype, Equals, dns.TypeNS)
 
 	Ns := label.Records[dns.TypeNS]

+ 13 - 0
zones.go

@@ -168,6 +168,12 @@ func readZoneFile(zoneName, fileName string) (zone *Zone, zerr error) {
 				zone.Options.Contact = v.(string)
 			case "max_hosts":
 				zone.Options.MaxHosts = valueToInt(v)
+			case "targeting":
+				zone.Options.Targeting, err = parseTargets(v.(string))
+				if err != nil {
+					log.Printf("Could not parse targeting '%s': %s", v, err)
+					return nil, err
+				}
 			}
 		case "logging":
 			{
@@ -199,6 +205,13 @@ func readZoneFile(zoneName, fileName string) (zone *Zone, zerr error) {
 
 	//log.Println("IP", string(Zone.Regions["0.us"].IPv4[0].ip))
 
+	switch {
+	case zone.Options.Targeting >= TargetRegionGroup:
+		geoIP.setupGeoIPCity()
+	case zone.Options.Targeting >= TargetContinent:
+		geoIP.setupGeoIPCountry()
+	}
+
 	return zone, nil
 }