dns_server.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. package nebula
  2. import (
  3. "fmt"
  4. "net"
  5. "net/netip"
  6. "strconv"
  7. "strings"
  8. "sync"
  9. "github.com/gaissmai/bart"
  10. "github.com/miekg/dns"
  11. "github.com/sirupsen/logrus"
  12. "github.com/slackhq/nebula/config"
  13. )
  14. // This whole thing should be rewritten to use context
  15. var dnsR *dnsRecords
  16. var dnsServer *dns.Server
  17. var dnsAddr string
  18. type dnsRecords struct {
  19. sync.RWMutex
  20. l *logrus.Logger
  21. dnsMap4 map[string]netip.Addr
  22. dnsMap6 map[string]netip.Addr
  23. hostMap *HostMap
  24. myVpnAddrsTable *bart.Lite
  25. }
  26. func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
  27. return &dnsRecords{
  28. l: l,
  29. dnsMap4: make(map[string]netip.Addr),
  30. dnsMap6: make(map[string]netip.Addr),
  31. hostMap: hostMap,
  32. myVpnAddrsTable: cs.myVpnAddrsTable,
  33. }
  34. }
  35. func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
  36. data = strings.ToLower(data)
  37. d.RLock()
  38. defer d.RUnlock()
  39. switch q {
  40. case dns.TypeA:
  41. if r, ok := d.dnsMap4[data]; ok {
  42. return r
  43. }
  44. case dns.TypeAAAA:
  45. if r, ok := d.dnsMap6[data]; ok {
  46. return r
  47. }
  48. }
  49. return netip.Addr{}
  50. }
  51. func (d *dnsRecords) QueryCert(data string) string {
  52. ip, err := netip.ParseAddr(data[:len(data)-1])
  53. if err != nil {
  54. return ""
  55. }
  56. hostinfo := d.hostMap.QueryVpnAddr(ip)
  57. if hostinfo == nil {
  58. return ""
  59. }
  60. q := hostinfo.GetCert()
  61. if q == nil {
  62. return ""
  63. }
  64. b, err := q.Certificate.MarshalJSON()
  65. if err != nil {
  66. return ""
  67. }
  68. return string(b)
  69. }
  70. // Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
  71. func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
  72. host = strings.ToLower(host)
  73. d.Lock()
  74. defer d.Unlock()
  75. haveV4 := false
  76. haveV6 := false
  77. for _, addr := range addresses {
  78. if addr.Is4() && !haveV4 {
  79. d.dnsMap4[host] = addr
  80. haveV4 = true
  81. } else if addr.Is6() && !haveV6 {
  82. d.dnsMap6[host] = addr
  83. haveV6 = true
  84. }
  85. if haveV4 && haveV6 {
  86. break
  87. }
  88. }
  89. }
  90. func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
  91. a, _, _ := net.SplitHostPort(addr)
  92. b, err := netip.ParseAddr(a)
  93. if err != nil {
  94. return false
  95. }
  96. if b.IsLoopback() {
  97. return true
  98. }
  99. //if we found it in this table, it's good
  100. return d.myVpnAddrsTable.Contains(b)
  101. }
  102. func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
  103. for _, q := range m.Question {
  104. switch q.Qtype {
  105. case dns.TypeA, dns.TypeAAAA:
  106. qType := dns.TypeToString[q.Qtype]
  107. d.l.Debugf("Query for %s %s", qType, q.Name)
  108. ip := d.Query(q.Qtype, q.Name)
  109. if ip.IsValid() {
  110. rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
  111. if err == nil {
  112. m.Answer = append(m.Answer, rr)
  113. }
  114. }
  115. case dns.TypeTXT:
  116. // We only answer these queries from nebula nodes or localhost
  117. if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
  118. return
  119. }
  120. d.l.Debugf("Query for TXT %s", q.Name)
  121. ip := d.QueryCert(q.Name)
  122. if ip != "" {
  123. rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
  124. if err == nil {
  125. m.Answer = append(m.Answer, rr)
  126. }
  127. }
  128. }
  129. }
  130. if len(m.Answer) == 0 {
  131. m.Rcode = dns.RcodeNameError
  132. }
  133. }
  134. func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
  135. m := new(dns.Msg)
  136. m.SetReply(r)
  137. m.Compress = false
  138. switch r.Opcode {
  139. case dns.OpcodeQuery:
  140. d.parseQuery(m, w)
  141. }
  142. w.WriteMsg(m)
  143. }
  144. func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
  145. dnsR = newDnsRecords(l, cs, hostMap)
  146. // attach request handler func
  147. dns.HandleFunc(".", dnsR.handleDnsRequest)
  148. c.RegisterReloadCallback(func(c *config.C) {
  149. reloadDns(l, c)
  150. })
  151. return func() {
  152. startDns(l, c)
  153. }
  154. }
  155. func getDnsServerAddr(c *config.C) string {
  156. dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
  157. // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
  158. if dnsHost == "[::]" {
  159. dnsHost = "::"
  160. }
  161. return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
  162. }
  163. func startDns(l *logrus.Logger, c *config.C) {
  164. dnsAddr = getDnsServerAddr(c)
  165. dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
  166. l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
  167. err := dnsServer.ListenAndServe()
  168. defer dnsServer.Shutdown()
  169. if err != nil {
  170. l.Errorf("Failed to start server: %s\n ", err.Error())
  171. }
  172. }
  173. func reloadDns(l *logrus.Logger, c *config.C) {
  174. if dnsAddr == getDnsServerAddr(c) {
  175. l.Debug("No DNS server config change detected")
  176. return
  177. }
  178. l.Debug("Restarting DNS server")
  179. dnsServer.Shutdown()
  180. go startDns(l, c)
  181. }