|
@@ -1,6 +1,7 @@
|
|
|
package server
|
|
|
|
|
|
import (
|
|
|
+ "encoding/hex"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"net"
|
|
@@ -10,6 +11,7 @@ import (
|
|
|
"time"
|
|
|
|
|
|
"github.com/abh/geodns/applog"
|
|
|
+ "github.com/abh/geodns/edns"
|
|
|
"github.com/abh/geodns/querylog"
|
|
|
"github.com/abh/geodns/zones"
|
|
|
|
|
@@ -66,36 +68,28 @@ func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *zones.Zone) {
|
|
|
|
|
|
z.Metrics.ClientStats.Add(realIP.String())
|
|
|
|
|
|
- var ip net.IP // EDNS or real IP
|
|
|
- var edns *dns.EDNS0_SUBNET
|
|
|
- var opt_rr *dns.OPT
|
|
|
-
|
|
|
- for _, extra := range req.Extra {
|
|
|
-
|
|
|
- switch extra.(type) {
|
|
|
- case *dns.OPT:
|
|
|
- for _, o := range extra.(*dns.OPT).Option {
|
|
|
- opt_rr = extra.(*dns.OPT)
|
|
|
- switch e := o.(type) {
|
|
|
- case *dns.EDNS0_NSID:
|
|
|
- // do stuff with e.Nsid
|
|
|
- case *dns.EDNS0_SUBNET:
|
|
|
- applog.Println("Got edns", e.Address, e.Family, e.SourceNetmask, e.SourceScope)
|
|
|
- if e.Address != nil {
|
|
|
- edns = e
|
|
|
- ip = e.Address
|
|
|
-
|
|
|
- if qle != nil {
|
|
|
- qle.HasECS = true
|
|
|
- qle.ClientAddr = fmt.Sprintf("%s/%d", ip, e.SourceNetmask)
|
|
|
- }
|
|
|
+ var ip net.IP // EDNS CLIENT SUBNET or real IP
|
|
|
+ var ecs *dns.EDNS0_SUBNET
|
|
|
+
|
|
|
+ if option := req.IsEdns0(); option != nil {
|
|
|
+ for _, s := range option.Option {
|
|
|
+ switch e := s.(type) {
|
|
|
+ case *dns.EDNS0_SUBNET:
|
|
|
+ applog.Println("Got edns-client-subnet", e.Address, e.Family, e.SourceNetmask, e.SourceScope)
|
|
|
+ if e.Address != nil {
|
|
|
+ ecs = e
|
|
|
+ ip = e.Address
|
|
|
+
|
|
|
+ if qle != nil {
|
|
|
+ qle.HasECS = true
|
|
|
+ qle.ClientAddr = fmt.Sprintf("%s/%d", ip, e.SourceNetmask)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if len(ip) == 0 { // no edns subnet
|
|
|
+ if len(ip) == 0 { // no edns client subnet
|
|
|
ip = realIP
|
|
|
if qle != nil {
|
|
|
qle.ClientAddr = fmt.Sprintf("%s/%d", ip, len(ip)*8)
|
|
@@ -104,8 +98,9 @@ func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *zones.Zone) {
|
|
|
|
|
|
targets, netmask, location := z.Options.Targeting.GetTargets(ip, z.HasClosest)
|
|
|
|
|
|
- m := new(dns.Msg)
|
|
|
+ m := &dns.Msg{}
|
|
|
|
|
|
+ // setup logging of answers and rcode
|
|
|
if qle != nil {
|
|
|
qle.Targets = targets
|
|
|
defer func() {
|
|
@@ -114,23 +109,40 @@ func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *zones.Zone) {
|
|
|
}()
|
|
|
}
|
|
|
|
|
|
- m.SetReply(req)
|
|
|
- if e := m.IsEdns0(); e != nil {
|
|
|
- m.SetEdns0(4096, e.Do())
|
|
|
+ mv, err := edns.Version(req)
|
|
|
+ if err != nil {
|
|
|
+ m = mv
|
|
|
+ err := w.WriteMsg(m)
|
|
|
+ if err != nil {
|
|
|
+ applog.Printf("could not write response: %s", err)
|
|
|
+ }
|
|
|
+ return
|
|
|
}
|
|
|
- m.Authoritative = true
|
|
|
|
|
|
- // TODO: set scope to 0 if there are no alternate responses
|
|
|
- if edns != nil {
|
|
|
- if edns.Family != 0 {
|
|
|
- if netmask < 16 {
|
|
|
- netmask = 16
|
|
|
+ m.SetReply(req)
|
|
|
+
|
|
|
+ if option := edns.SetSizeAndDo(req, m); option != nil {
|
|
|
+
|
|
|
+ for _, s := range option.Option {
|
|
|
+ switch e := s.(type) {
|
|
|
+ case *dns.EDNS0_NSID:
|
|
|
+ e.Code = dns.EDNS0NSID
|
|
|
+ e.Nsid = hex.EncodeToString([]byte(srv.info.ID))
|
|
|
+ case *dns.EDNS0_SUBNET:
|
|
|
+ // access e.Family, e.Address, etc.
|
|
|
+ // TODO: set scope to 0 if there are no alternate responses
|
|
|
+ if ecs.Family != 0 {
|
|
|
+ if netmask < 16 {
|
|
|
+ netmask = 16
|
|
|
+ }
|
|
|
+ e.SourceScope = uint8(netmask)
|
|
|
+ }
|
|
|
}
|
|
|
- edns.SourceScope = uint8(netmask)
|
|
|
- m.Extra = append(m.Extra, opt_rr)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ m.Authoritative = true
|
|
|
+
|
|
|
labelMatches := z.FindLabels(qlabel, targets, []uint16{dns.TypeMF, dns.TypeCNAME, qtype})
|
|
|
|
|
|
if len(labelMatches) == 0 {
|
|
@@ -267,7 +279,7 @@ func (srv *Server) serve(w dns.ResponseWriter, req *dns.Msg, z *zones.Zone) {
|
|
|
// should this be in the match loop above?
|
|
|
qle.Rcode = m.Rcode
|
|
|
}
|
|
|
- err := w.WriteMsg(m)
|
|
|
+ err = w.WriteMsg(m)
|
|
|
if err != nil {
|
|
|
// if Pack'ing fails the Write fails. Return SERVFAIL.
|
|
|
applog.Printf("Error writing packet: %q, %s", err, m)
|