server.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "log"
  6. "sync"
  7. "time"
  8. "github.com/abh/geodns/v3/monitor"
  9. "github.com/abh/geodns/v3/querylog"
  10. "github.com/abh/geodns/v3/zones"
  11. "golang.org/x/sync/errgroup"
  12. "github.com/miekg/dns"
  13. "github.com/prometheus/client_golang/prometheus"
  14. )
  15. type serverMetrics struct {
  16. Queries *prometheus.CounterVec
  17. }
  18. // Server ...
  19. type Server struct {
  20. queryLogger querylog.QueryLogger
  21. mux *dns.ServeMux
  22. PublicDebugQueries bool
  23. info *monitor.ServerInfo
  24. metrics *serverMetrics
  25. lock sync.Mutex
  26. dnsServers []*dns.Server
  27. }
  28. // NewServer ...
  29. func NewServer(si *monitor.ServerInfo) *Server {
  30. mux := dns.NewServeMux()
  31. queries := prometheus.NewCounterVec(
  32. prometheus.CounterOpts{
  33. Name: "dns_queries_total",
  34. Help: "Number of served queries",
  35. },
  36. []string{"zone", "qtype", "qname", "rcode"},
  37. )
  38. prometheus.MustRegister(queries)
  39. buildInfo := prometheus.NewGaugeVec(
  40. prometheus.GaugeOpts{
  41. Name: "geodns_build_info",
  42. Help: "GeoDNS build information (in labels)",
  43. },
  44. []string{"Version", "ID", "IP", "Group"},
  45. )
  46. prometheus.MustRegister(buildInfo)
  47. group := ""
  48. if len(si.Groups) > 0 {
  49. group = si.Groups[0]
  50. }
  51. buildInfo.WithLabelValues(si.Version, si.ID, si.IP, group).Set(1)
  52. startTime := prometheus.NewGauge(
  53. prometheus.GaugeOpts{
  54. Name: "geodns_start_time_seconds",
  55. Help: "Unix time process started",
  56. },
  57. )
  58. prometheus.MustRegister(startTime)
  59. nano := si.Started.UnixNano()
  60. startTime.Set(float64(nano) / 1e9)
  61. metrics := &serverMetrics{
  62. Queries: queries,
  63. }
  64. return &Server{mux: mux, info: si, metrics: metrics}
  65. }
  66. // SetQueryLogger configures the query logger. For now it only supports writing to
  67. // a file (and all zones get logged to the same file).
  68. func (srv *Server) SetQueryLogger(logger querylog.QueryLogger) {
  69. srv.queryLogger = logger
  70. }
  71. // Add adds the Zone to be handled under the specified name
  72. func (srv *Server) Add(name string, zone *zones.Zone) {
  73. srv.mux.HandleFunc(name, srv.setupServerFunc(zone))
  74. }
  75. // Remove removes the zone name from being handled by the server
  76. func (srv *Server) Remove(name string) {
  77. srv.mux.HandleRemove(name)
  78. }
  79. func (srv *Server) setupServerFunc(zone *zones.Zone) func(dns.ResponseWriter, *dns.Msg) {
  80. return func(w dns.ResponseWriter, r *dns.Msg) {
  81. srv.serve(w, r, zone)
  82. }
  83. }
  84. // ServeDNS calls ServeDNS in the dns package
  85. func (srv *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
  86. srv.mux.ServeDNS(w, r)
  87. }
  88. func (srv *Server) addDNSServer(dnsServer *dns.Server) {
  89. srv.lock.Lock()
  90. defer srv.lock.Unlock()
  91. srv.dnsServers = append(srv.dnsServers, dnsServer)
  92. }
  93. // ListenAndServe starts the DNS server on the specified IP
  94. // (both tcp and udp). It returns an error if
  95. // something goes wrong.
  96. func (srv *Server) ListenAndServe(ctx context.Context, ip string) error {
  97. prots := []string{"udp", "tcp"}
  98. g, _ := errgroup.WithContext(ctx)
  99. for _, prot := range prots {
  100. p := prot
  101. g.Go(func() error {
  102. server := &dns.Server{
  103. Addr: ip,
  104. Net: p,
  105. Handler: srv,
  106. }
  107. srv.addDNSServer(server)
  108. log.Printf("Opening on %s %s", ip, p)
  109. if err := server.ListenAndServe(); err != nil {
  110. log.Printf("geodns: failed to setup %s %s: %s", ip, p, err)
  111. return err
  112. }
  113. return nil
  114. })
  115. }
  116. // the servers will be shutdown when Shutdown() is called
  117. return g.Wait()
  118. }
  119. // Shutdown gracefully shuts down the server
  120. func (srv *Server) Shutdown() error {
  121. var errs []error
  122. for _, dnsServer := range srv.dnsServers {
  123. timeoutCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
  124. defer cancel()
  125. err := dnsServer.ShutdownContext(timeoutCtx)
  126. if err != nil {
  127. errs = append(errs, err)
  128. }
  129. }
  130. if srv.queryLogger != nil {
  131. err := srv.queryLogger.Close()
  132. if err != nil {
  133. errs = append(errs, err)
  134. }
  135. }
  136. err := errors.Join(errs...)
  137. return err
  138. }