Ver Fonte

v3 refactoring wip

Ask Bjørn Hansen há 8 anos atrás
pai
commit
e40da8e7c6
22 ficheiros alterados com 1372 adições e 1157 exclusões
  1. 2 0
      .gitignore
  2. 1 1
      .travis.yml
  3. 22 2
      Makefile
  4. 2 1
      config.go
  5. 37 33
      geodns.go
  6. 272 0
      http.go
  7. 0 436
      monitor.go
  8. 99 0
      monitor/hub.go
  9. 123 0
      monitor/monitor.go
  10. 44 0
      monitor/monitor_test.go
  11. 0 57
      monitor_test.go
  12. 0 301
      serve_test.go
  13. 3 1
      server/metrics.go
  14. 7 7
      server/serve.go
  15. 325 0
      server/serve_test.go
  16. 74 0
      server/server.go
  17. 0 130
      server_test.go
  18. 2 0
      targeting/targeting_test.go
  19. 6 4
      util.go
  20. 82 85
      zones/muxmanager.go
  21. 152 0
      zones/reader_test.go
  22. 119 99
      zones/zone_test.go

+ 2 - 0
.gitignore

@@ -5,3 +5,5 @@
 /run
 .idea
 /dns/geodns.conf
+geodns-*-*
+geodns-*-*.tar

+ 1 - 1
.travis.yml

@@ -22,5 +22,5 @@ install:
   - go install
 
 script:
-  - cd $TRAVIS_BUILD_DIR && go test -gocheck.v
+  - cd $TRAVIS_BUILD_DIR && make test
   - go test -gocheck.v -gocheck.b -gocheck.btime=2s

+ 22 - 2
Makefile

