serve.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. package main
  2. import (
  3. "fmt"
  4. "github.com/miekg/dns"
  5. "log"
  6. "os"
  7. "os/signal"
  8. "strings"
  9. )
  10. func getQuestionName(z *Zone, req *dns.Msg) string {
  11. lx := dns.SplitLabels(req.Question[0].Name)
  12. ql := lx[0 : len(lx)-z.LenLabels-1]
  13. return strings.Join(ql, ".")
  14. }
  15. var geoIP = setupGeoIP()
  16. func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
  17. qtype := req.Question[0].Qtype
  18. logPrintf("[zone %s] incoming %s %s %d from %s\n", z.Origin, req.Question[0].Name,
  19. dns.Rr_str[qtype], req.MsgHdr.Id, w.RemoteAddr())
  20. fmt.Printf("ZONE DATA %#v\n", z)
  21. fmt.Println("Got request", req)
  22. label := getQuestionName(z, req)
  23. raddr := w.RemoteAddr()
  24. var country *string
  25. if geoIP != nil {
  26. country = geoIP.GetCountry(raddr.String())
  27. fmt.Println("Country:", country)
  28. }
  29. m := new(dns.Msg)
  30. m.SetReply(req)
  31. ednsFromRequest(req, m)
  32. m.MsgHdr.Authoritative = true
  33. m.Authoritative = true
  34. labels := z.findLabels(label, *country, qtype)
  35. if labels == nil {
  36. // return NXDOMAIN
  37. m.SetRcode(req, dns.RcodeNameError)
  38. m.Authoritative = true
  39. w.Write(m)
  40. return
  41. }
  42. fmt.Println("Has the label, looking for records")
  43. if region_rr := labels.Records[qtype]; region_rr != nil {
  44. //fmt.Printf("REGION_RR %T %v\n", region_rr, region_rr)
  45. max := len(region_rr)
  46. if max > 4 {
  47. max = 4
  48. }
  49. // TODO(ask) Pick random servers based on weight, not just the first 'max' entries
  50. servers := region_rr[0:max]
  51. var rrs []dns.RR
  52. for _, record := range servers {
  53. rr := record.RR
  54. fmt.Println("RR", rr)
  55. rr.Header().Name = req.Question[0].Name
  56. fmt.Println(rr)
  57. rrs = append(rrs, rr)
  58. }
  59. m.Answer = rrs
  60. }
  61. fmt.Println("Writing reply")
  62. w.Write(m)
  63. return
  64. }
  65. func setupServer(Zone Zone) func(dns.ResponseWriter, *dns.Msg) {
  66. return func(w dns.ResponseWriter, r *dns.Msg) {
  67. serve(w, r, &Zone)
  68. }
  69. }
  70. func runServe(Zones *Zones) {
  71. for zoneName, Zone := range *Zones {
  72. dns.HandleFunc(zoneName, setupServer(*Zone))
  73. }
  74. // Only listen on UDP for now
  75. go func() {
  76. if err := dns.ListenAndServe(*listen, "udp", nil); err != nil {
  77. log.Fatalf("geodns: failed to setup %s %s", *listen, "udp")
  78. }
  79. }()
  80. if *flagrun {
  81. sig := make(chan os.Signal)
  82. signal.Notify(sig, os.Interrupt)
  83. forever:
  84. for {
  85. select {
  86. case <-sig:
  87. log.Printf("geodns: signal received, stopping")
  88. break forever
  89. }
  90. }
  91. }
  92. }
  93. func ednsFromRequest(req, m *dns.Msg) {
  94. for _, r := range req.Extra {
  95. if r.Header().Rrtype == dns.TypeOPT {
  96. m.SetEdns0(4096, r.(*dns.RR_OPT).Do())
  97. return
  98. }
  99. }
  100. return
  101. }