zones_closest_test.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package zones
  2. import (
  3. "net"
  4. "reflect"
  5. "sort"
  6. "testing"
  7. "github.com/miekg/dns"
  8. )
  9. func TestClosest(t *testing.T) {
  10. muxm := loadZones(t)
  11. t.Log("test closests")
  12. tests := []struct {
  13. Label string
  14. ClientIP string
  15. ExpectedA []string
  16. QType uint16
  17. MaxHosts int
  18. }{
  19. {"closest", "212.237.144.84", []string{"194.106.223.155"}, dns.TypeA, 1},
  20. {"closest", "208.113.157.108", []string{"207.171.7.49", "207.171.7.59"}, dns.TypeA, 2},
  21. {"closest", "2620:0:872::1", []string{"2607:f238:3::1:45"}, dns.TypeAAAA, 1},
  22. // {"closest", "208.113.157.108", []string{"207.171.7.59"}, 1},
  23. }
  24. for _, x := range tests {
  25. ip := net.ParseIP(x.ClientIP)
  26. if ip == nil {
  27. t.Fatalf("Invalid ClientIP: %s", x.ClientIP)
  28. }
  29. tz := muxm.zonelist["test.example.com"]
  30. targets, netmask, location := tz.Options.Targeting.GetTargets(ip, true)
  31. t.Logf("targets: %q, netmask: %d, location: %+v", targets, netmask, location)
  32. // This is a weird API, but it's what serve() uses now. Fixing it
  33. // isn't super straight forward. Moving some of the exceptions from serve()
  34. // into configuration and making the "find the best answer" code have
  35. // a better API should be possible though. Some day.
  36. labelMatches := tz.FindLabels(
  37. x.Label,
  38. targets,
  39. []uint16{dns.TypeMF, dns.TypeCNAME, x.QType},
  40. )
  41. if len(labelMatches) == 0 {
  42. t.Fatalf("no labelmatches")
  43. }
  44. for _, match := range labelMatches {
  45. label := match.Label
  46. labelQtype := match.Type
  47. records := tz.Picker(label, labelQtype, x.MaxHosts, location)
  48. if records == nil {
  49. t.Fatalf("didn't get closest records")
  50. }
  51. if len(x.ExpectedA) == 0 {
  52. if len(records) > 0 {
  53. t.Logf("Expected 0 records but got %d", len(records))
  54. t.Fail()
  55. }
  56. }
  57. if len(x.ExpectedA) != len(records) {
  58. t.Logf("Expected %d records, got %d", len(x.ExpectedA), len(records))
  59. t.Fail()
  60. }
  61. ips := []string{}
  62. for _, r := range records {
  63. switch rr := r.RR.(type) {
  64. case *dns.A:
  65. ips = append(ips, rr.A.String())
  66. case *dns.AAAA:
  67. ips = append(ips, rr.AAAA.String())
  68. default:
  69. t.Fatalf("unexpected RR type: %s", rr.Header().String())
  70. }
  71. }
  72. sort.Strings(ips)
  73. sort.Strings(x.ExpectedA)
  74. if !reflect.DeepEqual(ips, x.ExpectedA) {
  75. t.Logf("Got '%+v', expected '%+v'", ips, x.ExpectedA)
  76. t.Fail()
  77. }
  78. }
  79. }
  80. }