server.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "log"
  6. "sync"
  7. "time"
  8. "github.com/abh/geodns/v3/appconfig"
  9. "github.com/abh/geodns/v3/monitor"
  10. "github.com/abh/geodns/v3/querylog"
  11. "github.com/abh/geodns/v3/zones"
  12. "go.ntppool.org/common/version"
  13. "golang.org/x/sync/errgroup"
  14. "github.com/miekg/dns"
  15. "github.com/prometheus/client_golang/prometheus"
  16. )
  17. type serverMetrics struct {
  18. Queries *prometheus.CounterVec
  19. }
  20. // Server ...
  21. type Server struct {
  22. PublicDebugQueries bool
  23. DetailedMetrics bool
  24. queryLogger querylog.QueryLogger
  25. mux *dns.ServeMux
  26. info *monitor.ServerInfo
  27. metrics *serverMetrics
  28. lock sync.Mutex
  29. dnsServers []*dns.Server
  30. }
  31. // NewServer ...
  32. func NewServer(config *appconfig.AppConfig, si *monitor.ServerInfo) *Server {
  33. mux := dns.NewServeMux()
  34. queries := prometheus.NewCounterVec(
  35. prometheus.CounterOpts{
  36. Name: "dns_queries_total",
  37. Help: "Number of served queries",
  38. },
  39. []string{"zone", "qtype", "qname", "rcode"},
  40. )
  41. prometheus.MustRegister(queries)
  42. version.RegisterMetric("geodns", prometheus.DefaultRegisterer)
  43. instanceInfo := prometheus.NewGaugeVec(
  44. prometheus.GaugeOpts{
  45. Name: "geodns_instance_info",
  46. Help: "GeoDNS instance information",
  47. },
  48. []string{"ID", "IP", "Group"},
  49. )
  50. prometheus.MustRegister(instanceInfo)
  51. group := ""
  52. if len(si.Groups) > 0 {
  53. group = si.Groups[0]
  54. }
  55. instanceInfo.WithLabelValues(si.ID, si.IP, group).Set(1)
  56. startTime := prometheus.NewGauge(
  57. prometheus.GaugeOpts{
  58. Name: "geodns_start_time_seconds",
  59. Help: "Unix time process started",
  60. },
  61. )
  62. prometheus.MustRegister(startTime)
  63. nano := si.Started.UnixNano()
  64. startTime.Set(float64(nano) / 1e9)
  65. metrics := &serverMetrics{
  66. Queries: queries,
  67. }
  68. return &Server{
  69. PublicDebugQueries: appconfig.Config.DNS.PublicDebugQueries,
  70. DetailedMetrics: appconfig.Config.DNS.DetailedMetrics,
  71. mux: mux,
  72. info: si,
  73. metrics: metrics,
  74. }
  75. }
  76. // SetQueryLogger configures the query logger. For now it only supports writing to
  77. // a file (and all zones get logged to the same file).
  78. func (srv *Server) SetQueryLogger(logger querylog.QueryLogger) {
  79. srv.queryLogger = logger
  80. }
  81. // Add adds the Zone to be handled under the specified name
  82. func (srv *Server) Add(name string, zone *zones.Zone) {
  83. srv.mux.HandleFunc(name, srv.setupServerFunc(zone))
  84. }
  85. // Remove removes the zone name from being handled by the server
  86. func (srv *Server) Remove(name string) {
  87. srv.mux.HandleRemove(name)
  88. }
  89. func (srv *Server) setupServerFunc(zone *zones.Zone) func(dns.ResponseWriter, *dns.Msg) {
  90. return func(w dns.ResponseWriter, r *dns.Msg) {
  91. srv.serve(w, r, zone)
  92. }
  93. }
  94. // ServeDNS calls ServeDNS in the dns package
  95. func (srv *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
  96. srv.mux.ServeDNS(w, r)
  97. }
  98. func (srv *Server) addDNSServer(dnsServer *dns.Server) {
  99. srv.lock.Lock()
  100. defer srv.lock.Unlock()
  101. srv.dnsServers = append(srv.dnsServers, dnsServer)
  102. }
  103. // ListenAndServe starts the DNS server on the specified IP
  104. // (both tcp and udp). It returns an error if
  105. // something goes wrong.
  106. func (srv *Server) ListenAndServe(ctx context.Context, ip string) error {
  107. prots := []string{"udp", "tcp"}
  108. g, _ := errgroup.WithContext(ctx)
  109. for _, prot := range prots {
  110. p := prot
  111. g.Go(func() error {
  112. server := &dns.Server{
  113. Addr: ip,
  114. Net: p,
  115. Handler: srv,
  116. }
  117. srv.addDNSServer(server)
  118. log.Printf("Opening on %s %s", ip, p)
  119. if err := server.ListenAndServe(); err != nil {
  120. log.Printf("geodns: failed to setup %s %s: %s", ip, p, err)
  121. return err
  122. }
  123. return nil
  124. })
  125. }
  126. // the servers will be shutdown when Shutdown() is called
  127. return g.Wait()
  128. }
  129. // Shutdown gracefully shuts down the server
  130. func (srv *Server) Shutdown() error {
  131. var errs []error
  132. for _, dnsServer := range srv.dnsServers {
  133. timeoutCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
  134. defer cancel()
  135. err := dnsServer.ShutdownContext(timeoutCtx)
  136. if err != nil {
  137. errs = append(errs, err)
  138. }
  139. }
  140. if srv.queryLogger != nil {
  141. err := srv.queryLogger.Close()
  142. if err != nil {
  143. errs = append(errs, err)
  144. }
  145. }
  146. err := errors.Join(errs...)
  147. return err
  148. }