Browse Source

geoip2: consolidate open, open all DBs

Per feedback in PR #424, change functions `open`, `New`, and the GeoIP2
struct to allow for independent lock management, update the file in-place if
`open` is called again, and remove enums and other references in the
package.
Tyler Davis 4 years ago
parent
commit
749e8ce73e
3 changed files with 93 additions and 155 deletions
  1. 0 1
      go.mod
  2. 0 1
      go.sum
  3. 93 153
      targeting/geoip2/geoip2.go

+ 0 - 1
go.mod

@@ -13,7 +13,6 @@ require (
 	github.com/nxadm/tail v1.4.8
 	github.com/oschwald/geoip2-golang v1.5.0
 	github.com/pborman/uuid v1.2.1
-	github.com/pkg/errors v0.9.1 // indirect
 	github.com/prometheus/client_golang v1.11.0
 	github.com/prometheus/common v0.30.0 // indirect
 	github.com/prometheus/procfs v0.7.2 // indirect

+ 0 - 1
go.sum

@@ -172,7 +172,6 @@ github.com/pborman/uuid v1.2.1 h1:+ZZIw58t/ozdjRaXh/3awHfmWRbzYxJoAdNJxe/3pvw=
 github.com/pborman/uuid v1.2.1/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
 github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=

+ 93 - 153
targeting/geoip2/geoip2.go

@@ -13,45 +13,24 @@ import (
 
 	"github.com/abh/geodns/countries"
 	"github.com/abh/geodns/targeting/geo"
-	geoip2 "github.com/oschwald/geoip2-golang"
-	"github.com/pkg/errors"
+	gdb "github.com/oschwald/geoip2-golang"
 )
 
-type geoType uint8
-
-const (
-	countryDB = iota
-	cityDB
-	asnDB
-)
-
-var dbFiles map[geoType][]string
-
 // GeoIP2 contains the geoip implementation of the GeoDNS geo
 // targeting interface
 type GeoIP2 struct {
-	dir string
-
+	dir     string
 	country geodb
 	city    geodb
 	asn     geodb
-
-	mu sync.RWMutex
 }
 
 type geodb struct {
-	db           *geoip2.Reader // Database reader
-	fp           string         // FilePath
-	lastModified int64          // Epoch time
-	// l            sync.Mutex     // Individual lock for separate DB access and reload -- Future?
-}
-
-func init() {
-	dbFiles = map[geoType][]string{
-		countryDB: {"GeoIP2-Country.mmdb", "GeoLite2-Country.mmdb"},
-		asnDB:     {"GeoIP2-ASN.mmdb", "GeoLite2-ASN.mmdb"},
-		cityDB:    {"GeoIP2-City.mmdb", "GeoLite2-City.mmdb"},
-	}
+	active       bool
+	lastModified int64        // Epoch time
+	fp           string       // FilePath
+	db           *gdb.Reader  // Database reader
+	l            sync.RWMutex // Individual lock for separate DB access and reload -- Future?
 }
 
 // FindDB returns a guess at a directory path for GeoIP data files
@@ -74,152 +53,119 @@ func FindDB() string {
 	return ""
 }
 
-func (g *GeoIP2) open(t geoType, db string) (*geoip2.Reader, error) {
-	fileName := filepath.Join(g.dir, db)
+// open will create a filehandle for the provided GeoIP2 database. If opened once before and a newer modification time is present, the function will reopen the file with its new contents
+func (g *GeoIP2) open(v *geodb, fns ...string) error {
 	var fi fs.FileInfo
-
-	if len(db) == 0 {
-		found := false
-		for _, f := range dbFiles[t] {
-			var err error
-			fileName = filepath.Join(g.dir, f)
-			if fi, err = os.Stat(fileName); err == nil {
-				found = true
-				break
+	var err error
+	if v.fp == "" {
+		// We're opening this file for the first time
+		for _, i := range fns {
+			fp := filepath.Join(g.dir, i)
+			fi, err = os.Stat(fp)
+			if err != nil {
+				continue
 			}
-		}
-		if !found {
-			return nil, fmt.Errorf("could not find '%s' in '%s'", dbFiles[t], g.dir)
+			v.fp = fp
 		}
 	}
-
-	n, err := geoip2.Open(fileName)
-	if err != nil {
-		return nil, err
+	if v.fp == "" { // Recheck for empty string in case none of the DB files are found
+		return fmt.Errorf("no files found for db")
 	}
-	g.mu.Lock()
-	defer g.mu.Unlock()
-
-	switch t {
-	case countryDB:
-		g.country.db = n
-		g.country.lastModified = fi.ModTime().UTC().Unix()
-		g.country.fp = fileName
-	case cityDB:
-		g.city.db = n
-		g.city.lastModified = fi.ModTime().UTC().Unix()
-		g.city.fp = fileName
-	case asnDB:
-		g.asn.db = n
-		g.asn.lastModified = fi.ModTime().UTC().Unix()
-		g.asn.fp = fileName
+	if fi == nil { // We have not set fileInfo and v.fp is set
+		fi, err = os.Stat(v.fp)
 	}
-	return n, nil
-}
-
-func (g *GeoIP2) get(t geoType, db string) (*geoip2.Reader, error) {
-	g.mu.RLock()
-
-	var r *geoip2.Reader
-
-	switch t {
-	case countryDB:
-		r = g.country.db
-	case cityDB:
-		r = g.city.db
-	case asnDB:
-		r = g.asn.db
+	if err != nil {
+		return err
 	}
+	if v.lastModified >= fi.ModTime().UTC().Unix() { // No update to existing file
+		return nil
+	}
+	// Delay the lock to here because we're only
+	v.l.Lock()
+	defer v.l.Unlock()
 
-	// unlock so the g.open() call below won't lock
-	g.mu.RUnlock()
-
-	if r != nil {
-		return r, nil
+	o, e := gdb.Open(v.fp)
+	if e != nil {
+		return e
 	}
+	v.db = o
+	v.active = true
+	v.lastModified = fi.ModTime().UTC().Unix()
 
-	return g.open(t, db)
+	return nil
 }
 
 // watchFiles spawns a goroutine to check for new files every minute, reloading if the modtime is newer than the original file's modtime
 func (g *GeoIP2) watchFiles() {
 	// Not worried about goroutines leaking because only one geoip2.New call is made in main (outside of testing)
 	ticker := time.NewTicker(1 * time.Minute)
-	go func() {
-		for {
-			select {
-			case <-ticker.C:
-				// Iterate through each file, check modtime. If new, reload file
-				d := []*geodb{&g.country, &g.city, &g.asn} // Slice of pointers is kinda gross, but want to directly reference struct values (per const type)
-				for _, v := range d {
-					fi, err := os.Stat(v.fp) // Stat is a non-blocking call we can do for (nearly) free
-					if err != nil {
-						log.Printf("unable to stat DB file at %s :: %v", v.fp, err)
-						continue
-					}
-					if fi.ModTime().UTC().Unix() > v.lastModified {
-						e := g.reloadFile(v)
-						if e != nil {
-							log.Printf("failure to update DB: %v", e)
-						}
-					}
-				}
+	for { // We forever-loop here because we only run this function in a separate goroutine
+		select {
+		case <-ticker.C:
+			// Iterate through each db, check modtime. If new, reload file
+			cityErr := g.open(&g.city, "GeoIP2-City.mmdb", "GeoLite2-City.mmdb")
+			if cityErr != nil {
+				log.Printf("Failed to update City: %v\n", cityErr)
+			}
+			countryErr := g.open(&g.country, "GeoIP2-Country.mmdb", "GeoLite2-Country.mmdb")
+			if countryErr != nil {
+				log.Printf("failed to update Country: %v\n", countryErr)
+			}
+			asnErr := g.open(&g.asn, "GeoIP2-ASN.mmdb", "GeoLite2-ASN.mmdb")
+			if asnErr != nil {
+				log.Printf("failed to update ASN: %v\n", asnErr)
 			}
 		}
-	}()
+	}
 }
 
-// reloadFile wraps the DB update operation with a pointer to the geodb struct
-func (g *GeoIP2) reloadFile(v *geodb) error {
-	// Wrap this sequence of operations
-	g.mu.Lock()
-	defer g.mu.Unlock()
-	e := v.db.Close()
-	if e != nil {
-		return errors.Wrapf(e, "unable to close DB file %s", v.fp)
-	}
-	// Directly call geoip2.Open instead of the open() function because we cannot know the related enum value for the given file.
-	n, e := geoip2.Open(v.fp)
-	if e != nil {
-		return errors.Wrapf(e, "unable to reopen DB file %s", v.fp)
-	}
-	v.db = n
-	return nil
+func (g *GeoIP2) anyActive() bool {
+	return g.country.active || g.city.active || g.asn.active
 }
 
 // New returns a new GeoIP2 provider
-func New(dir string) (*GeoIP2, error) {
-	g := &GeoIP2{
+func New(dir string) (g *GeoIP2, err error) {
+	g = &GeoIP2{
 		dir: dir,
 	}
-	_, err := g.open(countryDB, "")
-	if err != nil {
+	// This routine MUST load the database files at least once.
+	cityErr := g.open(&g.city, "GeoIP2-City.mmdb", "GeoLite2-City.mmdb")
+	if cityErr != nil {
+		log.Printf("failed to load City DB: %v\n", cityErr)
+		err = cityErr
+	}
+	countryErr := g.open(&g.country, "GeoIP2-Country.mmdb", "GeoLite2-Country.mmdb")
+	if countryErr != nil {
+		log.Printf("failed to load Country DB: %v\n", countryErr)
+		err = countryErr
+	}
+	asnErr := g.open(&g.asn, "GeoIP2-ASN.mmdb", "GeoLite2-ASN.mmdb")
+	if asnErr != nil {
+		log.Printf("failed to load ASN DB: %v\n", asnErr)
+		err = asnErr
+	}
+	if !g.anyActive() {
 		return nil, err
 	}
-
-	go g.watchFiles() // Launch goroutine to monitor
-
-	return g, nil
+	go g.watchFiles() // Launch goroutine to load and monitor
+	return
 }
 
 // HasASN returns if we can do ASN lookups
 func (g *GeoIP2) HasASN() (bool, error) {
-	r, err := g.get(asnDB, "")
-	if r != nil && err == nil {
-		return true, nil
-	}
-	return false, err
+	return g.asn.active, nil
 }
 
 // GetASN returns the ASN for the IP (as a "as123" string) and the netmask
 func (g *GeoIP2) GetASN(ip net.IP) (string, int, error) {
-	r, err := g.get(asnDB, "")
-	log.Printf("GetASN for %s, got DB? %s", ip, err)
-	if err != nil {
-		return "", 0, err
+	g.asn.l.RLock()
+	defer g.asn.l.RUnlock()
+
+	if !g.asn.active {
+		return "", 0, fmt.Errorf("ASN db not active")
 	}
 
-	c, err := r.ASN(ip)
+	c, err := g.asn.db.ASN(ip)
 	if err != nil {
 		return "", 0, fmt.Errorf("lookup ASN for '%s': %s", ip.String(), err)
 	}
@@ -233,17 +179,16 @@ func (g *GeoIP2) GetASN(ip net.IP) (string, int, error) {
 
 // HasCountry checks if the GeoIP country database is available
 func (g *GeoIP2) HasCountry() (bool, error) {
-	r, err := g.get(countryDB, "")
-	if r != nil && err == nil {
-		return true, nil
-	}
-	return false, err
+	return g.country.active, nil
 }
 
 // GetCountry returns the country, continent and netmask for the given IP
 func (g *GeoIP2) GetCountry(ip net.IP) (country, continent string, netmask int) {
-	r, err := g.get(countryDB, "")
-	c, err := r.Country(ip)
+	// Need a read-lock because return value of Country is a pointer, not copy of the struct/object
+	g.country.l.RLock()
+	defer g.country.l.RUnlock()
+
+	c, err := g.country.db.Country(ip)
 	if err != nil {
 		log.Printf("Could not lookup country for '%s': %s", ip.String(), err)
 		return "", "", 0
@@ -259,21 +204,16 @@ func (g *GeoIP2) GetCountry(ip net.IP) (country, continent string, netmask int)
 	return country, continent, 0
 }
 
-// HasLocation returns if the city database is available to
-// return lat/lon information for an IP
+// HasLocation returns if the city database is available to return lat/lon information for an IP
 func (g *GeoIP2) HasLocation() (bool, error) {
-	r, err := g.get(cityDB, "")
-	if r != nil && err == nil {
-		return true, nil
-	}
-	return false, err
+	return g.city.active, nil
 }
 
 // GetLocation returns a geo.Location object for the given IP
 func (g *GeoIP2) GetLocation(ip net.IP) (l *geo.Location, err error) {
 	// Need a read-lock because return value of City is a pointer, not copy of the struct/object
-	g.mu.RLock()
-	defer g.mu.RUnlock()
+	g.city.l.RLock()
+	defer g.city.l.RUnlock()
 
 	c, err := g.city.db.City(ip)
 	if err != nil {