dns_server.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. package nebula
  2. import (
  3. "fmt"
  4. "net"
  5. "strconv"
  6. "sync"
  7. "github.com/miekg/dns"
  8. )
  9. // This whole thing should be rewritten to use context
  10. var dnsR *dnsRecords
  11. type dnsRecords struct {
  12. sync.RWMutex
  13. dnsMap map[string]string
  14. hostMap *HostMap
  15. }
  16. func newDnsRecords(hostMap *HostMap) *dnsRecords {
  17. return &dnsRecords{
  18. dnsMap: make(map[string]string),
  19. hostMap: hostMap,
  20. }
  21. }
  22. func (d *dnsRecords) Query(data string) string {
  23. d.RLock()
  24. if r, ok := d.dnsMap[data]; ok {
  25. d.RUnlock()
  26. return r
  27. }
  28. d.RUnlock()
  29. return ""
  30. }
  31. func (d *dnsRecords) QueryCert(data string) string {
  32. ip := net.ParseIP(data[:len(data)-1])
  33. if ip == nil {
  34. return ""
  35. }
  36. iip := ip2int(ip)
  37. hostinfo, err := d.hostMap.QueryVpnIP(iip)
  38. if err != nil {
  39. return ""
  40. }
  41. q := hostinfo.GetCert()
  42. if q == nil {
  43. return ""
  44. }
  45. cert := q.Details
  46. c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAFter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
  47. return c
  48. }
  49. func (d *dnsRecords) Add(host, data string) {
  50. d.Lock()
  51. d.dnsMap[host] = data
  52. d.Unlock()
  53. }
  54. func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
  55. for _, q := range m.Question {
  56. switch q.Qtype {
  57. case dns.TypeA:
  58. l.Debugf("Query for A %s", q.Name)
  59. ip := dnsR.Query(q.Name)
  60. if ip != "" {
  61. rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
  62. if err == nil {
  63. m.Answer = append(m.Answer, rr)
  64. }
  65. }
  66. case dns.TypeTXT:
  67. a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
  68. b := net.ParseIP(a)
  69. // We don't answer these queries from non nebula nodes or localhost
  70. //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
  71. if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
  72. return
  73. }
  74. l.Debugf("Query for TXT %s", q.Name)
  75. ip := dnsR.QueryCert(q.Name)
  76. if ip != "" {
  77. rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
  78. if err == nil {
  79. m.Answer = append(m.Answer, rr)
  80. }
  81. }
  82. }
  83. }
  84. }
  85. func handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
  86. m := new(dns.Msg)
  87. m.SetReply(r)
  88. m.Compress = false
  89. switch r.Opcode {
  90. case dns.OpcodeQuery:
  91. parseQuery(m, w)
  92. }
  93. w.WriteMsg(m)
  94. }
  95. func dnsMain(hostMap *HostMap) {
  96. dnsR = newDnsRecords(hostMap)
  97. // attach request handler func
  98. dns.HandleFunc(".", handleDnsRequest)
  99. // start server
  100. port := 53
  101. server := &dns.Server{Addr: ":" + strconv.Itoa(port), Net: "udp"}
  102. l.Debugf("Starting DNS responder at %d\n", port)
  103. err := server.ListenAndServe()
  104. defer server.Shutdown()
  105. if err != nil {
  106. l.Errorf("Failed to start server: %s\n ", err.Error())
  107. }
  108. }