Browse Source

Refactor: New packages for targeting and zones data and functions

- Also rename variables that were named as their types
- Remove StatHat features (we have more generic logs/metrics now)
Ask Bjørn Hansen 8 years ago
parent
commit
ec6de4876f

+ 0 - 4
Godeps/Godeps.json

@@ -43,10 +43,6 @@
 			"ImportPath": "github.com/rcrowley/go-metrics",
 			"ImportPath": "github.com/rcrowley/go-metrics",
 			"Rev": "eeba7bd0dd01ace6e690fa833b3f22aaec29af43"
 			"Rev": "eeba7bd0dd01ace6e690fa833b3f22aaec29af43"
 		},
 		},
-		{
-			"ImportPath": "github.com/stathat/go",
-			"Rev": "01d012b9ee2ecc107cb28b6dd32d9019ed5c1d77"
-		},
 		{
 		{
 			"ImportPath": "golang.org/x/net/websocket",
 			"ImportPath": "golang.org/x/net/websocket",
 			"Rev": "db8e4de5b2d6653f66aea53094624468caad15d2"
 			"Rev": "db8e4de5b2d6653f66aea53094624468caad15d2"

+ 10 - 11
geodns.go

@@ -31,6 +31,7 @@ import (
 
 
 	"github.com/abh/geodns/applog"
 	"github.com/abh/geodns/applog"
 	"github.com/abh/geodns/querylog"
 	"github.com/abh/geodns/querylog"
+	"github.com/abh/geodns/zones"
 	"github.com/pborman/uuid"
 	"github.com/pborman/uuid"
 )
 )
 
 
@@ -132,7 +133,7 @@ func main() {
 			os.Exit(2)
 			os.Exit(2)
 		}
 		}
 
 
