geoip2.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. package geoip2
  2. import (
  3. "fmt"
  4. "io/fs"
  5. "log"
  6. "net"
  7. "os"
  8. "path/filepath"
  9. "strings"
  10. "sync"
  11. "time"
  12. "github.com/abh/geodns/v3/countries"
  13. "github.com/abh/geodns/v3/targeting/geo"
  14. gdb "github.com/oschwald/geoip2-golang"
  15. )
  16. // GeoIP2 contains the geoip implementation of the GeoDNS geo
  17. // targeting interface
  18. type GeoIP2 struct {
  19. dir string
  20. country geodb
  21. city geodb
  22. asn geodb
  23. }
  24. type geodb struct {
  25. active bool
  26. lastModified int64 // Epoch time
  27. fp string // FilePath
  28. db *gdb.Reader // Database reader
  29. l sync.RWMutex // Individual lock for separate DB access and reload -- Future?
  30. }
  31. // FindDB returns a guess at a directory path for GeoIP data files
  32. func FindDB() string {
  33. dirs := []string{
  34. "/usr/share/GeoIP/", // Linux default
  35. "/usr/share/local/GeoIP/", // source install?
  36. "/usr/local/share/GeoIP/", // FreeBSD
  37. "/opt/local/share/GeoIP/", // MacPorts
  38. "/opt/homebrew/var/GeoIP", // Homebrew
  39. }
  40. for _, dir := range dirs {
  41. if _, err := os.Stat(dir); err != nil {
  42. if os.IsExist(err) {
  43. log.Println(err)
  44. }
  45. continue
  46. }
  47. return dir
  48. }
  49. return ""
  50. }
  51. // open will create a filehandle for the provided GeoIP2 database. If opened once before and a newer modification time is present, the function will reopen the file with its new contents
  52. func (g *GeoIP2) open(v *geodb, fns ...string) error {
  53. var fi fs.FileInfo
  54. var err error
  55. if v.fp == "" {
  56. // We're opening this file for the first time
  57. for _, i := range fns {
  58. fp := filepath.Join(g.dir, i)
  59. fi, err = os.Stat(fp)
  60. if err != nil {
  61. continue
  62. }
  63. v.fp = fp
  64. }
  65. }
  66. if v.fp == "" { // Recheck for empty string in case none of the DB files are found
  67. return fmt.Errorf("no files found for db")
  68. }
  69. if fi == nil { // We have not set fileInfo and v.fp is set
  70. fi, err = os.Stat(v.fp)
  71. }
  72. if err != nil {
  73. return err
  74. }
  75. if v.lastModified >= fi.ModTime().UTC().Unix() { // No update to existing file
  76. return nil
  77. }
  78. // Delay the lock to here because we're only
  79. v.l.Lock()
  80. defer v.l.Unlock()
  81. o, e := gdb.Open(v.fp)
  82. if e != nil {
  83. return e
  84. }
  85. v.db = o
  86. v.active = true
  87. v.lastModified = fi.ModTime().UTC().Unix()
  88. return nil
  89. }
  90. // watchFiles spawns a goroutine to check for new files every minute, reloading if the modtime is newer than the original file's modtime
  91. func (g *GeoIP2) watchFiles() {
  92. // Not worried about goroutines leaking because only one geoip2.New call is made in main (outside of testing)
  93. ticker := time.NewTicker(1 * time.Minute)
  94. for { // We forever-loop here because we only run this function in a separate goroutine
  95. select {
  96. case <-ticker.C:
  97. // Iterate through each db, check modtime. If new, reload file
  98. cityErr := g.open(&g.city, "GeoIP2-City.mmdb", "GeoLite2-City.mmdb")
  99. if cityErr != nil {
  100. log.Printf("Failed to update City: %v\n", cityErr)
  101. }
  102. countryErr := g.open(&g.country, "GeoIP2-Country.mmdb", "GeoLite2-Country.mmdb")
  103. if countryErr != nil {
  104. log.Printf("failed to update Country: %v\n", countryErr)
  105. }
  106. asnErr := g.open(&g.asn, "GeoIP2-ASN.mmdb", "GeoLite2-ASN.mmdb")
  107. if asnErr != nil {
  108. log.Printf("failed to update ASN: %v\n", asnErr)
  109. }
  110. }
  111. }
  112. }
  113. func (g *GeoIP2) anyActive() bool {
  114. return g.country.active || g.city.active || g.asn.active
  115. }
  116. // New returns a new GeoIP2 provider
  117. func New(dir string) (g *GeoIP2, err error) {
  118. g = &GeoIP2{
  119. dir: dir,
  120. }
  121. // This routine MUST load the database files at least once.
  122. cityErr := g.open(&g.city, "GeoIP2-City.mmdb", "GeoLite2-City.mmdb")
  123. if cityErr != nil {
  124. log.Printf("failed to load City DB: %v\n", cityErr)
  125. err = cityErr
  126. }
  127. countryErr := g.open(&g.country, "GeoIP2-Country.mmdb", "GeoLite2-Country.mmdb")
  128. if countryErr != nil {
  129. log.Printf("failed to load Country DB: %v\n", countryErr)
  130. err = countryErr
  131. }
  132. asnErr := g.open(&g.asn, "GeoIP2-ASN.mmdb", "GeoLite2-ASN.mmdb")
  133. if asnErr != nil {
  134. log.Printf("failed to load ASN DB: %v\n", asnErr)
  135. err = asnErr
  136. }
  137. if !g.anyActive() {
  138. return nil, err
  139. }
  140. go g.watchFiles() // Launch goroutine to load and monitor
  141. return
  142. }
  143. // HasASN returns if we can do ASN lookups
  144. func (g *GeoIP2) HasASN() (bool, error) {
  145. return g.asn.active, nil
  146. }
  147. // GetASN returns the ASN for the IP (as a "as123" string) and the netmask
  148. func (g *GeoIP2) GetASN(ip net.IP) (string, int, error) {
  149. g.asn.l.RLock()
  150. defer g.asn.l.RUnlock()
  151. if !g.asn.active {
  152. return "", 0, fmt.Errorf("ASN db not active")
  153. }
  154. c, err := g.asn.db.ASN(ip)
  155. if err != nil {
  156. return "", 0, fmt.Errorf("lookup ASN for '%s': %s", ip.String(), err)
  157. }
  158. asn := c.AutonomousSystemNumber
  159. netmask := 24
  160. if ip.To4() != nil {
  161. netmask = 48
  162. }
  163. return fmt.Sprintf("as%d", asn), netmask, nil
  164. }
  165. // HasCountry checks if the GeoIP country database is available
  166. func (g *GeoIP2) HasCountry() (bool, error) {
  167. return g.country.active, nil
  168. }
  169. // GetCountry returns the country, continent and netmask for the given IP
  170. func (g *GeoIP2) GetCountry(ip net.IP) (country, continent string, netmask int) {
  171. // Need a read-lock because return value of Country is a pointer, not copy of the struct/object
  172. g.country.l.RLock()
  173. defer g.country.l.RUnlock()
  174. if !g.country.active {
  175. return "", "", 0
  176. }
  177. c, err := g.country.db.Country(ip)
  178. if err != nil {
  179. log.Printf("Could not lookup country for '%s': %s", ip.String(), err)
  180. return "", "", 0
  181. }
  182. country = c.Country.IsoCode
  183. if len(country) > 0 {
  184. country = strings.ToLower(country)
  185. continent = countries.CountryContinent[country]
  186. }
  187. return country, continent, 0
  188. }
  189. // HasLocation returns if the city database is available to return lat/lon information for an IP
  190. func (g *GeoIP2) HasLocation() (bool, error) {
  191. return g.city.active, nil
  192. }
  193. // GetLocation returns a geo.Location object for the given IP
  194. func (g *GeoIP2) GetLocation(ip net.IP) (l *geo.Location, err error) {
  195. // Need a read-lock because return value of City is a pointer, not copy of the struct/object
  196. g.city.l.RLock()
  197. defer g.city.l.RUnlock()
  198. if !g.city.active {
  199. return nil, fmt.Errorf("city db not active")
  200. }
  201. c, err := g.city.db.City(ip)
  202. if err != nil {
  203. log.Printf("Could not lookup CountryRegion for '%s': %s", ip.String(), err)
  204. return
  205. }
  206. l = &geo.Location{
  207. Latitude: float64(c.Location.Latitude),
  208. Longitude: float64(c.Location.Longitude),
  209. Country: strings.ToLower(c.Country.IsoCode),
  210. }
  211. if len(c.Subdivisions) > 0 {
  212. l.Region = strings.ToLower(c.Subdivisions[0].IsoCode)
  213. }
  214. if len(l.Country) > 0 {
  215. l.Continent = countries.CountryContinent[l.Country]
  216. if len(l.Region) > 0 {
  217. l.Region = l.Country + "-" + l.Region
  218. l.RegionGroup = countries.CountryRegionGroup(l.Country, l.Region)
  219. }
  220. }
  221. return
  222. }