Browse Source

De-configure zones when the .json file is removed

Ask Bjørn Hansen 12 years ago
parent
commit
f218d4540e
6 changed files with 80 additions and 11 deletions
  1. 2 0
      CHANGES.md
  2. 18 8
      config.go
  3. 57 1
      config_test.go
  4. 1 1
      serve_test.go
  5. 2 0
      types.go
  6. 0 1
      zone_test.go

+ 2 - 0
CHANGES.md

@@ -1,5 +1,7 @@
 # GeoDNS Changelog
 # GeoDNS Changelog
 
 
+* De-configure zones when the .json file is removed
+
 ## 2.2.3, March 1st 2013
 ## 2.2.3, March 1st 2013
 
 
 * Always log when zones are re-read
 * Always log when zones are re-read

+ 18 - 8
config.go

@@ -18,8 +18,6 @@ import (
 	"time"
 	"time"
 )
 )
 
 
-var zonesLastRead = map[string]time.Time{}
-
 func zonesReader(dirName string, zones Zones) {
 func zonesReader(dirName string, zones Zones) {
 	for {
 	for {
 		zonesReadDir(dirName, zones)
 		zonesReadDir(dirName, zones)
@@ -39,7 +37,7 @@ func zonesReadDir(dirName string, zones Zones) error {
 		return err
 		return err
 	}
 	}
 
 
-	seenFiles := map[string]bool{}
+	seenZones := map[string]bool{}
 
 
 	var parse_err error
 	var parse_err error
 
 
@@ -49,30 +47,42 @@ func zonesReadDir(dirName string, zones Zones) error {
 			continue
 			continue
 		}
 		}
 
 
-		seenFiles[fileName] = true
+		zoneName := zoneNameFromFile(fileName)
+
+		seenZones[zoneName] = true
 
 
-		if lastRead, ok := zonesLastRead[fileName]; !ok || file.ModTime().After(lastRead) {
+		if zone, ok := zones[zoneName]; !ok || file.ModTime().After(zone.LastRead) {
 			if ok {
 			if ok {
 				log.Printf("Reloading %s\n", fileName)
 				log.Printf("Reloading %s\n", fileName)
 			} else {
 			} else {
 				logPrintf("Reading new file %s\n", fileName)
 				logPrintf("Reading new file %s\n", fileName)
 			}
 			}
-			zonesLastRead[fileName] = file.ModTime()
 
 
-			zoneName := zoneNameFromFile(fileName)
 			//log.Println("FILE:", i, file, zoneName)
 			//log.Println("FILE:", i, file, zoneName)
 			config, err := readZoneFile(zoneName, path.Join(dirName, fileName))
 			config, err := readZoneFile(zoneName, path.Join(dirName, fileName))
 			if config == nil || err != nil {
 			if config == nil || err != nil {
+				config.LastRead = file.ModTime()
 				log.Println(err)
 				log.Println(err)
 				parse_err = err
 				parse_err = err
 				continue
 				continue
 			}
 			}
+			config.LastRead = file.ModTime()
 
 
 			addHandler(zones, zoneName, config)
 			addHandler(zones, zoneName, config)
 			runtime.GC()
 			runtime.GC()
 		}
 		}
+	}
 
 
-		// TODO(ask) Disable zones not seen in two subsequent runs
+	for zoneName, zone := range zones {
+		if zoneName == "pgeodns" {
+			continue
+		}
+		if ok, _ := seenZones[zoneName]; ok {
+			continue
+		}
+		log.Println("Removing zone", zone.Origin)
+		dns.HandleRemove(zoneName)
+		delete(zones, zoneName)
 	}
 	}
 
 
 	return parse_err
 	return parse_err

+ 57 - 1
config_test.go

@@ -2,7 +2,10 @@ package main
 
 
 import (
 import (
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
+	"io"
+	"io/ioutil"
 	. "launchpad.net/gocheck"
 	. "launchpad.net/gocheck"
+	"os"
 	"testing"
 	"testing"
 )
 )
 
 
