瀏覽代碼

Calculated weight for 'weighted' labels, sort by server weight

Ask Bjørn Hansen 13 年之前
父節點
當前提交
eb7a56f99a
共有 3 個文件被更改,包括 60 次插入18 次删除
  1. 26 10
      geodns.go
  2. 33 0
      picker.go
  3. 1 8
      serve.go

+ 26 - 10
geodns.go

@@ -8,6 +8,7 @@ import (
 	"io/ioutil"
 	"io/ioutil"
 	"net"
 	"net"
 	"path"
 	"path"
+	"sort"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 )
 )
@@ -22,11 +23,21 @@ type Record struct {
 	Weight int
 	Weight int
 }
 }
 
 
+type Records []Record
+
+func (s Records) Len() int      { return len(s) }
+func (s Records) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+type RecordsByWeight struct{ Records }
+
+func (s RecordsByWeight) Less(i, j int) bool { return s.Records[i].Weight > s.Records[j].Weight }
+
 type Label struct {
 type Label struct {
 	Label    string
 	Label    string
 	MaxHosts int
 	MaxHosts int
 	Ttl      int
 	Ttl      int
-	Records  map[uint16][]Record
+	Records  map[uint16]Records
+	Weight   map[uint16]int
 }
 }
 
 
 type labels map[string]*Label
 type labels map[string]*Label
@@ -181,7 +192,7 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 
 
 			fmt.Printf("rdata %s TYPE-R %T\n", rdata, rdata)
 			fmt.Printf("rdata %s TYPE-R %T\n", rdata, rdata)
 
 
-			Records := make(map[string][]interface{})
+			records := make(map[string][]interface{})
 
 
 			switch rdata.(type) {
 			switch rdata.(type) {
 			case map[string]interface{}:
 			case map[string]interface{}:
@@ -193,22 +204,23 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 					}
 					}
 					tmp = append(tmp, []string{rdata_k, rdata_v.(string)})
 					tmp = append(tmp, []string{rdata_k, rdata_v.(string)})
 				}
 				}
-				Records[rType] = tmp
+				records[rType] = tmp
 			default:
 			default:
-				Records[rType] = rdata.([]interface{})
+				records[rType] = rdata.([]interface{})
 			}
 			}
 
 
 			//fmt.Printf("RECORDS %s TYPE-REC %T\n", Records, Records)
 			//fmt.Printf("RECORDS %s TYPE-REC %T\n", Records, Records)
 
 
 			if label.Records == nil {
 			if label.Records == nil {
-				label.Records = make(map[uint16][]Record)
+				label.Records = make(map[uint16]Records)
+				label.Weight = make(map[uint16]int)
 			}
 			}
 
 
-			label.Records[dnsType] = make([]Record, len(Records[rType]))
+			label.Records[dnsType] = make(Records, len(records[rType]))
 
 
-			for i := 0; i < len(Records[rType]); i++ {
+			for i := 0; i < len(records[rType]); i++ {
 
 
-				fmt.Printf("RT %T %#v\n", Records[rType][i], Records[rType][i])
+				fmt.Printf("RT %T %#v\n", records[rType][i], records[rType][i])
 
 
 				record := new(Record)
 				record := new(Record)
 
 
@@ -220,7 +232,7 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 
 
 				switch dnsType {
 				switch dnsType {
 				case dns.TypeA, dns.TypeAAAA:
 				case dns.TypeA, dns.TypeAAAA:
-					rec := Records[rType][i].([]interface{})
+					rec := records[rType][i].([]interface{})
 					ip := rec[0].(string)
 					ip := rec[0].(string)
 					var err error
 					var err error
 					switch rec[1].(type) {
 					switch rec[1].(type) {
@@ -229,6 +241,7 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 						if err != nil {
 						if err != nil {
 							panic("Error converting weight to integer")
 							panic("Error converting weight to integer")
 						}
 						}
+						label.Weight[dnsType] += record.Weight
 					case float64:
 					case float64:
 						record.Weight = int(rec[1].(float64))
 						record.Weight = int(rec[1].(float64))
 					}
 					}
@@ -249,7 +262,7 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 						record.RR = rr
 						record.RR = rr
 					}
 					}
 				case dns.TypeNS:
 				case dns.TypeNS:
-					rec := Records[rType][i]
+					rec := records[rType][i]
 					rr := &dns.RR_NS{Hdr: h}
 					rr := &dns.RR_NS{Hdr: h}
 
 
 					switch rec.(type) {
 					switch rec.(type) {
@@ -283,6 +296,9 @@ func setupZoneData(data map[string]interface{}, Zone *Zone) {
 
 
 				label.Records[dnsType][i] = *record
 				label.Records[dnsType][i] = *record
 			}
 			}
+			if label.Weight > 0 {
+				sort.Sort(RecordsByWeight{label.Records[dnsType]})
+			}
 		}
 		}
 	}
 	}
 
 

+ 33 - 0
picker.go

@@ -0,0 +1,33 @@
+package main
+
+import (
+	"fmt"
+)
+
+func (label *Label) Picker(dnsType uint16, max int) Records {
+
+	if label_rr := label.Records[dnsType]; label_rr != nil {
+
+		//fmt.Printf("REGION_RR %T %v\n", label_rr, label_rr)
+
+		// not "balanced", just return all
+		if label.Weight[dnsType] == 0 {
+			return label_rr
+		}
+
+		rr_count := len(label_rr)
+		if max > rr_count {
+			max = rr_count
+		}
+
+		fmt.Println("Total weight", label.Weight[dnsType])
+
+		// TODO(ask) Pick random servers based on weight, not just the first 'max' entries
+		servers := label_rr[0:max]
+
+		fmt.Println("SERVERS", servers)
+
+		return servers
+	}
+	return nil
+}

+ 1 - 8
serve.go

@@ -59,14 +59,7 @@ func serve(w dns.ResponseWriter, req *dns.Msg, z *Zone) {
 
 
 	fmt.Println("Has the label, looking for records")
 	fmt.Println("Has the label, looking for records")
 
 
-	if region_rr := labels.Records[qtype]; region_rr != nil {
-		//fmt.Printf("REGION_RR %T %v\n", region_rr, region_rr)
-		max := len(region_rr)
-		if max > 4 {
-			max = 4
-		}
-		// TODO(ask) Pick random servers based on weight, not just the first 'max' entries
-		servers := region_rr[0:max]
+	if servers := labels.Picker(qtype, 4); servers != nil {
 		var rrs []dns.RR
 		var rrs []dns.RR
 		for _, record := range servers {
 		for _, record := range servers {
 			rr := record.RR
 			rr := record.RR