|
@@ -1,7 +1,6 @@
|
|
package guerrilla
|
|
package guerrilla
|
|
|
|
|
|
import (
|
|
import (
|
|
- "bufio"
|
|
|
|
"crypto/rand"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"fmt"
|
|
@@ -15,6 +14,8 @@ import (
|
|
"runtime"
|
|
"runtime"
|
|
|
|
|
|
log "github.com/Sirupsen/logrus"
|
|
log "github.com/Sirupsen/logrus"
|
|
|
|
+
|
|
|
|
+ "sync"
|
|
)
|
|
)
|
|
|
|
|
|
const (
|
|
const (
|
|
@@ -26,29 +27,32 @@ const (
|
|
|
|
|
|
// Server listens for SMTP clients on the port specified in its config
|
|
// Server listens for SMTP clients on the port specified in its config
|
|
type server struct {
|
|
type server struct {
|
|
- config *ServerConfig
|
|
|
|
- backend Backend
|
|
|
|
- tlsConfig *tls.Config
|
|
|
|
- maxSize int64
|
|
|
|
- timeout time.Duration
|
|
|
|
- sem chan int
|
|
|
|
|
|
+ config *ServerConfig
|
|
|
|
+ backend Backend
|
|
|
|
+ tlsConfig *tls.Config
|
|
|
|
+ maxSize int64
|
|
|
|
+ timeout time.Duration
|
|
|
|
+ clientPool *Pool
|
|
|
|
+ wg sync.WaitGroup // for waiting to shutdown
|
|
|
|
+ listener net.Listener
|
|
|
|
+ closedListener chan (bool)
|
|
}
|
|
}
|
|
|
|
|
|
// Creates and returns a new ready-to-run Server from a configuration
|
|
// Creates and returns a new ready-to-run Server from a configuration
|
|
-func newServer(sc *ServerConfig, b Backend) (*server, error) {
|
|
|
|
|
|
+func newServer(sc ServerConfig, b *Backend) (*server, error) {
|
|
server := &server{
|
|
server := &server{
|
|
- config: sc,
|
|
|
|
- backend: b,
|
|
|
|
- maxSize: sc.MaxSize,
|
|
|
|
- sem: make(chan int, sc.MaxClients),
|
|
|
|
|
|
+ config: &sc,
|
|
|
|
+ backend: *b,
|
|
|
|
+ maxSize: sc.MaxSize,
|
|
|
|
+ timeout: time.Duration(sc.Timeout),
|
|
|
|
+ clientPool: NewPool(sc.MaxClients),
|
|
|
|
+ closedListener: make(chan (bool), 1),
|
|
}
|
|
}
|
|
-
|
|
|
|
if server.config.TLSAlwaysOn || server.config.StartTLSOn {
|
|
if server.config.TLSAlwaysOn || server.config.StartTLSOn {
|
|
cert, err := tls.LoadX509KeyPair(server.config.PublicKeyFile, server.config.PrivateKeyFile)
|
|
cert, err := tls.LoadX509KeyPair(server.config.PublicKeyFile, server.config.PrivateKeyFile)
|
|
if err != nil {
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Error loading TLS certificate: %s", err.Error())
|
|
return nil, fmt.Errorf("Error loading TLS certificate: %s", err.Error())
|
|
}
|
|
}
|
|
-
|
|
|
|
server.tlsConfig = &tls.Config{
|
|
server.tlsConfig = &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
Certificates: []tls.Certificate{cert},
|
|
ClientAuth: tls.VerifyClientCertIfGiven,
|
|
ClientAuth: tls.VerifyClientCertIfGiven,
|
|
@@ -56,44 +60,74 @@ func newServer(sc *ServerConfig, b Backend) (*server, error) {
|
|
Rand: rand.Reader,
|
|
Rand: rand.Reader,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
-
|
|
|
|
- server.timeout = time.Duration(server.config.Timeout) * time.Second
|
|
|
|
-
|
|
|
|
return server, nil
|
|
return server, nil
|
|
}
|
|
}
|
|
|
|
|
|
// Begin accepting SMTP clients
|
|
// Begin accepting SMTP clients
|
|
-func (server *server) Start() error {
|
|
|
|
|
|
+func (server *server) Start(startWG *sync.WaitGroup) error {
|
|
|
|
+ var clientID uint64
|
|
|
|
+ clientID = 0
|
|
|
|
+
|
|
listener, err := net.Listen("tcp", server.config.ListenInterface)
|
|
listener, err := net.Listen("tcp", server.config.ListenInterface)
|
|
|
|
+ server.listener = listener
|
|
if err != nil {
|
|
if err != nil {
|
|
- return fmt.Errorf("Cannot listen on port: %s", err.Error())
|
|
|
|
|
|
+ return fmt.Errorf("[%s] Cannot listen on port: %s ", server.config.ListenInterface, err.Error())
|
|
}
|
|
}
|
|
|
|
|
|
log.Infof("Listening on TCP %s", server.config.ListenInterface)
|
|
log.Infof("Listening on TCP %s", server.config.ListenInterface)
|
|
- var clientID int64
|
|
|
|
- clientID = 1
|
|
|
|
|
|
+ startWG.Done() // start successful
|
|
|
|
+
|
|
for {
|
|
for {
|
|
- log.Debugf("Waiting for a new client. Client ID: %d", clientID)
|
|
|
|
|
|
+ log.Debugf("[%s] Waiting for a new client. Next Client ID: %d", server.config.ListenInterface, clientID+1)
|
|
conn, err := listener.Accept()
|
|
conn, err := listener.Accept()
|
|
|
|
+ clientID++
|
|
if err != nil {
|
|
if err != nil {
|
|
- log.WithError(err).Info("Error accepting client")
|
|
|
|
|
|
+ if e, ok := err.(net.Error); ok && !e.Temporary() {
|
|
|
|
+ log.Infof("Server [%s] has stopped accepting new clients", server.config.ListenInterface)
|
|
|
|
+ // the listener has been closed, wait for clients to exit
|
|
|
|
+ log.Infof("shutting down pool [%s]", server.config.ListenInterface)
|
|
|
|
+ server.clientPool.ShutdownWait()
|
|
|
|
+ server.closedListener <- true
|
|
|
|
+ return nil
|
|
|
|
+ }
|
|
|
|
+ log.WithError(err).Info("Temporary error accepting client")
|
|
continue
|
|
continue
|
|
}
|
|
}
|
|
- server.sem <- 1
|
|
|
|
- go server.handleClient(&client{
|
|
|
|
- Envelope: &Envelope{
|
|
|
|
- RemoteAddress: conn.RemoteAddr().String(),
|
|
|
|
- },
|
|
|
|
- conn: conn,
|
|
|
|
- ConnectedAt: time.Now(),
|
|
|
|
- bufin: newSMTPBufferedReader(conn),
|
|
|
|
- bufout: bufio.NewWriter(conn),
|
|
|
|
- ID: clientID,
|
|
|
|
- })
|
|
|
|
- clientID++
|
|
|
|
|
|
+ go func(p Poolable, borrow_err error) {
|
|
|
|
+ c := p.(*client)
|
|
|
|
+ if borrow_err == nil {
|
|
|
|
+ server.handleClient(c)
|
|
|
|
+ server.clientPool.Return(c)
|
|
|
|
+ } else {
|
|
|
|
+ log.WithError(borrow_err).Info("couldn't borrow a new client")
|
|
|
|
+ // we could not get a client, so close the connection.
|
|
|
|
+ conn.Close()
|
|
|
|
+
|
|
|
|
+ }
|
|
|
|
+ // intentionally placed Borrow in args so that it's called in the
|
|
|
|
+ // same main goroutine.
|
|
|
|
+ }(server.clientPool.Borrow(conn, clientID))
|
|
|
|
+
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func (server *server) Shutdown() {
|
|
|
|
+ server.clientPool.ShutdownState()
|
|
|
|
+ if server.listener != nil {
|
|
|
|
+ server.listener.Close()
|
|
|
|
+ // wait for the listener to close.
|
|
|
|
+ <-server.closedListener
|
|
|
|
+ // At this point Start will exit and close down the pool
|
|
|
|
+ } else {
|
|
|
|
+ // listener already closed, wait for clients to exit
|
|
|
|
+ server.clientPool.ShutdownWait()
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (server *server) GetActiveClientsCount() int {
|
|
|
|
+ return server.clientPool.GetActiveClientsCount()
|
|
|
|
+}
|
|
|
|
+
|
|
// Verifies that the host is a valid recipient.
|
|
// Verifies that the host is a valid recipient.
|
|
func (server *server) allowsHost(host string) bool {
|
|
func (server *server) allowsHost(host string) bool {
|
|
for _, allowed := range server.config.AllowedHosts {
|
|
for _, allowed := range server.config.AllowedHosts {
|
|
@@ -113,8 +147,8 @@ func (server *server) upgradeToTLS(client *client) bool {
|
|
return false
|
|
return false
|
|
}
|
|
}
|
|
client.conn = net.Conn(tlsConn)
|
|
client.conn = net.Conn(tlsConn)
|
|
- client.bufin = newSMTPBufferedReader(client.conn)
|
|
|
|
- client.bufout = bufio.NewWriter(client.conn)
|
|
|
|
|
|
+ client.bufout.Reset(client.conn)
|
|
|
|
+ client.bufin.Reset(client.conn)
|
|
client.TLS = true
|
|
client.TLS = true
|
|
|
|
|
|
return true
|
|
return true
|
|
@@ -124,7 +158,6 @@ func (server *server) upgradeToTLS(client *client) bool {
|
|
func (server *server) closeConn(client *client) {
|
|
func (server *server) closeConn(client *client) {
|
|
client.conn.Close()
|
|
client.conn.Close()
|
|
client.conn = nil
|
|
client.conn = nil
|
|
- <-server.sem
|
|
|
|
}
|
|
}
|
|
|
|
|
|
// Reads from the client until a terminating sequence is encountered,
|
|
// Reads from the client until a terminating sequence is encountered,
|
|
@@ -141,7 +174,7 @@ func (server *server) read(client *client) (string, error) {
|
|
}
|
|
}
|
|
|
|
|
|
for {
|
|
for {
|
|
- client.conn.SetDeadline(time.Now().Add(server.timeout))
|
|
|
|
|
|
+ client.setTimeout(server.timeout)
|
|
reply, err = client.bufin.ReadString('\n')
|
|
reply, err = client.bufin.ReadString('\n')
|
|
input = input + reply
|
|
input = input + reply
|
|
if client.state == ClientData && reply != "" {
|
|
if client.state == ClientData && reply != "" {
|
|
@@ -163,7 +196,7 @@ func (server *server) read(client *client) (string, error) {
|
|
|
|
|
|
// Writes a response to the client.
|
|
// Writes a response to the client.
|
|
func (server *server) writeResponse(client *client) error {
|
|
func (server *server) writeResponse(client *client) error {
|
|
- client.conn.SetDeadline(time.Now().Add(server.timeout))
|
|
|
|
|
|
+ client.setTimeout(server.timeout)
|
|
size, err := client.bufout.WriteString(client.response)
|
|
size, err := client.bufout.WriteString(client.response)
|
|
if err != nil {
|
|
if err != nil {
|
|
return err
|
|
return err
|
|
@@ -176,15 +209,20 @@ func (server *server) writeResponse(client *client) error {
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func (server *server) isShuttingDown() bool {
|
|
|
|
+ return server.clientPool.IsShuttingDown()
|
|
|
|
+}
|
|
|
|
+
|
|
// Handles an entire client SMTP exchange
|
|
// Handles an entire client SMTP exchange
|
|
func (server *server) handleClient(client *client) {
|
|
func (server *server) handleClient(client *client) {
|
|
defer server.closeConn(client)
|
|
defer server.closeConn(client)
|
|
|
|
+
|
|
log.Infof("Handle client [%s], id: %d", client.RemoteAddress, client.ID)
|
|
log.Infof("Handle client [%s], id: %d", client.RemoteAddress, client.ID)
|
|
|
|
|
|
// Initial greeting
|
|
// Initial greeting
|
|
greeting := fmt.Sprintf("220 %s SMTP Guerrilla(%s) #%d (%d) %s gr:%d",
|
|
greeting := fmt.Sprintf("220 %s SMTP Guerrilla(%s) #%d (%d) %s gr:%d",
|
|
server.config.Hostname, Version, client.ID,
|
|
server.config.Hostname, Version, client.ID,
|
|
- len(server.sem), time.Now().Format(time.RFC3339), runtime.NumGoroutine())
|
|
|
|
|
|
+ server.clientPool.GetActiveClientsCount(), time.Now().Format(time.RFC3339), runtime.NumGoroutine())
|
|
|
|
|
|
helo := fmt.Sprintf("250 %s Hello", server.config.Hostname)
|
|
helo := fmt.Sprintf("250 %s Hello", server.config.Hostname)
|
|
ehlo := fmt.Sprintf("250-%s Hello", server.config.Hostname)
|
|
ehlo := fmt.Sprintf("250-%s Hello", server.config.Hostname)
|
|
@@ -212,7 +250,6 @@ func (server *server) handleClient(client *client) {
|
|
case ClientGreeting:
|
|
case ClientGreeting:
|
|
client.responseAdd(greeting)
|
|
client.responseAdd(greeting)
|
|
client.state = ClientCmd
|
|
client.state = ClientCmd
|
|
-
|
|
|
|
case ClientCmd:
|
|
case ClientCmd:
|
|
client.bufin.setLimit(CommandLineMaxLength)
|
|
client.bufin.setLimit(CommandLineMaxLength)
|
|
input, err := server.read(client)
|
|
input, err := server.read(client)
|
|
@@ -232,6 +269,10 @@ func (server *server) handleClient(client *client) {
|
|
client.kill()
|
|
client.kill()
|
|
break
|
|
break
|
|
}
|
|
}
|
|
|
|
+ if server.isShuttingDown() {
|
|
|
|
+ client.state = ClientShutdown
|
|
|
|
+ continue
|
|
|
|
+ }
|
|
|
|
|
|
input = strings.Trim(input, " \r\n")
|
|
input = strings.Trim(input, " \r\n")
|
|
cmdLen := len(input)
|
|
cmdLen := len(input)
|
|
@@ -322,6 +363,9 @@ func (server *server) handleClient(client *client) {
|
|
continue
|
|
continue
|
|
}
|
|
}
|
|
client.state = ClientCmd
|
|
client.state = ClientCmd
|
|
|
|
+ if server.isShuttingDown() {
|
|
|
|
+ client.state = ClientShutdown
|
|
|
|
+ }
|
|
|
|
|
|
if client.MailFrom.isEmpty() {
|
|
if client.MailFrom.isEmpty() {
|
|
client.responseAdd("550 Error: No sender")
|
|
client.responseAdd("550 Error: No sender")
|
|
@@ -351,6 +395,10 @@ func (server *server) handleClient(client *client) {
|
|
}
|
|
}
|
|
// change to command state
|
|
// change to command state
|
|
client.state = ClientCmd
|
|
client.state = ClientCmd
|
|
|
|
+ case ClientShutdown:
|
|
|
|
+ // shutdown state
|
|
|
|
+ client.responseAdd("421 Server is shutting down. Please try again later. Sayonara!")
|
|
|
|
+ client.kill()
|
|
}
|
|
}
|
|
|
|
|
|
if len(client.response) > 0 {
|
|
if len(client.response) > 0 {
|