소스 검색

Graceful shutdown (to close log files)

Also various lint cleanups
Ask Bjørn Hansen 2 년 전
부모
커밋
7dfb8a0056
17개의 변경된 파일214개의 추가작업 그리고 108개의 파일을 삭제
  1. 16 35
      config.go
  2. 51 20
      geodns.go
  3. 1 0
      go.mod
  4. 6 0
      go.sum
  5. 33 3
      http.go
  6. 2 2
      http_test.go
  7. 0 2
      querylog/avro.go
  8. 11 4
      server/serve_test.go
  9. 57 7
      server/server.go
  10. 1 1
      targeting/targeting.go
  11. 2 2
      targeting/targeting_test.go
  12. 0 8
      templates_devel.go
  13. 19 9
      zones/muxmanager.go
  14. 9 9
      zones/reader.go
  15. 3 4
      zones/reader_test.go
  16. 1 1
      zones/zone_stats.go
  17. 2 1
      zones/zone_test.go

+ 16 - 35
config.go

@@ -1,7 +1,7 @@
 package main
 
 import (
-	"fmt"
+	"context"
 	"log"
 	"os"
 	"sync"
@@ -14,12 +14,6 @@ import (
 )
 
 type AppConfig struct {
-	StatHat struct {
-		ApiKey string
-	}
-	Flags struct {
-		HasStatHat bool
-	}
 	GeoIP struct {
 		Directory string
 	}
@@ -56,18 +50,6 @@ type AppConfig struct {
 var Config = new(AppConfig)
 var cfgMutex sync.RWMutex
 
-func (conf *AppConfig) HasStatHat() bool {
-	cfgMutex.RLock()
-	defer cfgMutex.RUnlock()
-	return conf.Flags.HasStatHat
-}
-
-func (conf *AppConfig) StatHatApiKey() string {
-	cfgMutex.RLock()
-	defer cfgMutex.RUnlock()
-	return conf.StatHat.ApiKey
-}
-
 func (conf *AppConfig) GeoIPDirectory() string {
 	cfgMutex.RLock()
 	defer cfgMutex.RUnlock()
@@ -77,38 +59,42 @@ func (conf *AppConfig) GeoIPDirectory() string {
 	return geoip2.FindDB()
 }
 
-func configWatcher(fileName string) {
+func configWatcher(ctx context.Context, fileName string) error {
 
 	watcher, err := fsnotify.NewWatcher()
 	if err != nil {
-		fmt.Println(err)
-		return
+		return err
 	}
 
 	if err := watcher.Add(*flagconfig); err != nil {
-		fmt.Println(err)
-		return
+		return err
 	}
 
 	for {
 		select {
+		case <-ctx.Done():
+			return nil
 		case ev := <-watcher.Events:
 			if ev.Name == fileName {
 				// Write = when the file is updated directly
 				// Rename = when it's updated atomicly
 				// Chmod = for `touch`
-				if ev.Op&fsnotify.Write == fsnotify.Write ||
-					ev.Op&fsnotify.Rename == fsnotify.Rename ||
-					ev.Op&fsnotify.Chmod == fsnotify.Chmod {
+				if ev.Has(fsnotify.Write) ||
+					ev.Has(fsnotify.Rename) ||
+					ev.Has(fsnotify.Chmod) {
 					time.Sleep(200 * time.Millisecond)
-					configReader(fileName)
+					err := configReader(fileName)
+					if err != nil {
+						// don't quit because we'll just keep the old config at this
+						// stage and try again next it changes
+						log.Printf("error reading config file: %s", err)
+					}
 				}
 			}
 		case err := <-watcher.Errors:
-			log.Println("fsnotify error:", err)
+			log.Printf("fsnotify error: %s", err)
 		}
 	}
-
 }
 
 var lastReadConfig time.Time
@@ -137,11 +123,6 @@ func configReader(fileName string) error {
 		return err
 	}
 
-	cfg.Flags.HasStatHat = len(cfg.StatHat.ApiKey) > 0
-
-	// log.Println("STATHAT APIKEY:", cfg.StatHat.ApiKey)
-	// log.Println("STATHAT FLAG  :", cfg.Flags.HasStatHat)
-
 	cfgMutex.Lock()
 	*Config = *cfg // shallow copy to prevent race conditions in referring to Config.foo()
 	cfgMutex.Unlock()

+ 51 - 20
geodns.go

@@ -17,6 +17,7 @@ package main
 */
 
 import (
+	"context"
 	"flag"
 	"fmt"
 	"log"
@@ -38,6 +39,7 @@ import (
 	"github.com/abh/geodns/v3/targeting/geoip2"
 	"github.com/abh/geodns/v3/zones"
 	"github.com/pborman/uuid"
+	"golang.org/x/sync/errgroup"
 )
 
 // VERSION is the current version of GeoDNS, set by the build process
@@ -45,10 +47,6 @@ var VERSION string = "devel"
 var buildTime string
 var gitVersion string
 
-// Set development with the 'devel' build flag to load
-// templates from disk instead of from the binary.
-var development bool
-
 var (
 	serverInfo *monitor.ServerInfo
 )
@@ -154,6 +152,15 @@ func main() {
 
 	log.Printf("Starting geodns %s (%s)\n", VERSION, runtime.Version())
 
+	ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
+	g, ctx := errgroup.WithContext(ctx)
+
+	g.Go(func() error {
+		<-ctx.Done()
+		log.Printf("server shutting down")
+		return nil
+	})
+
 	if *cpuprofile != "" {
 		prof, err := os.Create(*cpuprofile)
 		if err != nil {
@@ -172,14 +179,25 @@ func main() {
 	}
 
 	// load geodns.conf config
-	configReader(configFileName)
+	err := configReader(configFileName)
+	if err != nil {
+		log.Printf("error reading config file %s: %s", configFileName, err)
+		os.Exit(2)
+	}
 
 	if len(Config.Health.Directory) > 0 {
 		go health.DirectoryReader(Config.Health.Directory)
 	}
 
 	// load (and re-load) zone data
-	go configWatcher(configFileName)
+	g.Go(func() error {
+		err := configWatcher(ctx, configFileName)
+		if err != nil {
+			log.Printf("config watcher error: %s", err)
+			return err
+		}
+		return nil
+	})
 
 	if *flaginter == "*" {
 		addrs, _ := net.InterfaceAddrs()
@@ -199,10 +217,6 @@ func main() {
 
 	inter := getInterfaces()
 
-	if Config.HasStatHat() {
-		log.Println("StatHat integration has been removed in favor of more generic metrics")
-	}
-
 	if len(Config.GeoIPDirectory()) > 0 {
 		geoProvider, err := geoip2.New(Config.GeoIPDirectory())
 		if err != nil {
@@ -247,24 +261,41 @@ func main() {
 	if err != nil {
 		log.Printf("error loading zones: %s", err)
 	}
-	go muxm.Run()
+
+	g.Go(func() error {
+		muxm.Run(ctx)
+		return nil
+	})
 
 	for _, host := range inter {
-		go srv.ListenAndServe(host)
+		host := host
+		g.Go(func() error {
+			return srv.ListenAndServe(ctx, host)
+		})
 	}
 
+	g.Go(func() error {
+		<-ctx.Done()
+		log.Printf("shutting down DNS servers")
+		err = srv.Shutdown()
+		if err != nil {
+			return err
+		}
+		return nil
+	})
+
 	if len(*flaghttp) > 0 {
-		go func() {
+		g.Go(func() error {
 			hs := NewHTTPServer(muxm, serverInfo)
-			hs.Run(*flaghttp)
-		}()
+			err := hs.Run(ctx, *flaghttp)
+			return err
+		})
 	}
 
-	terminate := make(chan os.Signal)
-	signal.Notify(terminate, os.Interrupt)
-
-	<-terminate
-	log.Printf("geodns: signal received, stopping")
+	err = g.Wait()
+	if err != nil {
+		log.Printf("server error: %s", err)
+	}
 
 	if *memprofile != "" {
 		f, err := os.Create(*memprofile)

+ 1 - 0
go.mod

@@ -14,6 +14,7 @@ require (
 	github.com/prometheus/client_golang v1.16.0
 	github.com/stretchr/testify v1.8.4
 	golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
+	golang.org/x/sync v0.3.0
 	gopkg.in/gcfg.v1 v1.2.3
 	gopkg.in/natefinch/lumberjack.v2 v2.2.1
 )

+ 6 - 0
go.sum

@@ -11,6 +11,7 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
 github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
 github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
 github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
+github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
 github.com/golang/geo v0.0.0-20230421003525-6adc56603217 h1:HKlyj6in2JV6wVkmQ4XmG/EIm+SCYlPZ+V4GWit7Z+I=
 github.com/golang/geo v0.0.0-20230421003525-6adc56603217/go.mod h1:8wI0hitZ3a1IxZfeH3/5I97CI8i5cLGsYe7xNhQGs9U=
 github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@@ -27,8 +28,10 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
 github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/hamba/avro/v2 v2.12.0 h1:QZvbrfOfHQ7kZnlxRdwRU0opSf9ZrqlzpKzJuIUjIjU=
 github.com/hamba/avro/v2 v2.12.0/go.mod h1:Q9YK+qxAhtVrNqOhwlZTATLgLA8qxG2vtvkhK8fJ7Jo=
+github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
 github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
 github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
+github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
 github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
 github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
 github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
@@ -43,6 +46,7 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
 github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
+github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
 github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
 github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
 github.com/oschwald/geoip2-golang v1.9.0 h1:uvD3O6fXAXs+usU+UGExshpdP13GAqp4GBrzN7IgKZc=
@@ -74,6 +78,7 @@ golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
 golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
 golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
+golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
 golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
@@ -95,5 +100,6 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkep
 gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
 gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
 gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
+gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
 gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

+ 33 - 3
http.go

@@ -1,15 +1,19 @@
 package main
 
 import (
+	"context"
+	"errors"
 	"fmt"
 	"io"
 	"log"
 	"net/http"
 	"strconv"
+	"time"
 
 	"github.com/abh/geodns/v3/monitor"
 	"github.com/abh/geodns/v3/zones"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
+	"golang.org/x/sync/errgroup"
 )
 
 type httpServer struct {
@@ -73,9 +77,36 @@ func (hs *httpServer) Mux() *http.ServeMux {
 	return hs.mux
 }
 
-func (hs *httpServer) Run(listen string) {
+func (hs *httpServer) Run(ctx context.Context, listen string) error {
 	log.Println("Starting HTTP interface on", listen)
-	log.Fatal(http.ListenAndServe(listen, &basicauth{h: hs.mux}))
+
+	srv := http.Server{
+		Addr:         listen,
+		Handler:      &basicauth{h: hs.mux},
+		ReadTimeout:  5 * time.Second,
+		IdleTimeout:  10 * time.Second,
+		WriteTimeout: 10 * time.Second,
+	}
+
+	g, ctx := errgroup.WithContext(ctx)
+
+	g.Go(func() error {
+		err := srv.ListenAndServe()
+		if err != nil {
+			if !errors.Is(err, http.ErrServerClosed) {
+				return err
+			}
+		}
+		return nil
+	})
+
+	g.Go(func() error {
+		<-ctx.Done()
+		log.Printf("shutting down http server")
+		return srv.Shutdown(ctx)
+	})
+
+	return g.Wait()
 }
 
 func (hs *httpServer) mainServer(w http.ResponseWriter, req *http.Request) {
@@ -114,5 +145,4 @@ func (b *basicauth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 
 	w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm=%q`, "GeoDNS Status"))
 	http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
-	return
 }

+ 2 - 2
http_test.go

@@ -2,7 +2,7 @@ package main
 
 import (
 	"bytes"
-	"io/ioutil"
+	"io"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -36,7 +36,7 @@ func TestHTTP(t *testing.T) {
 
 	res, err := http.Get(baseurl + "/version")
 	require.Nil(t, err)
-	page, _ := ioutil.ReadAll(res.Body)
+	page, _ := io.ReadAll(res.Body)
 
 	if !bytes.HasPrefix(page, []byte("GeoDNS ")) {
 		t.Log("/version didn't start with 'GeoDNS '")

+ 0 - 2
querylog/avro.go

@@ -263,10 +263,8 @@ func (l *AvroLogger) writer(ctx context.Context) {
 }
 
 func (l *AvroLogger) Close() error {
-	log.Printf("calling Close()")
 	l.cancel(fmt.Errorf("closing"))
 	<-l.ctx.Done()
 	l.wg.Wait() // wait for all files to be closed
-	log.Printf("Close() returning")
 	return nil
 }

+ 11 - 4
server/serve_test.go

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"context"
 	"net"
 	"reflect"
 	"strings"
@@ -23,22 +24,28 @@ func TestServe(t *testing.T) {
 	serverInfo := &monitor.ServerInfo{}
 
 	srv := NewServer(serverInfo)
+	ctx, cancel := context.WithCancel(context.Background())
 
 	mm, err := zones.NewMuxManager("../dns", srv)
 	if err != nil {
 		t.Fatalf("Loading test zones: %s", err)
 	}
-	go mm.Run()
+	go mm.Run(ctx)
 
-	// listenAndServe returns after listening on udp + tcp, so just
-	// wait for it before continuing
-	srv.ListenAndServe(PORT)
+	go func() {
+		srv.ListenAndServe(ctx, PORT)
+	}()
 
 	// ensure service has properly started before we query it
 	time.Sleep(500 * time.Millisecond)
 
 	t.Run("Serving", testServing)
 
+	// todo: run test queries?
+
+	cancel()
+
+	srv.Shutdown()
 }
 
 func testServing(t *testing.T) {

+ 57 - 7
server/server.go

@@ -1,11 +1,16 @@
 package server
 
 import (
+	"context"
+	"errors"
 	"log"
+	"sync"
+	"time"
 
 	"github.com/abh/geodns/v3/monitor"
 	"github.com/abh/geodns/v3/querylog"
 	"github.com/abh/geodns/v3/zones"
+	"golang.org/x/sync/errgroup"
 
 	"github.com/miekg/dns"
 	"github.com/prometheus/client_golang/prometheus"
@@ -22,6 +27,9 @@ type Server struct {
 	PublicDebugQueries bool
 	info               *monitor.ServerInfo
 	metrics            *serverMetrics
+
+	lock       sync.Mutex
+	dnsServers []*dns.Server
 }
 
 // NewServer ...
@@ -97,26 +105,68 @@ func (srv *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
 	srv.mux.ServeDNS(w, r)
 }
 
+func (srv *Server) addDNSServer(dnsServer *dns.Server) {
+	srv.lock.Lock()
+	defer srv.lock.Unlock()
+	srv.dnsServers = append(srv.dnsServers, dnsServer)
+}
+
 // ListenAndServe starts the DNS server on the specified IP
-// (both tcp and udp) and returns. If something goes wrong
-// it will crash the process with an error message.
-func (srv *Server) ListenAndServe(ip string) {
+// (both tcp and udp). It returns an error if
+// something goes wrong.
+func (srv *Server) ListenAndServe(ctx context.Context, ip string) error {
 
 	prots := []string{"udp", "tcp"}
 
+	g, _ := errgroup.WithContext(ctx)
+
 	for _, prot := range prots {
-		go func(p string) {
+
+		p := prot
+
+		g.Go(func() error {
 			server := &dns.Server{
 				Addr:    ip,
 				Net:     p,
 				Handler: srv,
 			}
 
+			srv.addDNSServer(server)
+
 			log.Printf("Opening on %s %s", ip, p)
 			if err := server.ListenAndServe(); err != nil {
-				log.Fatalf("geodns: failed to setup %s %s: %s", ip, p, err)
+				log.Printf("geodns: failed to setup %s %s: %s", ip, p, err)
+				return err
 			}
-			log.Fatalf("geodns: ListenAndServe unexpectedly returned")
-		}(prot)
+			return nil
+		})
 	}
+
+	// the servers will be shutdown when Shutdown() is called
+	return g.Wait()
+}
+
+// Shutdown gracefully shuts down the server
+func (srv *Server) Shutdown() error {
+	var errs []error
+
+	for _, dnsServer := range srv.dnsServers {
+		timeoutCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+		defer cancel()
+		err := dnsServer.ShutdownContext(timeoutCtx)
+		if err != nil {
+			errs = append(errs, err)
+		}
+	}
+
+	if srv.queryLogger != nil {
+		err := srv.queryLogger.Close()
+		if err != nil {
+			errs = append(errs, err)
+		}
+	}
+
+	err := errors.Join(errs...)
+
+	return err
 }

+ 1 - 1
targeting/targeting.go

@@ -173,7 +173,7 @@ func ParseTargets(v string) (tgt TargetOptions, err error) {
 		case "ip":
 			x = TargetIP
 		default:
-			err = fmt.Errorf("Unknown targeting option '%s'", t)
+			err = fmt.Errorf("unknown targeting option '%s'", t)
 		}
 		tgt = tgt | x
 	}

+ 2 - 2
targeting/targeting_test.go

@@ -25,7 +25,7 @@ func TestTargetParse(t *testing.T) {
 		t.Logf("Expected '@ country', got '%s'", str)
 		t.Fail()
 	}
-	if err.Error() != "Unknown targeting option 'foo'" {
+	if err.Error() != "unknown targeting option 'foo'" {
 		t.Log("Failed erroring on an unknown targeting option")
 		t.Fail()
 	}
@@ -54,7 +54,7 @@ func TestGetTargets(t *testing.T) {
 
 	g, err := geoip2.New(geoip2.FindDB())
 	if err != nil {
-		t.Fatalf("opening geoip2: %s", err)
+		t.Skipf("opening geoip2: %s", err)
 	}
 	Setup(g)
 

+ 0 - 8
templates_devel.go

@@ -1,8 +0,0 @@
-// +build devel
-
-package main
-
-func init() {
-	// load templates from disk
-	development = true
-}

+ 19 - 9
zones/muxmanager.go

@@ -1,11 +1,12 @@
 package zones
 
 import (
+	"context"
 	"crypto/sha256"
 	"encoding/hex"
 	"fmt"
-	"io/ioutil"
 	"log"
+	"os"
 	"path"
 	"strings"
 	"time"
@@ -52,13 +53,17 @@ func NewMuxManager(path string, reg RegistrationAPI) (*MuxManager, error) {
 	return mm, err
 }
 
-func (mm *MuxManager) Run() {
+func (mm *MuxManager) Run(ctx context.Context) {
 	for {
 		err := mm.reload()
 		if err != nil {
 			log.Printf("error reading zones: %s", err)
 		}
-		time.Sleep(2 * time.Second)
+		select {
+		case <-time.After(2 * time.Second):
+		case <-ctx.Done():
+			return
+		}
 	}
 }
 
@@ -68,7 +73,7 @@ func (mm *MuxManager) Zones() ZoneList {
 }
 
 func (mm *MuxManager) reload() error {
-	dir, err := ioutil.ReadDir(mm.path)
+	dir, err := os.ReadDir(mm.path)
 	if err != nil {
 		return fmt.Errorf("could not read '%s': %s", mm.path, err)
 	}
@@ -85,12 +90,17 @@ func (mm *MuxManager) reload() error {
 			continue
 		}
 
+		fileInfo, err := file.Info()
+		if err != nil {
+			return err
+		}
+		modTime := fileInfo.ModTime()
+
 		zoneName := fileName[0:strings.LastIndex(fileName, ".")]
 
 		seenZones[zoneName] = true
 
-		if _, ok := mm.lastRead[zoneName]; !ok || file.ModTime().After(mm.lastRead[zoneName].time) {
-			modTime := file.ModTime()
+		if _, ok := mm.lastRead[zoneName]; !ok || modTime.After(mm.lastRead[zoneName].time) {
 			if ok {
 				log.Printf("Reloading %s\n", fileName)
 				mm.lastRead[zoneName].time = modTime
@@ -131,7 +141,7 @@ func (mm *MuxManager) reload() error {
 			zone := NewZone(zoneName)
 			err := zone.ReadZoneFile(filename)
 			if zone == nil || err != nil {
-				parseErr = fmt.Errorf("Error reading zone '%s': %s", zoneName, err)
+				parseErr = fmt.Errorf("error reading zone '%s': %s", zoneName, err)
 				log.Println(parseErr.Error())
 				continue
 			}
@@ -146,7 +156,7 @@ func (mm *MuxManager) reload() error {
 		if zoneName == "pgeodns" {
 			continue
 		}
-		if ok, _ := seenZones[zoneName]; ok {
+		if ok := seenZones[zoneName]; ok {
 			continue
 		}
 		log.Println("Removing zone", zone.Origin)
@@ -191,7 +201,7 @@ func (mm *MuxManager) setupRootZone() {
 }
 
 func sha256File(fn string) string {
-	data, err := ioutil.ReadFile(fn)
+	data, err := os.ReadFile(fn)
 	if err != nil {
 		return ""
 	}

+ 9 - 9
zones/reader.go

@@ -3,6 +3,7 @@ package zones
 import (
 	"encoding/json"
 	"fmt"
+	"io"
 	"log"
 	"net"
 	"os"
@@ -48,7 +49,7 @@ func (zone *Zone) ReadZoneFile(fileName string) (zerr error) {
 	if err = decoder.Decode(&objmap); err != nil {
 		extra := ""
 		if serr, ok := err.(*json.SyntaxError); ok {
-			if _, serr := fh.Seek(0, os.SEEK_SET); serr != nil {
+			if _, serr := fh.Seek(0, io.SeekStart); serr != nil {
 				log.Fatalf("seek error: %v", serr)
 			}
 			line, col, highlight := errorutil.HighlightBytePosition(fh, serr.Offset)
@@ -204,11 +205,11 @@ func setupZoneData(data map[string]interface{}, zone *Zone) {
 
 			records := make(map[string][]interface{})
 
-			switch rdata.(type) {
+			switch rd := rdata.(type) {
 			case map[string]interface{}:
 				// Handle NS map syntax, map[ns2.example.net:<nil> ns1.example.net:<nil>]
 				tmp := make([]interface{}, 0)
-				for rdataK, rdataV := range rdata.(map[string]interface{}) {
+				for rdataK, rdataV := range rd {
 					if rdataV == nil {
 						rdataV = ""
 					}
@@ -218,7 +219,7 @@ func setupZoneData(data map[string]interface{}, zone *Zone) {
 			case string:
 				// CNAME and alias
 				tmp := make([]interface{}, 1)
-				tmp[0] = rdata.(string)
+				tmp[0] = rd
 				records[rType] = tmp
 			default:
 				records[rType] = rdata.([]interface{})
@@ -299,19 +300,18 @@ func setupZoneData(data map[string]interface{}, zone *Zone) {
 					switch dnsType {
 					case dns.TypePTR:
 						record.RR = &dns.PTR{Hdr: h, Ptr: ip}
-						break
 					case dns.TypeA:
 						if x := net.ParseIP(ip); x != nil {
 							record.RR = &dns.A{Hdr: h, A: x}
-							break
+						} else {
+							panic(fmt.Errorf("bad A record %q for %q", ip, dk))
 						}
-						panic(fmt.Errorf("Bad A record %q for %q", ip, dk))
 					case dns.TypeAAAA:
 						if x := net.ParseIP(ip); x != nil {
 							record.RR = &dns.AAAA{Hdr: h, AAAA: x}
-							break
+						} else {
+							panic(fmt.Errorf("bad AAAA record %q for %q", ip, dk))
 						}
-						panic(fmt.Errorf("Bad AAAA record %q for %q", ip, dk))
 					}
 
 				case dns.TypeMX:

+ 3 - 4
zones/reader_test.go

@@ -3,7 +3,6 @@ package zones
 import (
 	"fmt"
 	"io"
-	"io/ioutil"
 	"os"
 	"testing"
 
@@ -18,7 +17,7 @@ func loadZones(t *testing.T) *MuxManager {
 		t.Logf("Setting up geo provider")
 		dbDir := geoip2.FindDB()
 		if len(dbDir) == 0 {
-			t.Fatalf("Could not find geoip directory")
+			t.Skip("Could not find geoip directory")
 		}
 		geoprovider, err := geoip2.New(dbDir)
 		if err == nil {
@@ -95,7 +94,7 @@ func TestReadConfigs(t *testing.T) {
 }
 
 func TestRemoveConfig(t *testing.T) {
-	dir, err := ioutil.TempDir("", "geodns-test.")
+	dir, err := os.MkdirTemp("", "geodns-test.")
 	if err != nil {
 		t.Fail()
 	}
@@ -120,7 +119,7 @@ func TestRemoveConfig(t *testing.T) {
 		t.Fail()
 	}
 
-	err = ioutil.WriteFile(dir+"/invalid.example.org.json", []byte("not-json"), 0644)
+	err = os.WriteFile(dir+"/invalid.example.org.json", []byte("not-json"), 0644)
 	if err != nil {
 		t.Log(err)
 		t.Fail()

+ 1 - 1
zones/zone_stats.go

@@ -86,7 +86,7 @@ func (zs *zoneLabelStats) Counts() map[string]int {
 
 	counts := make(map[string]int)
 	for i, l := range zs.log {
-		if zs.rotated == false && i >= zs.pos {
+		if !zs.rotated && i >= zs.pos {
 			break
 		}
 		counts[l]++

+ 2 - 1
zones/zone_test.go

@@ -70,8 +70,9 @@ func TestExampleComZone(t *testing.T) {
 
 	// Test that we get the expected NS records (in any order because
 	// of the configuration format used for this zone)
+	re := regexp.MustCompile(`^ns[12]\.example\.net.$`)
 	for i := 0; i < 2; i++ {
-		if matched, err := regexp.MatchString("^ns[12]\\.example\\.net.$", Ns[i].RR.(*dns.NS).Ns); err != nil || !matched {
+		if matched := re.MatchString(Ns[i].RR.(*dns.NS).Ns); !matched {
 			if err != nil {
 				t.Fatal(err)
 			}