@@ -15,10 +18,12 @@ type ConfigSuite struct {
 
 
 var _ = Suite(&ConfigSuite{})
 var _ = Suite(&ConfigSuite{})
 
 
-func (s *ConfigSuite) TestReadConfigs(c *C) {
+func (s *ConfigSuite) SetUpSuite(c *C) {
 	s.zones = make(Zones)
 	s.zones = make(Zones)
 	zonesReadDir("dns", s.zones)
 	zonesReadDir("dns", s.zones)
+}
 
 
+func (s *ConfigSuite) TestReadConfigs(c *C) {
 	// Just check that example.com and test.example.org loaded, too.
 	// Just check that example.com and test.example.org loaded, too.
 	c.Check(s.zones["example.com"].Origin, Equals, "example.com")
 	c.Check(s.zones["example.com"].Origin, Equals, "example.com")
 	c.Check(s.zones["test.example.org"].Origin, Equals, "test.example.org")
 	c.Check(s.zones["test.example.org"].Origin, Equals, "test.example.org")
@@ -46,3 +51,54 @@ func (s *ConfigSuite) TestReadConfigs(c *C) {
 		Mf, Equals, "www")
 		Mf, Equals, "www")
 
 
 }
 }
+
+func (s *ConfigSuite) TestRemoveConfig(c *C) {
+	// restore the dns.Mux
+	defer zonesReadDir("dns", s.zones)
+
+	dir, err := ioutil.TempDir("", "geodns-test.")
+	if err != nil {
+		c.Fail()
+	}
+	defer os.RemoveAll(dir)
+
+	_, err = CopyFile(c, "dns/test.example.org.json", dir+"/test.example.org.json")
+	if err != nil {
+		c.Log(err)
+		c.Fail()
+	}
+	_, err = CopyFile(c, "dns/test.example.org.json", dir+"/test2.example.org.json")
+	if err != nil {
+		c.Log(err)
+		c.Fail()
+	}
+
+	zonesReadDir(dir, s.zones)
+	c.Check(s.zones["test.example.org"].Origin, Equals, "test.example.org")
+	c.Check(s.zones["test2.example.org"].Origin, Equals, "test2.example.org")
+
+	os.Remove(dir + "/test2.example.org.json")
+
+	zonesReadDir(dir, s.zones)
+	c.Check(s.zones["test.example.org"].Origin, Equals, "test.example.org")
+	_, ok := s.zones["test2.example.org"]
+	c.Check(ok, Equals, false)
+}
+
+func CopyFile(c *C, src, dst string) (int64, error) {
+	sf, err := os.Open(src)
+	if err != nil {
+		c.Log("Could not copy", src, "to", dst, "because", err)
+		c.Fail()
+		return 0, err
+	}
+	defer sf.Close()
+	df, err := os.Create(dst)
+	if err != nil {
+		c.Log("Could not copy", src, "to", dst, "because", err)
+		c.Fail()
+		return 0, err
+	}
+	defer df.Close()
+	return io.Copy(df, sf)
+}

+ 1 - 1
serve_test.go

@@ -19,7 +19,7 @@ var _ = Suite(&ServeSuite{})
 
 
 func (s *ServeSuite) SetUpSuite(c *C) {
 func (s *ServeSuite) SetUpSuite(c *C) {
 
 
-	c.Log("Setting up test suite")
+	// log.Println("Setting up serve test suite")
 
 
 	Zones := make(Zones)
 	Zones := make(Zones)
 	setupPgeodnsZone(Zones)
 	setupPgeodnsZone(Zones)

+ 2 - 0
types.go

@@ -4,6 +4,7 @@ import (
 	"github.com/abh/geodns/countries"
 	"github.com/abh/geodns/countries"
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
 	"strings"
 	"strings"
+	"time"
 )
 )
 
 
 type Options struct {
 type Options struct {
@@ -44,6 +45,7 @@ type Zone struct {
 	Labels    labels
 	Labels    labels
 	LenLabels int
 	LenLabels int
 	Options   Options
 	Options   Options
+	LastRead  time.Time
 }
 }
 
 
 type qTypes []uint16
 type qTypes []uint16

+ 0 - 1
zone_test.go

@@ -6,7 +6,6 @@ import (
 )
 )
 
 
 func (s *ConfigSuite) TestZone(c *C) {
 func (s *ConfigSuite) TestZone(c *C) {
-
 	ex := s.zones["test.example.com"]
 	ex := s.zones["test.example.com"]
 	c.Check(ex.Labels["weight"].MaxHosts, Equals, 1)
 	c.Check(ex.Labels["weight"].MaxHosts, Equals, 1)