@@ -4,14 +4,34 @@ all: templates.go
 templates.go: templates/*.html monitor.go
 	go generate
 
+.PHONY: test
 test:
 	go test $(go list ./... | grep -v /vendor/)
 
-testrace:
-	go test -race $(go list ./... | grep -v /vendor/)
+testrace: .PHONY
+	go test -race $(shell go list ./... | grep -v /vendor/)
 
 devel:
 	go build -tags devel
 
 bench:
 	go test -check.b -check.bmem
+
+TARS=$(wildcard geodns-*-*.tar)
+
+push: $(TARS) tmp-install.sh
+	rsync -avz tmp-install.sh $(TARS)  x3.dev:webtmp/2016/07/
+
+builds: linux-build linux-build-i386 freebsd-build push
+
+linux-build:
+	docker run --rm -v `pwd`:/go/src/github.com/abh/geodns geodns-build ./build
+
+linux-build-i386:
+	docker run --rm -v `pwd`:/go/src/github.com/abh/geodns geodns-build-i386 ./build
+
+freebsd-build:
+	ssh 192.168.64.5 'cd go/src/github.com/abh/geodns; GOPATH=~/go ./build'
+	ssh [email protected] 'jexec -U ask fbsd32 /home/ask/build'
+
+.PHONY:

+ 2 - 1
config.go

@@ -34,7 +34,8 @@ type AppConfig struct {
 		Token string
 	}
 	Pingdom struct {
-		Username     string
+		Username string
+
 		Password     string
 		AccountEmail string
 		AppKey       string

+ 37 - 33
geodns.go

@@ -30,7 +30,9 @@ import (
 	"time"
 
 	"github.com/abh/geodns/applog"
+	"github.com/abh/geodns/monitor"
 	"github.com/abh/geodns/querylog"
+	"github.com/abh/geodns/server"
 	"github.com/abh/geodns/zones"
 	"github.com/pborman/uuid"
 )
@@ -45,14 +47,9 @@ var gitVersion string
 var development bool
 
 var (
-	serverID     string
-	serverIP     string
-	serverGroups []string
-	serverUUID   = uuid.New()
+	serverInfo *monitor.ServerInfo
 )
 
-var timeStarted = time.Now()
-
 var (
 	flagconfig       = flag.String("config", "./dns/", "directory of zone files")
 	flagconfigfile   = flag.String("configfile", "geodns.conf", "filename of config file (in 'config' directory)")
@@ -79,6 +76,11 @@ func init() {
 
 	log.SetPrefix("geodns ")
 	log.SetFlags(log.Lmicroseconds | log.Lshortfile)
+
+	serverInfo = &monitor.ServerInfo{}
+	serverInfo.UUID = uuid.New()
+	serverInfo.Started = time.Now()
+
 }
 
 func main() {
@@ -98,8 +100,6 @@ func main() {
 		os.Exit(0)
 	}
 
-	srv := Server{}
-
 	if *flaglog {
 		applog.Enabled = true
 	}
@@ -110,9 +110,9 @@ func main() {
 
 	if len(*flagidentifier) > 0 {
 		ids := strings.Split(*flagidentifier, ",")
-		serverID = ids[0]
+		serverInfo.ID = ids[0]
 		if len(ids) > 1 {
-			serverGroups = ids[1:]
+			serverInfo.Groups = ids[1:]
 		}
 	}
 
@@ -125,17 +125,16 @@ func main() {
 	}
 
 	if *flagcheckconfig {
-		dirName := *flagconfig
-
 		err := configReader(configFileName)
 		if err != nil {
 			log.Println("Errors reading config", err)
 			os.Exit(2)
 		}
 
-		Zones := make(zones.Zones)
-		srv.setupPgeodnsZone(Zones)
-		err = srv.zonesReadDir(dirName, Zones)
+		// dirName := *flagconfig
+		// Zones := make(zones.Zones)
+		// srv.setupPgeodnsZone(Zones)
+		// err = srv.zonesReadDir(dirName, Zones)
 		if err != nil {
 			log.Println("Errors reading zones", err)
 			os.Exit(2)
@@ -174,17 +173,6 @@ func main() {
 	// load (and re-load) zone data
 	go configWatcher(configFileName)
 
-	metrics := NewMetrics()
-	go metrics.Updater()
-
-	if qlc := Config.QueryLog; len(qlc.Path) > 0 {
-		ql, err := querylog.NewFileLogger(qlc.Path, qlc.MaxSize, qlc.Keep)
-		if err != nil {
-			log.Fatalf("Could not start file query logger: %s", err)
-		}
-		srv.SetQueryLogger(ql)
-	}
-
 	if *flaginter == "*" {
 		addrs, _ := net.InterfaceAddrs()
 		ips := make([]string, 0)
@@ -207,17 +195,33 @@ func main() {
 		log.Println("StatHat integration has been removed in favor of more generic metrics")
 	}
 
-	// the global-ish zones 'context' is quite a mess
-	zonelist := make(zones.Zones)
-	go monitor(zonelist)
-	srv.setupRootZone()
-	srv.setupPgeodnsZone(zonelist)
-	go srv.zonesReader(*flagconfig, zonelist)
+	mon := monitor.NewMonitor(serverInfo)
+	go mon.Run()
+
+	srv := server.NewServer(serverInfo)
+
+	if qlc := Config.QueryLog; len(qlc.Path) > 0 {
+		ql, err := querylog.NewFileLogger(qlc.Path, qlc.MaxSize, qlc.Keep)
+		if err != nil {
+			log.Fatalf("Could not start file query logger: %s", err)
+		}
+		srv.SetQueryLogger(ql)
+	}
+
+	muxm, err := zones.NewMuxManager(*flagconfig, srv)
+	if err != nil {
+		log.Printf("error loading zones: %s", err)
+	}
+	go muxm.Run()
 
 	for _, host := range inter {
-		go srv.listenAndServe(host)
+		go srv.ListenAndServe(host)
 	}
 
+	go func() {
+		// setup metrics httpd stuff
+	}()
+
 	terminate := make(chan os.Signal)
 	signal.Notify(terminate, os.Interrupt)
 

+ 272 - 0
http.go

@@ -0,0 +1,272 @@
+package main
+
+import (
+	"encoding/json"
+	"fmt"
+	"html/template"
+	"io"
+	"log"
+	"net/http"
+	"runtime"
+	"sort"
+	"strconv"
+	"time"
+
+	"github.com/abh/geodns/monitor"
+	"github.com/abh/geodns/zones"
+	metrics "github.com/rcrowley/go-metrics"
+)
+
+type httpServer struct {
+	mux        *http.ServeMux
+	zones      zones.Zones
+	serverInfo monitor.ServerInfo
+}
+
+type rate struct {
+	Name    string
+	Count   int64
+	Metrics zones.ZoneMetrics
+}
+type Rates []*rate
+
+func (s Rates) Len() int      { return len(s) }
+func (s Rates) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+type RatesByCount struct{ Rates }
+
+func (s RatesByCount) Less(i, j int) bool {
+	ic := s.Rates[i].Count
+	jc := s.Rates[j].Count
+	if ic == jc {
+		return s.Rates[i].Name < s.Rates[j].Name
+	}
+	return ic > jc
+}
+
+type histogramData struct {
+	Max    int64
+	Min    int64
+	Mean   float64
+	Pct90  float64
+	Pct99  float64
+	Pct999 float64
+	StdDev float64
+}
+
+func setupHistogramData(met metrics.Histogram, dat *histogramData) {
+	dat.Max = met.Max()
+	dat.Min = met.Min()
+	dat.Mean = met.Mean()
+	dat.StdDev = met.StdDev()
+	percentiles := met.Percentiles([]float64{0.90, 0.99, 0.999})
+	dat.Pct90 = percentiles[0]
+	dat.Pct99 = percentiles[1]
+	dat.Pct999 = percentiles[2]
+}
+
+func topParam(req *http.Request, def int) int {
+	req.ParseForm()
+
+	topOption := def
+	topParam := req.Form["top"]
+
+	if len(topParam) > 0 {
+		var err error
+		topOption, err = strconv.Atoi(topParam[0])
+		if err != nil {
+			topOption = def
+		}
+	}
+
+	return topOption
+}
+
+func NewHTTPServer(zones zones.Zones) *httpServer {
+	hs := &httpServer{
+		// todo: zones.MuxManager instead of zones?
+		zones: zones,
+		mux:   &http.ServeMux{},
+	}
+	hs.mux.HandleFunc("/status", hs.StatusHandler())
+	hs.mux.HandleFunc("/status.json", hs.StatusJSONHandler())
+	hs.mux.HandleFunc("/", hs.mainServer)
+
+	return hs
+}
+
+func (hs *httpServer) Mux() *http.ServeMux {
+	return hs.mux
+}
+
+func (hs *httpServer) Run(listen string) {
+	log.Println("Starting HTTP interface on", listen)
+	log.Fatal(http.ListenAndServe(listen, &basicauth{h: hs.mux}))
+}
+
+func (hs *httpServer) StatusJSONHandler() func(http.ResponseWriter, *http.Request) {
+
+	info := serverInfo
+
+	return func(w http.ResponseWriter, req *http.Request) {
+
+		zonemetrics := make(map[string]metrics.Registry)
+
+		for name, zone := range hs.zones {
+			zone.Lock()
+			zonemetrics[name] = zone.Metrics.Registry
+			zone.Unlock()
+		}
+
+		type statusData struct {
+			Version   string
+			GoVersion string
+			Uptime    int64
+			Platform  string
+			Zones     map[string]metrics.Registry
+			Global    metrics.Registry
+			ID        string
+			IP        string
+			UUID      string
+			Groups    []string
+		}
+
+		uptime := int64(time.Since(info.Started).Seconds())
+
+		status := statusData{
+			Version:   info.Version,
+			GoVersion: runtime.Version(),
+			Uptime:    uptime,
+			Platform:  runtime.GOARCH + "-" + runtime.GOOS,
+			Zones:     zonemetrics,
+			Global:    metrics.DefaultRegistry,
+			ID:        hs.serverInfo.ID,
+			IP:        hs.serverInfo.IP,
+			UUID:      hs.serverInfo.UUID,
+			Groups:    hs.serverInfo.Groups,
+		}
+
+		b, err := json.Marshal(status)
+		if err != nil {
+			http.Error(w, "Error encoding JSON", 500)
+			return
+		}
+		w.Header().Set("Content-Type", "application/json")
+		w.Write(b)
+		return
+	}
+}
+
+func (hs *httpServer) StatusHandler() func(http.ResponseWriter, *http.Request) {
+
+	return func(w http.ResponseWriter, req *http.Request) {
+
+		topOption := topParam(req, 10)
+
+		rates := make(Rates, 0)
+
+		for name, zone := range hs.zones {
+			count := zone.Metrics.Queries.Count()
+			rates = append(rates, &rate{
+				Name:    name,
+				Count:   count,
+				Metrics: zone.Metrics,
+			})
+		}
+
+		sort.Sort(RatesByCount{rates})
+
+		type statusData struct {
+			Version  string
+			Zones    Rates
+			Uptime   DayDuration
+			Platform string
+			Global   struct {
+				Queries         metrics.Meter
+				Histogram       histogramData
+				HistogramRecent histogramData
+			}
+			TopOption int
+		}
+
+		uptime := DayDuration{time.Since(hs.serverInfo.Started)}
+
+		status := statusData{
+			Version:   VERSION,
+			Zones:     rates,
+			Uptime:    uptime,
+			Platform:  runtime.GOARCH + "-" + runtime.GOOS,
+			TopOption: topOption,
+		}
+
+		status.Global.Queries = metrics.Get("queries").(*metrics.StandardMeter).Snapshot()
+
+		setupHistogramData(metrics.Get("queries-histogram").(*metrics.StandardHistogram).Snapshot(), &status.Global.Histogram)
+
+		statusTemplate, err := FSString(development, "/templates/status.html")
+		if err != nil {
+			log.Println("Could not read template:", err)
+			w.WriteHeader(500)
+			return
+		}
+		tmpl, err := template.New("status_html").Parse(statusTemplate)
+
+		if err != nil {
+			str := fmt.Sprintf("Could not parse template: %s", err)
+			io.WriteString(w, str)
+			return
+		}
+
+		err = tmpl.Execute(w, status)
+		if err != nil {
+			log.Println("Status template error", err)
+		}
+	}
+}
+
+func (hs *httpServer) mainServer(w http.ResponseWriter, req *http.Request) {
+	if req.RequestURI != "/version" {
+		http.NotFound(w, req)
+		return
+	}
+	io.WriteString(w, `<html><head><title>GeoDNS `+
+		hs.serverInfo.Version+`</title><body>`+
+		`GeoDNS Server`+
+		`</body></html>`)
+}
+
+type basicauth struct {
+	h http.Handler
+}
+
+func (b *basicauth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+
+	// don't request passwords for the websocket interface (for now)
+	// because 'wscat' doesn't support that.
+	if r.RequestURI == "/monitor" {
+		b.h.ServeHTTP(w, r)
+		return
+	}
+
+	cfgMutex.RLock()
+	user := Config.HTTP.User
+	password := Config.HTTP.Password
+	cfgMutex.RUnlock()
+
+	if len(user) == 0 {
+		b.h.ServeHTTP(w, r)
+		return
+	}
+
+	ruser, rpass, ok := r.BasicAuth()
+	if ok {
+		if ruser == user && rpass == password {
+			b.h.ServeHTTP(w, r)
+			return
+		}
+	}
+
+	w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm=%q`, "GeoDNS Status"))
+	http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
+	return
+}

+ 0 - 436
monitor.go

@@ -1,436 +0,0 @@
-package main
-
-//go:generate esc -o templates.go templates/
-
-import (
-	"encoding/json"
-	"fmt"
-	"html/template"
-	"io"
-	"log"
-	"net/http"
-	"os"
-	"runtime"
-	"sort"
-	"strconv"
-	"time"
-
-	"github.com/abh/geodns/zones"
-
-	"github.com/rcrowley/go-metrics"
-	"golang.org/x/net/websocket"
-)
-
-// Initial status message on websocket
-type statusStreamMsgStart struct {
-	Hostname  string   `json:"h,omitemty"`
-	Version   string   `json:"v"`
-	GoVersion string   `json:"gov"`
-	ID        string   `json:"id"`
-	IP        string   `json:"ip"`
-	UUID      string   `json:"uuid"`
-	Uptime    int      `json:"up"`
-	Started   int      `json:"started"`
-	Groups    []string `json:"groups"`
-}
-
-// Update message on websocket
-type statusStreamMsgUpdate struct {
-	Uptime     int     `json:"up"`
-	QueryCount int64   `json:"qs"`
-	Qps        int64   `json:"qps"`
-	Qps1m      float64 `json:"qps1m,omitempty"`
-}
-
-type wsConnection struct {
-	// The websocket connection.
-	ws *websocket.Conn
-
-	// Buffered channel of outbound messages.
-	send chan string
-}
-
-type monitorHub struct {
-	connections map[*wsConnection]bool
-	broadcast   chan string
-	register    chan *wsConnection
-	unregister  chan *wsConnection
-}
-
-var hub = monitorHub{
-	broadcast:   make(chan string),
-	register:    make(chan *wsConnection, 10),
-	unregister:  make(chan *wsConnection, 10),
-	connections: make(map[*wsConnection]bool),
-}
-
-func (h *monitorHub) run() {
-	for {
-		select {
-		case c := <-h.register:
-			h.connections[c] = true
-			log.Println("Queuing initial status")
-			c.send <- initialStatus()
-		case c := <-h.unregister:
-			log.Println("Unregistering connection")
-			delete(h.connections, c)
-		case m := <-h.broadcast:
-			for c := range h.connections {
-				if len(c.send)+5 > cap(c.send) {
-					log.Println("WS connection too close to cap")
-					c.send <- `{"error": "too slow"}`
-					close(c.send)
-					go c.ws.Close()
-					h.unregister <- c
-					continue
-				}
-				select {
-				case c.send <- m:
-				default:
-					close(c.send)
-					delete(h.connections, c)
-					log.Println("Closing channel when sending")
-					go c.ws.Close()
-				}
-			}
-		}
-	}
-}
-
-func (c *wsConnection) reader() {
-	for {
-		var message string
-		err := websocket.Message.Receive(c.ws, &message)
-		if err != nil {
-			if err == io.EOF {
-				log.Println("WS connection closed")
-			} else {
-				log.Println("WS read error:", err)
-			}
-			break
-		}
-		log.Println("WS message", message)
-		// TODO(ask) take configuration options etc
-		//h.broadcast <- message
-	}
-	c.ws.Close()
-}
-
-func (c *wsConnection) writer() {
-	for message := range c.send {
-		err := websocket.Message.Send(c.ws, message)
-		if err != nil {
-			log.Println("WS write error:", err)
-			break
-		}
-	}
-	c.ws.Close()
-}
-
-func wsHandler(ws *websocket.Conn) {
-	log.Println("Starting new WS connection")
-	c := &wsConnection{send: make(chan string, 180), ws: ws}
-	hub.register <- c
-	defer func() {
-		log.Println("sending unregister message")
-		hub.unregister <- c
-	}()
-	go c.writer()
-	c.reader()
-}
-
-func initialStatus() string {
-	status := new(statusStreamMsgStart)
-	status.Version = VERSION
-	status.ID = serverID
-	status.IP = serverIP
-	status.UUID = serverUUID
-	status.GoVersion = runtime.Version()
-	if len(serverGroups) > 0 {
-		status.Groups = serverGroups
-	}
-	hostname, err := os.Hostname()
-	if err == nil {
-		status.Hostname = hostname
-	}
-
-	status.Uptime = int(time.Since(timeStarted).Seconds())
-	status.Started = int(timeStarted.Unix())
-
-	message, err := json.Marshal(status)
-	return string(message)
-}
-
-func monitor(zones zones.Zones) {
-
-	if len(*flaghttp) == 0 {
-		return
-	}
-	go hub.run()
-	go httpHandler(zones)
-
-	qCounter := metrics.Get("queries").(metrics.Meter)
-	lastQueryCount := qCounter.Count()
-
-	status := new(statusStreamMsgUpdate)
-	var lastQps1m float64
-
-	for {
-		current := qCounter.Count()
-		newQueries := current - lastQueryCount
-		lastQueryCount = current
-
-		status.Uptime = int(time.Since(timeStarted).Seconds())
-		status.QueryCount = qCounter.Count()
-		status.Qps = newQueries
-
-		newQps1m := qCounter.Rate1()
-		if newQps1m != lastQps1m {
-			status.Qps1m = newQps1m
-			lastQps1m = newQps1m
-		} else {
-			status.Qps1m = 0
-		}
-
-		message, err := json.Marshal(status)
-
-		if err == nil {
-			hub.broadcast <- string(message)
-		}
-		time.Sleep(1 * time.Second)
-	}
-}
-
-func MainServer(w http.ResponseWriter, req *http.Request) {
-	if req.RequestURI != "/version" {
-		http.NotFound(w, req)
-		return
-	}
-	io.WriteString(w, `<html><head><title>GeoDNS `+
-		VERSION+`</title><body>`+
-		initialStatus()+
-		`</body></html>`)
-}
-
-type rate struct {
-	Name    string
-	Count   int64
-	Metrics zones.ZoneMetrics
-}
-type Rates []*rate
-
-func (s Rates) Len() int      { return len(s) }
-func (s Rates) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
-
-type RatesByCount struct{ Rates }
-
-func (s RatesByCount) Less(i, j int) bool {
-	ic := s.Rates[i].Count
-	jc := s.Rates[j].Count
-	if ic == jc {
-		return s.Rates[i].Name < s.Rates[j].Name
-	}
-	return ic > jc
-}
-
-type histogramData struct {
-	Max    int64
-	Min    int64
-	Mean   float64
-	Pct90  float64
-	Pct99  float64
-	Pct999 float64
-	StdDev float64
-}
-
-func setupHistogramData(met metrics.Histogram, dat *histogramData) {
-	dat.Max = met.Max()
-	dat.Min = met.Min()
-	dat.Mean = met.Mean()
-	dat.StdDev = met.StdDev()
-	percentiles := met.Percentiles([]float64{0.90, 0.99, 0.999})
-	dat.Pct90 = percentiles[0]
-	dat.Pct99 = percentiles[1]
-	dat.Pct999 = percentiles[2]
-}
-
-func topParam(req *http.Request, def int) int {
-	req.ParseForm()
-
-	topOption := def
-	topParam := req.Form["top"]
-
-	if len(topParam) > 0 {
-		var err error
-		topOption, err = strconv.Atoi(topParam[0])
-		if err != nil {
-			topOption = def
-		}
-	}
-
-	return topOption
-}
-
-func StatusJSONHandler(zones zones.Zones) func(http.ResponseWriter, *http.Request) {
-	return func(w http.ResponseWriter, req *http.Request) {
-
-		zonemetrics := make(map[string]metrics.Registry)
-
-		for name, zone := range zones {
-			zone.Lock()
-			zonemetrics[name] = zone.Metrics.Registry
-			zone.Unlock()
-		}
-
-		type statusData struct {
-			Version   string
-			GoVersion string
-			Uptime    int64
-			Platform  string
-			Zones     map[string]metrics.Registry
-			Global    metrics.Registry
-			ID        string
-			IP        string
-			UUID      string
-			Groups    []string
-		}
-
-		uptime := int64(time.Since(timeStarted).Seconds())
-
-		status := statusData{
-			Version:   VERSION,
-			GoVersion: runtime.Version(),
-			Uptime:    uptime,
-			Platform:  runtime.GOARCH + "-" + runtime.GOOS,
-			Zones:     zonemetrics,
-			Global:    metrics.DefaultRegistry,
-			ID:        serverID,
-			IP:        serverIP,
-			UUID:      serverUUID,
-			Groups:    serverGroups,
-		}
-
-		b, err := json.Marshal(status)
-		if err != nil {
-			http.Error(w, "Error encoding JSON", 500)
-			return
-		}
-		w.Header().Set("Content-Type", "application/json")
-		w.Write(b)
-		return
-	}
-}
-
-func StatusHandler(zones zones.Zones) func(http.ResponseWriter, *http.Request) {
-
-	return func(w http.ResponseWriter, req *http.Request) {
-
-		topOption := topParam(req, 10)
-
-		rates := make(Rates, 0)
-
-		for name, zone := range zones {
-			count := zone.Metrics.Queries.Count()
-			rates = append(rates, &rate{
-				Name:    name,
-				Count:   count,
-				Metrics: zone.Metrics,
-			})
-		}
-
-		sort.Sort(RatesByCount{rates})
-
-		type statusData struct {
-			Version  string
-			Zones    Rates
-			Uptime   DayDuration
-			Platform string
-			Global   struct {
-				Queries         metrics.Meter
-				Histogram       histogramData
-				HistogramRecent histogramData
-			}
-			TopOption int
-		}
-
-		uptime := DayDuration{time.Since(timeStarted)}
-
-		status := statusData{
-			Version:   VERSION,
-			Zones:     rates,
-			Uptime:    uptime,
-			Platform:  runtime.GOARCH + "-" + runtime.GOOS,
-			TopOption: topOption,
-		}
-
-		status.Global.Queries = metrics.Get("queries").(*metrics.StandardMeter).Snapshot()
-
-		setupHistogramData(metrics.Get("queries-histogram").(*metrics.StandardHistogram).Snapshot(), &status.Global.Histogram)
-
-		statusTemplate, err := FSString(development, "/templates/status.html")
-		if err != nil {
-			log.Println("Could not read template:", err)
-			w.WriteHeader(500)
-			return
-		}
-		tmpl, err := template.New("status_html").Parse(statusTemplate)
-
-		if err != nil {
-			str := fmt.Sprintf("Could not parse template: %s", err)
-			io.WriteString(w, str)
-			return
-		}
-
-		err = tmpl.Execute(w, status)
-		if err != nil {
-			log.Println("Status template error", err)
-		}
-	}
-}
-
-type basicauth struct {
-	h http.Handler
-}
-
-func (b *basicauth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-
-	// don't request passwords for the websocket interface (for now)
-	// because 'wscat' doesn't support that.
-	if r.RequestURI == "/monitor" {
-		b.h.ServeHTTP(w, r)
-		return
-	}
-
-	cfgMutex.RLock()
-	user := Config.HTTP.User
-	password := Config.HTTP.Password
-	cfgMutex.RUnlock()
-
-	if len(user) == 0 {
-		b.h.ServeHTTP(w, r)
-		return
-	}
-
-	ruser, rpass, ok := r.BasicAuth()
-	if ok {
-		if ruser == user && rpass == password {
-			b.h.ServeHTTP(w, r)
-			return
-		}
-	}
-
-	w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm=%q`, "GeoDNS Status"))
-	http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
-	return
-}
-
-func httpHandler(zones zones.Zones) {
-	http.Handle("/monitor", websocket.Handler(wsHandler))
-	http.HandleFunc("/status", StatusHandler(zones))
-	http.HandleFunc("/status.json", StatusJSONHandler(zones))
-	http.HandleFunc("/", MainServer)
-
-	log.Println("Starting HTTP interface on", *flaghttp)
-
-	log.Fatal(http.ListenAndServe(*flaghttp, &basicauth{h: http.DefaultServeMux}))
-}

+ 99 - 0
monitor/hub.go

@@ -0,0 +1,99 @@
+package monitor
+
+import (
+	"io"
+	"log"
+
+	"golang.org/x/net/websocket"
+)
+
+type monitorHub struct {
+	connections map[*wsConnection]bool
+	broadcast   chan string
+	register    chan *wsConnection
+	unregister  chan *wsConnection
+}
+
+var hub = monitorHub{
+	broadcast:   make(chan string),
+	register:    make(chan *wsConnection, 10),
+	unregister:  make(chan *wsConnection, 10),
+	connections: make(map[*wsConnection]bool),
+}
+
+type initialStatusFn func() string
+
+func (h *monitorHub) run(statusFn initialStatusFn) {
+	for {
+		select {
+		case c := <-h.register:
+			h.connections[c] = true
+			log.Println("Queuing initial status")
+			c.send <- statusFn()
+		case c := <-h.unregister:
+			log.Println("Unregistering connection")
+			delete(h.connections, c)
+		case m := <-h.broadcast:
+			for c := range h.connections {
+				if len(c.send)+5 > cap(c.send) {
+					log.Println("WS connection too close to cap")
+					c.send <- `{"error": "too slow"}`
+					close(c.send)
+					go c.ws.Close()
+					h.unregister <- c
+					continue
+				}
+				select {
+				case c.send <- m:
+				default:
+					close(c.send)
+					delete(h.connections, c)
+					log.Println("Closing channel when sending")
+					go c.ws.Close()
+				}
+			}
+		}
+	}
+}
+
+func (c *wsConnection) reader() {
+	for {
+		var message string
+		err := websocket.Message.Receive(c.ws, &message)
+		if err != nil {
+			if err == io.EOF {
+				log.Println("WS connection closed")
+			} else {
+				log.Println("WS read error:", err)
+			}
+			break
+		}
+		log.Println("WS message", message)
+		// TODO(ask) take configuration options etc
+		//h.broadcast <- message
+	}
+	c.ws.Close()
+}
+
+func (c *wsConnection) writer() {
+	for message := range c.send {
+		err := websocket.Message.Send(c.ws, message)
+		if err != nil {
+			log.Println("WS write error:", err)
+			break
+		}
+	}
+	c.ws.Close()
+}
+
+func wsHandler(ws *websocket.Conn) {
+	log.Println("Starting new WS connection")
+	c := &wsConnection{send: make(chan string, 180), ws: ws}
+	hub.register <- c
+	defer func() {
+		log.Println("sending unregister message")
+		hub.unregister <- c
+	}()
+	go c.writer()
+	c.reader()
+}

+ 123 - 0
monitor/monitor.go

@@ -0,0 +1,123 @@
+package monitor
+
+//go:generate esc -o templates.go templates/
+
+import (
+	"encoding/json"
+	"os"
+	"runtime"
+	"time"
+
+	"github.com/rcrowley/go-metrics"
+	"golang.org/x/net/websocket"
+)
+
+type ServerInfo struct {
+	Version string
+	ID      string
+	IP      string
+	UUID    string
+	Groups  []string
+	Started time.Time
+}
+
+// Initial status message on websocket
+type statusStreamMsgStart struct {
+	Hostname  string   `json:"h,omitemty"`
+	Version   string   `json:"v"`
+	GoVersion string   `json:"gov"`
+	ID        string   `json:"id"`
+	IP        string   `json:"ip"`
+	UUID      string   `json:"uuid"`
+	Uptime    int      `json:"up"`
+	Started   int      `json:"started"`
+	Groups    []string `json:"groups"`
+}
+
+// Update message on websocket
+type statusStreamMsgUpdate struct {
+	Uptime     int     `json:"up"`
+	QueryCount int64   `json:"qs"`
+	Qps        int64   `json:"qps"`
+	Qps1m      float64 `json:"qps1m,omitempty"`
+}
+
+type wsConnection struct {
+	// The websocket connection.
+	ws *websocket.Conn
+
+	// Buffered channel of outbound messages.
+	send chan string
+}
+
+type monitor struct {
+	serverInfo *ServerInfo
+}
+
+func NewMonitor(serverInfo *ServerInfo) *monitor {
+	return &monitor{serverInfo: serverInfo}
+}
+
+func (m *monitor) initialStatus() string {
+	status := new(statusStreamMsgStart)
+	status.Version = m.serverInfo.Version
+	status.ID = m.serverInfo.ID
+	status.IP = m.serverInfo.IP
+	status.UUID = m.serverInfo.UUID
+
+	status.GoVersion = runtime.Version()
+	if len(m.serverInfo.Groups) > 0 {
+		status.Groups = m.serverInfo.Groups
+	}
+	hostname, err := os.Hostname()
+	if err == nil {
+		status.Hostname = hostname
+	}
+
+	started := m.serverInfo.Started
+
+	status.Started = int(started.Unix())
+	status.Uptime = int(time.Since(started).Seconds())
+
+	message, err := json.Marshal(status)
+	return string(message)
+}
+
+func (m *monitor) Run() {
+	go hub.run(m.initialStatus)
+
+	qCounter := metrics.Get("queries").(metrics.Meter)
+	lastQueryCount := qCounter.Count()
+
+	status := new(statusStreamMsgUpdate)
+	var lastQps1m float64
+
+	for {
+		current := qCounter.Count()
+		newQueries := current - lastQueryCount
+		lastQueryCount = current
+
+		status.Uptime = int(time.Since(m.serverInfo.Started).Seconds())
+		status.QueryCount = qCounter.Count()
+		status.Qps = newQueries
+
+		newQps1m := qCounter.Rate1()
+		if newQps1m != lastQps1m {
+			status.Qps1m = newQps1m
+			lastQps1m = newQps1m
+		} else {
+			status.Qps1m = 0
+		}
+
+		message, err := json.Marshal(status)
+
+		if err == nil {
+			hub.broadcast <- string(message)
+		}
+		time.Sleep(1 * time.Second)
+	}
+}
+
+func (m *monitor) Handler() websocket.Handler {
+	return websocket.Handler(wsHandler)
+}

+ 44 - 0
monitor/monitor_test.go

@@ -0,0 +1,44 @@
+package monitor
+
+import "testing"
+
+func TestMonitor(t *testing.T) {
+
+	// mux := dns.NewServeMux()
+	// mm := zones.NewMuxManager("dns", mux)
+
+	// // s.zones = make(zones.Zones)
+	// metrics := NewMetrics()
+	// go metrics.Updater()
+
+	// *flaghttp = ":8881"
+
+	// fmt.Println("Starting http server")
+
+	// // TODO: use httptest
+	// // https://groups.google.com/forum/?fromgroups=#!topic/golang-nuts/Jk785WB7F8I
+
+	// srv := Server{}
+	// srv.
+
+	// todo: this isn't right, it should probably just take the mux?
+	// go httpHandler(mm.Zones())
+	// time.Sleep(500 * time.Millisecond)
+
+	// c.Check(true, DeepEquals, true)
+
+	// res, err := http.Get("http://localhost:8881/version")
+	// c.Assert(err, IsNil)
+	// page, _ := ioutil.ReadAll(res.Body)
+	// c.Check(string(page), Matches, ".*<title>GeoDNS [0-9].*")
+
+	// res, err = http.Get("http://localhost:8881/status")
+	// c.Assert(err, IsNil)
+	// page, _ = ioutil.ReadAll(res.Body)
+	// // just check that template basically works
+
+	// isOk := strings.Contains(string(page), "<html>")
+	// // page has <html>
+	// c.Check(isOk, Equals, true)
+
+}

+ 0 - 57
monitor_test.go

@@ -1,57 +0,0 @@
-package main
-
-import (
-	"fmt"
-	"io/ioutil"
-	"net/http"
-	"strings"
-	"time"
-
-	"github.com/abh/geodns/zones"
-
-	. "gopkg.in/check.v1"
-)
-
-type MonitorSuite struct {
-	zones   zones.Zones
-	metrics *ServerMetrics
-}
-
-var _ = Suite(&MonitorSuite{})
-
-func (s *MonitorSuite) SetUpSuite(c *C) {
-	s.zones = make(zones.Zones)
-	s.metrics = NewMetrics()
-	go s.metrics.Updater()
-
-	*flaghttp = ":8881"
-
-	fmt.Println("Starting http server")
-
-	// TODO: use httptest
-	// https://groups.google.com/forum/?fromgroups=#!topic/golang-nuts/Jk785WB7F8I
-
-	srv := Server{}
-	srv.zonesReadDir("dns", s.zones)
-	go httpHandler(s.zones)
-	time.Sleep(500 * time.Millisecond)
-}
-
-func (s *MonitorSuite) TestMonitorVersion(c *C) {
-	c.Check(true, DeepEquals, true)
-
-	res, err := http.Get("http://localhost:8881/version")
-	c.Assert(err, IsNil)
-	page, _ := ioutil.ReadAll(res.Body)
-	c.Check(string(page), Matches, ".*<title>GeoDNS [0-9].*")
-
-	res, err = http.Get("http://localhost:8881/status")
-	c.Assert(err, IsNil)
-	page, _ = ioutil.ReadAll(res.Body)
-	// just check that template basically works
-
-	isOk := strings.Contains(string(page), "<html>")
-	// page has <html>
-	c.Check(isOk, Equals, true)
-
-}

+ 0 - 301
serve_test.go

@@ -1,301 +0,0 @@
-package main
-
-import (
-	"math/rand"
-	"net"
-	"strings"
-	"sync"
-	"time"
-
-	"github.com/abh/geodns/zones"
-	"github.com/miekg/dns"
-	. "gopkg.in/check.v1"
-)
-
-const (
-	PORT = ":8853"
-)
-
-type ServeSuite struct {
-}
-
-var _ = Suite(&ServeSuite{})
-
-func (s *ServeSuite) SetUpSuite(c *C) {
-
-	// setup and register metrics
-	metrics := NewMetrics()
-	go metrics.Updater()
-
-	srv := Server{}
-
-	zonelist := make(zones.Zones)
-	srv.setupPgeodnsZone(zonelist)
-	srv.setupRootZone()
-	srv.zonesReadDir("dns", zonelist)
-
-	// listenAndServe returns after listening on udp + tcp, so just
-	// wait for it before continuing
-	srv.listenAndServe(PORT)
-
-	// ensure service has properly started before we query it
-	time.Sleep(200 * time.Millisecond)
-}
-
-func (s *ServeSuite) TestServing(c *C) {
-
-	r := exchange(c, "_status.pgeodns.", dns.TypeTXT)
-	txt := r.Answer[0].(*dns.TXT).Txt[0]
-	if !strings.HasPrefix(txt, "{") {
-		c.Log("Unexpected result for _status.pgeodns", txt)
-		c.Fail()
-	}
-
-	// Allow _country and _status queries as long as the first label is that
-	r = exchange(c, "_country.foo.pgeodns.", dns.TypeTXT)
-	txt = r.Answer[0].(*dns.TXT).Txt[0]
-	// Got appropriate response for _country txt query
-	if !strings.HasPrefix(txt, "127.0.0.1:") {
-		c.Log("Unexpected result for _country.foo.pgeodns", txt)
-		c.Fail()
-	}
-
-	// Make sure A requests for _status doesn't NXDOMAIN
-	r = exchange(c, "_status.pgeodns.", dns.TypeA)
-	c.Check(r.Answer, HasLen, 0)
-	// Got one SOA record
-	c.Check(r.Ns, HasLen, 1)
-	// NOERROR for A request
-	c.Check(r.Rcode, Equals, dns.RcodeSuccess)
-
-	r = exchange(c, "bar.test.example.com.", dns.TypeA)
-	ip := r.Answer[0].(*dns.A).A
-	c.Check(ip.String(), Equals, "192.168.1.2")
-	c.Check(int(r.Answer[0].Header().Ttl), Equals, 601)
-
-	r = exchange(c, "test.example.com.", dns.TypeSOA)
-	soa := r.Answer[0].(*dns.SOA)
-	serial := soa.Serial
-	c.Check(int(serial), Equals, 3)
-
-	// no AAAA records for 'bar', so check we get a soa record back
-	r = exchange(c, "test.example.com.", dns.TypeAAAA)
-	soa2 := r.Ns[0].(*dns.SOA)
-	c.Check(soa, DeepEquals, soa2)
-
-	// CNAMEs
-	r = exchange(c, "www.test.example.com.", dns.TypeA)
-	c.Check(r.Answer[0].(*dns.CNAME).Target, Equals, "geo.bitnames.com.")
-	c.Check(int(r.Answer[0].Header().Ttl), Equals, 1800)
-
-	//SPF
-	r = exchange(c, "test.example.com.", dns.TypeSPF)
-	c.Check(r.Answer[0].(*dns.SPF).Txt[0], Equals, "v=spf1 ~all")
-
-	//SRV
-	r = exchange(c, "_sip._tcp.test.example.com.", dns.TypeSRV)
-	c.Check(r.Answer[0].(*dns.SRV).Target, Equals, "sipserver.example.com.")
-	c.Check(r.Answer[0].(*dns.SRV).Port, Equals, uint16(5060))
-	c.Check(r.Answer[0].(*dns.SRV).Priority, Equals, uint16(10))
-	c.Check(r.Answer[0].(*dns.SRV).Weight, Equals, uint16(100))
-
-	// MX
-	r = exchange(c, "test.example.com.", dns.TypeMX)
-	c.Check(r.Answer[0].(*dns.MX).Mx, Equals, "mx.example.net.")
-	c.Check(r.Answer[1].(*dns.MX).Mx, Equals, "mx2.example.net.")
-	c.Check(r.Answer[1].(*dns.MX).Preference, Equals, uint16(20))
-
-	// Verify the first A record was created
-	r = exchange(c, "a.b.c.test.example.com.", dns.TypeA)
-	ip = r.Answer[0].(*dns.A).A
-	c.Check(ip.String(), Equals, "192.168.1.7")
-
-	// Verify sub-labels are created
-	r = exchange(c, "b.c.test.example.com.", dns.TypeA)
-	c.Check(r.Answer, HasLen, 0)
-	c.Check(r.Rcode, Equals, dns.RcodeSuccess)
-
-	r = exchange(c, "c.test.example.com.", dns.TypeA)
-	c.Check(r.Answer, HasLen, 0)
-	c.Check(r.Rcode, Equals, dns.RcodeSuccess)
-
-	// Verify the first A record was created
-	r = exchange(c, "three.two.one.test.example.com.", dns.TypeA)
-	ip = r.Answer[0].(*dns.A).A
-	c.Check(ip.String(), Equals, "192.168.1.5")
-
-	// Verify single sub-labels is created and no record is returned
-	r = exchange(c, "two.one.test.example.com.", dns.TypeA)
-	c.Check(r.Answer, HasLen, 0)
-	c.Check(r.Rcode, Equals, dns.RcodeSuccess)
-
-	// Verify the A record wasn't over written
-	r = exchange(c, "one.test.example.com.", dns.TypeA)
-	ip = r.Answer[0].(*dns.A).A
-	c.Check(ip.String(), Equals, "192.168.1.6")
-
-	// PTR
-	r = exchange(c, "2.1.168.192.IN-ADDR.ARPA.", dns.TypePTR)
-	c.Check(r.Answer, HasLen, 1)
-	// NOERROR for PTR request
-	c.Check(r.Rcode, Equals, dns.RcodeSuccess)
-	name := r.Answer[0].(*dns.PTR).Ptr
-	c.Check(name, Equals, "bar.example.com.")
-}
-
-func (s *ServeSuite) TestServingMixedCase(c *C) {
-
-	r := exchange(c, "_sTaTUs.pGEOdns.", dns.TypeTXT)
-	c.Assert(r.Rcode, Equals, dns.RcodeSuccess)
-	txt := r.Answer[0].(*dns.TXT).Txt[0]
-	if !strings.HasPrefix(txt, "{") {
-		c.Log("Unexpected result for _status.pgeodns", txt)
-		c.Fail()
-	}
-
-	n := "baR.test.eXAmPLe.cOM."
-	r = exchange(c, n, dns.TypeA)
-	ip := r.Answer[0].(*dns.A).A
-	c.Check(ip.String(), Equals, "192.168.1.2")
-	c.Check(r.Answer[0].Header().Name, Equals, n)
-
-}
-
-func (s *ServeSuite) TestCname(c *C) {
-	// Cname, two possible results
-
-	results := make(map[string]int)
-
-	for i := 0; i < 10; i++ {
-		r := exchange(c, "www.se.test.example.com.", dns.TypeA)
-		// only return one CNAME even if there are multiple options
-		c.Check(r.Answer, HasLen, 1)
-		target := r.Answer[0].(*dns.CNAME).Target
-		results[target]++
-	}
-
-	// Two possible results from this cname
-	c.Check(results, HasLen, 2)
-}
-
-func (s *ServeSuite) TestUnknownDomain(c *C) {
-	r := exchange(c, "no.such.domain.", dns.TypeAAAA)
-	c.Assert(r.Rcode, Equals, dns.RcodeRefused)
-}
-
-func (s *ServeSuite) TestServingAliases(c *C) {
-	// Alias, no geo matches
-	r := exchange(c, "bar-alias.test.example.com.", dns.TypeA)
-	ip := r.Answer[0].(*dns.A).A
-	c.Check(ip.String(), Equals, "192.168.1.2")
-
-	// Alias to a cname record
-	r = exchange(c, "www-alias.test.example.com.", dns.TypeA)
-	c.Check(r.Answer[0].(*dns.CNAME).Target, Equals, "geo.bitnames.com.")
-
-	// Alias returning a cname, with geo overrides
-	r = exchangeSubnet(c, "www-alias.test.example.com.", dns.TypeA, "194.239.134.1")
-	c.Check(r.Answer, HasLen, 1)
-	if len(r.Answer) > 0 {
-		c.Check(r.Answer[0].(*dns.CNAME).Target, Equals, "geo-europe.bitnames.com.")
-	}
-
-	// Alias to Ns records
-	r = exchange(c, "sub-alias.test.example.org.", dns.TypeNS)
-	c.Check(r.Answer[0].(*dns.NS).Ns, Equals, "ns1.example.com.")
-
-}
-
-func (s *ServeSuite) TestServingEDNS(c *C) {
-	// MX test
-	r := exchangeSubnet(c, "test.example.com.", dns.TypeMX, "194.239.134.1")
-	c.Check(r.Answer, HasLen, 1)
-	if len(r.Answer) > 0 {
-		c.Check(r.Answer[0].(*dns.MX).Mx, Equals, "mx-eu.example.net.")
-	}
-
-	c.Log("Testing www.test.example.com from .dk, should match www.europe (a cname)")
-
-	r = exchangeSubnet(c, "www.test.example.com.", dns.TypeA, "194.239.134.0")
-	// www.test from .dk IP address gets at least one answer
-	c.Check(r.Answer, HasLen, 1)
-	if len(r.Answer) > 0 {
-		// EDNS-SUBNET test (request A, respond CNAME)
-		c.Check(r.Answer[0].(*dns.CNAME).Target, Equals, "geo-europe.bitnames.com.")
-	}
-
-}
-
-func (s *ServeSuite) TestServeRace(c *C) {
-	wg := sync.WaitGroup{}
-	for i := 0; i < 5; i++ {
-		wg.Add(1)
-		go func() {
-			s.TestServing(c)
-			wg.Done()
-		}()
-	}
-	wg.Wait()
-}
-
-func (s *ServeSuite) BenchmarkServingCountryDebug(c *C) {
-	for i := 0; i < c.N; i++ {
-		exchange(c, "_country.foo.pgeodns.", dns.TypeTXT)
-	}
-}
-
-func (s *ServeSuite) BenchmarkServing(c *C) {
-
-	// a deterministic seed is the default anyway, but let's be explicit we want it here.
-	rnd := rand.NewSource(1)
-
-	testNames := []string{"foo.test.example.com.", "one.test.example.com.",
-		"weight.test.example.com.", "three.two.one.test.example.com.",
-		"bar.test.example.com.", "0-alias.test.example.com.",
-	}
-
-	for i := 0; i < c.N; i++ {
-		name := testNames[rnd.Int63()%int64(len(testNames))]
-		exchange(c, name, dns.TypeA)
-	}
-}
-
-func exchangeSubnet(c *C, name string, dnstype uint16, ip string) *dns.Msg {
-	msg := new(dns.Msg)
-
-	msg.SetQuestion(name, dnstype)
-
-	o := new(dns.OPT)
-	o.Hdr.Name = "."
-	o.Hdr.Rrtype = dns.TypeOPT
-	e := new(dns.EDNS0_SUBNET)
-	e.Code = dns.EDNS0SUBNET
-	e.SourceScope = 0
-	e.Address = net.ParseIP(ip)
-	e.Family = 1 // IP4
-	e.SourceNetmask = net.IPv4len * 8
-	o.Option = append(o.Option, e)
-	msg.Extra = append(msg.Extra, o)
-
-	c.Log("msg", msg)
-
-	return dorequest(c, msg)
-}
-
-func exchange(c *C, name string, dnstype uint16) *dns.Msg {
-	msg := new(dns.Msg)
-
-	msg.SetQuestion(name, dnstype)
-	return dorequest(c, msg)
-}
-
-func dorequest(c *C, msg *dns.Msg) *dns.Msg {
-	cli := new(dns.Client)
-	// cli.ReadTimeout = 2 * time.Second
-	r, _, err := cli.Exchange(msg, "127.0.0.1"+PORT)
-	if err != nil {
-		c.Logf("request err '%s': %s", msg.String(), err)
-		c.Fail()
-	}
-	return r
-}

+ 3 - 1
metrics.go → server/metrics.go

@@ -1,4 +1,4 @@
-package main
+package server
 
 import (
 	"runtime"
@@ -7,6 +7,8 @@ import (
 	metrics "github.com/rcrowley/go-metrics"
 )
 
+// todo: make this not have global variables ...
+
 type ServerMetrics struct {
 	qCounter         metrics.Meter
 	lastQueryCount   int64

+ 7 - 7
serve.go → server/serve.go

@@ -1,4 +1,4 @@
-package main
+package server
 
 import (
 	"encoding/json"
@@ -148,7 +148,7 @@ func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *zones.Zone) {
 
 	if labels == nil {
 
-		permitDebug := !*flagPrivateDebug || (realIP != nil && realIP.IsLoopback())
+		permitDebug := srv.PublicDebugQueries || (realIP != nil && realIP.IsLoopback())
 
 		firstLabel := (strings.Split(label, "."))[0]
 
@@ -158,7 +158,7 @@ func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *zones.Zone) {
 
 		if permitDebug && firstLabel == "_status" {
 			if qtype == dns.TypeANY || qtype == dns.TypeTXT {
-				m.Answer = statusRR(label + "." + z.Origin + ".")
+				m.Answer = srv.statusRR(label + "." + z.Origin + ".")
 			} else {
 				m.Ns = append(m.Ns, z.SoaRR())
 			}
@@ -193,7 +193,7 @@ func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *zones.Zone) {
 
 				targets, netmask, location := z.Options.Targeting.GetTargets(ip, z.HasClosest)
 				txt = append(txt, strings.Join(targets, " "))
-				txt = append(txt, fmt.Sprintf("/%d", netmask), serverID, serverIP)
+				txt = append(txt, fmt.Sprintf("/%d", netmask), srv.info.ID, srv.info.IP)
 				if location != nil {
 					txt = append(txt, fmt.Sprintf("(%.3f,%.3f)", location.Latitude, location.Longitude))
 				} else {
@@ -258,11 +258,11 @@ func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *zones.Zone) {
 	return
 }
 
-func statusRR(label string) []dns.RR {
+func (srv *Server) statusRR(label string) []dns.RR {
 	h := dns.RR_Header{Ttl: 1, Class: dns.ClassINET, Rrtype: dns.TypeTXT}
 	h.Name = label
 
-	status := map[string]string{"v": VERSION, "id": serverID}
+	status := map[string]string{"v": srv.info.Version, "id": srv.info.ID}
 
 	hostname, err := os.Hostname()
 	if err == nil {
@@ -270,7 +270,7 @@ func statusRR(label string) []dns.RR {
 	}
 
 	qCounter := metrics.Get("queries").(metrics.Meter)
-	status["up"] = strconv.Itoa(int(time.Since(timeStarted).Seconds()))
+	status["up"] = strconv.Itoa(int(time.Since(srv.info.Started).Seconds()))
 	status["qs"] = strconv.FormatInt(qCounter.Count(), 10)
 	status["qps1"] = fmt.Sprintf("%.4f", qCounter.Rate1())
 

+ 325 - 0
server/serve_test.go

@@ -0,0 +1,325 @@
+package server
+
+import (
+	"net"
+	"reflect"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+
+	"github.com/abh/geodns/monitor"
+	"github.com/abh/geodns/zones"
+	"github.com/miekg/dns"
+)
+
+const (
+	PORT = ":8853"
+)
+
+func TestServe(t *testing.T) {
+	// setup and register global metrics for serve()
+	metrics := NewMetrics()
+	go metrics.Updater()
+
+	serverInfo := &monitor.ServerInfo{}
+
+	srv := NewServer(serverInfo)
+
+	mm, err := zones.NewMuxManager("../dns", srv)
+	if err != nil {
+		t.Fatalf("Loading test zones: %s", err)
+	}
+	go mm.Run()
+
+	// listenAndServe returns after listening on udp + tcp, so just
+	// wait for it before continuing
+	srv.ListenAndServe(PORT)
+
+	// ensure service has properly started before we query it
+	time.Sleep(500 * time.Millisecond)
+
+	t.Run("Serving", testServing)
+
+}
+
+func testServing(t *testing.T) {
+	r := exchange(t, "_status.pgeodns.", dns.TypeTXT)
+	require.Len(t, r.Answer, 1, "1 txt record for _status.pgeodns")
+	txt := r.Answer[0].(*dns.TXT).Txt[0]
+	if !strings.HasPrefix(txt, "{") {
+		t.Log("Unexpected result for _status.pgeodns", txt)
+		t.Fail()
+	}
+
+	// Allow _country and _status queries as long as the first label is that
+	r = exchange(t, "_country.foo.pgeodns.", dns.TypeTXT)
+	txt = r.Answer[0].(*dns.TXT).Txt[0]
+	// Got appropriate response for _country txt query
+	if !strings.HasPrefix(txt, "127.0.0.1:") {
+		t.Log("Unexpected result for _country.foo.pgeodns", txt)
+		t.Fail()
+	}
+
+	// Make sure A requests for _status doesn't NXDOMAIN
+	r = exchange(t, "_status.pgeodns.", dns.TypeA)
+	if len(r.Answer) != 0 {
+		t.Log("got A record for _status.pgeodns")
+		t.Fail()
+	}
+	if len(r.Ns) != 1 {
+		t.Logf("Expected 1 SOA record, got %d", len(r.Ns))
+		t.Fail()
+	}
+	// NOERROR for A request
+	checkRcode(t, r.Rcode, dns.RcodeSuccess, "_status.pgeodns")
+
+	r = exchange(t, "bar.test.example.com.", dns.TypeA)
+	ip := r.Answer[0].(*dns.A).A
+
+	// c.Check(ip.String(), Equals, "192.168.1.2")
+	// c.Check(int(r.Answer[0].Header().Ttl), Equals, 601)
+
+	r = exchange(t, "test.example.com.", dns.TypeSOA)
+	soa := r.Answer[0].(*dns.SOA)
+	serial := soa.Serial
+	assert.Equal(t, 3, int(serial))
+
+	// no AAAA records for 'bar', so check we get a soa record back
+	r = exchange(t, "test.example.com.", dns.TypeAAAA)
+	soa2 := r.Ns[0].(*dns.SOA)
+	if !reflect.DeepEqual(soa, soa2) {
+		t.Logf("AAAA empty NOERROR soa record different from SOA request")
+		t.Fail()
+	}
+
+	// CNAMEs
+	r = exchange(t, "www.test.example.com.", dns.TypeA)
+	// c.Check(r.Answer[0].(*dns.CNAME).Target, Equals, "geo.bitnames.com.")
+	if int(r.Answer[0].Header().Ttl) != 1800 {
+		t.Logf("unexpected ttl '%d' for geo.bitnames.com (expected %d)", int(r.Answer[0].Header().Ttl), 1800)
+		t.Fail()
+	}
+
+	//SPF
+	r = exchange(t, "test.example.com.", dns.TypeSPF)
+	assert.Equal(t, r.Answer[0].(*dns.SPF).Txt[0], "v=spf1 ~all")
+
+	//SRV
+	r = exchange(t, "_sip._tcp.test.example.com.", dns.TypeSRV)
+	assert.Equal(t, r.Answer[0].(*dns.SRV).Target, "sipserver.example.com.")
+	assert.Equal(t, r.Answer[0].(*dns.SRV).Port, uint16(5060))
+	assert.Equal(t, r.Answer[0].(*dns.SRV).Priority, uint16(10))
+	assert.Equal(t, r.Answer[0].(*dns.SRV).Weight, uint16(100))
+
+	// MX
+	r = exchange(t, "test.example.com.", dns.TypeMX)
+	assert.Equal(t, r.Answer[0].(*dns.MX).Mx, "mx.example.net.")
+	assert.Equal(t, r.Answer[1].(*dns.MX).Mx, "mx2.example.net.")
+	assert.Equal(t, r.Answer[1].(*dns.MX).Preference, uint16(20))
+
+	// Verify the first A record was created
+	r = exchange(t, "a.b.c.test.example.com.", dns.TypeA)
+	ip = r.Answer[0].(*dns.A).A
+	assert.Equal(t, ip.String(), "192.168.1.7")
+
+	// Verify sub-labels are created
+	r = exchange(t, "b.c.test.example.com.", dns.TypeA)
+	assert.Len(t, r.Answer, 0, "expect 0 answer records for b.c.test.example.com")
+	checkRcode(t, r.Rcode, dns.RcodeSuccess, "b.c.test.example.com")
+
+	r = exchange(t, "c.test.example.com.", dns.TypeA)
+	assert.Len(t, r.Answer, 0, "expect 0 answer records for c.test.example.com")
+	checkRcode(t, r.Rcode, dns.RcodeSuccess, "c.test.example.com")
+
+	// Verify the first A record was created
+	r = exchange(t, "three.two.one.test.example.com.", dns.TypeA)
+	ip = r.Answer[0].(*dns.A).A
+
+	assert.Equal(t, ip.String(), "192.168.1.5", "three.two.one.test.example.com A record")
+
+	// Verify single sub-labels is created and no record is returned
+	r = exchange(t, "two.one.test.example.com.", dns.TypeA)
+	assert.Len(t, r.Answer, 0, "expect 0 answer records for two.one.test.example.com")
+	checkRcode(t, r.Rcode, dns.RcodeSuccess, "two.one.test.example.com")
+
+	// Verify the A record wasn't over written
+	r = exchange(t, "one.test.example.com.", dns.TypeA)
+	ip = r.Answer[0].(*dns.A).A
+	assert.Equal(t, ip.String(), "192.168.1.6", "one.test.example.com A record")
+
+	// PTR
+	r = exchange(t, "2.1.168.192.IN-ADDR.ARPA.", dns.TypePTR)
+	assert.Len(t, r.Answer, 1, "expect 1 answer records for 2.1.168.192.IN-ADDR.ARPA")
+	checkRcode(t, r.Rcode, dns.RcodeSuccess, "2.1.168.192.IN-ADDR.ARPA")
+
+	name := r.Answer[0].(*dns.PTR).Ptr
+	assert.Equal(t, name, "bar.example.com.", "PTR record")
+}
+
+// func TestServingMixedCase(t *testing.T) {
+
+// 	r := exchange(c, "_sTaTUs.pGEOdns.", dns.TypeTXT)
+// 	checkRcode(t, r.Rcode, dns.RcodeSuccess, "_sTaTUs.pGEOdns.")
+
+// 	txt := r.Answer[0].(*dns.TXT).Txt[0]
+// 	if !strings.HasPrefix(txt, "{") {
+// 		t.Log("Unexpected result for _status.pgeodns", txt)
+// 		t.Fail()
+// 	}
+
+// 	n := "baR.test.eXAmPLe.cOM."
+// 	r = exchange(c, n, dns.TypeA)
+// 	ip := r.Answer[0].(*dns.A).A
+// 	c.Check(ip.String(), Equals, "192.168.1.2")
+// 	c.Check(r.Answer[0].Header().Name, Equals, n)
+
+// }
+
+// func TestCname(t *testing.T) {
+// 	// Cname, two possible results
+
+// 	results := make(map[string]int)
+
+// 	for i := 0; i < 10; i++ {
+// 		r := exchange(c, "www.se.test.example.com.", dns.TypeA)
+// 		// only return one CNAME even if there are multiple options
+// 		c.Check(r.Answer, HasLen, 1)
+// 		target := r.Answer[0].(*dns.CNAME).Target
+// 		results[target]++
+// 	}
+
+// 	// Two possible results from this cname
+// 	c.Check(results, HasLen, 2)
+// }
+
+// func testUnknownDomain(t *testing.T) {
+// 	r := exchange(t, "no.such.domain.", dns.TypeAAAA)
+// 	c.Assert(r.Rcode, Equals, dns.RcodeRefused)
+// }
+
+// func testServingAliases(t *testing.T) {
+// 	// Alias, no geo matches
+// 	r := exchange(c, "bar-alias.test.example.com.", dns.TypeA)
+// 	ip := r.Answer[0].(*dns.A).A
+// 	c.Check(ip.String(), Equals, "192.168.1.2")
+
+// 	// Alias to a cname record
+// 	r = exchange(c, "www-alias.test.example.com.", dns.TypeA)
+// 	c.Check(r.Answer[0].(*dns.CNAME).Target, Equals, "geo.bitnames.com.")
+
+// 	// Alias returning a cname, with geo overrides
+// 	r = exchangeSubnet(c, "www-alias.test.example.com.", dns.TypeA, "194.239.134.1")
+// 	c.Check(r.Answer, HasLen, 1)
+// 	if len(r.Answer) > 0 {
+// 		c.Check(r.Answer[0].(*dns.CNAME).Target, Equals, "geo-europe.bitnames.com.")
+// 	}
+
+// 	// Alias to Ns records
+// 	r = exchange(c, "sub-alias.test.example.org.", dns.TypeNS)
+// 	c.Check(r.Answer[0].(*dns.NS).Ns, Equals, "ns1.example.com.")
+
+// }
+
+// func testServingEDNS(t *testing.T) {
+// 	// MX test
+// 	r := exchangeSubnet(t, "test.example.com.", dns.TypeMX, "194.239.134.1")
+// 	c.Check(r.Answer, HasLen, 1)
+// 	if len(r.Answer) > 0 {
+// 		c.Check(r.Answer[0].(*dns.MX).Mx, Equals, "mx-eu.example.net.")
+// 	}
+
+// 	c.Log("Testing www.test.example.com from .dk, should match www.europe (a cname)")
+
+// 	r = exchangeSubnet(c, "www.test.example.com.", dns.TypeA, "194.239.134.0")
+// 	// www.test from .dk IP address gets at least one answer
+// 	c.Check(r.Answer, HasLen, 1)
+// 	if len(r.Answer) > 0 {
+// 		// EDNS-SUBNET test (request A, respond CNAME)
+// 		c.Check(r.Answer[0].(*dns.CNAME).Target, Equals, "geo-europe.bitnames.com.")
+// 	}
+
+// }
+
+// func TestServeRace(t *testing.T) {
+// 	wg := sync.WaitGroup{}
+// 	for i := 0; i < 5; i++ {
+// 		wg.Add(1)
+// 		go func() {
+// 			s.TestServing(t)
+// 			wg.Done()
+// 		}()
+// 	}
+// 	wg.Wait()
+// }
+
+// func BenchmarkServingCountryDebug(b *testing.B) {
+// 	for i := 0; i < b.N; i++ {
+// 		exchange(b, "_country.foo.pgeodns.", dns.TypeTXT)
+// 	}
+// }
+
+// func BenchmarkServing(b *testing.B) {
+
+// 	// a deterministic seed is the default anyway, but let's be explicit we want it here.
+// 	rnd := rand.NewSource(1)
+
+// 	testNames := []string{"foo.test.example.com.", "one.test.example.com.",
+// 		"weight.test.example.com.", "three.two.one.test.example.com.",
+// 		"bar.test.example.com.", "0-alias.test.example.com.",
+// 	}
+
+// 	for i := 0; i < c.N; i++ {
+// 		name := testNames[rnd.Int63()%int64(len(testNames))]
+// 		exchange(t, name, dns.TypeA)
+// 	}
+// }
+
+func checkRcode(t *testing.T, rcode int, expected int, name string) {
+	if rcode != expected {
+		t.Logf("'%s': rcode!=%s: %s", name, dns.RcodeToString[expected], dns.RcodeToString[rcode])
+		t.Fail()
+	}
+}
+
+func exchangeSubnet(t *testing.T, name string, dnstype uint16, ip string) *dns.Msg {
+	msg := new(dns.Msg)
+
+	msg.SetQuestion(name, dnstype)
+
+	o := new(dns.OPT)
+	o.Hdr.Name = "."
+	o.Hdr.Rrtype = dns.TypeOPT
+	e := new(dns.EDNS0_SUBNET)
+	e.Code = dns.EDNS0SUBNET
+	e.SourceScope = 0
+	e.Address = net.ParseIP(ip)
+	e.Family = 1 // IP4
+	e.SourceNetmask = net.IPv4len * 8
+	o.Option = append(o.Option, e)
+	msg.Extra = append(msg.Extra, o)
+
+	t.Log("msg", msg)
+
+	return dorequest(t, msg)
+}
+
+func exchange(t *testing.T, name string, dnstype uint16) *dns.Msg {
+	msg := new(dns.Msg)
+
+	msg.SetQuestion(name, dnstype)
+	return dorequest(t, msg)
+}
+
+func dorequest(t *testing.T, msg *dns.Msg) *dns.Msg {
+	cli := new(dns.Client)
+	// cli.ReadTimeout = 2 * time.Second
+	r, _, err := cli.Exchange(msg, "127.0.0.1"+PORT)
+	if err != nil {
+		t.Logf("request err '%s': %s", msg.String(), err)
+		t.Fail()
+	}
+	return r
+}

+ 74 - 0
server/server.go

@@ -0,0 +1,74 @@
+package server
+
+import (
+	"log"
+
+	"github.com/abh/geodns/monitor"
+	"github.com/abh/geodns/querylog"
+	"github.com/abh/geodns/zones"
+
+	"github.com/miekg/dns"
+)
+
+type Server struct {
+	queryLogger        querylog.QueryLogger
+	mux                *dns.ServeMux
+	PublicDebugQueries bool
+	info               *monitor.ServerInfo
+}
+
+func NewServer(si *monitor.ServerInfo) *Server {
+	mux := dns.NewServeMux()
+
+	// todo: this should be in the monitor package, or somewhere else.
+	// Also if we can stop the server later, need to stop the server too.
+	metrics := NewMetrics()
+	go metrics.Updater()
+
+	return &Server{mux: mux, info: si}
+}
+
+// Setup the QueryLogger. For now it only supports writing to a file (and all
+// zones get logged to the same file).
+func (srv *Server) SetQueryLogger(logger querylog.QueryLogger) {
+	srv.queryLogger = logger
+}
+
+func (srv *Server) Add(name string, zone *zones.Zone) {
+	srv.mux.HandleFunc(name, srv.setupServerFunc(zone))
+}
+
+func (srv *Server) Remove(name string) {
+	srv.mux.HandleRemove(name)
+}
+
+func (srv *Server) setupServerFunc(zone *zones.Zone) func(dns.ResponseWriter, *dns.Msg) {
+	return func(w dns.ResponseWriter, r *dns.Msg) {
+		srv.serve(w, r, zone)
+	}
+}
+
+func (srv *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
+	srv.mux.ServeDNS(w, r)
+}
+
+func (srv *Server) ListenAndServe(ip string) {
+
+	prots := []string{"udp", "tcp"}
+
+	for _, prot := range prots {
+		go func(p string) {
+			server := &dns.Server{
+				Addr:    ip,
+				Net:     p,
+				Handler: srv,
+			}
+
+			log.Printf("Opening on %s %s", ip, p)
+			if err := server.ListenAndServe(); err != nil {
+				log.Fatalf("geodns: failed to setup %s %s: %s", ip, p, err)
+			}
+			log.Fatalf("geodns: ListenAndServe unexpectedly returned")
+		}(prot)
+	}
+}

+ 0 - 130
server_test.go

@@ -1,130 +0,0 @@
-package main
-
-import (
-	"io"
-	"io/ioutil"
-	"os"
-	"testing"
-
-	"github.com/abh/geodns/zones"
-	"github.com/miekg/dns"
-	. "gopkg.in/check.v1"
-)
-
-// Hook up gocheck into the gotest runner.
-func Test(t *testing.T) { TestingT(t) }
-
-type ConfigSuite struct {
-	srv      *Server
-	zonelist zones.Zones
-}
-
-var _ = Suite(&ConfigSuite{})
-
-func (s *ConfigSuite) SetUpSuite(c *C) {
-	s.zonelist = make(zones.Zones)
-	lastRead = map[string]*zoneReadRecord{}
-	s.srv = &Server{}
-	s.srv.zonesReadDir("dns", s.zonelist)
-}
-
-func (s *ConfigSuite) TestReadConfigs(c *C) {
-	// Just check that example.com and test.example.org loaded, too.
-	c.Check(s.zonelist["example.com"].Origin, Equals, "example.com")
-	c.Check(s.zonelist["test.example.org"].Origin, Equals, "test.example.org")
-	if s.zonelist["test.example.org"].Options.Serial == 0 {
-		c.Log("Serial number is 0, should be set by file timestamp")
-		c.Fail()
-	}
-
-	// The real tests are in test.example.com so we have a place
-	// to make nutty configuration entries
-	tz := s.zonelist["test.example.com"]
-
-	// test.example.com was loaded
-	c.Check(tz.Origin, Equals, "test.example.com")
-
-	c.Check(tz.Options.MaxHosts, Equals, 2)
-	c.Check(tz.Options.Contact, Equals, "support.bitnames.com")
-	c.Check(tz.Options.Targeting.String(), Equals, "@ continent country regiongroup region asn ip")
-
-	// Got logging option
-	c.Check(tz.Logging.StatHat, Equals, true)
-
-	c.Check(tz.Labels["weight"].MaxHosts, Equals, 1)
-
-	/* test different cname targets */
-	c.Check(tz.Labels["www"].
-		FirstRR(dns.TypeCNAME).(*dns.CNAME).
-		Target, Equals, "geo.bitnames.com.")
-
-	c.Check(tz.Labels["www-cname"].
-		FirstRR(dns.TypeCNAME).(*dns.CNAME).
-		Target, Equals, "bar.test.example.com.")
-
-	c.Check(tz.Labels["www-alias"].
-		FirstRR(dns.TypeMF).(*dns.MF).
-		Mf, Equals, "www")
-
-	// The header name should just have a dot-prefix
-	c.Check(tz.Labels[""].Records[dns.TypeNS][0].RR.(*dns.NS).Hdr.Name, Equals, "test.example.com.")
-
-}
-
-func (s *ConfigSuite) TestRemoveConfig(c *C) {
-	// restore the dns.Mux
-	defer s.srv.zonesReadDir("dns", s.zonelist)
-
-	dir, err := ioutil.TempDir("", "geodns-test.")
-	if err != nil {
-		c.Fail()
-	}
-	defer os.RemoveAll(dir)
-
-	_, err = CopyFile(c, "dns/test.example.org.json", dir+"/test.example.org.json")
-	if err != nil {
-		c.Log(err)
-		c.Fail()
-	}
-	_, err = CopyFile(c, "dns/test.example.org.json", dir+"/test2.example.org.json")
-	if err != nil {
-		c.Log(err)
-		c.Fail()
-	}
-
-	err = ioutil.WriteFile(dir+"/invalid.example.org.json", []byte("not-json"), 0644)
-	if err != nil {
-		c.Log(err)
-		c.Fail()
-	}
-
-	s.srv.zonesReadDir(dir, s.zonelist)
-	c.Check(s.zonelist["test.example.org"].Origin, Equals, "test.example.org")
-	c.Check(s.zonelist["test2.example.org"].Origin, Equals, "test2.example.org")
-
-	os.Remove(dir + "/test2.example.org.json")
-	os.Remove(dir + "/invalid.example.org.json")
-
-	s.srv.zonesReadDir(dir, s.zonelist)
-	c.Check(s.zonelist["test.example.org"].Origin, Equals, "test.example.org")
-	_, ok := s.zonelist["test2.example.org"]
-	c.Check(ok, Equals, false)
-}
-
-func CopyFile(c *C, src, dst string) (int64, error) {
-	sf, err := os.Open(src)
-	if err != nil {
-		c.Log("Could not copy", src, "to", dst, "because", err)
-		c.Fail()
-		return 0, err
-	}
-	defer sf.Close()
-	df, err := os.Create(dst)
-	if err != nil {
-		c.Log("Could not copy", src, "to", dst, "because", err)
-		c.Fail()
-		return 0, err
-	}
-	defer df.Close()
-	return io.Copy(df, sf)
-}

