|
@@ -10,6 +10,9 @@ import (
|
|
"time"
|
|
"time"
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+// Maximum number of TCP queries before we close the socket.
|
|
|
|
+const maxTCPQueries = 128
|
|
|
|
+
|
|
// Handler is implemented by any value that implements ServeDNS.
|
|
// Handler is implemented by any value that implements ServeDNS.
|
|
type Handler interface {
|
|
type Handler interface {
|
|
ServeDNS(w ResponseWriter, r *Msg)
|
|
ServeDNS(w ResponseWriter, r *Msg)
|
|
@@ -47,6 +50,7 @@ type response struct {
|
|
tcp *net.TCPConn // i/o connection if TCP was used
|
|
tcp *net.TCPConn // i/o connection if TCP was used
|
|
udpSession *SessionUDP // oob data to get egress interface right
|
|
udpSession *SessionUDP // oob data to get egress interface right
|
|
remoteAddr net.Addr // address of the client
|
|
remoteAddr net.Addr // address of the client
|
|
|
|
+ writer Writer // writer to output the raw DNS bits
|
|
}
|
|
}
|
|
|
|
|
|
// ServeMux is an DNS request multiplexer. It matches the
|
|
// ServeMux is an DNS request multiplexer. It matches the
|
|
@@ -158,9 +162,9 @@ func (mux *ServeMux) HandleRemove(pattern string) {
|
|
if pattern == "" {
|
|
if pattern == "" {
|
|
panic("dns: invalid pattern " + pattern)
|
|
panic("dns: invalid pattern " + pattern)
|
|
}
|
|
}
|
|
- // don't need a mutex here, because deleting is OK, even if the
|
|
|
|
- // entry is note there.
|
|
|
|
|
|
+ mux.m.Lock()
|
|
delete(mux.z, Fqdn(pattern))
|
|
delete(mux.z, Fqdn(pattern))
|
|
|
|
+ mux.m.Unlock()
|
|
}
|
|
}
|
|
|
|
|
|
// ServeDNS dispatches the request to the handler whose
|
|
// ServeDNS dispatches the request to the handler whose
|
|
@@ -197,6 +201,43 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
|
|
DefaultServeMux.HandleFunc(pattern, handler)
|
|
DefaultServeMux.HandleFunc(pattern, handler)
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+// Writer writes raw DNS messages; each call to Write should send an entire message.
|
|
|
|
+type Writer interface {
|
|
|
|
+ io.Writer
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// Reader reads raw DNS messages; each call to ReadTCP or ReadUDP should return an entire message.
|
|
|
|
+type Reader interface {
|
|
|
|
+ // ReadTCP reads a raw message from a TCP connection. Implementations may alter
|
|
|
|
+ // connection properties, for example the read-deadline.
|
|
|
|
+ ReadTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error)
|
|
|
|
+ // ReadUDP reads a raw message from a UDP connection. Implementations may alter
|
|
|
|
+ // connection properties, for example the read-deadline.
|
|
|
|
+ ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// defaultReader is an adapter for the Server struct that implements the Reader interface
|
|
|
|
+// using the readTCP and readUDP func of the embedded Server.
|
|
|
|
+type defaultReader struct {
|
|
|
|
+ *Server
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (dr *defaultReader) ReadTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) {
|
|
|
|
+ return dr.readTCP(conn, timeout)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
|
|
|
|
+ return dr.readUDP(conn, timeout)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
|
|
|
|
+// Implementations should never return a nil Reader.
|
|
|
|
+type DecorateReader func(Reader) Reader
|
|
|
|
+
|
|
|
|
+// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
|
|
|
|
+// Implementations should never return a nil Writer.
|
|
|
|
+type DecorateWriter func(Writer) Writer
|
|
|
|
+
|
|
// A Server defines parameters for running an DNS server.
|
|
// A Server defines parameters for running an DNS server.
|
|
type Server struct {
|
|
type Server struct {
|
|
// Address to listen on, ":dns" if empty.
|
|
// Address to listen on, ":dns" if empty.
|
|
@@ -223,8 +264,12 @@ type Server struct {
|
|
// Unsafe instructs the server to disregard any sanity checks and directly hand the message to
|
|
// Unsafe instructs the server to disregard any sanity checks and directly hand the message to
|
|
// the handler. It will specfically not check if the query has the QR bit not set.
|
|
// the handler. It will specfically not check if the query has the QR bit not set.
|
|
Unsafe bool
|
|
Unsafe bool
|
|
- // If NotifyStartedFunc is set is is called, once the server has started listening.
|
|
|
|
|
|
+ // If NotifyStartedFunc is set it is called once the server has started listening.
|
|
NotifyStartedFunc func()
|
|
NotifyStartedFunc func()
|
|
|
|
+ // DecorateReader is optional, allows customization of the process that reads raw DNS messages.
|
|
|
|
+ DecorateReader DecorateReader
|
|
|
|
+ // DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
|
|
|
|
+ DecorateWriter DecorateWriter
|
|
|
|
|
|
// For graceful shutdown.
|
|
// For graceful shutdown.
|
|
stopUDP chan bool
|
|
stopUDP chan bool
|
|
@@ -246,7 +291,6 @@ func (srv *Server) ListenAndServe() error {
|
|
}
|
|
}
|
|
srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool)
|
|
srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool)
|
|
srv.started = true
|
|
srv.started = true
|
|
- srv.lock.Unlock()
|
|
|
|
addr := srv.Addr
|
|
addr := srv.Addr
|
|
if addr == "" {
|
|
if addr == "" {
|
|
addr = ":domain"
|
|
addr = ":domain"
|
|
@@ -265,6 +309,7 @@ func (srv *Server) ListenAndServe() error {
|
|
return e
|
|
return e
|
|
}
|
|
}
|
|
srv.Listener = l
|
|
srv.Listener = l
|
|
|
|
+ srv.lock.Unlock()
|
|
return srv.serveTCP(l)
|
|
return srv.serveTCP(l)
|
|
case "udp", "udp4", "udp6":
|
|
case "udp", "udp4", "udp6":
|
|
a, e := net.ResolveUDPAddr(srv.Net, addr)
|
|
a, e := net.ResolveUDPAddr(srv.Net, addr)
|
|
@@ -279,8 +324,10 @@ func (srv *Server) ListenAndServe() error {
|
|
return e
|
|
return e
|
|
}
|
|
}
|
|
srv.PacketConn = l
|
|
srv.PacketConn = l
|
|
|
|
+ srv.lock.Unlock()
|
|
return srv.serveUDP(l)
|
|
return srv.serveUDP(l)
|
|
}
|
|
}
|
|
|
|
+ srv.lock.Unlock()
|
|
return &Error{err: "bad network"}
|
|
return &Error{err: "bad network"}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -294,20 +341,22 @@ func (srv *Server) ActivateAndServe() error {
|
|
}
|
|
}
|
|
srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool)
|
|
srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool)
|
|
srv.started = true
|
|
srv.started = true
|
|
|
|
+ pConn := srv.PacketConn
|
|
|
|
+ l := srv.Listener
|
|
srv.lock.Unlock()
|
|
srv.lock.Unlock()
|
|
- if srv.PacketConn != nil {
|
|
|
|
|
|
+ if pConn != nil {
|
|
if srv.UDPSize == 0 {
|
|
if srv.UDPSize == 0 {
|
|
srv.UDPSize = MinMsgSize
|
|
srv.UDPSize = MinMsgSize
|
|
}
|
|
}
|
|
- if t, ok := srv.PacketConn.(*net.UDPConn); ok {
|
|
|
|
|
|
+ if t, ok := pConn.(*net.UDPConn); ok {
|
|
if e := setUDPSocketOptions(t); e != nil {
|
|
if e := setUDPSocketOptions(t); e != nil {
|
|
return e
|
|
return e
|
|
}
|
|
}
|
|
return srv.serveUDP(t)
|
|
return srv.serveUDP(t)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
- if srv.Listener != nil {
|
|
|
|
- if t, ok := srv.Listener.(*net.TCPListener); ok {
|
|
|
|
|
|
+ if l != nil {
|
|
|
|
+ if t, ok := l.(*net.TCPListener); ok {
|
|
return srv.serveTCP(t)
|
|
return srv.serveTCP(t)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -316,7 +365,7 @@ func (srv *Server) ActivateAndServe() error {
|
|
|
|
|
|
// Shutdown gracefully shuts down a server. After a call to Shutdown, ListenAndServe and
|
|
// Shutdown gracefully shuts down a server. After a call to Shutdown, ListenAndServe and
|
|
// ActivateAndServe will return. All in progress queries are completed before the server
|
|
// ActivateAndServe will return. All in progress queries are completed before the server
|
|
-// is taken down. If the Shutdown is taking longer than the reading timeout and error
|
|
|
|
|
|
+// is taken down. If the Shutdown is taking longer than the reading timeout an error
|
|
// is returned.
|
|
// is returned.
|
|
func (srv *Server) Shutdown() error {
|
|
func (srv *Server) Shutdown() error {
|
|
srv.lock.Lock()
|
|
srv.lock.Lock()
|
|
@@ -325,7 +374,6 @@ func (srv *Server) Shutdown() error {
|
|
return &Error{err: "server not started"}
|
|
return &Error{err: "server not started"}
|
|
}
|
|
}
|
|
srv.started = false
|
|
srv.started = false
|
|
- srv.lock.Unlock()
|
|
|
|
net, addr := srv.Net, srv.Addr
|
|
net, addr := srv.Net, srv.Addr
|
|
switch {
|
|
switch {
|
|
case srv.Listener != nil:
|
|
case srv.Listener != nil:
|
|
@@ -335,6 +383,7 @@ func (srv *Server) Shutdown() error {
|
|
a := srv.PacketConn.LocalAddr()
|
|
a := srv.PacketConn.LocalAddr()
|
|
net, addr = a.Network(), a.String()
|
|
net, addr = a.Network(), a.String()
|
|
}
|
|
}
|
|
|
|
+ srv.lock.Unlock()
|
|
|
|
|
|
fin := make(chan bool)
|
|
fin := make(chan bool)
|
|
switch net {
|
|
switch net {
|
|
@@ -382,6 +431,11 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
|
|
srv.NotifyStartedFunc()
|
|
srv.NotifyStartedFunc()
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ reader := Reader(&defaultReader{srv})
|
|
|
|
+ if srv.DecorateReader != nil {
|
|
|
|
+ reader = srv.DecorateReader(reader)
|
|
|
|
+ }
|
|
|
|
+
|
|
handler := srv.Handler
|
|
handler := srv.Handler
|
|
if handler == nil {
|
|
if handler == nil {
|
|
handler = DefaultServeMux
|
|
handler = DefaultServeMux
|
|
@@ -393,7 +447,7 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
|
|
if e != nil {
|
|
if e != nil {
|
|
continue
|
|
continue
|
|
}
|
|
}
|
|
- m, e := srv.readTCP(rw, rtimeout)
|
|
|
|
|
|
+ m, e := reader.ReadTCP(rw, rtimeout)
|
|
select {
|
|
select {
|
|
case <-srv.stopTCP:
|
|
case <-srv.stopTCP:
|
|
return nil
|
|
return nil
|
|
@@ -405,7 +459,6 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
|
|
srv.wgTCP.Add(1)
|
|
srv.wgTCP.Add(1)
|
|
go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
|
|
go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
|
|
}
|
|
}
|
|
- panic("dns: not reached")
|
|
|
|
}
|
|
}
|
|
|
|
|
|
// serveUDP starts a UDP listener for the server.
|
|
// serveUDP starts a UDP listener for the server.
|
|
@@ -417,6 +470,11 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
|
srv.NotifyStartedFunc()
|
|
srv.NotifyStartedFunc()
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ reader := Reader(&defaultReader{srv})
|
|
|
|
+ if srv.DecorateReader != nil {
|
|
|
|
+ reader = srv.DecorateReader(reader)
|
|
|
|
+ }
|
|
|
|
+
|
|
handler := srv.Handler
|
|
handler := srv.Handler
|
|
if handler == nil {
|
|
if handler == nil {
|
|
handler = DefaultServeMux
|
|
handler = DefaultServeMux
|
|
@@ -424,7 +482,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
|
rtimeout := srv.getReadTimeout()
|
|
rtimeout := srv.getReadTimeout()
|
|
// deadline is not used here
|
|
// deadline is not used here
|
|
for {
|
|
for {
|
|
- m, s, e := srv.readUDP(l, rtimeout)
|
|
|
|
|
|
+ m, s, e := reader.ReadUDP(l, rtimeout)
|
|
select {
|
|
select {
|
|
case <-srv.stopUDP:
|
|
case <-srv.stopUDP:
|
|
return nil
|
|
return nil
|
|
@@ -436,13 +494,19 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
|
srv.wgUDP.Add(1)
|
|
srv.wgUDP.Add(1)
|
|
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
|
|
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
|
|
}
|
|
}
|
|
- panic("dns: not reached")
|
|
|
|
}
|
|
}
|
|
|
|
|
|
// Serve a new connection.
|
|
// Serve a new connection.
|
|
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t *net.TCPConn) {
|
|
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t *net.TCPConn) {
|
|
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
|
|
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
|
|
- q := 0
|
|
|
|
|
|
+ if srv.DecorateWriter != nil {
|
|
|
|
+ w.writer = srv.DecorateWriter(w)
|
|
|
|
+ } else {
|
|
|
|
+ w.writer = w
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ q := 0 // counter for the amount of TCP queries we get
|
|
|
|
+
|
|
defer func() {
|
|
defer func() {
|
|
if u != nil {
|
|
if u != nil {
|
|
srv.wgUDP.Done()
|
|
srv.wgUDP.Done()
|
|
@@ -451,6 +515,11 @@ func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *Ses
|
|
srv.wgTCP.Done()
|
|
srv.wgTCP.Done()
|
|
}
|
|
}
|
|
}()
|
|
}()
|
|
|
|
+
|
|
|
|
+ reader := Reader(&defaultReader{srv})
|
|
|
|
+ if srv.DecorateReader != nil {
|
|
|
|
+ reader = srv.DecorateReader(reader)
|
|
|
|
+ }
|
|
Redo:
|
|
Redo:
|
|
req := new(Msg)
|
|
req := new(Msg)
|
|
err := req.Unpack(m)
|
|
err := req.Unpack(m)
|
|
@@ -479,6 +548,12 @@ Redo:
|
|
h.ServeDNS(w, req) // Writes back to the client
|
|
h.ServeDNS(w, req) // Writes back to the client
|
|
|
|
|
|
Exit:
|
|
Exit:
|
|
|
|
+ // TODO(miek): make this number configurable?
|
|
|
|
+ if q > maxTCPQueries { // close socket after this many queries
|
|
|
|
+ w.Close()
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
if w.hijacked {
|
|
if w.hijacked {
|
|
return // client calls Close()
|
|
return // client calls Close()
|
|
}
|
|
}
|
|
@@ -490,14 +565,9 @@ Exit:
|
|
if srv.IdleTimeout != nil {
|
|
if srv.IdleTimeout != nil {
|
|
idleTimeout = srv.IdleTimeout()
|
|
idleTimeout = srv.IdleTimeout()
|
|
}
|
|
}
|
|
- m, e := srv.readTCP(w.tcp, idleTimeout)
|
|
|
|
|
|
+ m, e := reader.ReadTCP(w.tcp, idleTimeout)
|
|
if e == nil {
|
|
if e == nil {
|
|
q++
|
|
q++
|
|
- // TODO(miek): make this number configurable?
|
|
|
|
- if q > 128 { // close socket after this many queries
|
|
|
|
- w.Close()
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
goto Redo
|
|
goto Redo
|
|
}
|
|
}
|
|
w.Close()
|
|
w.Close()
|
|
@@ -562,7 +632,7 @@ func (w *response) WriteMsg(m *Msg) (err error) {
|
|
if err != nil {
|
|
if err != nil {
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
- _, err = w.Write(data)
|
|
|
|
|
|
+ _, err = w.writer.Write(data)
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -570,7 +640,7 @@ func (w *response) WriteMsg(m *Msg) (err error) {
|
|
if err != nil {
|
|
if err != nil {
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
- _, err = w.Write(data)
|
|
|
|
|
|
+ _, err = w.writer.Write(data)
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
|
|
|