Browse Source

Move CNAME and alias logic into findLabels function

This makes alias'es and cname's go through the proper targeting logic
the same as other record types.

Closes #25
Ask Bjørn Hansen 12 years ago
parent
commit
3fe4f6c1a0
5 changed files with 63 additions and 61 deletions
  1. 3 0
      dns/test.example.com.json
  2. 5 17
      serve.go
  3. 13 16
      serve_test.go
  4. 30 24
      types.go
  5. 12 4
      zone_test.go

+ 3 - 0
dns/test.example.com.json

@@ -38,6 +38,9 @@
     "www": {
       "cname": "geo.bitnames.com."
     },
+    "www.europe": {
+      "cname": "geo-europe.bitnames.com."
+    },
     "www-cname": {
       "cname": "bar"
     },

+ 5 - 17
serve.go

@@ -81,14 +81,11 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 		}
 	}
 
-	// TODO(ask) Fix the findLabels API to make this work better
-	if alias := z.findLabels(label, "", dns.TypeMF); alias != nil &&
-		alias.Records[dns.TypeMF] != nil {
-		// We found an alias record, so pretend the question was for that name instead
-		label = alias.firstRR(dns.TypeMF).(*dns.MF).Mf
+	labels, labelQtype := z.findLabels(label, country, qTypes{dns.TypeMF, dns.TypeCNAME, qtype})
+	if labelQtype == 0 {
+		labelQtype = qtype
 	}
 
-	labels := z.findLabels(label, country, qtype)
 	if labels == nil {
 
 		if label == "_status" && (qtype == dns.TypeANY || qtype == dns.TypeTXT) {
@@ -126,7 +123,7 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 		return
 	}
 
-	if servers := labels.Picker(qtype, labels.MaxHosts); servers != nil {
+	if servers := labels.Picker(labelQtype, labels.MaxHosts); servers != nil {
 		var rrs []dns.RR
 		for _, record := range servers {
 			rr := record.RR
@@ -137,16 +134,7 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 	}
 
 	if len(m.Answer) == 0 {
-		if labels := z.Labels[label]; labels != nil {
-			if _, ok := labels.Records[dns.TypeCNAME]; ok {
-				cname := labels.firstRR(dns.TypeCNAME)
-				m.Answer = append(m.Answer, cname)
-			} else {
-				m.Ns = append(m.Ns, z.SoaRR())
-			}
-		} else {
-			m.Ns = append(m.Ns, z.SoaRR())
-		}
+		m.Ns = append(m.Ns, z.SoaRR())
 	}
 
 	logPrintln(m)

+ 13 - 16
serve_test.go

@@ -74,14 +74,12 @@ func (s *ServeSuite) TestServingAliases(c *C) {
 	r = exchange(c, "www-alias.test.example.com.", dns.TypeA)
 	c.Check(r.Answer[0].(*dns.CNAME).Target, Equals, "geo.bitnames.com.")
 
-	/*
-		// Alias returning a cname, with geo overrides
-		r = exchangeSubnet(c, "www-alias.test.example.com.", dns.TypeA, "194.239.134.1")
-		c.Check(r.Answer, HasLen, 1)
-		if len(r.Answer) > 0 {
-			c.Check(r.Answer[0].(*dns.CNAME).Target, Equals, "geo-europe.bitnames.com.")
-		}
-	*/
+	// Alias returning a cname, with geo overrides
+	r = exchangeSubnet(c, "www-alias.test.example.com.", dns.TypeA, "194.239.134.1")
+	c.Check(r.Answer, HasLen, 1)
+	if len(r.Answer) > 0 {
+		c.Check(r.Answer[0].(*dns.CNAME).Target, Equals, "geo-europe.bitnames.com.")
+	}
 }
 
 func (s *ServeSuite) TestServingEDNS(c *C) {
@@ -92,15 +90,14 @@ func (s *ServeSuite) TestServingEDNS(c *C) {
 		c.Check(r.Answer[0].(*dns.MX).Mx, Equals, "mx-eu.example.net.")
 	}
 
-	/*
-		c.Log("Testing www.test.example.com from .dk, should match www.europe (a cname)")
+	c.Log("Testing www.test.example.com from .dk, should match www.europe (a cname)")
+
+	r = exchangeSubnet(c, "www.test.example.com.", dns.TypeA, "194.239.134.1")
+	c.Check(r.Answer, HasLen, 1)
+	if len(r.Answer) > 0 {
+		c.Check(r.Answer[0].(*dns.CNAME).Target, Equals, "geo-europe.bitnames.com.")
+	}
 
-		r = exchangeSubnet(c, "www.test.example.com.", dns.TypeA, "194.239.134.1")
-		c.Check(r.Answer, HasLen, 1)
-		if len(r.Answer) > 0 {
-			c.Check(r.Answer[0].(*dns.CNAME).Target, Equals, "geo-europe.bitnames.com.")
-		}
-	*/
 }
 
 func exchangeSubnet(c *C, name string, dnstype uint16, ip string) *dns.Msg {

+ 30 - 24
types.go

@@ -46,6 +46,8 @@ type Zone struct {
 	Options   Options
 }
 
+type qTypes []uint16
+
 func (l *Label) firstRR(dnsType uint16) dns.RR {
 	return l.Records[dnsType][0].RR
 }
@@ -68,48 +70,52 @@ func (z *Zone) SoaRR() dns.RR {
 	return z.Labels[""].firstRR(dns.TypeSOA)
 }
 
-func (z *Zone) findLabels(s, cc string, qtype uint16) *Label {
-
-	if qtype == dns.TypeANY {
-		// short-circuit mostly to avoid subtle bugs later
-		// to be correct we should run through all the selectors and
-		// pick types not already picked
-		return z.Labels[s]
-	}
+func (z *Zone) findLabels(s, cc string, qts qTypes) (*Label, uint16) {
 
 	selectors := []string{}
 
 	if len(cc) > 0 {
 		continent := countries.CountryContinent[cc]
+		var s_cc string
 		if len(s) > 0 {
-			cc = s + "." + cc
+			s_cc = s + "." + cc
 			if len(continent) > 0 {
 				continent = s + "." + continent
 			}
+		} else {
+			s_cc = cc
 		}
-		selectors = append(selectors, cc, continent)
+		selectors = append(selectors, s_cc, continent)
 	}
 	selectors = append(selectors, s)
 
 	for _, name := range selectors {
+
 		if label, ok := z.Labels[name]; ok {
 
-			// look for aliases
-			if label.Records[dns.TypeMF] != nil {
-				name = label.firstRR(dns.TypeMF).(*dns.MF).Mf
-				// BUG(ask) - restructure this so it supports chains of aliases
-				label, ok = z.Labels[name]
-				if label == nil {
-					continue
+			for _, qtype := range qts {
+
+				switch qtype {
+				case dns.TypeANY:
+					// short-circuit mostly to avoid subtle bugs later
+					// to be correct we should run through all the selectors and
+					// pick types not already picked
+					return z.Labels[s], qtype
+				case dns.TypeMF:
+					if label.Records[dns.TypeMF] != nil {
+						name = label.firstRR(dns.TypeMF).(*dns.MF).Mf
+						// TODO: need to avoid loops here somehow
+						return z.findLabels(name, cc, qts)
+					}
+				default:
+					// return the label if it has the right record
+					if label.Records[qtype] != nil && len(label.Records[qtype]) > 0 {
+						return label, qtype
+					}
 				}
 			}
-
-			// return the label if it has the right records
-			// TODO(ask) Should this also look for CNAME records?
-			if label.Records[qtype] != nil && len(label.Records[qtype]) > 0 {
-				return label
-			}
 		}
 	}
-	return z.Labels[s]
+
+	return z.Labels[s], 0
 }

+ 12 - 4
zone_test.go

@@ -6,22 +6,30 @@ import (
 )
 
 func (s *ConfigSuite) TestZone(c *C) {
-	ex := s.zones["example.com"]
+
+	ex := s.zones["test.example.com"]
 	c.Check(ex.Labels["weight"].MaxHosts, Equals, 1)
 
 	// Make sure that the empty "no.bar" zone gets skipped and "bar" is used
-	label := ex.findLabels("bar", "no", dns.TypeA)
+	label, qtype := ex.findLabels("bar", "no", qTypes{dns.TypeA})
 	c.Check(label.Records[dns.TypeA], HasLen, 1)
 	c.Check(label.Records[dns.TypeA][0].RR.(*dns.A).A.String(), Equals, "192.168.1.2")
+	c.Check(qtype, Equals, dns.TypeA)
 
-	label = ex.findLabels("", "", dns.TypeMX)
+	label, qtype = ex.findLabels("", "", qTypes{dns.TypeMX})
 	Mxs := label.Records[dns.TypeMX]
 	c.Check(Mxs, HasLen, 2)
 	c.Check(Mxs[0].RR.(*dns.MX).Mx, Equals, "mx.example.net.")
 	c.Check(Mxs[1].RR.(*dns.MX).Mx, Equals, "mx2.example.net.")
 
-	Mxs = ex.findLabels("", "dk", dns.TypeMX).Records[dns.TypeMX]
+	label, qtype = ex.findLabels("", "dk", qTypes{dns.TypeMX})
+	Mxs = label.Records[dns.TypeMX]
 	c.Check(Mxs, HasLen, 1)
 	c.Check(Mxs[0].RR.(*dns.MX).Mx, Equals, "mx-eu.example.net.")
+	c.Check(qtype, Equals, dns.TypeMX)
 
+	// look for multiple record types
+	label, qtype = ex.findLabels("www", "", qTypes{dns.TypeCNAME, dns.TypeA})
+	c.Check(label.Records[dns.TypeCNAME], HasLen, 1)
+	c.Check(qtype, Equals, dns.TypeCNAME)
 }