dns_server.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. package nebula
  2. import (
  3. "fmt"
  4. "net"
  5. "net/netip"
  6. "strconv"
  7. "strings"
  8. "sync"
  9. "github.com/miekg/dns"
  10. "github.com/sirupsen/logrus"
  11. "github.com/slackhq/nebula/config"
  12. )
  13. // This whole thing should be rewritten to use context
  14. var dnsR *dnsRecords
  15. var dnsServer *dns.Server
  16. var dnsAddr string
  17. type dnsRecords struct {
  18. sync.RWMutex
  19. dnsMap map[string]string
  20. hostMap *HostMap
  21. }
  22. func newDnsRecords(hostMap *HostMap) *dnsRecords {
  23. return &dnsRecords{
  24. dnsMap: make(map[string]string),
  25. hostMap: hostMap,
  26. }
  27. }
  28. func (d *dnsRecords) Query(data string) string {
  29. d.RLock()
  30. defer d.RUnlock()
  31. if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
  32. return r
  33. }
  34. return ""
  35. }
  36. func (d *dnsRecords) QueryCert(data string) string {
  37. ip, err := netip.ParseAddr(data[:len(data)-1])
  38. if err != nil {
  39. return ""
  40. }
  41. hostinfo := d.hostMap.QueryVpnIp(ip)
  42. if hostinfo == nil {
  43. return ""
  44. }
  45. q := hostinfo.GetCert()
  46. if q == nil {
  47. return ""
  48. }
  49. cert := q.Details
  50. c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
  51. return c
  52. }
  53. func (d *dnsRecords) Add(host, data string) {
  54. d.Lock()
  55. defer d.Unlock()
  56. d.dnsMap[strings.ToLower(host)] = data
  57. }
  58. func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
  59. for _, q := range m.Question {
  60. switch q.Qtype {
  61. case dns.TypeA:
  62. l.Debugf("Query for A %s", q.Name)
  63. ip := dnsR.Query(q.Name)
  64. if ip != "" {
  65. rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
  66. if err == nil {
  67. m.Answer = append(m.Answer, rr)
  68. }
  69. }
  70. case dns.TypeTXT:
  71. a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
  72. b, err := netip.ParseAddr(a)
  73. if err != nil {
  74. return
  75. }
  76. // We don't answer these queries from non nebula nodes or localhost
  77. //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
  78. if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
  79. return
  80. }
  81. l.Debugf("Query for TXT %s", q.Name)
  82. ip := dnsR.QueryCert(q.Name)
  83. if ip != "" {
  84. rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
  85. if err == nil {
  86. m.Answer = append(m.Answer, rr)
  87. }
  88. }
  89. }
  90. }
  91. if len(m.Answer) == 0 {
  92. m.Rcode = dns.RcodeNameError
  93. }
  94. }
  95. func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
  96. m := new(dns.Msg)
  97. m.SetReply(r)
  98. m.Compress = false
  99. switch r.Opcode {
  100. case dns.OpcodeQuery:
  101. parseQuery(l, m, w)
  102. }
  103. w.WriteMsg(m)
  104. }
  105. func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
  106. dnsR = newDnsRecords(hostMap)
  107. // attach request handler func
  108. dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
  109. handleDnsRequest(l, w, r)
  110. })
  111. c.RegisterReloadCallback(func(c *config.C) {
  112. reloadDns(l, c)
  113. })
  114. return func() {
  115. startDns(l, c)
  116. }
  117. }
  118. func getDnsServerAddr(c *config.C) string {
  119. dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
  120. // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
  121. if dnsHost == "[::]" {
  122. dnsHost = "::"
  123. }
  124. return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
  125. }
  126. func startDns(l *logrus.Logger, c *config.C) {
  127. dnsAddr = getDnsServerAddr(c)
  128. dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
  129. l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
  130. err := dnsServer.ListenAndServe()
  131. defer dnsServer.Shutdown()
  132. if err != nil {
  133. l.Errorf("Failed to start server: %s\n ", err.Error())
  134. }
  135. }
  136. func reloadDns(l *logrus.Logger, c *config.C) {
  137. if dnsAddr == getDnsServerAddr(c) {
  138. l.Debug("No DNS server config change detected")
  139. return
  140. }
  141. l.Debug("Restarting DNS server")
  142. dnsServer.Shutdown()
  143. go startDns(l, c)
  144. }