-		Zones := make(Zones)
+		Zones := make(zones.Zones)
 		srv.setupPgeodnsZone(Zones)
 		srv.setupPgeodnsZone(Zones)
 		err = srv.zonesReadDir(dirName, Zones)
 		err = srv.zonesReadDir(dirName, Zones)
 		if err != nil {
 		if err != nil {
@@ -202,18 +203,16 @@ func main() {
 
 
 	inter := getInterfaces()
 	inter := getInterfaces()
 
 
-	go statHatPoster()
-
-	Zones := make(Zones)
-
-	go monitor(Zones)
-	go Zones.statHatPoster()
+	if Config.HasStatHat() {
+		log.Println("StatHat integration has been removed in favor of more generic metrics")
+	}
 
 
+	// the global-ish zones 'context' is quite a mess
+	zonelist := make(zones.Zones)
+	go monitor(zonelist)
 	srv.setupRootZone()
 	srv.setupRootZone()
-	srv.setupPgeodnsZone(Zones)
-
-	dirName := *flagconfig
-	go srv.zonesReader(dirName, Zones)
+	srv.setupPgeodnsZone(zonelist)
+	go srv.zonesReader(*flagconfig, zonelist)
 
 
 	for _, host := range inter {
 	for _, host := range inter {
 		go srv.listenAndServe(host)
 		go srv.listenAndServe(host)

+ 7 - 5
monitor.go

@@ -15,6 +15,8 @@ import (
 	"strconv"
 	"strconv"
 	"time"
 	"time"
 
 
+	"github.com/abh/geodns/zones"
+
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"golang.org/x/net/websocket"
 	"golang.org/x/net/websocket"
 )
 )
@@ -159,7 +161,7 @@ func initialStatus() string {
 	return string(message)
 	return string(message)
 }
 }
 
 
-func monitor(zones Zones) {
+func monitor(zones zones.Zones) {
 
 
 	if len(*flaghttp) == 0 {
 	if len(*flaghttp) == 0 {
 		return
 		return
@@ -213,7 +215,7 @@ func MainServer(w http.ResponseWriter, req *http.Request) {
 type rate struct {
 type rate struct {
 	Name    string
 	Name    string
 	Count   int64
 	Count   int64
-	Metrics ZoneMetrics
+	Metrics zones.ZoneMetrics
 }
 }
 type Rates []*rate
 type Rates []*rate
 
 
@@ -269,7 +271,7 @@ func topParam(req *http.Request, def int) int {
 	return topOption
 	return topOption
 }
 }
 
 
-func StatusJSONHandler(zones Zones) func(http.ResponseWriter, *http.Request) {
+func StatusJSONHandler(zones zones.Zones) func(http.ResponseWriter, *http.Request) {
 	return func(w http.ResponseWriter, req *http.Request) {
 	return func(w http.ResponseWriter, req *http.Request) {
 
 
 		zonemetrics := make(map[string]metrics.Registry)
 		zonemetrics := make(map[string]metrics.Registry)
@@ -319,7 +321,7 @@ func StatusJSONHandler(zones Zones) func(http.ResponseWriter, *http.Request) {
 	}
 	}
 }
 }
 
 
-func StatusHandler(zones Zones) func(http.ResponseWriter, *http.Request) {
+func StatusHandler(zones zones.Zones) func(http.ResponseWriter, *http.Request) {
 
 
 	return func(w http.ResponseWriter, req *http.Request) {
 	return func(w http.ResponseWriter, req *http.Request) {
 
 
@@ -422,7 +424,7 @@ func (b *basicauth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	return
 	return
 }
 }
 
 
-func httpHandler(zones Zones) {
+func httpHandler(zones zones.Zones) {
 	http.Handle("/monitor", websocket.Handler(wsHandler))
 	http.Handle("/monitor", websocket.Handler(wsHandler))
 	http.HandleFunc("/status", StatusHandler(zones))
 	http.HandleFunc("/status", StatusHandler(zones))
 	http.HandleFunc("/status.json", StatusJSONHandler(zones))
 	http.HandleFunc("/status.json", StatusJSONHandler(zones))

+ 4 - 2
monitor_test.go

@@ -7,18 +7,20 @@ import (
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
+	"github.com/abh/geodns/zones"
+
 	. "gopkg.in/check.v1"
 	. "gopkg.in/check.v1"
 )
 )
 
 
 type MonitorSuite struct {
 type MonitorSuite struct {
-	zones   Zones
+	zones   zones.Zones
 	metrics *ServerMetrics
 	metrics *ServerMetrics
 }
 }
 
 
 var _ = Suite(&MonitorSuite{})
 var _ = Suite(&MonitorSuite{})
 
 
 func (s *MonitorSuite) SetUpSuite(c *C) {
 func (s *MonitorSuite) SetUpSuite(c *C) {
-	s.zones = make(Zones)
+	s.zones = make(zones.Zones)
 	s.metrics = NewMetrics()
 	s.metrics = NewMetrics()
 	go s.metrics.Updater()
 	go s.metrics.Updater()
 
 

+ 6 - 31
serve.go

@@ -11,20 +11,20 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/abh/geodns/applog"
 	"github.com/abh/geodns/applog"
-	"github.com/abh/geodns/health"
 	"github.com/abh/geodns/querylog"
 	"github.com/abh/geodns/querylog"
+	"github.com/abh/geodns/zones"
 
 
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 )
 )
 
 
-func getQuestionName(z *Zone, req *dns.Msg) string {
+func getQuestionName(z *zones.Zone, req *dns.Msg) string {
 	lx := dns.SplitDomainName(req.Question[0].Name)
 	lx := dns.SplitDomainName(req.Question[0].Name)
 	ql := lx[0 : len(lx)-z.LabelCount]
 	ql := lx[0 : len(lx)-z.LabelCount]
 	return strings.ToLower(strings.Join(ql, "."))
 	return strings.ToLower(strings.Join(ql, "."))
 }
 }
 
 
-func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
+func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *zones.Zone) {
 
 
 	qname := req.Question[0].Name
 	qname := req.Question[0].Name
 	qtype := req.Question[0].Qtype
 	qtype := req.Question[0].Qtype
@@ -141,7 +141,7 @@ func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 		}
 		}
 	}
 	}
 
 
-	labels, labelQtype := z.findLabels(label, targets, qTypes{dns.TypeMF, dns.TypeCNAME, qtype})
+	labels, labelQtype := z.FindLabels(label, targets, []uint16{dns.TypeMF, dns.TypeCNAME, qtype})
 	if labelQtype == 0 {
 	if labelQtype == 0 {
 		labelQtype = qtype
 		labelQtype = qtype
 	}
 	}
@@ -170,7 +170,7 @@ func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 		if permitDebug && firstLabel == "_health" {
 		if permitDebug && firstLabel == "_health" {
 			if qtype == dns.TypeANY || qtype == dns.TypeTXT {
 			if qtype == dns.TypeANY || qtype == dns.TypeTXT {
 				baseLabel := strings.Join((strings.Split(label, "."))[1:], ".")
 				baseLabel := strings.Join((strings.Split(label, "."))[1:], ".")
-				m.Answer = z.healthRR(label+"."+z.Origin+".", baseLabel)
+				m.Answer = z.HealthRR(label+"."+z.Origin+".", baseLabel)
 				m.Authoritative = true
 				m.Authoritative = true
 				w.WriteMsg(m)
 				w.WriteMsg(m)
 				return
 				return
@@ -195,7 +195,7 @@ func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 				txt = append(txt, strings.Join(targets, " "))
 				txt = append(txt, strings.Join(targets, " "))
 				txt = append(txt, fmt.Sprintf("/%d", netmask), serverID, serverIP)
 				txt = append(txt, fmt.Sprintf("/%d", netmask), serverID, serverIP)
 				if location != nil {
 				if location != nil {
-					txt = append(txt, fmt.Sprintf("(%.3f,%.3f)", location.latitude, location.longitude))
+					txt = append(txt, fmt.Sprintf("(%.3f,%.3f)", location.Latitude, location.Longitude))
 				} else {
 				} else {
 					txt = append(txt, "(?,?)")
 					txt = append(txt, "(?,?)")
 				}
 				}
@@ -278,28 +278,3 @@ func statusRR(label string) []dns.RR {
 
 
 	return []dns.RR{&dns.TXT{Hdr: h, Txt: []string{string(js)}}}
 	return []dns.RR{&dns.TXT{Hdr: h, Txt: []string{string(js)}}}
 }
 }
-
-func (z *Zone) healthRR(label string, baseLabel string) []dns.RR {
-	h := dns.RR_Header{Ttl: 1, Class: dns.ClassINET, Rrtype: dns.TypeTXT}
-	h.Name = label
-
-	healthstatus := make(map[string]map[string]bool)
-
-	if l, ok := z.Labels[baseLabel]; ok {
-		for qt, records := range l.Records {
-			if qts, ok := dns.TypeToString[qt]; ok {
-				hmap := make(map[string]bool)
-				for _, record := range records {
-					if record.Test != nil {
-						hmap[(*record.Test).IP().String()] = health.TestRunner.IsHealthy(record.Test)
-					}
-				}
-				healthstatus[qts] = hmap
-			}
-		}
-	}
-
-	js, _ := json.Marshal(healthstatus)
-
-	return []dns.RR{&dns.TXT{Hdr: h, Txt: []string{string(js)}}}
-}

+ 4 - 3
serve_test.go

@@ -7,6 +7,7 @@ import (
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
+	"github.com/abh/geodns/zones"
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
 	. "gopkg.in/check.v1"
 	. "gopkg.in/check.v1"
 )
 )
@@ -28,10 +29,10 @@ func (s *ServeSuite) SetUpSuite(c *C) {
 
 
 	srv := Server{}
 	srv := Server{}
 
 
-	Zones := make(Zones)
-	srv.setupPgeodnsZone(Zones)
+	zonelist := make(zones.Zones)
+	srv.setupPgeodnsZone(zonelist)
 	srv.setupRootZone()
 	srv.setupRootZone()
-	srv.zonesReadDir("dns", Zones)
+	srv.zonesReadDir("dns", zonelist)
 
 
 	// listenAndServe returns after listening on udp + tcp, so just
 	// listenAndServe returns after listening on udp + tcp, so just
 	// wait for it before continuing
 	// wait for it before continuing

+ 148 - 5
server.go

@@ -1,10 +1,18 @@
 package main
 package main
 
 
 import (
 import (
+	"crypto/sha256"
+	"encoding/hex"
+	"fmt"
+	"io/ioutil"
 	"log"
 	"log"
+	"path"
+	"strings"
 	"time"
 	"time"
 
 
+	"github.com/abh/geodns/applog"
 	"github.com/abh/geodns/querylog"
 	"github.com/abh/geodns/querylog"
+	"github.com/abh/geodns/zones"
 
 
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
 )
 )
@@ -13,6 +21,12 @@ type Server struct {
 	queryLogger querylog.QueryLogger
 	queryLogger querylog.QueryLogger
 }
 }
 
 
+// track when each zone was read last
+type zoneReadRecord struct {
+	time time.Time
+	hash string
+}
+
 func NewServer() *Server {
 func NewServer() *Server {
 	return &Server{}
 	return &Server{}
 }
 }
@@ -23,9 +37,9 @@ func (srv *Server) SetQueryLogger(logger querylog.QueryLogger) {
 	srv.queryLogger = logger
 	srv.queryLogger = logger
 }
 }
 
 
-func (srv *Server) setupServerFunc(Zone *Zone) func(dns.ResponseWriter, *dns.Msg) {
+func (srv *Server) setupServerFunc(zone *zones.Zone) func(dns.ResponseWriter, *dns.Msg) {
 	return func(w dns.ResponseWriter, r *dns.Msg) {
 	return func(w dns.ResponseWriter, r *dns.Msg) {
-		srv.serve(w, r, Zone)
+		srv.serve(w, r, zone)
 	}
 	}
 }
 }
 
 
@@ -46,7 +60,7 @@ func (srv *Server) listenAndServe(ip string) {
 	}
 	}
 }
 }
 
 
-func (srv *Server) addHandler(zones Zones, name string, config *Zone) {
+func (srv *Server) addHandler(zones zones.Zones, name string, config *zones.Zone) {
 	oldZone := zones[name]
 	oldZone := zones[name]
 	// across the recconfiguration keep a reference to all healthchecks to ensure
 	// across the recconfiguration keep a reference to all healthchecks to ensure
 	// the global map doesn't get destroyed
 	// the global map doesn't get destroyed
@@ -61,9 +75,138 @@ func (srv *Server) addHandler(zones Zones, name string, config *Zone) {
 	dns.HandleFunc(name, srv.setupServerFunc(config))
 	dns.HandleFunc(name, srv.setupServerFunc(config))
 }
 }
 
 
-func (srv *Server) zonesReader(dirName string, zones Zones) {
+func (srv *Server) setupPgeodnsZone(zonelist zones.Zones) {
+	zoneName := "pgeodns"
+	zone := zones.NewZone(zoneName)
+	label := new(zones.Label)
+	label.Records = make(map[uint16]zones.Records)
+	label.Weight = make(map[uint16]int)
+	zone.Labels[""] = label
+	zone.AddSOA()
+	srv.addHandler(zonelist, zoneName, zone)
+}
+
+func (srv *Server) setupRootZone() {
+	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
+		m := new(dns.Msg)
+		m.SetRcode(r, dns.RcodeRefused)
+		w.WriteMsg(m)
+	})
+}
+
+var lastRead = map[string]*zoneReadRecord{}
+
+func (srv *Server) zonesReader(dirName string, zones zones.Zones) {
 	for {
 	for {
-		srv.zonesReadDir(dirName, zones)
+		err := srv.zonesReadDir(dirName, zones)
+		if err != nil {
+			log.Printf("error reading zones: %s", err)
+		}
 		time.Sleep(5 * time.Second)
 		time.Sleep(5 * time.Second)
 	}
 	}
 }
 }
+
+func (srv *Server) zonesReadDir(dirName string, zonelist zones.Zones) error {
+	dir, err := ioutil.ReadDir(dirName)
+	if err != nil {
+		return fmt.Errorf("could not read", dirName, ":", err)
+	}
+
+	seenZones := map[string]bool{}
+
+	var parseErr error
+
+	for _, file := range dir {
+		fileName := file.Name()
+		if !strings.HasSuffix(strings.ToLower(fileName), ".json") ||
+			strings.HasPrefix(path.Base(fileName), ".") ||
+			file.IsDir() {
+			continue
+		}
+
+		zoneName := zoneNameFromFile(fileName)
+
+		seenZones[zoneName] = true
+
+		if _, ok := lastRead[zoneName]; !ok || file.ModTime().After(lastRead[zoneName].time) {
+			modTime := file.ModTime()
+			if ok {
+				applog.Printf("Reloading %s\n", fileName)
+				lastRead[zoneName].time = modTime
+			} else {
+				applog.Printf("Reading new file %s\n", fileName)
+				lastRead[zoneName] = &zoneReadRecord{time: modTime}
+			}
+
+			filename := path.Join(dirName, fileName)
+
+			// Check the sha256 of the file has not changed. It's worth an explanation of
+			// why there isn't a TOCTOU race here. Conceivably after checking whether the
+			// SHA has changed, the contents then change again before we actually load
+			// the JSON. This can occur in two situations:
+			//
+			// 1. The SHA has not changed when we read the file for the SHA, but then
+			//    changes before we process the JSON
+			//
+			// 2. The SHA has changed when we read the file for the SHA, but then changes
+			//    again before we process the JSON
+			//
+			// In circumstance (1) we won't reread the file the first time, but the subsequent
+			// change should alter the mtime again, causing us to reread it. This reflects
+			// the fact there were actually two changes.
+			//
+			// In circumstance (2) we have already reread the file once, and then when the
+			// contents are changed the mtime changes again
+			//
+			// Provided files are replaced atomically, this should be OK. If files are not
+			// replaced atomically we have other problems (e.g. partial reads).
+
+			sha256 := sha256File(filename)
+			if lastRead[zoneName].hash == sha256 {
+				applog.Printf("Skipping new file %s as hash is unchanged\n", filename)
+				continue
+			}
+
+			zone, err := zones.ReadZoneFile(zoneName, filename)
+			if zone == nil || err != nil {
+				parseErr = fmt.Errorf("Error reading zone '%s': %s", zoneName, err)
+				log.Println(parseErr.Error())
+				continue
+			}
+
+			(lastRead[zoneName]).hash = sha256
+
+			srv.addHandler(zonelist, zoneName, zone)
+		}
+	}
+
+	for zoneName, zone := range zonelist {
+		if zoneName == "pgeodns" {
+			continue
+		}
+		if ok, _ := seenZones[zoneName]; ok {
+			continue
+		}
+		log.Println("Removing zone", zone.Origin)
+		delete(lastRead, zoneName)
+		zone.Close()
+		dns.HandleRemove(zoneName)
+		delete(zonelist, zoneName)
+	}
+
+	return parseErr
+}
+
+func zoneNameFromFile(fileName string) string {
+	return fileName[0:strings.LastIndex(fileName, ".")]
+}
+
+func sha256File(fn string) string {
+	if data, err := ioutil.ReadFile(fn); err != nil {
+		return ""
+	} else {
+		hasher := sha256.New()
+		hasher.Write(data)
+		return hex.EncodeToString(hasher.Sum(nil))
+	}
+}

+ 0 - 86
stathat.go

@@ -1,86 +0,0 @@
-package main
-
-import (
-	"log"
-	"runtime"
-	"strings"
-	"time"
-
-	"github.com/rcrowley/go-metrics"
-	"github.com/stathat/go"
-)
-
-func (zs *Zones) statHatPoster() {
-
-	if !Config.HasStatHat() {
-		return
-	}
-
-	stathatGroups := append(serverGroups, "total", serverID)
-	suffix := strings.Join(stathatGroups, ",")
-
-	lastCounts := map[string]int64{}
-	lastEdnsCounts := map[string]int64{}
-
-	for name, zone := range *zs {
-		lastCounts[name] = zone.Metrics.Queries.Count()
-		lastEdnsCounts[name] = zone.Metrics.EdnsQueries.Count()
-	}
-
-	for {
-		time.Sleep(60 * time.Second)
-
-		for name, zone := range *zs {
-
-			count := zone.Metrics.Queries.Count()
-			newCount := count - lastCounts[name]
-			lastCounts[name] = count
-
-			if zone.Logging != nil && zone.Logging.StatHat == true {
-
-				apiKey := zone.Logging.StatHatAPI
-				if len(apiKey) == 0 {
-					apiKey = Config.StatHatApiKey()
-				}
-				if len(apiKey) == 0 {
-					continue
-				}
-				stathat.PostEZCount("zone "+name+" queries~"+suffix, Config.StatHatApiKey(), int(newCount))
-
-				ednsCount := zone.Metrics.EdnsQueries.Count()
-				newEdnsCount := ednsCount - lastEdnsCounts[name]
-				lastEdnsCounts[name] = ednsCount
-				stathat.PostEZCount("zone "+name+" edns queries~"+suffix, Config.StatHatApiKey(), int(newEdnsCount))
-
-			}
-		}
-	}
-}
-
-func statHatPoster() {
-
-	qCounter := metrics.Get("queries").(metrics.Meter)
-	lastQueryCount := qCounter.Count()
-	stathatGroups := append(serverGroups, "total", serverID)
-	suffix := strings.Join(stathatGroups, ",")
-	// stathat.Verbose = true
-
-	for {
-		time.Sleep(60 * time.Second)
-
-		if !Config.HasStatHat() {
-			log.Println("No stathat configuration")
-			continue
-		}
-
-		log.Println("Posting to stathat")
-
-		current := qCounter.Count()
-		newQueries := current - lastQueryCount
-		lastQueryCount = current
-
-		stathat.PostEZCount("queries~"+suffix, Config.StatHatApiKey(), int(newQueries))
-		stathat.PostEZValue("goroutines "+serverID, Config.StatHatApiKey(), float64(runtime.NumGoroutine()))
-
-	}
-}

+ 30 - 29
geoip.go → targeting/geoip.go

@@ -1,17 +1,19 @@
-package main
+package targeting
 
 
 import (
 import (
-	"github.com/abh/geodns/countries"
-	"github.com/abh/geoip"
-	"github.com/golang/geo/s2"
 	"log"
 	"log"
 	"math"
 	"math"
 	"net"
 	"net"
 	"strings"
 	"strings"
 	"time"
 	"time"
+
+	"github.com/abh/geodns/countries"
+
+	"github.com/abh/geoip"
+	"github.com/golang/geo/s2"
 )
 )
 
 
-type GeoIP struct {
+type GeoIPData struct {
 	country         *geoip.GeoIP
 	country         *geoip.GeoIP
 	hasCountry      bool
 	hasCountry      bool
 	countryLastLoad time.Time
 	countryLastLoad time.Time
@@ -28,8 +30,8 @@ type GeoIP struct {
 const MAX_DISTANCE = 360
 const MAX_DISTANCE = 360
 
 
 type Location struct {
 type Location struct {
-	latitude  float64
-	longitude float64
+	Latitude  float64
+	Longitude float64
 }
 }
 
 
 func (l *Location) MaxDistance() float64 {
 func (l *Location) MaxDistance() float64 {
@@ -40,20 +42,25 @@ func (l *Location) Distance(to *Location) float64 {
 	if to == nil {
 	if to == nil {
 		return MAX_DISTANCE
 		return MAX_DISTANCE
 	}
 	}
-	ll1 := s2.LatLngFromDegrees(l.latitude, l.longitude)
-	ll2 := s2.LatLngFromDegrees(to.latitude, to.longitude)
+	ll1 := s2.LatLngFromDegrees(l.Latitude, l.Longitude)
+	ll2 := s2.LatLngFromDegrees(to.Latitude, to.Longitude)
 	angle := ll1.Distance(ll2)
 	angle := ll1.Distance(ll2)
 	return math.Abs(angle.Degrees())
 	return math.Abs(angle.Degrees())
 }
 }
 
 
-var geoIP = new(GeoIP)
+var geoIP = &GeoIPData{}
+
+func GeoIP() *GeoIPData {
+	// mutex this and allow it to reload as needed?
+	return geoIP
+}
 
 
-func (g *GeoIP) GetCountry(ip net.IP) (country, continent string, netmask int) {
+func (g *GeoIPData) GetCountry(ip net.IP) (country, continent string, netmask int) {
 	if g.country == nil {
 	if g.country == nil {
 		return "", "", 0
 		return "", "", 0
 	}
 	}
 
 
-	country, netmask = geoIP.country.GetCountry(ip.String())
+	country, netmask = g.country.GetCountry(ip.String())
 	if len(country) > 0 {
 	if len(country) > 0 {
 		country = strings.ToLower(country)
 		country = strings.ToLower(country)
 		continent = countries.CountryContinent[country]
 		continent = countries.CountryContinent[country]
@@ -61,14 +68,14 @@ func (g *GeoIP) GetCountry(ip net.IP) (country, continent string, netmask int) {
 	return
 	return
 }
 }
 
 
-func (g *GeoIP) GetCountryRegion(ip net.IP) (country, continent, regionGroup, region string, netmask int, location *Location) {
+func (g *GeoIPData) GetCountryRegion(ip net.IP) (country, continent, regionGroup, region string, netmask int, location *Location) {
 	if g.city == nil {
 	if g.city == nil {
 		log.Println("No city database available")
 		log.Println("No city database available")
 		country, continent, netmask = g.GetCountry(ip)
 		country, continent, netmask = g.GetCountry(ip)
 		return
 		return
 	}
 	}
 
 
-	record := geoIP.city.GetRecord(ip.String())
+	record := g.city.GetRecord(ip.String())
 	if record == nil {
 	if record == nil {
 		return
 		return
 	}
 	}
@@ -90,7 +97,7 @@ func (g *GeoIP) GetCountryRegion(ip net.IP) (country, continent, regionGroup, re
 	return
 	return
 }
 }
 
 
-func (g *GeoIP) GetASN(ip net.IP) (asn string, netmask int) {
+func (g *GeoIPData) GetASN(ip net.IP) (asn string, netmask int) {
 	if g.asn == nil {
 	if g.asn == nil {
 		log.Println("No asn database available")
 		log.Println("No asn database available")
 		return
 		return
@@ -105,23 +112,21 @@ func (g *GeoIP) GetASN(ip net.IP) (asn string, netmask int) {
 	return
 	return
 }
 }
 
 
-func (g *GeoIP) setDirectory() {
-	directory := Config.GeoIPDirectory()
+func (g *GeoIPData) SetDirectory(directory string) {
+	// directory := Config.GeoIPDataDirectory()
 	if len(directory) > 0 {
 	if len(directory) > 0 {
 		geoip.SetCustomDirectory(directory)
 		geoip.SetCustomDirectory(directory)
 	}
 	}
 }
 }
 
 
-func (g *GeoIP) setupGeoIPCountry() {
+func (g *GeoIPData) SetupGeoIPCountry() {
 	if g.country != nil {
 	if g.country != nil {
 		return
 		return
 	}
 	}
 
 
-	g.setDirectory()
-
 	gi, err := geoip.OpenType(geoip.GEOIP_COUNTRY_EDITION)
 	gi, err := geoip.OpenType(geoip.GEOIP_COUNTRY_EDITION)
 	if gi == nil || err != nil {
 	if gi == nil || err != nil {
-		log.Printf("Could not open country GeoIP database: %s\n", err)
+		log.Printf("Could not open country GeoIPData database: %s\n", err)
 		return
 		return
 	}
 	}
 	g.countryLastLoad = time.Now()
 	g.countryLastLoad = time.Now()
@@ -130,16 +135,14 @@ func (g *GeoIP) setupGeoIPCountry() {
 
 
 }
 }
 
 
-func (g *GeoIP) setupGeoIPCity() {
+func (g *GeoIPData) SetupGeoIPCity() {
 	if g.city != nil {
 	if g.city != nil {
 		return
 		return
 	}
 	}
 
 
-	g.setDirectory()
-
 	gi, err := geoip.OpenType(geoip.GEOIP_CITY_EDITION_REV1)
 	gi, err := geoip.OpenType(geoip.GEOIP_CITY_EDITION_REV1)
 	if gi == nil || err != nil {
 	if gi == nil || err != nil {
-		log.Printf("Could not open city GeoIP database: %s\n", err)
+		log.Printf("Could not open city GeoIPData database: %s\n", err)
 		return
 		return
 	}
 	}
 	g.cityLastLoad = time.Now()
 	g.cityLastLoad = time.Now()
@@ -148,16 +151,14 @@ func (g *GeoIP) setupGeoIPCity() {
 
 
 }
 }
 
 
-func (g *GeoIP) setupGeoIPASN() {
+func (g *GeoIPData) SetupGeoIPASN() {
 	if g.asn != nil {
 	if g.asn != nil {
 		return
 		return
 	}
 	}
 
 
-	g.setDirectory()
-
 	gi, err := geoip.OpenType(geoip.GEOIP_ASNUM_EDITION)
 	gi, err := geoip.OpenType(geoip.GEOIP_ASNUM_EDITION)
 	if gi == nil || err != nil {
 	if gi == nil || err != nil {
-		log.Printf("Could not open ASN GeoIP database: %s\n", err)
+		log.Printf("Could not open ASN GeoIPData database: %s\n", err)
 		return
 		return
 	}
 	}
 	g.asnLastLoad = time.Now()
 	g.asnLastLoad = time.Now()

+ 6 - 4
targeting.go → targeting/targeting.go

@@ -1,4 +1,4 @@
-package main
+package targeting
 
 
 import (
 import (
 	"fmt"
 	"fmt"
@@ -32,13 +32,15 @@ func (t TargetOptions) GetTargets(ip net.IP, hasClosest bool) ([]string, int, *L
 	var netmask int
 	var netmask int
 	var location *Location
 	var location *Location
 
 
+	g := GeoIP()
+
 	if t&TargetASN > 0 {
 	if t&TargetASN > 0 {
-		asn, netmask = geoIP.GetASN(ip)
+		asn, netmask = g.GetASN(ip)
 	}
 	}
 	if t&TargetRegion > 0 || t&TargetRegionGroup > 0 || hasClosest {
 	if t&TargetRegion > 0 || t&TargetRegionGroup > 0 || hasClosest {
 		country, continent, regionGroup, region, netmask, location = geoIP.GetCountryRegion(ip)
 		country, continent, regionGroup, region, netmask, location = geoIP.GetCountryRegion(ip)
 	} else if t&TargetCountry > 0 || t&TargetContinent > 0 {
 	} else if t&TargetCountry > 0 || t&TargetContinent > 0 {
-		country, continent, netmask = geoIP.GetCountry(ip)
+		country, continent, netmask = g.GetCountry(ip)
 	}
 	}
 
 
 	if t&TargetIP > 0 {
 	if t&TargetIP > 0 {
@@ -108,7 +110,7 @@ func (t TargetOptions) String() string {
 	return strings.Join(targets, " ")
 	return strings.Join(targets, " ")
 }
 }
 
 
-func parseTargets(v string) (tgt TargetOptions, err error) {
+func ParseTargets(v string) (tgt TargetOptions, err error) {
 	targets := strings.Split(v, " ")
 	targets := strings.Split(v, " ")
 	for _, t := range targets {
 	for _, t := range targets {
 		var x TargetOptions
 		var x TargetOptions

+ 109 - 0
targeting/targeting_test.go

@@ -0,0 +1,109 @@
+package targeting
+
+import (
+	"net"
+	"reflect"
+	"testing"
+)
+
+func TestTargetString(t *testing.T) {
+	tgt := TargetOptions(TargetGlobal + TargetCountry + TargetContinent)
+
+	str := tgt.String()
+	if str != "@ continent country" {
+		t.Logf("wrong target string '%s'", str)
+		t.Fail()
+	}
+}
+
+func TestTargetParse(t *testing.T) {
+	tgt, err := ParseTargets("@ foo country")
+	str := tgt.String()
+	if str != "@ country" {
+		t.Logf("Expected '@ country', got '%s'", str)
+		t.Fail()
+	}
+	if err.Error() != "Unknown targeting option 'foo'" {
+		t.Log("Failed erroring on an unknown targeting option")
+		t.Fail()
+	}
+
+	tests := [][]string{
+		[]string{"@ continent country asn", "@ continent country asn"},
+		[]string{"asn country", "country asn"},
+		[]string{"continent @ country", "@ continent country"},
+	}
+
+	for _, strs := range tests {
+		tgt, err = ParseTargets(strs[0])
+		if err != nil {
+			t.Fatalf("Parsing '%s': %s", strs[0], err)
+		}
+		if tgt.String() != strs[1] {
+			t.Logf("Unexpected result parsing '%s', got '%s', expected '%s'",
+				strs[0], tgt.String(), strs[1])
+			t.Fail()
+		}
+	}
+}
+
+func TestGetTargets(t *testing.T) {
+	ip := net.ParseIP("207.171.1.1")
+
+	GeoIP().SetupGeoIPCity()
+	GeoIP().SetupGeoIPCountry()
+	GeoIP().SetupGeoIPASN()
+
+	tgt, _ := ParseTargets("@ continent country")
+	targets, _, _ := tgt.GetTargets(ip, false)
+	if !reflect.DeepEqual(targets, []string{"us", "north-america", "@"}) {
+		t.Fatalf("Unexpected parse results of targets")
+	}
+
+	if geoIP.city == nil {
+		t.Log("City GeoIP database requred for these tests")
+		return
+	}
+
+	tests := []struct {
+		Str     string
+		Targets []string
+		IP      string
+	}{
+		{
+			"@ continent country region ",
+			[]string{"us-ca", "us", "north-america", "@"},
+			"",
+		},
+		{
+			"@ continent regiongroup country region ",
+			[]string{"us-ca", "us-west", "us", "north-america", "@"},
+			"",
+		},
+		{
+			"@ continent regiongroup country region asn ip",
+			[]string{"[207.171.1.1]", "[207.171.1.0]", "as7012", "us-ca", "us-west", "us", "north-america", "@"},
+			"",
+		},
+		{
+			"ip",
+			[]string{"[2607:f238:2::ff:4]", "[2607:f238:2::]"},
+			"2607:f238:2:0::ff:4",
+		},
+	}
+
+	for _, test := range tests {
+		if len(test.IP) > 0 {
+			ip = net.ParseIP(test.IP)
+		}
+
+		tgt, _ = ParseTargets(test.Str)
+		targets, _, _ = tgt.GetTargets(ip, false)
+
+		if !reflect.DeepEqual(targets, test.Targets) {
+			t.Logf("For targets '%s' expected '%s', got '%s'", test.Str, test.Targets, targets)
+			t.Fail()
+		}
+
+	}
+}

+ 0 - 69
targeting_test.go

@@ -1,69 +0,0 @@
-package main
-
-import (
-	. "gopkg.in/check.v1"
-	"net"
-)
-
-type TargetingSuite struct {
-}
-
-var _ = Suite(&TargetingSuite{})
-
-func (s *TargetingSuite) SetUpSuite(c *C) {
-	Config.GeoIP.Directory = "db"
-}
-
-func (s *TargetingSuite) TestTargetString(c *C) {
-	tgt := TargetOptions(TargetGlobal + TargetCountry + TargetContinent)
-
-	str := tgt.String()
-	c.Check(str, Equals, "@ continent country")
-}
-
-func (s *TargetingSuite) TestTargetParse(c *C) {
-	tgt, err := parseTargets("@ foo country")
-	str := tgt.String()
-	c.Check(str, Equals, "@ country")
-	c.Check(err.Error(), Equals, "Unknown targeting option 'foo'")
-
-	tgt, err = parseTargets("@ continent country asn")
-	c.Assert(err, IsNil)
-	str = tgt.String()
-	c.Check(str, Equals, "@ continent country asn")
-}
-
-func (s *TargetingSuite) TestGetTargets(c *C) {
-	ip := net.ParseIP("207.171.1.1")
-
-	geoIP.setupGeoIPCity()
-	geoIP.setupGeoIPCountry()
-	geoIP.setupGeoIPASN()
-
-	tgt, _ := parseTargets("@ continent country")
-	targets, _, _ := tgt.GetTargets(ip, false)
-	c.Check(targets, DeepEquals, []string{"us", "north-america", "@"})
-
-	if geoIP.city == nil {
-		c.Log("City GeoIP database requred for these tests")
-		return
-	}
-
-	tgt, _ = parseTargets("@ continent country region ")
-	targets, _, _ = tgt.GetTargets(ip, false)
-	c.Check(targets, DeepEquals, []string{"us-ca", "us", "north-america", "@"})
-
-	tgt, _ = parseTargets("@ continent regiongroup country region ")
-	targets, _, _ = tgt.GetTargets(ip, false)
-	c.Check(targets, DeepEquals, []string{"us-ca", "us-west", "us", "north-america", "@"})
-
-	tgt, _ = parseTargets("@ continent regiongroup country region asn ip")
-	targets, _, _ = tgt.GetTargets(ip, false)
-	c.Check(targets, DeepEquals, []string{"[207.171.1.1]", "[207.171.1.0]", "as7012", "us-ca", "us-west", "us", "north-america", "@"})
-
-	ip = net.ParseIP("2607:f238:2:0::ff:4")
-	tgt, _ = parseTargets("ip")
-	targets, _, _ = tgt.GetTargets(ip, false)
-	c.Check(targets, DeepEquals, []string{"[2607:f238:2::ff:4]", "[2607:f238:2::]"})
-
-}

+ 0 - 22
vendor/github.com/stathat/go/.gitignore

@@ -1,22 +0,0 @@
-# Compiled Object files, Static and Dynamic libs (Shared Objects)
-*.o
-*.a
-*.so
-
-# Folders
-_obj
-_test
-
-# Architecture specific extensions/prefixes
-*.[568vq]
-[568vq].out
-
-*.cgo1.go
-*.cgo2.c
-_cgo_defun.c
-_cgo_gotypes.go
-_cgo_export.*
-
-_testmain.go
-
-*.exe

+ 0 - 19
vendor/github.com/stathat/go/LICENSE

@@ -1,19 +0,0 @@
-Copyright (C) 2012 Numerotron Inc.
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in
-all copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-THE SOFTWARE.

+ 0 - 67
vendor/github.com/stathat/go/README.md

@@ -1,67 +0,0 @@
-stathat
-=======
-
-This is a Go package for posting stats to your StatHat account.
-
-For more information about StatHat, visit [www.stathat.com](http://www.stathat.com).
-
-Installation
-------------
-
-Use `go get`:
-
-    go get github.com/stathat/go
-
-That's it.
-
-Import it like this:
-
-    import (
-            "github.com/stathat/go"
-    )
-
-Usage
------
-
-The easiest way to use the package is with the EZ API functions.  You can add stats
-directly in your code by just adding a call with a new stat name.  Once StatHat
-receives the call, a new stat will be created for you.
-
-To post a count of 1 to a stat:
-
-    stathat.PostEZCountOne("messages sent - female to male", "[email protected]")
-
-To specify the count:
-
-    stathat.PostEZCount("messages sent - male to male", "[email protected]", 37)
-
-To post a value:
-
-    stathat.PostEZValue("ws0 load average", "[email protected]", 0.372)
-
-There are also functions for the classic API.  The drawback to the classic API is
-that you need to create the stats using the web interface and copy the keys it
-gives you into your code.
-
-To post a count of 1 to a stat using the classic API:
-
-    stathat.PostCountOne("statkey", "userkey")
-
-To specify the count:
-
-    stathat.PostCount("statkey", "userkey", 37)
-
-To post a value:
-
-    stathat.PostValue("statkey", "userkey", 0.372)
-
-Contact us
-----------
-
-We'd love to hear from you if you are using this in your projects!  Please drop us a
-line: [@stat_hat](http://twitter.com/stat_hat) or [contact us here](http://www.stathat.com/docs/contact).
-
-About
------
-
-Written by Patrick Crosby at [StatHat](http://www.stathat.com).  Twitter:  [@stat_hat](http://twitter.com/stat_hat)

+ 0 - 27
vendor/github.com/stathat/go/example_test.go

@@ -1,27 +0,0 @@
-// Copyright (C) 2012 Numerotron Inc.
-// Use of this source code is governed by an MIT-style license
-// that can be found in the LICENSE file.
-
-package stathat_test
-
-import (
-	"fmt"
-	"github.com/stathat/go"
-	"log"
-	"time"
-)
-
-func ExamplePostEZCountOne() {
-	log.Printf("starting example")
-	stathat.Verbose = true
-	err := stathat.PostEZCountOne("go example test run", "[email protected]")
-	if err != nil {
-		log.Printf("error posting ez count one: %v", err)
-		return
-	}
-	ok := stathat.WaitUntilFinished(5 * time.Second)
-	if ok {
-		fmt.Println("ok")
-	}
-	// Output: ok
-}

+ 0 - 420
vendor/github.com/stathat/go/stathat.go

@@ -1,420 +0,0 @@
-// Copyright (C) 2012 Numerotron Inc.
-// Use of this source code is governed by an MIT-style license
-// that can be found in the LICENSE file.
-
-// Copyright 2012 Numerotron Inc.
-// Use of this source code is governed by an MIT-style license
-// that can be found in the LICENSE file.
-//
-// Developed at www.stathat.com by Patrick Crosby
-// Contact us on twitter with any questions:  twitter.com/stat_hat
-
-// The stathat package makes it easy to post any values to your StatHat
-// account.
-package stathat
-
-import (
-	"fmt"
-	"io/ioutil"
-	"log"
-	"net/http"
-	"net/url"
-	"strconv"
-	"sync"
-	"time"
-)
-
-const hostname = "api.stathat.com"
-
-type statKind int
-
-const (
-	_                 = iota
-	kcounter statKind = iota
-	kvalue
-)
-
-func (sk statKind) classicPath() string {
-	switch sk {
-	case kcounter:
-		return "/c"
-	case kvalue:
-		return "/v"
-	}
-	return ""
-}
-
-type apiKind int
-
-const (
-	_               = iota
-	classic apiKind = iota
-	ez
-)
-
-func (ak apiKind) path(sk statKind) string {
-	switch ak {
-	case ez:
-		return "/ez"
-	case classic:
-		return sk.classicPath()
-	}
-	return ""
-}
-
-type statReport struct {
-	StatKey   string
-	UserKey   string
-	Value     float64
-	Timestamp int64
-	statType  statKind
-	apiType   apiKind
-}
-
-// Reporter is a StatHat client that can report stat values/counts to the servers.
-type Reporter struct {
-	reports chan *statReport
-	done    chan bool
-	client  *http.Client
-	wg      *sync.WaitGroup
-}
-
-// NewReporter returns a new Reporter.  You must specify the channel bufferSize and the
-// goroutine poolSize.  You can pass in nil for the transport and it will use the
-// default http transport.
-func NewReporter(bufferSize, poolSize int, transport http.RoundTripper) *Reporter {
-	r := new(Reporter)
-	r.client = &http.Client{Transport: transport}
-	r.reports = make(chan *statReport, bufferSize)
-	r.done = make(chan bool)
-	r.wg = new(sync.WaitGroup)
-	for i := 0; i < poolSize; i++ {
-		r.wg.Add(1)
-		go r.processReports()
-	}
-	return r
-}
-
-// DefaultReporter is the default instance of *Reporter.
-var DefaultReporter = NewReporter(100000, 10, nil)
-
-var testingEnv = false
-
-type testPost struct {
-	url    string
-	values url.Values
-}
-
-var testPostChannel chan *testPost
-
-// The Verbose flag determines if the package should write verbose output to stdout.
-var Verbose = false
-
-func setTesting() {
-	testingEnv = true
-	testPostChannel = make(chan *testPost)
-}
-
-func newEZStatCount(statName, ezkey string, count int) *statReport {
-	return &statReport{StatKey: statName,
-		UserKey:  ezkey,
-		Value:    float64(count),
-		statType: kcounter,
-		apiType:  ez}
-}
-
-func newEZStatValue(statName, ezkey string, value float64) *statReport {
-	return &statReport{StatKey: statName,
-		UserKey:  ezkey,
-		Value:    value,
-		statType: kvalue,
-		apiType:  ez}
-}
-
-func newClassicStatCount(statKey, userKey string, count int) *statReport {
-	return &statReport{StatKey: statKey,
-		UserKey:  userKey,
-		Value:    float64(count),
-		statType: kcounter,
-		apiType:  classic}
-}
-
-func newClassicStatValue(statKey, userKey string, value float64) *statReport {
-	return &statReport{StatKey: statKey,
-		UserKey:  userKey,
-		Value:    value,
-		statType: kvalue,
-		apiType:  classic}
-}
-
-func (sr *statReport) values() url.Values {
-	switch sr.apiType {
-	case ez:
-		return sr.ezValues()
-	case classic:
-		return sr.classicValues()
-	}
-
-	return nil
-}
-
-func (sr *statReport) ezValues() url.Values {
-	switch sr.statType {
-	case kcounter:
-		return sr.ezCounterValues()
-	case kvalue:
-		return sr.ezValueValues()
-	}
-	return nil
-}
-
-func (sr *statReport) classicValues() url.Values {
-	switch sr.statType {
-	case kcounter:
-		return sr.classicCounterValues()
-	case kvalue:
-		return sr.classicValueValues()
-	}
-	return nil
-}
-
-func (sr *statReport) ezCommonValues() url.Values {
-	result := make(url.Values)
-	result.Set("stat", sr.StatKey)
-	result.Set("ezkey", sr.UserKey)
-	if sr.Timestamp > 0 {
-		result.Set("t", sr.timeString())
-	}
-	return result
-}
-
-func (sr *statReport) classicCommonValues() url.Values {
-	result := make(url.Values)
-	result.Set("key", sr.StatKey)
-	result.Set("ukey", sr.UserKey)
-	if sr.Timestamp > 0 {
-		result.Set("t", sr.timeString())
-	}
-	return result
-}
-
-func (sr *statReport) ezCounterValues() url.Values {
-	result := sr.ezCommonValues()
-	result.Set("count", sr.valueString())
-	return result
-}
-
-func (sr *statReport) ezValueValues() url.Values {
-	result := sr.ezCommonValues()
-	result.Set("value", sr.valueString())
-	return result
-}
-
-func (sr *statReport) classicCounterValues() url.Values {
-	result := sr.classicCommonValues()
-	result.Set("count", sr.valueString())
-	return result
-}
-
-func (sr *statReport) classicValueValues() url.Values {
-	result := sr.classicCommonValues()
-	result.Set("value", sr.valueString())
-	return result
-}
-
-func (sr *statReport) valueString() string {
-	return strconv.FormatFloat(sr.Value, 'g', -1, 64)
-}
-
-func (sr *statReport) timeString() string {
-	return strconv.FormatInt(sr.Timestamp, 10)
-}
-
-func (sr *statReport) path() string {
-	return sr.apiType.path(sr.statType)
-}
-
-func (sr *statReport) url() string {
-	return fmt.Sprintf("http://%s%s", hostname, sr.path())
-}
-
-// Using the classic API, posts a count to a stat using DefaultReporter.
-func PostCount(statKey, userKey string, count int) error {
-	return DefaultReporter.PostCount(statKey, userKey, count)
-}
-
-// Using the classic API, posts a count to a stat using DefaultReporter at a specific
-// time.
-func PostCountTime(statKey, userKey string, count int, timestamp int64) error {
-	return DefaultReporter.PostCountTime(statKey, userKey, count, timestamp)
-}
-
-// Using the classic API, posts a count of 1 to a stat using DefaultReporter.
-func PostCountOne(statKey, userKey string) error {
-	return DefaultReporter.PostCountOne(statKey, userKey)
-}
-
-// Using the classic API, posts a value to a stat using DefaultReporter.
-func PostValue(statKey, userKey string, value float64) error {
-	return DefaultReporter.PostValue(statKey, userKey, value)
-}
-
-// Using the classic API, posts a value to a stat at a specific time using DefaultReporter.
-func PostValueTime(statKey, userKey string, value float64, timestamp int64) error {
-	return DefaultReporter.PostValueTime(statKey, userKey, value, timestamp)
-}
-
-// Using the EZ API, posts a count of 1 to a stat using DefaultReporter.
-func PostEZCountOne(statName, ezkey string) error {
-	return DefaultReporter.PostEZCountOne(statName, ezkey)
-}
-
-// Using the EZ API, posts a count to a stat using DefaultReporter.
-func PostEZCount(statName, ezkey string, count int) error {
-	return DefaultReporter.PostEZCount(statName, ezkey, count)
-}
-
-// Using the EZ API, posts a count to a stat at a specific time using DefaultReporter.
-func PostEZCountTime(statName, ezkey string, count int, timestamp int64) error {
-	return DefaultReporter.PostEZCountTime(statName, ezkey, count, timestamp)
-}
-
-// Using the EZ API, posts a value to a stat using DefaultReporter.
-func PostEZValue(statName, ezkey string, value float64) error {
-	return DefaultReporter.PostEZValue(statName, ezkey, value)
-}
-
-// Using the EZ API, posts a value to a stat at a specific time using DefaultReporter.
-func PostEZValueTime(statName, ezkey string, value float64, timestamp int64) error {
-	return DefaultReporter.PostEZValueTime(statName, ezkey, value, timestamp)
-}
-
-// Wait for all stats to be sent, or until timeout. Useful for simple command-
-// line apps to defer a call to this in main()
-func WaitUntilFinished(timeout time.Duration) bool {
-	return DefaultReporter.WaitUntilFinished(timeout)
-}
-
-// Using the classic API, posts a count to a stat.
-func (r *Reporter) PostCount(statKey, userKey string, count int) error {
-	r.reports <- newClassicStatCount(statKey, userKey, count)
-	return nil
-}
-
-// Using the classic API, posts a count to a stat at a specific time.
-func (r *Reporter) PostCountTime(statKey, userKey string, count int, timestamp int64) error {
-	x := newClassicStatCount(statKey, userKey, count)
-	x.Timestamp = timestamp
-	r.reports <- x
-	return nil
-}
-
-// Using the classic API, posts a count of 1 to a stat.
-func (r *Reporter) PostCountOne(statKey, userKey string) error {
-	return r.PostCount(statKey, userKey, 1)
-}
-
-// Using the classic API, posts a value to a stat.
-func (r *Reporter) PostValue(statKey, userKey string, value float64) error {
-	r.reports <- newClassicStatValue(statKey, userKey, value)
-	return nil
-}
-
-// Using the classic API, posts a value to a stat at a specific time.
-func (r *Reporter) PostValueTime(statKey, userKey string, value float64, timestamp int64) error {
-	x := newClassicStatValue(statKey, userKey, value)
-	x.Timestamp = timestamp
-	r.reports <- x
-	return nil
-}
-
-// Using the EZ API, posts a count of 1 to a stat.
-func (r *Reporter) PostEZCountOne(statName, ezkey string) error {
-	return r.PostEZCount(statName, ezkey, 1)
-}
-
-// Using the EZ API, posts a count to a stat.
-func (r *Reporter) PostEZCount(statName, ezkey string, count int) error {
-	r.reports <- newEZStatCount(statName, ezkey, count)
-	return nil
-}
-
-// Using the EZ API, posts a count to a stat at a specific time.
-func (r *Reporter) PostEZCountTime(statName, ezkey string, count int, timestamp int64) error {
-	x := newEZStatCount(statName, ezkey, count)
-	x.Timestamp = timestamp
-	r.reports <- x
-	return nil
-}
-
-// Using the EZ API, posts a value to a stat.
-func (r *Reporter) PostEZValue(statName, ezkey string, value float64) error {
-	r.reports <- newEZStatValue(statName, ezkey, value)
-	return nil
-}
-
-// Using the EZ API, posts a value to a stat at a specific time.
-func (r *Reporter) PostEZValueTime(statName, ezkey string, value float64, timestamp int64) error {
-	x := newEZStatValue(statName, ezkey, value)
-	x.Timestamp = timestamp
-	r.reports <- x
-	return nil
-}
-
-func (r *Reporter) processReports() {
-	for {
-		sr, ok := <-r.reports
-
-		if !ok {
-			if Verbose {
-				log.Printf("channel closed, stopping processReports()")
-			}
-			break
-		}
-
-		if Verbose {
-			log.Printf("posting stat to stathat: %s, %v", sr.url(), sr.values())
-		}
-
-		if testingEnv {
-			if Verbose {
-				log.Printf("in test mode, putting stat on testPostChannel")
-			}
-			testPostChannel <- &testPost{sr.url(), sr.values()}
-			continue
-		}
-
-		resp, err := r.client.PostForm(sr.url(), sr.values())
-		if err != nil {
-			log.Printf("error posting stat to stathat: %s", err)
-			continue
-		}
-
-		if Verbose {
-			body, _ := ioutil.ReadAll(resp.Body)
-			log.Printf("stathat post result: %s", body)
-		}
-
-		resp.Body.Close()
-	}
-	r.wg.Done()
-}
-
-func (r *Reporter) finish() {
-	close(r.reports)
-	r.wg.Wait()
-	r.done <- true
-}
-
-// Wait for all stats to be sent, or until timeout. Useful for simple command-
-// line apps to defer a call to this in main()
-func (r *Reporter) WaitUntilFinished(timeout time.Duration) bool {
-	go r.finish()
-	select {
-	case <-r.done:
-		return true
-	case <-time.After(timeout):
-		return false
-	}
-	return false
-}

+ 0 - 320
vendor/github.com/stathat/go/stathat_test.go

@@ -1,320 +0,0 @@
-// Copyright (C) 2012 Numerotron Inc.
-// Use of this source code is governed by an MIT-style license
-// that can be found in the LICENSE file.
-
-package stathat
-
-import (
-	"testing"
-)
-
-func TestNewEZStatCount(t *testing.T) {
-	setTesting()
-	x := newEZStatCount("abc", "[email protected]", 1)
-	if x == nil {
-		t.Fatalf("expected a StatReport object")
-	}
-	if x.statType != kcounter {
-		t.Errorf("expected counter")
-	}
-	if x.apiType != ez {
-		t.Errorf("expected EZ api")
-	}
-	if x.StatKey != "abc" {
-		t.Errorf("expected abc")
-	}
-	if x.UserKey != "[email protected]" {
-		t.Errorf("expected [email protected]")
-	}
-	if x.Value != 1.0 {
-		t.Errorf("expected 1.0")
-	}
-	if x.Timestamp != 0 {
-		t.Errorf("expected 0")
-	}
-}
-
-func TestNewEZStatValue(t *testing.T) {
-	setTesting()
-	x := newEZStatValue("abc", "[email protected]", 3.14159)
-	if x == nil {
-		t.Fatalf("expected a StatReport object")
-	}
-	if x.statType != kvalue {
-		t.Errorf("expected value")
-	}
-	if x.apiType != ez {
-		t.Errorf("expected EZ api")
-	}
-	if x.StatKey != "abc" {
-		t.Errorf("expected abc")
-	}
-	if x.UserKey != "[email protected]" {
-		t.Errorf("expected [email protected]")
-	}
-	if x.Value != 3.14159 {
-		t.Errorf("expected 3.14159")
-	}
-}
-
-func TestNewClassicStatCount(t *testing.T) {
-	setTesting()
-	x := newClassicStatCount("statkey", "userkey", 1)
-	if x == nil {
-		t.Fatalf("expected a StatReport object")
-	}
-	if x.statType != kcounter {
-		t.Errorf("expected counter")
-	}
-	if x.apiType != classic {
-		t.Errorf("expected CLASSIC api")
-	}
-	if x.StatKey != "statkey" {
-		t.Errorf("expected statkey")
-	}
-	if x.UserKey != "userkey" {
-		t.Errorf("expected userkey")
-	}
-	if x.Value != 1.0 {
-		t.Errorf("expected 1.0")
-	}
-	if x.Timestamp != 0 {
-		t.Errorf("expected 0")
-	}
-}
-
-func TestNewClassicStatValue(t *testing.T) {
-	setTesting()
-	x := newClassicStatValue("statkey", "userkey", 2.28)
-	if x == nil {
-		t.Fatalf("expected a StatReport object")
-	}
-	if x.statType != kvalue {
-		t.Errorf("expected value")
-	}
-	if x.apiType != classic {
-		t.Errorf("expected CLASSIC api")
-	}
-	if x.StatKey != "statkey" {
-		t.Errorf("expected statkey")
-	}
-	if x.UserKey != "userkey" {
-		t.Errorf("expected userkey")
-	}
-	if x.Value != 2.28 {
-		t.Errorf("expected 2.28")
-	}
-}
-
-func TestURLValues(t *testing.T) {
-	setTesting()
-	x := newEZStatCount("abc", "[email protected]", 1)
-	v := x.values()
-	if v == nil {
-		t.Fatalf("expected url values")
-	}
-	if v.Get("stat") != "abc" {
-		t.Errorf("expected abc")
-	}
-	if v.Get("ezkey") != "[email protected]" {
-		t.Errorf("expected [email protected]")
-	}
-	if v.Get("count") != "1" {
-		t.Errorf("expected count of 1")
-	}
-
-	y := newEZStatValue("abc", "[email protected]", 3.14159)
-	v = y.values()
-	if v == nil {
-		t.Fatalf("expected url values")
-	}
-	if v.Get("stat") != "abc" {
-		t.Errorf("expected abc")
-	}
-	if v.Get("ezkey") != "[email protected]" {
-		t.Errorf("expected [email protected]")
-	}
-	if v.Get("value") != "3.14159" {
-		t.Errorf("expected value of 3.14159")
-	}
-
-	a := newClassicStatCount("statkey", "userkey", 1)
-	v = a.values()
-	if v == nil {
-		t.Fatalf("expected url values")
-	}
-	if v.Get("key") != "statkey" {
-		t.Errorf("expected statkey")
-	}
-	if v.Get("ukey") != "userkey" {
-		t.Errorf("expected userkey")
-	}
-	if v.Get("count") != "1" {
-		t.Errorf("expected count of 1")
-	}
-
-	b := newClassicStatValue("statkey", "userkey", 2.28)
-	v = b.values()
-	if v == nil {
-		t.Fatalf("expected url values")
-	}
-	if v.Get("key") != "statkey" {
-		t.Errorf("expected statkey")
-	}
-	if v.Get("ukey") != "userkey" {
-		t.Errorf("expected userkey")
-	}
-	if v.Get("value") != "2.28" {
-		t.Errorf("expected value of 2.28")
-	}
-}
-
-func TestPaths(t *testing.T) {
-	if ez.path(kcounter) != "/ez" {
-		t.Errorf("expected /ez")
-	}
-	if ez.path(kvalue) != "/ez" {
-		t.Errorf("expected /ez")
-	}
-	if classic.path(kcounter) != "/c" {
-		t.Errorf("expected /c")
-	}
-	if classic.path(kvalue) != "/v" {
-		t.Errorf("expected /v")
-	}
-
-	x := newEZStatCount("abc", "[email protected]", 1)
-	if x.path() != "/ez" {
-		t.Errorf("expected /ez")
-	}
-	y := newEZStatValue("abc", "[email protected]", 3.14159)
-	if y.path() != "/ez" {
-		t.Errorf("expected /ez")
-	}
-	a := newClassicStatCount("statkey", "userkey", 1)
-	if a.path() != "/c" {
-		t.Errorf("expected /c")
-	}
-	b := newClassicStatValue("statkey", "userkey", 2.28)
-	if b.path() != "/v" {
-		t.Errorf("expected /v")
-	}
-}
-
-func TestPosts(t *testing.T) {
-	setTesting()
-	Verbose = true
-	PostCountOne("statkey", "userkey")
-	p := <-testPostChannel
-	if p.url != "http://api.stathat.com/c" {
-		t.Errorf("expected classic count url")
-	}
-	if p.values.Get("key") != "statkey" {
-		t.Errorf("expected statkey")
-	}
-	if p.values.Get("ukey") != "userkey" {
-		t.Errorf("expected userkey")
-	}
-	if p.values.Get("count") != "1" {
-		t.Errorf("expected count of 1")
-	}
-
-	PostCount("statkey", "userkey", 13)
-	p = <-testPostChannel
-	if p.url != "http://api.stathat.com/c" {
-		t.Errorf("expected classic count url")
-	}
-	if p.values.Get("key") != "statkey" {
-		t.Errorf("expected statkey")
-	}
-	if p.values.Get("ukey") != "userkey" {
-		t.Errorf("expected userkey")
-	}
-	if p.values.Get("count") != "13" {
-		t.Errorf("expected count of 13")
-	}
-
-	PostValue("statkey", "userkey", 9.312)
-	p = <-testPostChannel
-	if p.url != "http://api.stathat.com/v" {
-		t.Errorf("expected classic value url")
-	}
-	if p.values.Get("key") != "statkey" {
-		t.Errorf("expected statkey")
-	}
-	if p.values.Get("ukey") != "userkey" {
-		t.Errorf("expected userkey")
-	}
-	if p.values.Get("value") != "9.312" {
-		t.Errorf("expected value of 9.312")
-	}
-
-	PostEZCountOne("a stat", "[email protected]")
-	p = <-testPostChannel
-	if p.url != "http://api.stathat.com/ez" {
-		t.Errorf("expected ez url")
-	}
-	if p.values.Get("stat") != "a stat" {
-		t.Errorf("expected a stat")
-	}
-	if p.values.Get("ezkey") != "[email protected]" {
-		t.Errorf("expected [email protected]")
-	}
-	if p.values.Get("count") != "1" {
-		t.Errorf("expected count of 1")
-	}
-
-	PostEZCount("a stat", "[email protected]", 213)
-	p = <-testPostChannel
-	if p.url != "http://api.stathat.com/ez" {
-		t.Errorf("expected ez url")
-	}
-	if p.values.Get("stat") != "a stat" {
-		t.Errorf("expected a stat")
-	}
-	if p.values.Get("ezkey") != "[email protected]" {
-		t.Errorf("expected [email protected]")
-	}
-	if p.values.Get("count") != "213" {
-		t.Errorf("expected count of 213")
-	}
-
-	PostEZValue("a stat", "[email protected]", 2.13)
-	p = <-testPostChannel
-	if p.url != "http://api.stathat.com/ez" {
-		t.Errorf("expected ez url")
-	}
-	if p.values.Get("stat") != "a stat" {
-		t.Errorf("expected a stat")
-	}
-	if p.values.Get("ezkey") != "[email protected]" {
-		t.Errorf("expected [email protected]")
-	}
-	if p.values.Get("value") != "2.13" {
-		t.Errorf("expected value of 2.13")
-	}
-
-	PostCountTime("statkey", "userkey", 13, 100000)
-	p = <-testPostChannel
-	if p.values.Get("t") != "100000" {
-		t.Errorf("expected t value of 100000, got %s", p.values.Get("t"))
-	}
-
-	PostValueTime("statkey", "userkey", 9.312, 200000)
-	p = <-testPostChannel
-	if p.values.Get("t") != "200000" {
-		t.Errorf("expected t value of 200000, got %s", p.values.Get("t"))
-	}
-
-	PostEZCountTime("a stat", "[email protected]", 213, 300000)
-	p = <-testPostChannel
-	if p.values.Get("t") != "300000" {
-		t.Errorf("expected t value of 300000, got %s", p.values.Get("t"))
-	}
-
-	PostEZValueTime("a stat", "[email protected]", 2.13, 400000)
-	p = <-testPostChannel
-	if p.values.Get("t") != "400000" {
-		t.Errorf("expected t value of 400000, got %s", p.values.Get("t"))
-	}
-}

+ 3 - 2
picker.go → zones/picker.go

@@ -1,14 +1,15 @@
-package main
+package zones
 
 
 import (
 import (
 	"math/rand"
 	"math/rand"
 
 
 	"github.com/abh/geodns/health"
 	"github.com/abh/geodns/health"
+	"github.com/abh/geodns/targeting"
 
 
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
 )
 )
 
 
-func (label *Label) Picker(qtype uint16, max int, location *Location) Records {
+func (label *Label) Picker(qtype uint16, max int, location *targeting.Location) Records {
 
 
 	if qtype == dns.TypeANY {
 	if qtype == dns.TypeANY {
 		var result []Record
 		var result []Record

+ 87 - 9
zone.go → zones/zone.go

@@ -1,11 +1,15 @@
-package main
+package zones
 
 
 import (
 import (
+	"encoding/json"
+	"log"
+	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 
 
 	"github.com/abh/geodns/applog"
 	"github.com/abh/geodns/applog"
 	"github.com/abh/geodns/health"
 	"github.com/abh/geodns/health"
+	"github.com/abh/geodns/targeting"
 
 
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
@@ -16,7 +20,7 @@ type ZoneOptions struct {
 	Ttl       int
 	Ttl       int
 	MaxHosts  int
 	MaxHosts  int
 	Contact   string
 	Contact   string
-	Targeting TargetOptions
+	Targeting targeting.TargetOptions
 	Closest   bool
 	Closest   bool
 }
 }
 
 
@@ -28,7 +32,7 @@ type ZoneLogging struct {
 type Record struct {
 type Record struct {
 	RR     dns.RR
 	RR     dns.RR
 	Weight int
 	Weight int
-	Loc    *Location
+	Loc    *targeting.Location
 	Test   *health.HealthTest
 	Test   *health.HealthTest
 }
 }
 
 
@@ -72,8 +76,6 @@ type Zone struct {
 	sync.RWMutex
 	sync.RWMutex
 }
 }
 
 
-type qTypes []uint16
-
 func NewZone(name string) *Zone {
 func NewZone(name string) *Zone {
 	zone := new(Zone)
 	zone := new(Zone)
 	zone.Labels = make(labels)
 	zone.Labels = make(labels)
@@ -84,7 +86,7 @@ func NewZone(name string) *Zone {
 	zone.Options.Ttl = 120
 	zone.Options.Ttl = 120
 	zone.Options.MaxHosts = 2
 	zone.Options.MaxHosts = 2
 	zone.Options.Contact = "hostmaster." + name
 	zone.Options.Contact = "hostmaster." + name
-	zone.Options.Targeting = TargetGlobal + TargetCountry + TargetContinent
+	zone.Options.Targeting = targeting.TargetGlobal + targeting.TargetCountry + targeting.TargetContinent
 
 
 	return zone
 	return zone
 }
 }
@@ -149,11 +151,62 @@ func (z *Zone) SoaRR() dns.RR {
 	return z.Labels[""].firstRR(dns.TypeSOA)
 	return z.Labels[""].firstRR(dns.TypeSOA)
 }
 }
 
 
+func (zone *Zone) AddSOA() {
+	zone.addSOA()
+}
+
+func (zone *Zone) addSOA() {
+	label := zone.Labels[""]
+
+	primaryNs := "ns"
+
+	// log.Println("LABEL", label)
+
+	if label == nil {
+		log.Println(zone.Origin, "doesn't have any 'root' records,",
+			"you should probably add some NS records")
+		label = zone.AddLabel("")
+	}
+
+	if record, ok := label.Records[dns.TypeNS]; ok {
+		primaryNs = record[0].RR.(*dns.NS).Ns
+	}
+
+	ttl := zone.Options.Ttl * 10
+	if ttl > 3600 {
+		ttl = 3600
+	}
+	if ttl == 0 {
+		ttl = 600
+	}
+
+	s := zone.Origin + ". " + strconv.Itoa(ttl) + " IN SOA " +
+		primaryNs + " " + zone.Options.Contact + " " +
+		strconv.Itoa(zone.Options.Serial) +
+		// refresh, retry, expire, minimum are all
+		// meaningless with this implementation
+		" 5400 5400 1209600 3600"
+
+	// log.Println("SOA: ", s)
+
+	rr, err := dns.NewRR(s)
+
+	if err != nil {
+		log.Println("SOA Error", err)
+		panic("Could not setup SOA")
+	}
+
+	record := Record{RR: rr}
+
+	label.Records[dns.TypeSOA] = make([]Record, 1)
+	label.Records[dns.TypeSOA][0] = record
+}
+
 // Find label "s" in country "cc" falling back to the appropriate
 // Find label "s" in country "cc" falling back to the appropriate
 // continent and the global label name as needed. Looks for the
 // continent and the global label name as needed. Looks for the
 // first available qType at each targeting level. Return a Label
 // first available qType at each targeting level. Return a Label
 // and the qtype that was "found"
 // and the qtype that was "found"
-func (z *Zone) findLabels(s string, targets []string, qts qTypes) (*Label, uint16) {
+func (z *Zone) FindLabels(s string, targets []string, qts []uint16) (*Label, uint16) {
 	for _, target := range targets {
 	for _, target := range targets {
 		var name string
 		var name string
 
 
@@ -181,7 +234,7 @@ func (z *Zone) findLabels(s string, targets []string, qts qTypes) (*Label, uint1
 					if label.Records[dns.TypeMF] != nil {
 					if label.Records[dns.TypeMF] != nil {
 						name = label.firstRR(dns.TypeMF).(*dns.MF).Mf
 						name = label.firstRR(dns.TypeMF).(*dns.MF).Mf
 						// TODO: need to avoid loops here somehow
 						// TODO: need to avoid loops here somehow
-						return z.findLabels(name, targets, qts)
+						return z.FindLabels(name, targets, qts)
 					}
 					}
 				default:
 				default:
 					// return the label if it has the right record
 					// return the label if it has the right record
@@ -209,7 +262,7 @@ func (z *Zone) SetLocations() {
 						rr := label.Records[qtype][i].RR
 						rr := label.Records[qtype][i].RR
 						if a, ok := rr.(*dns.A); ok {
 						if a, ok := rr.(*dns.A); ok {
 							ip := a.A
 							ip := a.A
-							_, _, _, _, _, location := geoIP.GetCountryRegion(ip)
+							_, _, _, _, _, location := targeting.GeoIP().GetCountryRegion(ip)
 							label.Records[qtype][i].Loc = location
 							label.Records[qtype][i].Loc = location
 						}
 						}
 					}
 					}
@@ -318,3 +371,28 @@ func (z *Zone) StartStopHealthChecks(start bool, oldZone *Zone) {
 	// 		}
 	// 		}
 	// 	}
 	// 	}
 }
 }
+
+func (z *Zone) HealthRR(label string, baseLabel string) []dns.RR {
+	h := dns.RR_Header{Ttl: 1, Class: dns.ClassINET, Rrtype: dns.TypeTXT}
+	h.Name = label
+
+	healthstatus := make(map[string]map[string]bool)
+
+	if l, ok := z.Labels[baseLabel]; ok {
+		for qt, records := range l.Records {
+			if qts, ok := dns.TypeToString[qt]; ok {
+				hmap := make(map[string]bool)
+				for _, record := range records {
+					if record.Test != nil {
+						hmap[(*record.Test).IP().String()] = health.TestRunner.IsHealthy(record.Test)
+					}
+				}
+				healthstatus[qts] = hmap
+			}
+		}
+	}
+
+	js, _ := json.Marshal(healthstatus)
+
+	return []dns.RR{&dns.TXT{Hdr: h, Txt: []string{string(js)}}}
+}

+ 1 - 1
zone_stats.go → zones/zone_stats.go

@@ -1,4 +1,4 @@
-package main
+package zones
 
 
 import (
 import (
 	"sort"
 	"sort"

+ 1 - 1
zone_stats_test.go → zones/zone_stats_test.go

@@ -1,4 +1,4 @@
-package main
+package zones
 
 
 import (
 import (
 	. "gopkg.in/check.v1"
 	. "gopkg.in/check.v1"

+ 1 - 1
zone_test.go → zones/zone_test.go

@@ -1,4 +1,4 @@
-package main
+package zones
 
 
 import (
 import (
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"

+ 28 - 214
zones.go → zones/zones.go

@@ -1,22 +1,17 @@
-package main
+package zones
 
 
 import (
 import (
-	"crypto/sha256"
-	"encoding/hex"
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
-	"io/ioutil"
 	"log"
 	"log"
 	"net"
 	"net"
 	"os"
 	"os"
-	"path"
 	"runtime/debug"
 	"runtime/debug"
 	"sort"
 	"sort"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
-	"time"
 
 
-	"github.com/abh/geodns/applog"
+	"github.com/abh/geodns/targeting"
 	"github.com/abh/geodns/typeutil"
 	"github.com/abh/geodns/typeutil"
 
 
 	"github.com/abh/errorutil"
 	"github.com/abh/errorutil"
@@ -26,125 +21,7 @@ import (
 // Zones maps domain names to zone data
 // Zones maps domain names to zone data
 type Zones map[string]*Zone
 type Zones map[string]*Zone
 
 
-type ZoneReadRecord struct {
-	time time.Time
-	hash string
-}
-
-var lastRead = map[string]*ZoneReadRecord{}
-
-func (srv *Server) zonesReadDir(dirName string, zones Zones) error {
-	dir, err := ioutil.ReadDir(dirName)
-	if err != nil {
-		log.Println("Could not read", dirName, ":", err)
-		return err
-	}
-
-	seenZones := map[string]bool{}
-
-	var parseErr error
-
-	for _, file := range dir {
-		fileName := file.Name()
-		if !strings.HasSuffix(strings.ToLower(fileName), ".json") ||
-			strings.HasPrefix(path.Base(fileName), ".") ||
-			file.IsDir() {
-			continue
-		}
-
-		zoneName := zoneNameFromFile(fileName)
-
-		seenZones[zoneName] = true
-
-		if _, ok := lastRead[zoneName]; !ok || file.ModTime().After(lastRead[zoneName].time) {
-			modTime := file.ModTime()
-			if ok {
-				applog.Printf("Reloading %s\n", fileName)
-				lastRead[zoneName].time = modTime
-			} else {
-				applog.Printf("Reading new file %s\n", fileName)
-				lastRead[zoneName] = &ZoneReadRecord{time: modTime}
-			}
-
-			filename := path.Join(dirName, fileName)
-
-			// Check the sha256 of the file has not changed. It's worth an explanation of
-			// why there isn't a TOCTOU race here. Conceivably after checking whether the
-			// SHA has changed, the contents then change again before we actually load
-			// the JSON. This can occur in two situations:
-			//
-			// 1. The SHA has not changed when we read the file for the SHA, but then
-			//    changes before we process the JSON
-			//
-			// 2. The SHA has changed when we read the file for the SHA, but then changes
-			//    again before we process the JSON
-			//
-			// In circumstance (1) we won't reread the file the first time, but the subsequent
-			// change should alter the mtime again, causing us to reread it. This reflects
-			// the fact there were actually two changes.
-			//
-			// In circumstance (2) we have already reread the file once, and then when the
-			// contents are changed the mtime changes again
-			//
-			// Provided files are replaced atomically, this should be OK. If files are not
-			// replaced atomically we have other problems (e.g. partial reads).
-
-			sha256 := sha256File(filename)
-			if lastRead[zoneName].hash == sha256 {
-				applog.Printf("Skipping new file %s as hash is unchanged\n", filename)
-				continue
-			}
-
-			config, err := readZoneFile(zoneName, filename)
-			if config == nil || err != nil {
-				parseErr = fmt.Errorf("Error reading zone '%s': %s", zoneName, err)
-				log.Println(parseErr.Error())
-				continue
-			}
-
-			(lastRead[zoneName]).hash = sha256
-
-			srv.addHandler(zones, zoneName, config)
-		}
-	}
-
-	for zoneName, zone := range zones {
-		if zoneName == "pgeodns" {
-			continue
-		}
-		if ok, _ := seenZones[zoneName]; ok {
-			continue
-		}
-		log.Println("Removing zone", zone.Origin)
-		delete(lastRead, zoneName)
-		zone.Close()
-		dns.HandleRemove(zoneName)
-		delete(zones, zoneName)
-	}
-
-	return parseErr
-}
-
-func (srv *Server) setupPgeodnsZone(zones Zones) {
-	zoneName := "pgeodns"
-	Zone := NewZone(zoneName)
-	label := new(Label)
-	label.Records = make(map[uint16]Records)
-	label.Weight = make(map[uint16]int)
-	Zone.Labels[""] = label
-	setupSOA(Zone)
-	srv.addHandler(zones, zoneName, Zone)
-}
-
-func (srv *Server) setupRootZone() {
-	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
-		m := new(dns.Msg)
-		m.SetRcode(r, dns.RcodeRefused)
-		w.WriteMsg(m)
-	})
-}
-
-func readZoneFile(zoneName, fileName string) (zone *Zone, zerr error) {
+func ReadZoneFile(zoneName, fileName string) (zone *Zone, zerr error) {
 	defer func() {
 	defer func() {
 		if r := recover(); r != nil {
 		if r := recover(); r != nil {
 			log.Printf("reading %s failed: %s", zoneName, r)
 			log.Printf("reading %s failed: %s", zoneName, r)
@@ -209,7 +86,7 @@ func readZoneFile(zoneName, fileName string) (zone *Zone, zerr error) {
 				zone.HasClosest = true
 				zone.HasClosest = true
 			}
 			}
 		case "targeting":
 		case "targeting":
-			zone.Options.Targeting, err = parseTargets(v.(string))
+			zone.Options.Targeting, err = targeting.ParseTargets(v.(string))
 			if err != nil {
 			if err != nil {
 				log.Printf("Could not parse targeting '%s': %s", v, err)
 				log.Printf("Could not parse targeting '%s': %s", v, err)
 				return nil, err
 				return nil, err
@@ -246,13 +123,13 @@ func readZoneFile(zoneName, fileName string) (zone *Zone, zerr error) {
 	//log.Println("IP", string(Zone.Regions["0.us"].IPv4[0].ip))
 	//log.Println("IP", string(Zone.Regions["0.us"].IPv4[0].ip))
 
 
 	switch {
 	switch {
-	case zone.Options.Targeting >= TargetRegionGroup || zone.HasClosest:
-		geoIP.setupGeoIPCity()
-	case zone.Options.Targeting >= TargetContinent:
-		geoIP.setupGeoIPCountry()
+	case zone.Options.Targeting >= targeting.TargetRegionGroup || zone.HasClosest:
+		targeting.GeoIP().SetupGeoIPCity()
+	case zone.Options.Targeting >= targeting.TargetContinent:
+		targeting.GeoIP().SetupGeoIPCountry()
 	}
 	}
-	if zone.Options.Targeting&TargetASN > 0 {
-		geoIP.setupGeoIPASN()
+	if zone.Options.Targeting&targeting.TargetASN > 0 {
+		targeting.GeoIP().SetupGeoIPASN()
 	}
 	}
 
 
 	if zone.HasClosest {
 	if zone.HasClosest {
@@ -262,7 +139,7 @@ func readZoneFile(zoneName, fileName string) (zone *Zone, zerr error) {
 	return zone, nil
 	return zone, nil
 }
 }
 
 
-func setupZoneData(data map[string]interface{}, Zone *Zone) {
+func setupZoneData(data map[string]interface{}, zone *Zone) {
 	recordTypes := map[string]uint16{
 	recordTypes := map[string]uint16{
 		"a":     dns.TypeA,
 		"a":     dns.TypeA,
 		"aaaa":  dns.TypeAAAA,
 		"aaaa":  dns.TypeAAAA,
@@ -281,7 +158,7 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 
 
 		//log.Printf("K %s V %s TYPE-V %T\n", dk, dv, dv)
 		//log.Printf("K %s V %s TYPE-V %T\n", dk, dv, dv)
 
 
-		label := Zone.AddLabel(dk)
+		label := zone.AddLabel(dk)
 
 
 		for rType, rdata := range dv {
 		for rType, rdata := range dv {
 			switch rType {
 			switch rType {
@@ -291,14 +168,14 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 			case "closest":
 			case "closest":
 				label.Closest = rdata.(bool)
 				label.Closest = rdata.(bool)
 				if label.Closest {
 				if label.Closest {
-					Zone.HasClosest = true
+					zone.HasClosest = true
 				}
 				}
 				continue
 				continue
 			case "ttl":
 			case "ttl":
 				label.Ttl = typeutil.ToInt(rdata)
 				label.Ttl = typeutil.ToInt(rdata)
 				continue
 				continue
 			case "test":
 			case "test":
-				Zone.newHealthTest(label, rdata)
+				zone.newHealthTest(label, rdata)
 				continue
 				continue
 			}
 			}
 
 
@@ -355,9 +232,9 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 
 
 				switch len(label.Label) {
 				switch len(label.Label) {
 				case 0:
 				case 0:
-					h.Name = Zone.Origin + "."
+					h.Name = zone.Origin + "."
 				default:
 				default:
-					h.Name = label.Label + "." + Zone.Origin + "."
+					h.Name = label.Label + "." + zone.Origin + "."
 				}
 				}
 
 
 				switch dnsType {
 				switch dnsType {
@@ -411,7 +288,7 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 					target := rec["target"].(string)
 					target := rec["target"].(string)
 
 
 					if !dns.IsFqdn(target) {
 					if !dns.IsFqdn(target) {
-						target = target + "." + Zone.Origin
+						target = target + "." + zone.Origin
 					}
 					}
 
 
 					if rec["srv_weight"] != nil {
 					if rec["srv_weight"] != nil {
@@ -441,7 +318,7 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 						target, weight = getStringWeight(rec.([]interface{}))
 						target, weight = getStringWeight(rec.([]interface{}))
 					}
 					}
 					if !dns.IsFqdn(target) {
 					if !dns.IsFqdn(target) {
-						target = target + "." + Zone.Origin
+						target = target + "." + zone.Origin
 					}
 					}
 					record.Weight = weight
 					record.Weight = weight
 					record.RR = &dns.CNAME{Hdr: h, Target: dns.Fqdn(target)}
 					record.RR = &dns.CNAME{Hdr: h, Target: dns.Fqdn(target)}
@@ -497,7 +374,7 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 						rr := &dns.TXT{Hdr: h, Txt: []string{txt}}
 						rr := &dns.TXT{Hdr: h, Txt: []string{txt}}
 						record.RR = rr
 						record.RR = rr
 					} else {
 					} else {
-						log.Printf("Zero length txt record for '%s' in '%s'\n", label.Label, Zone.Origin)
+						log.Printf("Zero length txt record for '%s' in '%s'\n", label.Label, zone.Origin)
 						continue
 						continue
 					}
 					}
 					// Initial SPF support added here, cribbed from the TypeTXT case definition - SPF records should be handled identically
 					// Initial SPF support added here, cribbed from the TypeTXT case definition - SPF records should be handled identically
@@ -525,7 +402,7 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 						rr := &dns.SPF{Hdr: h, Txt: []string{spf}}
 						rr := &dns.SPF{Hdr: h, Txt: []string{spf}}
 						record.RR = rr
 						record.RR = rr
 					} else {
 					} else {
-						log.Printf("Zero length SPF record for '%s' in '%s'\n", label.Label, Zone.Origin)
+						log.Printf("Zero length SPF record for '%s' in '%s'\n", label.Label, zone.Origin)
 						continue
 						continue
 					}
 					}
 
 
@@ -549,26 +426,26 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 
 
 	// loop over exisiting labels, create zone records for missing sub-domains
 	// loop over exisiting labels, create zone records for missing sub-domains
 	// and set TTLs
 	// and set TTLs
-	for k := range Zone.Labels {
+	for k := range zone.Labels {
 		if strings.Contains(k, ".") {
 		if strings.Contains(k, ".") {
 			subLabels := strings.Split(k, ".")
 			subLabels := strings.Split(k, ".")
 			for i := 1; i < len(subLabels); i++ {
 			for i := 1; i < len(subLabels); i++ {
 				subSubLabel := strings.Join(subLabels[i:], ".")
 				subSubLabel := strings.Join(subLabels[i:], ".")
-				if _, ok := Zone.Labels[subSubLabel]; !ok {
-					Zone.AddLabel(subSubLabel)
+				if _, ok := zone.Labels[subSubLabel]; !ok {
+					zone.AddLabel(subSubLabel)
 				}
 				}
 			}
 			}
 		}
 		}
-		for _, records := range Zone.Labels[k].Records {
+		for _, records := range zone.Labels[k].Records {
 			for _, r := range records {
 			for _, r := range records {
 				var defaultTtl uint32 = 86400
 				var defaultTtl uint32 = 86400
 				if r.RR.Header().Rrtype != dns.TypeNS {
 				if r.RR.Header().Rrtype != dns.TypeNS {
 					// NS records have special treatment. If they are not specified, they default to 86400 rather than
 					// NS records have special treatment. If they are not specified, they default to 86400 rather than
 					// defaulting to the zone ttl option. The label TTL option always works though
 					// defaulting to the zone ttl option. The label TTL option always works though
-					defaultTtl = uint32(Zone.Options.Ttl)
+					defaultTtl = uint32(zone.Options.Ttl)
 				}
 				}
-				if Zone.Labels[k].Ttl > 0 {
-					defaultTtl = uint32(Zone.Labels[k].Ttl)
+				if zone.Labels[k].Ttl > 0 {
+					defaultTtl = uint32(zone.Labels[k].Ttl)
 				}
 				}
 				if r.RR.Header().Ttl == 0 {
 				if r.RR.Header().Ttl == 0 {
 					r.RR.Header().Ttl = defaultTtl
 					r.RR.Header().Ttl = defaultTtl
@@ -577,9 +454,8 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 		}
 		}
 	}
 	}
 
 
-	setupSOA(Zone)
+	zone.addSOA()
 
 
-	//log.Println(Zones[k])
 }
 }
 
 
 func getStringWeight(rec []interface{}) (string, int) {
 func getStringWeight(rec []interface{}) (string, int) {
@@ -601,65 +477,3 @@ func getStringWeight(rec []interface{}) (string, int) {
 
 
 	return str, weight
 	return str, weight
 }
 }
-
-func setupSOA(Zone *Zone) {
-	label := Zone.Labels[""]
-
-	primaryNs := "ns"
-
-	// log.Println("LABEL", label)
-
-	if label == nil {
-		log.Println(Zone.Origin, "doesn't have any 'root' records,",
-			"you should probably add some NS records")
-		label = Zone.AddLabel("")
-	}
-
-	if record, ok := label.Records[dns.TypeNS]; ok {
-		primaryNs = record[0].RR.(*dns.NS).Ns
-	}
-
-	ttl := Zone.Options.Ttl * 10
-	if ttl > 3600 {
-		ttl = 3600
-	}
-	if ttl == 0 {
-		ttl = 600
-	}
-
-	s := Zone.Origin + ". " + strconv.Itoa(ttl) + " IN SOA " +
-		primaryNs + " " + Zone.Options.Contact + " " +
-		strconv.Itoa(Zone.Options.Serial) +
-		// refresh, retry, expire, minimum are all
-		// meaningless with this implementation
-		" 5400 5400 1209600 3600"
-
-	// log.Println("SOA: ", s)
-
-	rr, err := dns.NewRR(s)
-
-	if err != nil {
-		log.Println("SOA Error", err)
-		panic("Could not setup SOA")
-	}
-
-	record := Record{RR: rr}
-
-	label.Records[dns.TypeSOA] = make([]Record, 1)
-	label.Records[dns.TypeSOA][0] = record
-
-}
-
-func zoneNameFromFile(fileName string) string {
-	return fileName[0:strings.LastIndex(fileName, ".")]
-}
-
-func sha256File(fn string) string {
-	if data, err := ioutil.ReadFile(fn); err != nil {
-		return ""
-	} else {
-		hasher := sha256.New()
-		hasher.Write(data)
-		return hex.EncodeToString(hasher.Sum(nil))
-	}
-}

+ 1 - 1
zones_test.go → zones/zones_test.go

@@ -1,4 +1,4 @@
-package main
+package zones
 
 
 import (
 import (
 	"io"
 	"io"