2
0

dns_server.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. package nebula
  2. import (
  3. "fmt"
  4. "net"
  5. "strconv"
  6. "sync"
  7. "github.com/miekg/dns"
  8. "github.com/sirupsen/logrus"
  9. )
  10. // This whole thing should be rewritten to use context
  11. var dnsR *dnsRecords
  12. var dnsServer *dns.Server
  13. var dnsAddr string
  14. type dnsRecords struct {
  15. sync.RWMutex
  16. dnsMap map[string]string
  17. hostMap *HostMap
  18. }
  19. func newDnsRecords(hostMap *HostMap) *dnsRecords {
  20. return &dnsRecords{
  21. dnsMap: make(map[string]string),
  22. hostMap: hostMap,
  23. }
  24. }
  25. func (d *dnsRecords) Query(data string) string {
  26. d.RLock()
  27. if r, ok := d.dnsMap[data]; ok {
  28. d.RUnlock()
  29. return r
  30. }
  31. d.RUnlock()
  32. return ""
  33. }
  34. func (d *dnsRecords) QueryCert(data string) string {
  35. ip := net.ParseIP(data[:len(data)-1])
  36. if ip == nil {
  37. return ""
  38. }
  39. iip := ip2int(ip)
  40. hostinfo, err := d.hostMap.QueryVpnIP(iip)
  41. if err != nil {
  42. return ""
  43. }
  44. q := hostinfo.GetCert()
  45. if q == nil {
  46. return ""
  47. }
  48. cert := q.Details
  49. c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAFter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
  50. return c
  51. }
  52. func (d *dnsRecords) Add(host, data string) {
  53. d.Lock()
  54. d.dnsMap[host] = data
  55. d.Unlock()
  56. }
  57. func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
  58. for _, q := range m.Question {
  59. switch q.Qtype {
  60. case dns.TypeA:
  61. l.Debugf("Query for A %s", q.Name)
  62. ip := dnsR.Query(q.Name)
  63. if ip != "" {
  64. rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
  65. if err == nil {
  66. m.Answer = append(m.Answer, rr)
  67. }
  68. }
  69. case dns.TypeTXT:
  70. a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
  71. b := net.ParseIP(a)
  72. // We don't answer these queries from non nebula nodes or localhost
  73. //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
  74. if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
  75. return
  76. }
  77. l.Debugf("Query for TXT %s", q.Name)
  78. ip := dnsR.QueryCert(q.Name)
  79. if ip != "" {
  80. rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
  81. if err == nil {
  82. m.Answer = append(m.Answer, rr)
  83. }
  84. }
  85. }
  86. }
  87. }
  88. func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
  89. m := new(dns.Msg)
  90. m.SetReply(r)
  91. m.Compress = false
  92. switch r.Opcode {
  93. case dns.OpcodeQuery:
  94. parseQuery(l, m, w)
  95. }
  96. w.WriteMsg(m)
  97. }
  98. func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) {
  99. dnsR = newDnsRecords(hostMap)
  100. // attach request handler func
  101. dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
  102. handleDnsRequest(l, w, r)
  103. })
  104. c.RegisterReloadCallback(func(c *Config) {
  105. reloadDns(l, c)
  106. })
  107. startDns(l, c)
  108. }
  109. func getDnsServerAddr(c *Config) string {
  110. return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
  111. }
  112. func startDns(l *logrus.Logger, c *Config) {
  113. dnsAddr = getDnsServerAddr(c)
  114. dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
  115. l.Debugf("Starting DNS responder at %s\n", dnsAddr)
  116. err := dnsServer.ListenAndServe()
  117. defer dnsServer.Shutdown()
  118. if err != nil {
  119. l.Errorf("Failed to start server: %s\n ", err.Error())
  120. }
  121. }
  122. func reloadDns(l *logrus.Logger, c *Config) {
  123. if dnsAddr == getDnsServerAddr(c) {
  124. l.Debug("No DNS server config change detected")
  125. return
  126. }
  127. l.Debug("Restarting DNS server")
  128. dnsServer.Shutdown()
  129. go startDns(l, c)
  130. }