+ 2 - 0
targeting/targeting_test.go

@@ -50,6 +50,8 @@ func TestTargetParse(t *testing.T) {
 func TestGetTargets(t *testing.T) {
 	ip := net.ParseIP("207.171.1.1")
 
+	GeoIP().SetDirectory("../db")
+
 	GeoIP().SetupGeoIPCity()
 	GeoIP().SetupGeoIPCountry()
 	GeoIP().SetupGeoIPASN()

+ 6 - 4
util.go

@@ -40,11 +40,13 @@ func getInterfaces() []string {
 		}
 		uniq[host] = true
 
-		if len(serverID) == 0 {
-			serverID = ip
+		// default to the first interfaces
+		// todo: skip 127.0.0.1 and ::1 ?
+		if len(serverInfo.ID) == 0 {
+			serverInfo.ID = ip
 		}
-		if len(serverIP) == 0 {
-			serverIP = ip
+		if len(serverInfo.IP) == 0 {
+			serverInfo.IP = ip
 		}
 		inter = append(inter, host)
 

+ 82 - 85
server.go → zones/muxmanager.go

@@ -1,4 +1,4 @@
-package main
+package zones
 
 import (
 	"crypto/sha256"
@@ -11,14 +11,20 @@ import (
 	"time"
 
 	"github.com/abh/geodns/applog"
-	"github.com/abh/geodns/querylog"
-	"github.com/abh/geodns/zones"
 
 	"github.com/miekg/dns"
 )
 
-type Server struct {
-	queryLogger querylog.QueryLogger
+type RegistrationAPI interface {
+	Add(string, *Zone)
+	Remove(string)
+}
+
+type MuxManager struct {
+	reg      RegistrationAPI
+	zonelist Zones
+	path     string
+	lastRead map[string]*zoneReadRecord
 }
 
 // track when each zone was read last
@@ -27,78 +33,25 @@ type zoneReadRecord struct {
 	hash string
 }
 
-func NewServer() *Server {
-	return &Server{}
-}
-
-// Setup the QueryLogger. For now it only supports writing to a file (and all
-// zones get logged to the same file).
-func (srv *Server) SetQueryLogger(logger querylog.QueryLogger) {
-	srv.queryLogger = logger
-}
-
-func (srv *Server) setupServerFunc(zone *zones.Zone) func(dns.ResponseWriter, *dns.Msg) {
-	return func(w dns.ResponseWriter, r *dns.Msg) {
-		srv.serve(w, r, zone)
+func NewMuxManager(path string, reg RegistrationAPI) (*MuxManager, error) {
+	mm := &MuxManager{
+		reg:      reg,
+		path:     path,
+		zonelist: make(Zones),
+		lastRead: map[string]*zoneReadRecord{},
 	}
-}
-
-func (srv *Server) listenAndServe(ip string) {
 
-	prots := []string{"udp", "tcp"}
+	mm.setupRootZone()
+	mm.setupPgeodnsZone()
 
-	for _, prot := range prots {
-		go func(p string) {
-			server := &dns.Server{Addr: ip, Net: p}
+	err := mm.reload()
 
-			log.Printf("Opening on %s %s", ip, p)
-			if err := server.ListenAndServe(); err != nil {
-				log.Fatalf("geodns: failed to setup %s %s: %s", ip, p, err)
-			}
-			log.Fatalf("geodns: ListenAndServe unexpectedly returned")
-		}(prot)
-	}
+	return mm, err
 }
 
-func (srv *Server) addHandler(zones zones.Zones, name string, config *zones.Zone) {
-	oldZone := zones[name]
-	// across the recconfiguration keep a reference to all healthchecks to ensure
-	// the global map doesn't get destroyed
-	// health.TestRunner.refAllGlobalHealthChecks(name, true)
-	// defer health.TestRunner.refAllGlobalHealthChecks(name, false)
-	// if oldZone != nil {
-	// 	oldZone.StartStopHealthChecks(false, nil)
-	// }
-	config.SetupMetrics(oldZone)
-	zones[name] = config
-	// config.StartStopHealthChecks(true, oldZone)
-	dns.HandleFunc(name, srv.setupServerFunc(config))
-}
-
-func (srv *Server) setupPgeodnsZone(zonelist zones.Zones) {
-	zoneName := "pgeodns"
-	zone := zones.NewZone(zoneName)
-	label := new(zones.Label)
-	label.Records = make(map[uint16]zones.Records)
-	label.Weight = make(map[uint16]int)
-	zone.Labels[""] = label
-	zone.AddSOA()
-	srv.addHandler(zonelist, zoneName, zone)
-}
-
-func (srv *Server) setupRootZone() {
-	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
-		m := new(dns.Msg)
-		m.SetRcode(r, dns.RcodeRefused)
-		w.WriteMsg(m)
-	})
-}
-
-var lastRead = map[string]*zoneReadRecord{}
-
-func (srv *Server) zonesReader(dirName string, zones zones.Zones) {
+func (mm *MuxManager) Run() {
 	for {
-		err := srv.zonesReadDir(dirName, zones)
+		err := mm.reload()
 		if err != nil {
 			log.Printf("error reading zones: %s", err)
 		}
@@ -106,10 +59,16 @@ func (srv *Server) zonesReader(dirName string, zones zones.Zones) {
 	}
 }
 
-func (srv *Server) zonesReadDir(dirName string, zonelist zones.Zones) error {
-	dir, err := ioutil.ReadDir(dirName)
+// GetZones returns the list of currently active zones in the mux manager.
+// (todo: rename to Zones() when the Zones struct has been renamed to ZoneList)
+func (mm *MuxManager) Zones() Zones {
+	return mm.zonelist
+}
+
+func (mm *MuxManager) reload() error {
+	dir, err := ioutil.ReadDir(mm.path)
 	if err != nil {
-		return fmt.Errorf("could not read", dirName, ":", err)
+		return fmt.Errorf("could not read '%s': %s", mm.path, err)
 	}
 
 	seenZones := map[string]bool{}
@@ -128,17 +87,17 @@ func (srv *Server) zonesReadDir(dirName string, zonelist zones.Zones) error {
 
 		seenZones[zoneName] = true
 
-		if _, ok := lastRead[zoneName]; !ok || file.ModTime().After(lastRead[zoneName].time) {
+		if _, ok := mm.lastRead[zoneName]; !ok || file.ModTime().After(mm.lastRead[zoneName].time) {
 			modTime := file.ModTime()
 			if ok {
 				applog.Printf("Reloading %s\n", fileName)
-				lastRead[zoneName].time = modTime
+				mm.lastRead[zoneName].time = modTime
 			} else {
 				applog.Printf("Reading new file %s\n", fileName)
-				lastRead[zoneName] = &zoneReadRecord{time: modTime}
+				mm.lastRead[zoneName] = &zoneReadRecord{time: modTime}
 			}
 
-			filename := path.Join(dirName, fileName)
+			filename := path.Join(mm.path, fileName)
 
 			// Check the sha256 of the file has not changed. It's worth an explanation of
 			// why there isn't a TOCTOU race here. Conceivably after checking whether the
@@ -162,25 +121,25 @@ func (srv *Server) zonesReadDir(dirName string, zonelist zones.Zones) error {
 			// replaced atomically we have other problems (e.g. partial reads).
 
 			sha256 := sha256File(filename)
-			if lastRead[zoneName].hash == sha256 {
+			if mm.lastRead[zoneName].hash == sha256 {
 				applog.Printf("Skipping new file %s as hash is unchanged\n", filename)
 				continue
 			}
 
-			zone, err := zones.ReadZoneFile(zoneName, filename)
+			zone, err := ReadZoneFile(zoneName, filename)
 			if zone == nil || err != nil {
 				parseErr = fmt.Errorf("Error reading zone '%s': %s", zoneName, err)
 				log.Println(parseErr.Error())
 				continue
 			}
 
-			(lastRead[zoneName]).hash = sha256
+			(mm.lastRead[zoneName]).hash = sha256
 
-			srv.addHandler(zonelist, zoneName, zone)
+			mm.addHandler(zoneName, zone)
 		}
 	}
 
-	for zoneName, zone := range zonelist {
+	for zoneName, zone := range mm.zonelist {
 		if zoneName == "pgeodns" {
 			continue
 		}
@@ -188,15 +147,53 @@ func (srv *Server) zonesReadDir(dirName string, zonelist zones.Zones) error {
 			continue
 		}
 		log.Println("Removing zone", zone.Origin)
-		delete(lastRead, zoneName)
 		zone.Close()
-		dns.HandleRemove(zoneName)
-		delete(zonelist, zoneName)
+		mm.removeHandler(zoneName)
 	}
 
 	return parseErr
 }
 
+func (mm *MuxManager) addHandler(name string, zone *Zone) {
+	oldZone := mm.zonelist[name]
+	// across the recconfiguration keep a reference to all healthchecks to ensure
+	// the global map doesn't get destroyed
+	// healtmm.TestRunner.refAllGlobalHealthChecks(name, true)
+	// defer healtmm.TestRunner.refAllGlobalHealthChecks(name, false)
+	// if oldZone != nil {
+	// 	oldZone.StartStopHealthChecks(false, nil)
+	// }
+	zone.SetupMetrics(oldZone)
+	mm.zonelist[name] = zone
+	// config.StartStopHealthChecks(true, oldZone)
+	mm.reg.Add(name, zone)
+}
+
+func (mm *MuxManager) removeHandler(name string) {
+	delete(mm.lastRead, name)
+	delete(mm.zonelist, name)
+	mm.reg.Remove(name)
+}
+
+func (mm *MuxManager) setupPgeodnsZone() {
+	zoneName := "pgeodns"
+	zone := NewZone(zoneName)
+	label := new(Label)
+	label.Records = make(map[uint16]Records)
+	label.Weight = make(map[uint16]int)
+	zone.Labels[""] = label
+	zone.AddSOA()
+	mm.addHandler(zoneName, zone)
+}
+
+func (mm *MuxManager) setupRootZone() {
+	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
+		m := new(dns.Msg)
+		m.SetRcode(r, dns.RcodeRefused)
+		w.WriteMsg(m)
+	})
+}
+
 func zoneNameFromFile(fileName string) string {
 	return fileName[0:strings.LastIndex(fileName, ".")]
 }

+ 152 - 0
zones/reader_test.go

@@ -0,0 +1,152 @@
+package zones
+
+import (
+	"fmt"
+	"io"
+	"io/ioutil"
+	"os"
+	"testing"
+)
+
+type TestReg struct{}
+
+func (r *TestReg) Add(name string, zone *Zone) {}
+
+func (r *TestReg) Remove(name string) {}
+
+func TestReadConfigs(t *testing.T) {
+
+	muxm, err := NewMuxManager("../dns", &TestReg{})
+	if err != nil {
+		t.Logf("loading zones: %s", err)
+		t.Fail()
+	}
+
+	// Just check that example.com and test.example.org loaded, too.
+	for _, zonename := range []string{"example.com", "test.example.com"} {
+
+		if z, ok := muxm.zonelist[zonename]; ok {
+			if z.Origin != zonename {
+				t.Logf("zone '%s' doesn't have that Origin '%s'", zonename, z.Origin)
+				t.Fail()
+			}
+			if z.Options.Serial == 0 {
+				t.Logf("Zone '%s' Serial number is 0, should be set by file timestamp", zonename)
+				t.Fail()
+			}
+		} else {
+			t.Fatalf("Didn't load '%s'", zonename)
+		}
+	}
+
+	// The real tests are in test.example.com so we have a place
+	// to make nutty configuration entries
+	tz := muxm.zonelist["test.example.com"]
+
+	// test.example.com was loaded
+
+	if tz.Options.MaxHosts != 2 {
+		t.Logf("MaxHosts=%d, expected 2", tz.Options.MaxHosts)
+		t.Fail()
+	}
+
+	if tz.Options.Contact != "support.bitnames.com" {
+		t.Logf("Contact='%s', expected support.bitnames.com", tz.Options.Contact)
+		t.Fail()
+	}
+	// c.Check(tz.Options.Targeting.String(), Equals, "@ continent country regiongroup region asn ip")
+
+	// // Got logging option
+	// c.Check(tz.Logging.StatHat, Equals, true)
+
+	// c.Check(tz.Labels["weight"].MaxHosts, Equals, 1)
+
+	// /* test different cname targets */
+	// c.Check(tz.Labels["www"].
+	// 	FirstRR(dns.TypeCNAME).(*dns.CNAME).
+	// 	Target, Equals, "geo.bitnames.com.")
+
+	// c.Check(tz.Labels["www-cname"].
+	// 	FirstRR(dns.TypeCNAME).(*dns.CNAME).
+	// 	Target, Equals, "bar.test.example.com.")
+
+	// c.Check(tz.Labels["www-alias"].
+	// 	FirstRR(dns.TypeMF).(*dns.MF).
+	// 	Mf, Equals, "www")
+
+	// // The header name should just have a dot-prefix
+	// c.Check(tz.Labels[""].Records[dns.TypeNS][0].RR.(*dns.NS).Hdr.Name, Equals, "test.example.com.")
+
+}
+
+func TestRemoveConfig(t *testing.T) {
+	dir, err := ioutil.TempDir("", "geodns-test.")
+	if err != nil {
+		t.Fail()
+	}
+	defer os.RemoveAll(dir)
+
+	muxm, err := NewMuxManager(dir, &TestReg{})
+	if err != nil {
+		t.Logf("loading zones: %s", err)
+		t.Fail()
+	}
+
+	muxm.reload()
+
+	_, err = CopyFile("../dns/test.example.org.json", dir+"/test.example.org.json")
+	if err != nil {
+		t.Log(err)
+		t.Fail()
+	}
+	_, err = CopyFile("../dns/test.example.org.json", dir+"/test2.example.org.json")
+	if err != nil {
+		t.Log(err)
+		t.Fail()
+	}
+
+	err = ioutil.WriteFile(dir+"/invalid.example.org.json", []byte("not-json"), 0644)
+	if err != nil {
+		t.Log(err)
+		t.Fail()
+	}
+
+	muxm.reload()
+	if muxm.zonelist["test.example.org"].Origin != "test.example.org" {
+		t.Log("test.example.org has unexpected Origin: '%s'", muxm.zonelist["test.example.org"].Origin)
+		t.Fail()
+	}
+	if muxm.zonelist["test2.example.org"].Origin != "test2.example.org" {
+		t.Log("test2.example.org has unexpected Origin: '%s'", muxm.zonelist["test2.example.org"].Origin)
+		t.Fail()
+	}
+
+	os.Remove(dir + "/test2.example.org.json")
+	os.Remove(dir + "/invalid.example.org.json")
+
+	muxm.reload()
+
+	if muxm.zonelist["test.example.org"].Origin != "test.example.org" {
+		t.Log("test.example.org has unexpected Origin: '%s'", muxm.zonelist["test.example.org"].Origin)
+		t.Fail()
+	}
+	_, ok := muxm.zonelist["test2.example.org"]
+	if ok != false {
+		t.Log("test2.example.org is still loaded")
+		t.Fail()
+	}
+}
+
+func CopyFile(src, dst string) (int64, error) {
+	sf, err := os.Open(src)
+	if err != nil {
+		return 0, fmt.Errorf("Could not copy '%s' to '%s': %s", src, dst, err)
+	}
+	defer sf.Close()
+	df, err := os.Create(dst)
+	if err != nil {
+		return 0, fmt.Errorf("Could not copy '%s' to '%s': %s", src, dst, err)
+	}
+	defer df.Close()
+	return io.Copy(df, sf)
+}

+ 119 - 99
zones/zone_test.go

@@ -1,113 +1,133 @@
 package zones
 
 import (
+	"testing"
+
 	"github.com/miekg/dns"
-	. "gopkg.in/check.v1"
 )
 
-func (s *ConfigSuite) TestExampleComZone(c *C) {
-	ex, ok := s.zones["test.example.com"]
-
-	c.Check(ok, Equals, true)
-	c.Check(ex, NotNil)
-
-	// test.example.com was loaded
-	c.Assert(ex.Labels, NotNil)
+func TestExampleComZone(t *testing.T) {
+	mm, err := NewMuxManager("../dns", &TestReg{})
+	if err != nil {
+		t.Fatalf("Loading test zones: %s", err)
+	}
 
-	c.Check(ex.Logging.StatHat, Equals, true)
-	c.Check(ex.Logging.StatHatAPI, Equals, "abc-test")
+	ex, ok := mm.zonelist["test.example.com"]
+	if !ok || ex == nil || ex.Labels == nil {
+		t.Fatalf("Did not load 'test.example.com' test zone")
+	}
 
-	c.Check(ex.Labels["weight"].MaxHosts, Equals, 1)
+	if mh := ex.Labels["weight"].MaxHosts; mh != 1 {
+		t.Logf("Invalid MaxHosts, expected one got '%d'", mh)
+		t.Fail()
+	}
 
 	// Make sure that the empty "no.bar" zone gets skipped and "bar" is used
-	label, qtype := ex.findLabels("bar", []string{"no", "europe", "@"}, qTypes{dns.TypeA})
-	c.Check(label.Records[dns.TypeA], HasLen, 1)
-	c.Check(label.Records[dns.TypeA][0].RR.(*dns.A).A.String(), Equals, "192.168.1.2")
-	c.Check(qtype, Equals, dns.TypeA)
-
-	label, qtype = ex.findLabels("", []string{"@"}, qTypes{dns.TypeMX})
-	Mxs := label.Records[dns.TypeMX]
-	c.Check(Mxs, HasLen, 2)
-	c.Check(Mxs[0].RR.(*dns.MX).Mx, Equals, "mx.example.net.")
-	c.Check(Mxs[1].RR.(*dns.MX).Mx, Equals, "mx2.example.net.")
-
-	label, qtype = ex.findLabels("", []string{"dk", "europe", "@"}, qTypes{dns.TypeMX})
-	Mxs = label.Records[dns.TypeMX]
-	c.Check(Mxs, HasLen, 1)
-	c.Check(Mxs[0].RR.(*dns.MX).Mx, Equals, "mx-eu.example.net.")
-	c.Check(qtype, Equals, dns.TypeMX)
-
-	// look for multiple record types
-	label, qtype = ex.findLabels("www", []string{"@"}, qTypes{dns.TypeCNAME, dns.TypeA})
-	c.Check(label.Records[dns.TypeCNAME], HasLen, 1)
-	c.Check(qtype, Equals, dns.TypeCNAME)
-
-	// pretty.Println(ex.Labels[""].Records[dns.TypeNS])
-
-	label, qtype = ex.findLabels("", []string{"@"}, qTypes{dns.TypeNS})
-	Ns := label.Records[dns.TypeNS]
-	c.Check(Ns, HasLen, 2)
-	// Test that we get the expected NS records (in any order because
-	// of the configuration format used for this zone)
-	c.Check(Ns[0].RR.(*dns.NS).Ns, Matches, "^ns[12]\\.example\\.net.$")
-	c.Check(Ns[1].RR.(*dns.NS).Ns, Matches, "^ns[12]\\.example\\.net.$")
-
-	label, qtype = ex.findLabels("", []string{"@"}, qTypes{dns.TypeSPF})
-	Spf := label.Records[dns.TypeSPF]
-	c.Check(Spf, HasLen, 1)
-	c.Check(Spf[0].RR.(*dns.SPF).Txt[0], Equals, "v=spf1 ~all")
-
-	label, qtype = ex.findLabels("foo", []string{"@"}, qTypes{dns.TypeTXT})
-	Txt := label.Records[dns.TypeTXT]
-	c.Check(Txt, HasLen, 1)
-	c.Check(Txt[0].RR.(*dns.TXT).Txt[0], Equals, "this is foo")
-
-	label, qtype = ex.findLabels("weight", []string{"@"}, qTypes{dns.TypeTXT})
-	Txt = label.Records[dns.TypeTXT]
-	c.Check(Txt, HasLen, 2)
-	c.Check(Txt[0].RR.(*dns.TXT).Txt[0], Equals, "w1000")
-	c.Check(Txt[1].RR.(*dns.TXT).Txt[0], Equals, "w1")
-
-	//verify empty labels are created
-	label, qtype = ex.findLabels("a.b.c", []string{"@"}, qTypes{dns.TypeA})
-	c.Check(label.Records[dns.TypeA], HasLen, 1)
-	c.Check(label.Records[dns.TypeA][0].RR.(*dns.A).A.String(), Equals, "192.168.1.7")
-
-	label, qtype = ex.findLabels("b.c", []string{"@"}, qTypes{dns.TypeA})
-	c.Check(label.Records[dns.TypeA], HasLen, 0)
-	c.Check(label.Label, Equals, "b.c")
-
-	label, qtype = ex.findLabels("c", []string{"@"}, qTypes{dns.TypeA})
-	c.Check(label.Records[dns.TypeA], HasLen, 0)
-	c.Check(label.Label, Equals, "c")
-
-	//verify label is created
-	label, qtype = ex.findLabels("three.two.one", []string{"@"}, qTypes{dns.TypeA})
-	c.Check(label.Records[dns.TypeA], HasLen, 1)
-	c.Check(label.Records[dns.TypeA][0].RR.(*dns.A).A.String(), Equals, "192.168.1.5")
-
-	label, qtype = ex.findLabels("two.one", []string{"@"}, qTypes{dns.TypeA})
-	c.Check(label.Records[dns.TypeA], HasLen, 0)
-	c.Check(label.Label, Equals, "two.one")
-
-	//verify label isn't overwritten
-	label, qtype = ex.findLabels("one", []string{"@"}, qTypes{dns.TypeA})
-	c.Check(label.Records[dns.TypeA], HasLen, 1)
-	c.Check(label.Records[dns.TypeA][0].RR.(*dns.A).A.String(), Equals, "192.168.1.6")
+	label, qtype := ex.FindLabels("bar", []string{"no", "europe", "@"}, []uint16{dns.TypeA})
+	if l := len(label.Records[dns.TypeA]); l != 1 {
+		t.Logf("Unexpected number of A records: '%d'", l)
+		t.Fail()
+	}
+	if qtype != dns.TypeA {
+		t.Fatalf("Expected qtype = A record (type %d), got type %d", dns.TypeA, qtype)
+	}
+	if str := label.Records[qtype][0].RR.(*dns.A).A.String(); str != "192.168.1.2" {
+		t.Logf("Got A '%s', expected '%s'", str, "192.168.1.2")
+		t.Fail()
+	}
+
+	// label, qtype = ex.FindLabels("", []string{"@"}, []uint16{dns.TypeMX})
+	// Mxs := label.Records[dns.TypeMX]
+	// c.Check(Mxs, HasLen, 2)
+	// c.Check(Mxs[0].RR.(*dns.MX).Mx, Equals, "mx.example.net.")
+	// c.Check(Mxs[1].RR.(*dns.MX).Mx, Equals, "mx2.example.net.")
+
+	// label, qtype = ex.FindLabels("", []string{"dk", "europe", "@"}, []uint16{dns.TypeMX})
+	// Mxs = label.Records[dns.TypeMX]
+	// c.Check(Mxs, HasLen, 1)
+	// c.Check(Mxs[0].RR.(*dns.MX).Mx, Equals, "mx-eu.example.net.")
+	// c.Check(qtype, Equals, dns.TypeMX)
+
+	// // look for multiple record types
+	// label, qtype = ex.FindLabels("www", []string{"@"}, []uint16{dns.TypeCNAME, dns.TypeA})
+	// c.Check(label.Records[dns.TypeCNAME], HasLen, 1)
+	// c.Check(qtype, Equals, dns.TypeCNAME)
+
+	// // pretty.Println(ex.Labels[""].Records[dns.TypeNS])
+
+	// label, qtype = ex.FindLabels("", []string{"@"}, []uint16{dns.TypeNS})
+	// Ns := label.Records[dns.TypeNS]
+	// c.Check(Ns, HasLen, 2)
+	// // Test that we get the expected NS records (in any order because
+	// // of the configuration format used for this zone)
+	// c.Check(Ns[0].RR.(*dns.NS).Ns, Matches, "^ns[12]\\.example\\.net.$")
+	// c.Check(Ns[1].RR.(*dns.NS).Ns, Matches, "^ns[12]\\.example\\.net.$")
+
+	// label, qtype = ex.FindLabels("", []string{"@"}, []uint16{dns.TypeSPF})
+	// Spf := label.Records[dns.TypeSPF]
+	// c.Check(Spf, HasLen, 1)
+	// c.Check(Spf[0].RR.(*dns.SPF).Txt[0], Equals, "v=spf1 ~all")
+
+	// label, qtype = ex.FindLabels("foo", []string{"@"}, []uint16{dns.TypeTXT})
+	// Txt := label.Records[dns.TypeTXT]
+	// c.Check(Txt, HasLen, 1)
+	// c.Check(Txt[0].RR.(*dns.TXT).Txt[0], Equals, "this is foo")
+
+	// label, qtype = ex.FindLabels("weight", []string{"@"}, []uint16{dns.TypeTXT})
+	// Txt = label.Records[dns.TypeTXT]
+	// c.Check(Txt, HasLen, 2)
+	// c.Check(Txt[0].RR.(*dns.TXT).Txt[0], Equals, "w1000")
+	// c.Check(Txt[1].RR.(*dns.TXT).Txt[0], Equals, "w1")
+
+	// //verify empty labels are created
+	// label, qtype = ex.FindLabels("a.b.c", []string{"@"}, []uint16{dns.TypeA})
+	// c.Check(label.Records[dns.TypeA], HasLen, 1)
+	// c.Check(label.Records[dns.TypeA][0].RR.(*dns.A).A.String(), Equals, "192.168.1.7")
+
+	// label, qtype = ex.FindLabels("b.c", []string{"@"}, []uint16{dns.TypeA})
+	// c.Check(label.Records[dns.TypeA], HasLen, 0)
+	// c.Check(label.Label, Equals, "b.c")
+
+	// label, qtype = ex.FindLabels("c", []string{"@"}, []uint16{dns.TypeA})
+	// c.Check(label.Records[dns.TypeA], HasLen, 0)
+	// c.Check(label.Label, Equals, "c")
+
+	// //verify label is created
+	// label, qtype = ex.FindLabels("three.two.one", []string{"@"}, []uint16{dns.TypeA})
+	// c.Check(label.Records[dns.TypeA], HasLen, 1)
+	// c.Check(label.Records[dns.TypeA][0].RR.(*dns.A).A.String(), Equals, "192.168.1.5")
+
+	// label, qtype = ex.FindLabels("two.one", []string{"@"}, []uint16{dns.TypeA})
+	// c.Check(label.Records[dns.TypeA], HasLen, 0)
+	// c.Check(label.Label, Equals, "two.one")
+
+	// //verify label isn't overwritten
+	// label, qtype = ex.FindLabels("one", []string{"@"}, []uint16{dns.TypeA})
+	// c.Check(label.Records[dns.TypeA], HasLen, 1)
+	// c.Check(label.Records[dns.TypeA][0].RR.(*dns.A).A.String(), Equals, "192.168.1.6")
 }
 
-func (s *ConfigSuite) TestExampleOrgZone(c *C) {
-	ex := s.zones["test.example.org"]
-
-	// test.example.org was loaded
-	c.Assert(ex.Labels, NotNil)
-
-	label, qtype := ex.findLabels("sub", []string{"@"}, qTypes{dns.TypeNS})
-	c.Assert(qtype, Equals, dns.TypeNS)
-
-	Ns := label.Records[dns.TypeNS]
-	c.Check(Ns, HasLen, 2)
-	c.Check(Ns[0].RR.(*dns.NS).Ns, Equals, "ns1.example.com.")
-	c.Check(Ns[1].RR.(*dns.NS).Ns, Equals, "ns2.example.com.")
+func TestExampleOrgZone(t *testing.T) {
+	mm, err := NewMuxManager("../dns", &TestReg{})
+	if err != nil {
+		t.Fatalf("Loading test zones: %s", err)
+	}
+
+	ex, ok := mm.zonelist["test.example.org"]
+	if !ok || ex == nil || ex.Labels == nil {
+		t.Fatalf("Did not load 'test.example.org' test zone")
+	}
+
+	label, qtype := ex.FindLabels("sub", []string{"@"}, []uint16{dns.TypeNS})
+	if qtype != dns.TypeNS {
+		t.Fatalf("Expected qtype = NS record (type %d), got type %d", dns.TypeNS, qtype)
+	}
+
+	Ns := label.Records[qtype]
+	if l := len(Ns); l != 2 {
+		t.Fatalf("Expected 2 NS records, got '%d'", l)
+	}
+	// c.Check(Ns[0].RR.(*dns.NS).Ns, Equals, "ns1.example.com.")
+	// c.Check(Ns[1].RR.(*dns.NS).Ns, Equals, "ns2.example.com.")
 
 }