123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- package nebula
- import (
- "fmt"
- "net"
- "strconv"
- "strings"
- "sync"
- "github.com/miekg/dns"
- "github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
- )
- // This whole thing should be rewritten to use context
- var dnsR *dnsRecords
- var dnsServer *dns.Server
- var dnsAddr string
- type dnsRecords struct {
- sync.RWMutex
- dnsMap map[string]string
- hostMap *HostMap
- }
- func newDnsRecords(hostMap *HostMap) *dnsRecords {
- return &dnsRecords{
- dnsMap: make(map[string]string),
- hostMap: hostMap,
- }
- }
- func (d *dnsRecords) Query(data string) string {
- d.RLock()
- defer d.RUnlock()
- if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
- return r
- }
- return ""
- }
- func (d *dnsRecords) QueryCert(data string) string {
- ip := net.ParseIP(data[:len(data)-1])
- if ip == nil {
- return ""
- }
- iip := iputil.Ip2VpnIp(ip)
- hostinfo := d.hostMap.QueryVpnIp(iip)
- if hostinfo == nil {
- return ""
- }
- q := hostinfo.GetCert()
- if q == nil {
- return ""
- }
- cert := q.Details
- c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
- return c
- }
- func (d *dnsRecords) Add(host, data string) {
- d.Lock()
- defer d.Unlock()
- d.dnsMap[strings.ToLower(host)] = data
- }
- func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
- for _, q := range m.Question {
- switch q.Qtype {
- case dns.TypeA:
- l.Debugf("Query for A %s", q.Name)
- ip := dnsR.Query(q.Name)
- if ip != "" {
- rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
- if err == nil {
- m.Answer = append(m.Answer, rr)
- }
- }
- case dns.TypeTXT:
- a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
- b := net.ParseIP(a)
- // We don't answer these queries from non nebula nodes or localhost
- //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
- if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
- return
- }
- l.Debugf("Query for TXT %s", q.Name)
- ip := dnsR.QueryCert(q.Name)
- if ip != "" {
- rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
- if err == nil {
- m.Answer = append(m.Answer, rr)
- }
- }
- }
- }
- if len(m.Answer) == 0 {
- m.Rcode = dns.RcodeNameError
- }
- }
- func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
- m := new(dns.Msg)
- m.SetReply(r)
- m.Compress = false
- switch r.Opcode {
- case dns.OpcodeQuery:
- parseQuery(l, m, w)
- }
- w.WriteMsg(m)
- }
- func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
- dnsR = newDnsRecords(hostMap)
- // attach request handler func
- dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
- handleDnsRequest(l, w, r)
- })
- c.RegisterReloadCallback(func(c *config.C) {
- reloadDns(l, c)
- })
- return func() {
- startDns(l, c)
- }
- }
- func getDnsServerAddr(c *config.C) string {
- dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
- // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
- if dnsHost == "[::]" {
- dnsHost = "::"
- }
- return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
- }
- func startDns(l *logrus.Logger, c *config.C) {
- dnsAddr = getDnsServerAddr(c)
- dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
- l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
- err := dnsServer.ListenAndServe()
- defer dnsServer.Shutdown()
- if err != nil {
- l.Errorf("Failed to start server: %s\n ", err.Error())
- }
- }
- func reloadDns(l *logrus.Logger, c *config.C) {
- if dnsAddr == getDnsServerAddr(c) {
- l.Debug("No DNS server config change detected")
- return
- }
- l.Debug("Restarting DNS server")
- dnsServer.Shutdown()
- go startDns(l, c)
- }
|