瀏覽代碼

Let Backends return their own error code (custom result or error) (#113)

add ability for backends to specify a custom return code, fixes #78
Flashmob 7 年之前
父節點
當前提交
15a3295f1e
共有 9 個文件被更改,包括 269 次插入170 次删除
  1. 58 1
      api_test.go
  2. 22 7
      backends/backend.go
  3. 30 24
      backends/gateway.go
  4. 30 30
      client.go
  5. 1 1
      cmd/guerrillad/serve.go
  6. 109 93
      response/enhanced.go
  7. 13 8
      server.go
  8. 2 2
      tests/guerrilla_test.go
  9. 4 4
      util.go

+ 58 - 1
api_test.go

@@ -2,10 +2,12 @@ package guerrilla
 
 
 import (
 import (
 	"bufio"
 	"bufio"
+	"errors"
 	"fmt"
 	"fmt"
 	"github.com/flashmob/go-guerrilla/backends"
 	"github.com/flashmob/go-guerrilla/backends"
 	"github.com/flashmob/go-guerrilla/log"
 	"github.com/flashmob/go-guerrilla/log"
 	"github.com/flashmob/go-guerrilla/mail"
 	"github.com/flashmob/go-guerrilla/mail"
+	"github.com/flashmob/go-guerrilla/response"
 	"io/ioutil"
 	"io/ioutil"
 	"net"
 	"net"
 	"os"
 	"os"
@@ -349,7 +351,7 @@ var funkyLogger = func() backends.Decorator {
 		return backends.ProcessWith(
 		return backends.ProcessWith(
 			func(e *mail.Envelope, task backends.SelectTask) (backends.Result, error) {
 			func(e *mail.Envelope, task backends.SelectTask) (backends.Result, error) {
 				if task == backends.TaskValidateRcpt {
 				if task == backends.TaskValidateRcpt {
-					// validate the last recipient appended to e.Rcpt
+					// log the last recipient appended to e.Rcpt
 					backends.Log().Infof(
 					backends.Log().Infof(
 						"another funky recipient [%s]",
 						"another funky recipient [%s]",
 						e.RcptTo[len(e.RcptTo)-1])
 						e.RcptTo[len(e.RcptTo)-1])
@@ -556,3 +558,58 @@ func TestSkipAllowsHost(t *testing.T) {
 	}
 	}
 
 
 }
 }
+
+var customBackend2 = func() backends.Decorator {
+
+	return func(p backends.Processor) backends.Processor {
+		return backends.ProcessWith(
+			func(e *mail.Envelope, task backends.SelectTask) (backends.Result, error) {
+				if task == backends.TaskValidateRcpt {
+					return p.Process(e, task)
+				} else if task == backends.TaskSaveMail {
+					backends.Log().Info("Another funky email!")
+					err := errors.New("system shock")
+					return backends.NewResult(response.Canned.FailReadErrorDataCmd, response.SP, err), err
+				}
+				return p.Process(e, task)
+			})
+	}
+}
+
+// Test a custom backend response
+func TestCustomBackendResult(t *testing.T) {
+	os.Truncate("tests/testlog", 0)
+	cfg := &AppConfig{
+		LogFile:      "tests/testlog",
+		AllowedHosts: []string{"grr.la"},
+		BackendConfig: backends.BackendConfig{
+			"save_process":     "HeadersParser|Debugger|Custom",
+			"validate_process": "Custom",
+		},
+	}
+	d := Daemon{Config: cfg}
+	d.AddProcessor("Custom", customBackend2)
+
+	if err := d.Start(); err != nil {
+		t.Error(err)
+	}
+	// lets have a talk with the server
+	talkToServer("127.0.0.1:2525")
+
+	d.Shutdown()
+
+	b, err := ioutil.ReadFile("tests/testlog")
+	if err != nil {
+		t.Error("could not read logfile")
+		return
+	}
+	// lets check for fingerprints
+	if strings.Index(string(b), "451 4.3.0 Error") < 0 {
+		t.Error("did not log: 451 4.3.0 Error")
+	}
+
+	if strings.Index(string(b), "system shock") < 0 {
+		t.Error("did not log: system shock")
+	}
+
+}

+ 22 - 7
backends/backend.go

@@ -1,6 +1,7 @@
 package backends
 package backends
 
 
 import (
 import (
+	"bytes"
 	"fmt"
 	"fmt"
 	"github.com/flashmob/go-guerrilla/log"
 	"github.com/flashmob/go-guerrilla/log"
 	"github.com/flashmob/go-guerrilla/mail"
 	"github.com/flashmob/go-guerrilla/mail"
@@ -54,6 +55,7 @@ type BaseConfig interface{}
 type notifyMsg struct {
 type notifyMsg struct {
 	err      error
 	err      error
 	queuedID string
 	queuedID string
+	result   Result
 }
 }
 
 
 // Result represents a response to an SMTP client after receiving DATA.
 // Result represents a response to an SMTP client after receiving DATA.
@@ -66,16 +68,18 @@ type Result interface {
 }
 }
 
 
 // Internal implementation of BackendResult for use by backend implementations.
 // Internal implementation of BackendResult for use by backend implementations.
-type result string
+type result struct {
+	bytes.Buffer
+}
 
 
-func (br result) String() string {
-	return string(br)
+func (r *result) String() string {
+	return r.Buffer.String()
 }
 }
 
 
 // Parses the SMTP code from the first 3 characters of the SMTP message.
 // Parses the SMTP code from the first 3 characters of the SMTP message.
 // Returns 554 if code cannot be parsed.
 // Returns 554 if code cannot be parsed.
-func (br result) Code() int {
-	trimmed := strings.TrimSpace(string(br))
+func (r *result) Code() int {
+	trimmed := strings.TrimSpace(r.String())
 	if len(trimmed) < 3 {
 	if len(trimmed) < 3 {
 		return 554
 		return 554
 	}
 	}
@@ -86,8 +90,19 @@ func (br result) Code() int {
 	return code
 	return code
 }
 }
 
 
-func NewResult(message string) Result {
-	return result(message)
+func NewResult(r ...interface{}) Result {
+	buf := new(result)
+	for _, item := range r {
+		switch v := item.(type) {
+		case error:
+			buf.WriteString(v.Error())
+		case fmt.Stringer:
+			buf.WriteString(v.String())
+		case string:
+			buf.WriteString(v)
+		}
+	}
+	return buf
 }
 }
 
 
 type processorInitializer interface {
 type processorInitializer interface {

+ 30 - 24
backends/gateway.go

@@ -128,7 +128,7 @@ func (w *workerMsg) reset(e *mail.Envelope, task SelectTask) {
 // Process distributes an envelope to one of the backend workers with a TaskSaveMail task
 // Process distributes an envelope to one of the backend workers with a TaskSaveMail task
 func (gw *BackendGateway) Process(e *mail.Envelope) Result {
 func (gw *BackendGateway) Process(e *mail.Envelope) Result {
 	if gw.State != BackendStateRunning {
 	if gw.State != BackendStateRunning {
-		return NewResult(response.Canned.FailBackendNotRunning + gw.State.String())
+		return NewResult(response.Canned.FailBackendNotRunning, response.SP, gw.State)
 	}
 	}
 	// borrow a workerMsg from the pool
 	// borrow a workerMsg from the pool
 	workerMsg := workerMsgPool.Get().(*workerMsg)
 	workerMsg := workerMsgPool.Get().(*workerMsg)
@@ -139,11 +139,32 @@ func (gw *BackendGateway) Process(e *mail.Envelope) Result {
 	// or timeout
 	// or timeout
 	select {
 	select {
 	case status := <-workerMsg.notifyMe:
 	case status := <-workerMsg.notifyMe:
-		workerMsgPool.Put(workerMsg) // can be recycled since we used the notifyMe channel
+		// email saving transaction completed
+		if status.result == BackendResultOK && status.queuedID != "" {
+			return NewResult(response.Canned.SuccessMessageQueued, response.SP, status.queuedID)
+		}
+
+		// A custom result, there was probably an error, if so, log it
+		if status.result != nil {
+			if status.err != nil {
+				Log().Error(status.err)
+			}
+			return status.result
+		}
+
+		// if there was no result, but there's an error, then make a new result from the error
 		if status.err != nil {
 		if status.err != nil {
-			return NewResult(response.Canned.FailBackendTransaction + status.err.Error())
+			if _, err := strconv.Atoi(status.err.Error()[:3]); err != nil {
+				return NewResult(response.Canned.FailBackendTransaction, response.SP, status.err)
+			}
+			return NewResult(status.err)
 		}
 		}
-		return NewResult(response.Canned.SuccessMessageQueued + status.queuedID)
+
+		// both result & error are nil (should not happen)
+		err := errors.New("no response from backend - processor did not return a result or an error")
+		Log().Error(err)
+		return NewResult(response.Canned.FailBackendTransaction, response.SP, err)
+
 	case <-time.After(gw.saveTimeout()):
 	case <-time.After(gw.saveTimeout()):
 		Log().Error("Backend has timed out while saving email")
 		Log().Error("Backend has timed out while saving email")
 		e.Lock() // lock the envelope - it's still processing here, we don't want the server to recycle it
 		e.Lock() // lock the envelope - it's still processing here, we don't want the server to recycle it
@@ -434,27 +455,12 @@ func (gw *BackendGateway) workDispatcher(
 			return
 			return
 		case msg = <-workIn:
 		case msg = <-workIn:
 			state = dispatcherStateWorking // recovers from panic if in this state
 			state = dispatcherStateWorking // recovers from panic if in this state
+			result, err := save.Process(msg.e, msg.task)
+			state = dispatcherStateNotify
 			if msg.task == TaskSaveMail {
 			if msg.task == TaskSaveMail {
-				// process the email here
-				result, _ := save.Process(msg.e, TaskSaveMail)
-				state = dispatcherStateNotify
-				if result.Code() < 300 {
-					// if all good, let the gateway know that it was saved
-					msg.notifyMe <- &notifyMsg{nil, msg.e.QueuedId}
-				} else {
-					// notify the gateway about the error
-					msg.notifyMe <- &notifyMsg{err: errors.New(result.String())}
-				}
-			} else if msg.task == TaskValidateRcpt {
-				_, err := validate.Process(msg.e, TaskValidateRcpt)
-				state = dispatcherStateNotify
-				if err != nil {
-					// validation failed
-					msg.notifyMe <- &notifyMsg{err: err}
-				} else {
-					// all good.
-					msg.notifyMe <- &notifyMsg{err: nil}
-				}
+				msg.notifyMe <- &notifyMsg{err: err, result: result, queuedID: msg.e.QueuedId}
+			} else {
+				msg.notifyMe <- &notifyMsg{err: err, result: result}
 			}
 			}
 		}
 		}
 		state = dispatcherStateIdle
 		state = dispatcherStateIdle

+ 30 - 30
client.go

@@ -38,8 +38,9 @@ type client struct {
 	errors       int
 	errors       int
 	state        ClientState
 	state        ClientState
 	messagesSent int
 	messagesSent int
-	// Response to be written to the client
+	// Response to be written to the client (for debugging)
 	response   bytes.Buffer
 	response   bytes.Buffer
+	bufErr     error
 	conn       net.Conn
 	conn       net.Conn
 	bufin      *smtpBufferedReader
 	bufin      *smtpBufferedReader
 	bufout     *bufio.Writer
 	bufout     *bufio.Writer
@@ -69,39 +70,38 @@ func NewClient(conn net.Conn, clientID uint64, logger log.Logger, envelope *mail
 	return c
 	return c
 }
 }
 
 
-// setResponse adds a response to be written on the next turn
+// sendResponse adds a response to be written on the next turn
+// the response gets buffered
 func (c *client) sendResponse(r ...interface{}) {
 func (c *client) sendResponse(r ...interface{}) {
 	c.bufout.Reset(c.conn)
 	c.bufout.Reset(c.conn)
 	if c.log.IsDebug() {
 	if c.log.IsDebug() {
-		// us additional buffer so that we can log the response in debug mode only
+		// an additional buffer so that we can log the response in debug mode only
 		c.response.Reset()
 		c.response.Reset()
 	}
 	}
+	var out string
+	if c.bufErr != nil {
+		c.bufErr = nil
+	}
 	for _, item := range r {
 	for _, item := range r {
 		switch v := item.(type) {
 		switch v := item.(type) {
-		case string:
-			if _, err := c.bufout.WriteString(v); err != nil {
-				c.log.WithError(err).Error("could not write to c.bufout")
-			}
-			if c.log.IsDebug() {
-				c.response.WriteString(v)
-			}
 		case error:
 		case error:
-			if _, err := c.bufout.WriteString(v.Error()); err != nil {
-				c.log.WithError(err).Error("could not write to c.bufout")
-			}
-			if c.log.IsDebug() {
-				c.response.WriteString(v.Error())
-			}
+			out = v.Error()
 		case fmt.Stringer:
 		case fmt.Stringer:
-			if _, err := c.bufout.WriteString(v.String()); err != nil {
-				c.log.WithError(err).Error("could not write to c.bufout")
-			}
-			if c.log.IsDebug() {
-				c.response.WriteString(v.String())
-			}
+			out = v.String()
+		case string:
+			out = v
+		}
+		if _, c.bufErr = c.bufout.WriteString(out); c.bufErr != nil {
+			c.log.WithError(c.bufErr).Error("could not write to c.bufout")
+		}
+		if c.log.IsDebug() {
+			c.response.WriteString(out)
+		}
+		if c.bufErr != nil {
+			return
 		}
 		}
 	}
 	}
-	c.bufout.WriteString("\r\n")
+	_, c.bufErr = c.bufout.WriteString("\r\n")
 	if c.log.IsDebug() {
 	if c.log.IsDebug() {
 		c.response.WriteString("\r\n")
 		c.response.WriteString("\r\n")
 	}
 	}
@@ -176,20 +176,20 @@ func (c *client) getID() uint64 {
 }
 }
 
 
 // UpgradeToTLS upgrades a client connection to TLS
 // UpgradeToTLS upgrades a client connection to TLS
-func (client *client) upgradeToTLS(tlsConfig *tls.Config) error {
+func (c *client) upgradeToTLS(tlsConfig *tls.Config) error {
 	var tlsConn *tls.Conn
 	var tlsConn *tls.Conn
-	// wrap client.conn in a new TLS server side connection
-	tlsConn = tls.Server(client.conn, tlsConfig)
+	// wrap c.conn in a new TLS server side connection
+	tlsConn = tls.Server(c.conn, tlsConfig)
 	// Call handshake here to get any handshake error before reading starts
 	// Call handshake here to get any handshake error before reading starts
 	err := tlsConn.Handshake()
 	err := tlsConn.Handshake()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 	// convert tlsConn to net.Conn
 	// convert tlsConn to net.Conn
-	client.conn = net.Conn(tlsConn)
-	client.bufout.Reset(client.conn)
-	client.bufin.Reset(client.conn)
-	client.TLS = true
+	c.conn = net.Conn(tlsConn)
+	c.bufout.Reset(c.conn)
+	c.bufin.Reset(c.conn)
+	c.TLS = true
 	return err
 	return err
 }
 }
 
 

+ 1 - 1
cmd/guerrillad/serve.go

@@ -136,7 +136,7 @@ func readConfig(path string, pidFile string) (*guerrilla.AppConfig, error) {
 	// command line flags can override config values
 	// command line flags can override config values
 	appConfig, err := d.LoadConfig(path)
 	appConfig, err := d.LoadConfig(path)
 	if err != nil {
 	if err != nil {
-		return &appConfig, fmt.Errorf("Could not read config file: %s", err.Error())
+		return &appConfig, fmt.Errorf("could not read config file: %s", err.Error())
 	}
 	}
 	// override config pidFile with with flag from the command line
 	// override config pidFile with with flag from the command line
 	if len(pidFile) > 0 {
 	if len(pidFile) > 0 {

+ 109 - 93
response/enhanced.go

@@ -22,6 +22,9 @@ const (
 	ClassPermanentFailure = 5
 	ClassPermanentFailure = 5
 )
 )
 
 
+// space char
+const SP = " "
+
 // class is a type for ClassSuccess, ClassTransientFailure and ClassPermanentFailure constants
 // class is a type for ClassSuccess, ClassTransientFailure and ClassPermanentFailure constants
 type class int
 type class int
 
 
@@ -118,39 +121,39 @@ var (
 type Responses struct {
 type Responses struct {
 
 
 	// The 500's
 	// The 500's
-	FailLineTooLong              string
-	FailNestedMailCmd            string
-	FailNoSenderDataCmd          string
-	FailNoRecipientsDataCmd      string
-	FailUnrecognizedCmd          string
-	FailMaxUnrecognizedCmd       string
-	FailReadLimitExceededDataCmd string
-	FailMessageSizeExceeded      string
-	FailReadErrorDataCmd         string
-	FailPathTooLong              string
-	FailInvalidAddress           string
-	FailLocalPartTooLong         string
-	FailDomainTooLong            string
-	FailBackendNotRunning        string
-	FailBackendTransaction       string
-	FailBackendTimeout           string
-	FailRcptCmd                  string
+	FailLineTooLong              *Response
+	FailNestedMailCmd            *Response
+	FailNoSenderDataCmd          *Response
+	FailNoRecipientsDataCmd      *Response
+	FailUnrecognizedCmd          *Response
+	FailMaxUnrecognizedCmd       *Response
+	FailReadLimitExceededDataCmd *Response
+	FailMessageSizeExceeded      *Response
+	FailReadErrorDataCmd         *Response
+	FailPathTooLong              *Response
+	FailInvalidAddress           *Response
+	FailLocalPartTooLong         *Response
+	FailDomainTooLong            *Response
+	FailBackendNotRunning        *Response
+	FailBackendTransaction       *Response
+	FailBackendTimeout           *Response
+	FailRcptCmd                  *Response
 
 
 	// The 400's
 	// The 400's
-	ErrorTooManyRecipients string
-	ErrorRelayDenied       string
-	ErrorShutdown          string
+	ErrorTooManyRecipients *Response
+	ErrorRelayDenied       *Response
+	ErrorShutdown          *Response
 
 
 	// The 200's
 	// The 200's
-	SuccessMailCmd       string
-	SuccessRcptCmd       string
-	SuccessResetCmd      string
-	SuccessVerifyCmd     string
-	SuccessNoopCmd       string
-	SuccessQuitCmd       string
-	SuccessDataCmd       string
-	SuccessStartTLSCmd   string
-	SuccessMessageQueued string
+	SuccessMailCmd       *Response
+	SuccessRcptCmd       *Response
+	SuccessResetCmd      *Response
+	SuccessVerifyCmd     *Response
+	SuccessNoopCmd       *Response
+	SuccessQuitCmd       *Response
+	SuccessDataCmd       *Response
+	SuccessStartTLSCmd   *Response
+	SuccessMessageQueued *Response
 }
 }
 
 
 // Called automatically during package load to build up the Responses struct
 // Called automatically during package load to build up the Responses struct
@@ -158,191 +161,195 @@ func init() {
 
 
 	Canned = Responses{}
 	Canned = Responses{}
 
 
-	Canned.FailLineTooLong = (&Response{
+	Canned.FailLineTooLong = &Response{
 		EnhancedCode: InvalidCommand,
 		EnhancedCode: InvalidCommand,
 		BasicCode:    554,
 		BasicCode:    554,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
 		Comment:      "Line too long.",
 		Comment:      "Line too long.",
-	}).String()
+	}
 
 
-	Canned.FailNestedMailCmd = (&Response{
+	Canned.FailNestedMailCmd = &Response{
 		EnhancedCode: InvalidCommand,
 		EnhancedCode: InvalidCommand,
 		BasicCode:    503,
 		BasicCode:    503,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
 		Comment:      "Error: nested MAIL command",
 		Comment:      "Error: nested MAIL command",
-	}).String()
+	}
 
 
-	Canned.SuccessMailCmd = (&Response{
+	Canned.SuccessMailCmd = &Response{
 		EnhancedCode: OtherAddressStatus,
 		EnhancedCode: OtherAddressStatus,
 		Class:        ClassSuccess,
 		Class:        ClassSuccess,
-	}).String()
+	}
 
 
-	Canned.SuccessRcptCmd = (&Response{
+	Canned.SuccessRcptCmd = &Response{
 		EnhancedCode: DestinationMailboxAddressValid,
 		EnhancedCode: DestinationMailboxAddressValid,
 		Class:        ClassSuccess,
 		Class:        ClassSuccess,
-	}).String()
+	}
 
 
 	Canned.SuccessResetCmd = Canned.SuccessMailCmd
 	Canned.SuccessResetCmd = Canned.SuccessMailCmd
-	Canned.SuccessNoopCmd = (&Response{
+
+	Canned.SuccessNoopCmd = &Response{
 		EnhancedCode: OtherStatus,
 		EnhancedCode: OtherStatus,
 		Class:        ClassSuccess,
 		Class:        ClassSuccess,
-	}).String()
+	}
 
 
-	Canned.SuccessVerifyCmd = (&Response{
+	Canned.SuccessVerifyCmd = &Response{
 		EnhancedCode: OtherOrUndefinedProtocolStatus,
 		EnhancedCode: OtherOrUndefinedProtocolStatus,
 		BasicCode:    252,
 		BasicCode:    252,
 		Class:        ClassSuccess,
 		Class:        ClassSuccess,
 		Comment:      "Cannot verify user",
 		Comment:      "Cannot verify user",
-	}).String()
+	}
 
 
-	Canned.ErrorTooManyRecipients = (&Response{
+	Canned.ErrorTooManyRecipients = &Response{
 		EnhancedCode: TooManyRecipients,
 		EnhancedCode: TooManyRecipients,
 		BasicCode:    452,
 		BasicCode:    452,
 		Class:        ClassTransientFailure,
 		Class:        ClassTransientFailure,
 		Comment:      "Too many recipients",
 		Comment:      "Too many recipients",
-	}).String()
+	}
 
 
-	Canned.ErrorRelayDenied = (&Response{
+	Canned.ErrorRelayDenied = &Response{
 		EnhancedCode: BadDestinationMailboxAddress,
 		EnhancedCode: BadDestinationMailboxAddress,
 		BasicCode:    454,
 		BasicCode:    454,
 		Class:        ClassTransientFailure,
 		Class:        ClassTransientFailure,
-		Comment:      "Error: Relay access denied: ",
-	}).String()
+		Comment:      "Error: Relay access denied:",
+	}
 
 
-	Canned.SuccessQuitCmd = (&Response{
+	Canned.SuccessQuitCmd = &Response{
 		EnhancedCode: OtherStatus,
 		EnhancedCode: OtherStatus,
 		BasicCode:    221,
 		BasicCode:    221,
 		Class:        ClassSuccess,
 		Class:        ClassSuccess,
 		Comment:      "Bye",
 		Comment:      "Bye",
-	}).String()
+	}
 
 
-	Canned.FailNoSenderDataCmd = (&Response{
+	Canned.FailNoSenderDataCmd = &Response{
 		EnhancedCode: InvalidCommand,
 		EnhancedCode: InvalidCommand,
 		BasicCode:    503,
 		BasicCode:    503,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
 		Comment:      "Error: No sender",
 		Comment:      "Error: No sender",
-	}).String()
+	}
 
 
-	Canned.FailNoRecipientsDataCmd = (&Response{
+	Canned.FailNoRecipientsDataCmd = &Response{
 		EnhancedCode: InvalidCommand,
 		EnhancedCode: InvalidCommand,
 		BasicCode:    503,
 		BasicCode:    503,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
 		Comment:      "Error: No recipients",
 		Comment:      "Error: No recipients",
-	}).String()
+	}
 
 
-	Canned.SuccessDataCmd = "354 Enter message, ending with '.' on a line by itself"
+	Canned.SuccessDataCmd = &Response{
+		BasicCode: 354,
+		Comment:   "354 Enter message, ending with '.' on a line by itself",
+	}
 
 
-	Canned.SuccessStartTLSCmd = (&Response{
+	Canned.SuccessStartTLSCmd = &Response{
 		EnhancedCode: OtherStatus,
 		EnhancedCode: OtherStatus,
 		BasicCode:    220,
 		BasicCode:    220,
 		Class:        ClassSuccess,
 		Class:        ClassSuccess,
 		Comment:      "Ready to start TLS",
 		Comment:      "Ready to start TLS",
-	}).String()
+	}
 
 
-	Canned.FailUnrecognizedCmd = (&Response{
+	Canned.FailUnrecognizedCmd = &Response{
 		EnhancedCode: InvalidCommand,
 		EnhancedCode: InvalidCommand,
 		BasicCode:    554,
 		BasicCode:    554,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
 		Comment:      "Unrecognized command",
 		Comment:      "Unrecognized command",
-	}).String()
+	}
 
 
-	Canned.FailMaxUnrecognizedCmd = (&Response{
+	Canned.FailMaxUnrecognizedCmd = &Response{
 		EnhancedCode: InvalidCommand,
 		EnhancedCode: InvalidCommand,
 		BasicCode:    554,
 		BasicCode:    554,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
 		Comment:      "Too many unrecognized commands",
 		Comment:      "Too many unrecognized commands",
-	}).String()
+	}
 
 
-	Canned.ErrorShutdown = (&Response{
+	Canned.ErrorShutdown = &Response{
 		EnhancedCode: OtherOrUndefinedMailSystemStatus,
 		EnhancedCode: OtherOrUndefinedMailSystemStatus,
 		BasicCode:    421,
 		BasicCode:    421,
 		Class:        ClassTransientFailure,
 		Class:        ClassTransientFailure,
 		Comment:      "Server is shutting down. Please try again later. Sayonara!",
 		Comment:      "Server is shutting down. Please try again later. Sayonara!",
-	}).String()
+	}
 
 
-	Canned.FailReadLimitExceededDataCmd = (&Response{
+	Canned.FailReadLimitExceededDataCmd = &Response{
 		EnhancedCode: SyntaxError,
 		EnhancedCode: SyntaxError,
 		BasicCode:    550,
 		BasicCode:    550,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
-		Comment:      "Error: ",
-	}).String()
+		Comment:      "Error:",
+	}
 
 
-	Canned.FailMessageSizeExceeded = (&Response{
+	Canned.FailMessageSizeExceeded = &Response{
 		EnhancedCode: SyntaxError,
 		EnhancedCode: SyntaxError,
 		BasicCode:    550,
 		BasicCode:    550,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
-		Comment:      "Error: ",
-	}).String()
+		Comment:      "Error:",
+	}
 
 
-	Canned.FailReadErrorDataCmd = (&Response{
+	Canned.FailReadErrorDataCmd = &Response{
 		EnhancedCode: OtherOrUndefinedMailSystemStatus,
 		EnhancedCode: OtherOrUndefinedMailSystemStatus,
 		BasicCode:    451,
 		BasicCode:    451,
 		Class:        ClassTransientFailure,
 		Class:        ClassTransientFailure,
-		Comment:      "Error: ",
-	}).String()
+		Comment:      "Error:",
+	}
 
 
-	Canned.FailPathTooLong = (&Response{
+	Canned.FailPathTooLong = &Response{
 		EnhancedCode: InvalidCommandArguments,
 		EnhancedCode: InvalidCommandArguments,
 		BasicCode:    550,
 		BasicCode:    550,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
 		Comment:      "Path too long",
 		Comment:      "Path too long",
-	}).String()
+	}
 
 
-	Canned.FailInvalidAddress = (&Response{
+	Canned.FailInvalidAddress = &Response{
 		EnhancedCode: InvalidCommandArguments,
 		EnhancedCode: InvalidCommandArguments,
 		BasicCode:    501,
 		BasicCode:    501,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
 		Comment:      "Invalid address",
 		Comment:      "Invalid address",
-	}).String()
+	}
 
 
-	Canned.FailLocalPartTooLong = (&Response{
+	Canned.FailLocalPartTooLong = &Response{
 		EnhancedCode: InvalidCommandArguments,
 		EnhancedCode: InvalidCommandArguments,
 		BasicCode:    550,
 		BasicCode:    550,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
 		Comment:      "Local part too long, cannot exceed 64 characters",
 		Comment:      "Local part too long, cannot exceed 64 characters",
-	}).String()
+	}
 
 
-	Canned.FailDomainTooLong = (&Response{
+	Canned.FailDomainTooLong = &Response{
 		EnhancedCode: InvalidCommandArguments,
 		EnhancedCode: InvalidCommandArguments,
 		BasicCode:    550,
 		BasicCode:    550,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
 		Comment:      "Domain cannot exceed 255 characters",
 		Comment:      "Domain cannot exceed 255 characters",
-	}).String()
+	}
 
 
-	Canned.FailBackendNotRunning = (&Response{
+	Canned.FailBackendNotRunning = &Response{
 		EnhancedCode: OtherOrUndefinedProtocolStatus,
 		EnhancedCode: OtherOrUndefinedProtocolStatus,
 		BasicCode:    554,
 		BasicCode:    554,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
-		Comment:      "Transaction failed - backend not running ",
-	}).String()
+		Comment:      "Transaction failed - backend not running",
+	}
 
 
-	Canned.FailBackendTransaction = (&Response{
+	Canned.FailBackendTransaction = &Response{
 		EnhancedCode: OtherOrUndefinedProtocolStatus,
 		EnhancedCode: OtherOrUndefinedProtocolStatus,
 		BasicCode:    554,
 		BasicCode:    554,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
-		Comment:      "Error: ",
-	}).String()
+		Comment:      "Error:",
+	}
 
 
-	Canned.SuccessMessageQueued = (&Response{
+	Canned.SuccessMessageQueued = &Response{
 		EnhancedCode: OtherStatus,
 		EnhancedCode: OtherStatus,
 		BasicCode:    250,
 		BasicCode:    250,
 		Class:        ClassSuccess,
 		Class:        ClassSuccess,
-		Comment:      "OK : queued as ",
-	}).String()
+		Comment:      "OK: queued as",
+	}
 
 
-	Canned.FailBackendTimeout = (&Response{
+	Canned.FailBackendTimeout = &Response{
 		EnhancedCode: OtherOrUndefinedProtocolStatus,
 		EnhancedCode: OtherOrUndefinedProtocolStatus,
 		BasicCode:    554,
 		BasicCode:    554,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
 		Comment:      "Error: transaction timeout",
 		Comment:      "Error: transaction timeout",
-	}).String()
+	}
 
 
-	Canned.FailRcptCmd = (&Response{
+	Canned.FailRcptCmd = &Response{
 		EnhancedCode: BadDestinationMailboxAddress,
 		EnhancedCode: BadDestinationMailboxAddress,
 		BasicCode:    550,
 		BasicCode:    550,
 		Class:        ClassPermanentFailure,
 		Class:        ClassPermanentFailure,
 		Comment:      "User unknown in local recipient table",
 		Comment:      "User unknown in local recipient table",
-	}).String()
+	}
 
 
 }
 }
 
 
@@ -409,6 +416,7 @@ type Response struct {
 	Class        class
 	Class        class
 	// Comment is optional
 	// Comment is optional
 	Comment string
 	Comment string
+	cached  string
 }
 }
 
 
 // it looks like this ".5.4"
 // it looks like this ".5.4"
@@ -428,6 +436,14 @@ func (e EnhancedStatusCode) String() string {
 // String returns a custom Response as a string
 // String returns a custom Response as a string
 func (r *Response) String() string {
 func (r *Response) String() string {
 
 
+	if r.cached != "" {
+		return r.cached
+	}
+	if r.EnhancedCode == "" {
+		r.cached = r.Comment
+		return r.Comment
+	}
+
 	basicCode := r.BasicCode
 	basicCode := r.BasicCode
 	comment := r.Comment
 	comment := r.Comment
 	if len(comment) == 0 && r.BasicCode == 0 {
 	if len(comment) == 0 && r.BasicCode == 0 {
@@ -447,8 +463,8 @@ func (r *Response) String() string {
 	if r.BasicCode == 0 {
 	if r.BasicCode == 0 {
 		basicCode = getBasicStatusCode(e)
 		basicCode = getBasicStatusCode(e)
 	}
 	}
-
-	return fmt.Sprintf("%d %s %s", basicCode, e.String(), comment)
+	r.cached = fmt.Sprintf("%d %s %s", basicCode, e.String(), comment)
+	return r.cached
 }
 }
 
 
 // getBasicStatusCode gets the basic status code from codeMap, or fallback code if not mapped
 // getBasicStatusCode gets the basic status code from codeMap, or fallback code if not mapped

+ 13 - 8
server.go

@@ -430,7 +430,7 @@ func (server *server) handleClient(client *client) {
 
 
 			case strings.Index(cmd, "HELP") == 0:
 			case strings.Index(cmd, "HELP") == 0:
 				quote := response.GetQuote()
 				quote := response.GetQuote()
-				client.sendResponse("214-OK\r\n" + quote)
+				client.sendResponse("214-OK\r\n", quote)
 
 
 			case sc.XClientOn && strings.Index(cmd, "XCLIENT ") == 0:
 			case sc.XClientOn && strings.Index(cmd, "XCLIENT ") == 0:
 				if toks := strings.Split(input[8:], " "); len(toks) > 0 {
 				if toks := strings.Split(input[8:], " "); len(toks) > 0 {
@@ -482,13 +482,13 @@ func (server *server) handleClient(client *client) {
 					client.sendResponse(err.Error())
 					client.sendResponse(err.Error())
 				} else {
 				} else {
 					if !server.allowsHost(to.Host) {
 					if !server.allowsHost(to.Host) {
-						client.sendResponse(response.Canned.ErrorRelayDenied, to.Host)
+						client.sendResponse(response.Canned.ErrorRelayDenied, " ", to.Host)
 					} else {
 					} else {
 						client.PushRcpt(to)
 						client.PushRcpt(to)
 						rcptError := server.backend().ValidateRcpt(client.Envelope)
 						rcptError := server.backend().ValidateRcpt(client.Envelope)
 						if rcptError != nil {
 						if rcptError != nil {
 							client.PopRcpt()
 							client.PopRcpt()
-							client.sendResponse(response.Canned.FailRcptCmd + " " + rcptError.Error())
+							client.sendResponse(response.Canned.FailRcptCmd, " ", rcptError.Error())
 						} else {
 						} else {
 							client.sendResponse(response.Canned.SuccessRcptCmd)
 							client.sendResponse(response.Canned.SuccessRcptCmd)
 						}
 						}
@@ -543,13 +543,13 @@ func (server *server) handleClient(client *client) {
 			}
 			}
 			if err != nil {
 			if err != nil {
 				if err == LineLimitExceeded {
 				if err == LineLimitExceeded {
-					client.sendResponse(response.Canned.FailReadLimitExceededDataCmd, LineLimitExceeded.Error())
+					client.sendResponse(response.Canned.FailReadLimitExceededDataCmd, " ", LineLimitExceeded.Error())
 					client.kill()
 					client.kill()
 				} else if err == MessageSizeExceeded {
 				} else if err == MessageSizeExceeded {
-					client.sendResponse(response.Canned.FailMessageSizeExceeded, MessageSizeExceeded.Error())
+					client.sendResponse(response.Canned.FailMessageSizeExceeded, " ", MessageSizeExceeded.Error())
 					client.kill()
 					client.kill()
 				} else {
 				} else {
-					client.sendResponse(response.Canned.FailReadErrorDataCmd, err.Error())
+					client.sendResponse(response.Canned.FailReadErrorDataCmd, " ", err.Error())
 					client.kill()
 					client.kill()
 				}
 				}
 				server.log().WithError(err).Warn("Error reading data")
 				server.log().WithError(err).Warn("Error reading data")
@@ -561,7 +561,7 @@ func (server *server) handleClient(client *client) {
 			if res.Code() < 300 {
 			if res.Code() < 300 {
 				client.messagesSent++
 				client.messagesSent++
 			}
 			}
-			client.sendResponse(res.String())
+			client.sendResponse(res)
 			client.state = ClientCmd
 			client.state = ClientCmd
 			if server.isShuttingDown() {
 			if server.isShuttingDown() {
 				client.state = ClientShutdown
 				client.state = ClientShutdown
@@ -589,13 +589,18 @@ func (server *server) handleClient(client *client) {
 			client.kill()
 			client.kill()
 		}
 		}
 
 
+		if client.bufErr != nil {
+			server.log().WithError(client.bufErr).Debug("client could not buffer a response")
+			return
+		}
+		// flush the response buffer
 		if client.bufout.Buffered() > 0 {
 		if client.bufout.Buffered() > 0 {
 			if server.log().IsDebug() {
 			if server.log().IsDebug() {
 				server.log().Debugf("Writing response to client: \n%s", client.response.String())
 				server.log().Debugf("Writing response to client: \n%s", client.response.String())
 			}
 			}
 			err := server.flushResponse(client)
 			err := server.flushResponse(client)
 			if err != nil {
 			if err != nil {
-				server.log().WithError(err).Debug("Error writing response")
+				server.log().WithError(err).Debug("error writing response")
 				return
 				return
 			}
 			}
 		}
 		}

+ 2 - 2
tests/guerrilla_test.go

@@ -1011,7 +1011,7 @@ func TestDataMaxLength(t *testing.T) {
 			//expected := "500 Line too long"
 			//expected := "500 Line too long"
 			expected := "451 4.3.0 Error: Maximum DATA size exceeded"
 			expected := "451 4.3.0 Error: Maximum DATA size exceeded"
 			if strings.Index(response, expected) != 0 {
 			if strings.Index(response, expected) != 0 {
-				t.Error("Server did not respond with", expected, ", it said:"+response, err)
+				t.Error("Server did not respond with", expected, ", it said:"+response)
 			}
 			}
 
 
 		}
 		}
@@ -1105,7 +1105,7 @@ func TestDataCommand(t *testing.T) {
 				bufin,
 				bufin,
 				email+"\r\n.\r\n")
 				email+"\r\n.\r\n")
 			//expected := "500 Line too long"
 			//expected := "500 Line too long"
-			expected := "250 2.0.0 OK : queued as "
+			expected := "250 2.0.0 OK: queued as "
 			if strings.Index(response, expected) != 0 {
 			if strings.Index(response, expected) != 0 {
 				t.Error("Server did not respond with", expected, ", it said:"+response, err)
 				t.Error("Server did not respond with", expected, ", it said:"+response, err)
 			}
 			}

+ 4 - 4
util.go

@@ -15,7 +15,7 @@ func extractEmail(str string) (mail.Address, error) {
 	email := mail.Address{}
 	email := mail.Address{}
 	var err error
 	var err error
 	if len(str) > RFC2821LimitPath {
 	if len(str) > RFC2821LimitPath {
-		return email, errors.New(response.Canned.FailPathTooLong)
+		return email, errors.New(response.Canned.FailPathTooLong.String())
 	}
 	}
 	if matched := extractEmailRegex.FindStringSubmatch(str); len(matched) > 2 {
 	if matched := extractEmailRegex.FindStringSubmatch(str); len(matched) > 2 {
 		email.User = matched[1]
 		email.User = matched[1]
@@ -26,11 +26,11 @@ func extractEmail(str string) (mail.Address, error) {
 	}
 	}
 	err = nil
 	err = nil
 	if email.User == "" || email.Host == "" {
 	if email.User == "" || email.Host == "" {
-		err = errors.New(response.Canned.FailInvalidAddress)
+		err = errors.New(response.Canned.FailInvalidAddress.String())
 	} else if len(email.User) > RFC2832LimitLocalPart {
 	} else if len(email.User) > RFC2832LimitLocalPart {
-		err = errors.New(response.Canned.FailLocalPartTooLong)
+		err = errors.New(response.Canned.FailLocalPartTooLong.String())
 	} else if len(email.Host) > RFC2821LimitDomain {
 	} else if len(email.Host) > RFC2821LimitDomain {
-		err = errors.New(response.Canned.FailDomainTooLong)
+		err = errors.New(response.Canned.FailDomainTooLong.String())
 	}
 	}
 	return email, err
 	return email, err
 }
 }