dns_server.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. package nebula
  2. import (
  3. "fmt"
  4. "net"
  5. "strconv"
  6. "strings"
  7. "sync"
  8. "github.com/miekg/dns"
  9. "github.com/sirupsen/logrus"
  10. "github.com/slackhq/nebula/config"
  11. "github.com/slackhq/nebula/iputil"
  12. )
  13. // This whole thing should be rewritten to use context
  14. var dnsR *dnsRecords
  15. var dnsServer *dns.Server
  16. var dnsAddr string
  17. type dnsRecords struct {
  18. sync.RWMutex
  19. dnsMap map[string]string
  20. hostMap *HostMap
  21. }
  22. func newDnsRecords(hostMap *HostMap) *dnsRecords {
  23. return &dnsRecords{
  24. dnsMap: make(map[string]string),
  25. hostMap: hostMap,
  26. }
  27. }
  28. func (d *dnsRecords) Query(data string) string {
  29. d.RLock()
  30. defer d.RUnlock()
  31. if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
  32. return r
  33. }
  34. return ""
  35. }
  36. func (d *dnsRecords) QueryCert(data string) string {
  37. ip := net.ParseIP(data[:len(data)-1])
  38. if ip == nil {
  39. return ""
  40. }
  41. iip := iputil.Ip2VpnIp(ip)
  42. hostinfo := d.hostMap.QueryVpnIp(iip)
  43. if hostinfo == nil {
  44. return ""
  45. }
  46. q := hostinfo.GetCert()
  47. if q == nil {
  48. return ""
  49. }
  50. cert := q.Details
  51. c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
  52. return c
  53. }
  54. func (d *dnsRecords) Add(host, data string) {
  55. d.Lock()
  56. defer d.Unlock()
  57. d.dnsMap[strings.ToLower(host)] = data
  58. }
  59. func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
  60. for _, q := range m.Question {
  61. switch q.Qtype {
  62. case dns.TypeA:
  63. l.Debugf("Query for A %s", q.Name)
  64. ip := dnsR.Query(q.Name)
  65. if ip != "" {
  66. rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
  67. if err == nil {
  68. m.Answer = append(m.Answer, rr)
  69. }
  70. }
  71. case dns.TypeTXT:
  72. a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
  73. b := net.ParseIP(a)
  74. // We don't answer these queries from non nebula nodes or localhost
  75. //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
  76. if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
  77. return
  78. }
  79. l.Debugf("Query for TXT %s", q.Name)
  80. ip := dnsR.QueryCert(q.Name)
  81. if ip != "" {
  82. rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
  83. if err == nil {
  84. m.Answer = append(m.Answer, rr)
  85. }
  86. }
  87. }
  88. }
  89. if len(m.Answer) == 0 {
  90. m.Rcode = dns.RcodeNameError
  91. }
  92. }
  93. func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
  94. m := new(dns.Msg)
  95. m.SetReply(r)
  96. m.Compress = false
  97. switch r.Opcode {
  98. case dns.OpcodeQuery:
  99. parseQuery(l, m, w)
  100. }
  101. w.WriteMsg(m)
  102. }
  103. func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
  104. dnsR = newDnsRecords(hostMap)
  105. // attach request handler func
  106. dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
  107. handleDnsRequest(l, w, r)
  108. })
  109. c.RegisterReloadCallback(func(c *config.C) {
  110. reloadDns(l, c)
  111. })
  112. return func() {
  113. startDns(l, c)
  114. }
  115. }
  116. func getDnsServerAddr(c *config.C) string {
  117. dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
  118. // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
  119. if dnsHost == "[::]" {
  120. dnsHost = "::"
  121. }
  122. return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
  123. }
  124. func startDns(l *logrus.Logger, c *config.C) {
  125. dnsAddr = getDnsServerAddr(c)
  126. dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
  127. l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
  128. err := dnsServer.ListenAndServe()
  129. defer dnsServer.Shutdown()
  130. if err != nil {
  131. l.Errorf("Failed to start server: %s\n ", err.Error())
  132. }
  133. }
  134. func reloadDns(l *logrus.Logger, c *config.C) {
  135. if dnsAddr == getDnsServerAddr(c) {
  136. l.Debug("No DNS server config change detected")
  137. return
  138. }
  139. l.Debug("Restarting DNS server")
  140. dnsServer.Shutdown()
  141. go startDns(l, c)
  142. }