dns_server.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. package nebula
  2. import (
  3. "context"
  4. "fmt"
  5. "net"
  6. "net/netip"
  7. "strconv"
  8. "strings"
  9. "sync"
  10. "github.com/gaissmai/bart"
  11. "github.com/miekg/dns"
  12. "github.com/sirupsen/logrus"
  13. "github.com/slackhq/nebula/config"
  14. )
  15. // This whole thing should be rewritten to use context
  16. var dnsR *dnsRecords
  17. var dnsServer *dns.Server
  18. var dnsAddr string
  19. type dnsRecords struct {
  20. sync.RWMutex
  21. l *logrus.Logger
  22. dnsMap4 map[string]netip.Addr
  23. dnsMap6 map[string]netip.Addr
  24. hostMap *HostMap
  25. myVpnAddrsTable *bart.Lite
  26. }
  27. func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
  28. return &dnsRecords{
  29. l: l,
  30. dnsMap4: make(map[string]netip.Addr),
  31. dnsMap6: make(map[string]netip.Addr),
  32. hostMap: hostMap,
  33. myVpnAddrsTable: cs.myVpnAddrsTable,
  34. }
  35. }
  36. func (d *dnsRecords) query(q uint16, data string) netip.Addr {
  37. data = strings.ToLower(data)
  38. d.RLock()
  39. defer d.RUnlock()
  40. switch q {
  41. case dns.TypeA:
  42. if r, ok := d.dnsMap4[data]; ok {
  43. return r
  44. }
  45. case dns.TypeAAAA:
  46. if r, ok := d.dnsMap6[data]; ok {
  47. return r
  48. }
  49. }
  50. return netip.Addr{}
  51. }
  52. func (d *dnsRecords) queryCert(data string) string {
  53. ip, err := netip.ParseAddr(data[:len(data)-1])
  54. if err != nil {
  55. return ""
  56. }
  57. hostinfo := d.hostMap.QueryVpnAddr(ip)
  58. if hostinfo == nil {
  59. return ""
  60. }
  61. q := hostinfo.GetCert()
  62. if q == nil {
  63. return ""
  64. }
  65. b, err := q.Certificate.MarshalJSON()
  66. if err != nil {
  67. return ""
  68. }
  69. return string(b)
  70. }
  71. // Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
  72. func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
  73. host = strings.ToLower(host)
  74. d.Lock()
  75. defer d.Unlock()
  76. haveV4 := false
  77. haveV6 := false
  78. for _, addr := range addresses {
  79. if addr.Is4() && !haveV4 {
  80. d.dnsMap4[host] = addr
  81. haveV4 = true
  82. } else if addr.Is6() && !haveV6 {
  83. d.dnsMap6[host] = addr
  84. haveV6 = true
  85. }
  86. if haveV4 && haveV6 {
  87. break
  88. }
  89. }
  90. }
  91. func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
  92. a, _, _ := net.SplitHostPort(addr)
  93. b, err := netip.ParseAddr(a)
  94. if err != nil {
  95. return false
  96. }
  97. if b.IsLoopback() {
  98. return true
  99. }
  100. //if we found it in this table, it's good
  101. return d.myVpnAddrsTable.Contains(b)
  102. }
  103. func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
  104. for _, q := range m.Question {
  105. switch q.Qtype {
  106. case dns.TypeA, dns.TypeAAAA:
  107. qType := dns.TypeToString[q.Qtype]
  108. d.l.Debugf("Query for %s %s", qType, q.Name)
  109. ip := d.query(q.Qtype, q.Name)
  110. if ip.IsValid() {
  111. rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
  112. if err == nil {
  113. m.Answer = append(m.Answer, rr)
  114. }
  115. }
  116. case dns.TypeTXT:
  117. // We only answer these queries from nebula nodes or localhost
  118. if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
  119. return
  120. }
  121. d.l.Debugf("Query for TXT %s", q.Name)
  122. ip := d.queryCert(q.Name)
  123. if ip != "" {
  124. rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
  125. if err == nil {
  126. m.Answer = append(m.Answer, rr)
  127. }
  128. }
  129. }
  130. }
  131. if len(m.Answer) == 0 {
  132. m.Rcode = dns.RcodeNameError
  133. }
  134. }
  135. func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
  136. m := new(dns.Msg)
  137. m.SetReply(r)
  138. m.Compress = false
  139. switch r.Opcode {
  140. case dns.OpcodeQuery:
  141. d.parseQuery(m, w)
  142. }
  143. w.WriteMsg(m)
  144. }
  145. func dnsMain(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
  146. dnsR = newDnsRecords(l, cs, hostMap)
  147. // attach request handler func
  148. dns.HandleFunc(".", dnsR.handleDnsRequest)
  149. c.RegisterReloadCallback(func(c *config.C) {
  150. reloadDns(ctx, l, c)
  151. })
  152. return func() {
  153. startDns(ctx, l, c)
  154. }
  155. }
  156. func getDnsServerAddr(c *config.C) string {
  157. dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
  158. // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
  159. if dnsHost == "[::]" {
  160. dnsHost = "::"
  161. }
  162. return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
  163. }
  164. func startDns(ctx context.Context, l *logrus.Logger, c *config.C) {
  165. dnsAddr = getDnsServerAddr(c)
  166. dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
  167. l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
  168. err := dnsServer.ListenAndServe()
  169. defer dnsServer.ShutdownContext(ctx)
  170. if err != nil {
  171. l.Errorf("Failed to start server: %s\n ", err.Error())
  172. }
  173. }
  174. func reloadDns(ctx context.Context, l *logrus.Logger, c *config.C) {
  175. if dnsAddr == getDnsServerAddr(c) {
  176. l.Debug("No DNS server config change detected")
  177. return
  178. }
  179. l.Debug("Restarting DNS server")
  180. dnsServer.ShutdownContext(ctx)
  181. go startDns(ctx, l, c)
  182. }