serve.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. package main
  2. import (
  3. "encoding/json"
  4. "github.com/abh/geodns/countries"
  5. "github.com/miekg/dns"
  6. "log"
  7. "net"
  8. "os"
  9. "strconv"
  10. "strings"
  11. "time"
  12. )
  13. func getQuestionName(z *Zone, req *dns.Msg) string {
  14. lx := dns.SplitLabels(req.Question[0].Name)
  15. ql := lx[0 : len(lx)-z.LenLabels]
  16. return strings.ToLower(strings.Join(ql, "."))
  17. }
  18. var geoIP = setupGeoIP()
  19. func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
  20. qtype := req.Question[0].Qtype
  21. logPrintf("[zone %s] incoming %s %s %d from %s\n", z.Origin, req.Question[0].Name,
  22. dns.TypeToString[qtype], req.MsgHdr.Id, w.RemoteAddr())
  23. // is this safe/atomic or does it need to go through a channel?
  24. qCounter++
  25. logPrintln("Got request", req)
  26. label := getQuestionName(z, req)
  27. var ip string
  28. var edns *dns.EDNS0_SUBNET
  29. var opt_rr *dns.OPT
  30. for _, extra := range req.Extra {
  31. log.Println("Extra", extra)
  32. for _, o := range extra.(*dns.OPT).Option {
  33. opt_rr = extra.(*dns.OPT)
  34. switch e := o.(type) {
  35. case *dns.EDNS0_NSID:
  36. // do stuff with e.Nsid
  37. case *dns.EDNS0_SUBNET:
  38. log.Println("========== XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
  39. log.Println("Got edns", e.Address, e.Family, e.SourceNetmask, e.SourceScope)
  40. if e.Address != nil {
  41. log.Println("Setting edns to", e)
  42. edns = e
  43. ip = e.Address.String()
  44. }
  45. }
  46. }
  47. }
  48. var country string
  49. if geoIP != nil {
  50. if len(ip) == 0 { // no edns subnet
  51. ip, _, _ = net.SplitHostPort(w.RemoteAddr().String())
  52. }
  53. country = strings.ToLower(geoIP.GetCountry(ip))
  54. logPrintln("Country:", ip, country)
  55. }
  56. m := new(dns.Msg)
  57. m.SetReply(req)
  58. if e := m.IsEdns0(); e != nil {
  59. m.SetEdns0(4096, e.Do())
  60. }
  61. m.Authoritative = true
  62. // TODO: set scope to 0 if there are no alternate responses
  63. if edns != nil {
  64. log.Println("family", edns.Family)
  65. if edns.Family != 0 {
  66. log.Println("edns response!")
  67. edns.SourceScope = 16
  68. m.Extra = append(m.Extra, opt_rr)
  69. }
  70. }
  71. // TODO(ask) Fix the findLabels API to make this work better
  72. if alias := z.findLabels(label, "", dns.TypeMF); alias != nil &&
  73. alias.Records[dns.TypeMF] != nil {
  74. // We found an alias record, so pretend the question was for that name instead
  75. label = alias.firstRR(dns.TypeMF).(*dns.MF).Mf
  76. }
  77. labels := z.findLabels(label, country, qtype)
  78. if labels == nil {
  79. if label == "_status" && (qtype == dns.TypeANY || qtype == dns.TypeTXT) {
  80. m.Answer = statusRR(z)
  81. m.Authoritative = true
  82. w.WriteMsg(m)
  83. return
  84. }
  85. if label == "_country" && (qtype == dns.TypeANY || qtype == dns.TypeTXT) {
  86. h := dns.RR_Header{Ttl: 1, Class: dns.ClassINET, Rrtype: dns.TypeTXT}
  87. h.Name = "_country." + z.Origin + "."
  88. m.Answer = []dns.RR{&dns.TXT{Hdr: h,
  89. Txt: []string{
  90. w.RemoteAddr().String(),
  91. ip,
  92. string(country),
  93. string(countries.CountryContinent[country]),
  94. },
  95. }}
  96. m.Authoritative = true
  97. w.WriteMsg(m)
  98. return
  99. }
  100. // return NXDOMAIN
  101. m.SetRcode(req, dns.RcodeNameError)
  102. m.Authoritative = true
  103. m.Ns = []dns.RR{z.SoaRR()}
  104. w.WriteMsg(m)
  105. return
  106. }
  107. if servers := labels.Picker(qtype, labels.MaxHosts); servers != nil {
  108. var rrs []dns.RR
  109. for _, record := range servers {
  110. rr := record.RR
  111. rr.Header().Name = req.Question[0].Name
  112. rrs = append(rrs, rr)
  113. }
  114. m.Answer = rrs
  115. }
  116. if len(m.Answer) == 0 {
  117. if labels := z.Labels[label]; labels != nil {
  118. if _, ok := labels.Records[dns.TypeCNAME]; ok {
  119. cname := labels.firstRR(dns.TypeCNAME)
  120. m.Answer = append(m.Answer, cname)
  121. } else {
  122. m.Ns = append(m.Ns, z.SoaRR())
  123. }
  124. } else {
  125. m.Ns = append(m.Ns, z.SoaRR())
  126. }
  127. }
  128. logPrintln(m)
  129. err := w.WriteMsg(m)
  130. if err != nil {
  131. // if Pack'ing fails the Write fails. Return SERVFAIL.
  132. log.Println("Error writing packet", m)
  133. dns.HandleFailed(w, req)
  134. }
  135. return
  136. }
  137. func statusRR(z *Zone) []dns.RR {
  138. h := dns.RR_Header{Ttl: 1, Class: dns.ClassINET, Rrtype: dns.TypeTXT}
  139. h.Name = "_status." + z.Origin + "."
  140. status := map[string]string{"v": VERSION, "id": serverId}
  141. hostname, err := os.Hostname()
  142. if err == nil {
  143. status["h"] = hostname
  144. }
  145. status["up"] = strconv.Itoa(int(time.Since(timeStarted).Seconds()))
  146. status["qs"] = strconv.FormatUint(qCounter, 10)
  147. js, err := json.Marshal(status)
  148. return []dns.RR{&dns.TXT{Hdr: h, Txt: []string{string(js)}}}
  149. }
  150. func setupServerFunc(Zone *Zone) func(dns.ResponseWriter, *dns.Msg) {
  151. return func(w dns.ResponseWriter, r *dns.Msg) {
  152. serve(w, r, Zone)
  153. }
  154. }
  155. func listenAndServe(ip string, Zones *Zones) {
  156. prots := []string{"udp", "tcp"}
  157. for _, prot := range prots {
  158. go func(p string) {
  159. server := &dns.Server{Addr: ip, Net: p}
  160. log.Printf("Opening on %s %s", ip, p)
  161. if err := server.ListenAndServe(); err != nil {
  162. log.Fatalf("geodns: failed to setup %s %s: %s", ip, p, err)
  163. }
  164. log.Fatalf("geodns: ListenAndServe unexpectedly returned")
  165. }(prot)
  166. }
  167. }