Browse Source

Merge branch 'master' of https://github.com/flashmob/go-guerrilla

Philipp Resch 8 years ago
parent
commit
456713b0f4
6 changed files with 311 additions and 40 deletions
  1. 31 7
      backends/backend.go
  2. 37 1
      response/enhanced.go
  3. 18 7
      response/enhanced_test.go
  4. 126 21
      server.go
  5. 77 0
      tests/guerrilla_test.go
  6. 22 4
      util.go

+ 31 - 7
backends/backend.go

@@ -128,10 +128,16 @@ func New(backendName string, backendConfig BackendConfig) (Backend, error) {
 	return gateway, nil
 }
 
-// Distributes an envelope to one of the backend workers
+// Process distributes an envelope to one of the backend workers
 func (gw *BackendGateway) Process(e *envelope.Envelope) BackendResult {
 	if gw.State != BackendStateRunning {
-		return NewBackendResult(response.CustomString(response.OtherOrUndefinedProtocolStatus, 554, response.ClassPermanentFailure, "Transaction failed - backend not running "+strconv.Itoa(gw.State)))
+		resp := &response.Response{
+			EnhancedCode: response.OtherOrUndefinedProtocolStatus,
+			BasicCode:    554,
+			Class:        response.ClassPermanentFailure,
+			Comment:      "Transaction failed - backend not running " + strconv.Itoa(gw.State),
+		}
+		return NewBackendResult(resp.String())
 	}
 
 	to := e.RcptTo
@@ -146,12 +152,31 @@ func (gw *BackendGateway) Process(e *envelope.Envelope) BackendResult {
 	select {
 	case status := <-savedNotify:
 		if status.err != nil {
-			return NewBackendResult(response.CustomString(response.OtherOrUndefinedProtocolStatus, 554, response.ClassPermanentFailure, "Error: "+status.err.Error()))
+			resp := &response.Response{
+				EnhancedCode: response.OtherOrUndefinedProtocolStatus,
+				BasicCode:    554,
+				Class:        response.ClassPermanentFailure,
+				Comment:      "Error: " + status.err.Error(),
+			}
+			return NewBackendResult(resp.String())
+		}
+		resp := &response.Response{
+			EnhancedCode: response.OtherStatus,
+			BasicCode:    250,
+			Class:        response.ClassSuccess,
+			Comment:      fmt.Sprintf("OK : queued as %s", status.hash),
 		}
-		return NewBackendResult(response.CustomString(response.OtherStatus, 250, response.ClassSuccess, fmt.Sprintf("OK : queued as %s", status.hash)))
+		return NewBackendResult(resp.String())
+
 	case <-time.After(time.Second * 30):
 		log.Infof("Backend has timed out")
-		return NewBackendResult(response.CustomString(response.OtherOrUndefinedProtocolStatus, 554, response.ClassPermanentFailure, "Error: transaction timeout"))
+		resp := &response.Response{
+			EnhancedCode: response.OtherOrUndefinedProtocolStatus,
+			BasicCode:    554,
+			Class:        response.ClassPermanentFailure,
+			Comment:      "Error: transaction timeout",
+		}
+		return NewBackendResult(resp.String())
 	}
 }
 func (gw *BackendGateway) Shutdown() error {
@@ -177,9 +202,8 @@ func (gw *BackendGateway) Reinitialize() error {
 	err := gw.Initialize(gw.config)
 	if err != nil {
 		return fmt.Errorf("error while initializing the backend: %s", err)
-	} else {
-		gw.State = BackendStateRunning
 	}
+	gw.State = BackendStateRunning
 	return err
 }
 

+ 37 - 1
response/enhanced.go

@@ -29,6 +29,7 @@ const (
 var codeMap = struct {
 	m map[string]int
 }{m: map[string]int{
+	"2.1.0":  250,
 	"2.1.5":  250,
 	"2.3.0":  250,
 	"2.5.0":  250,
@@ -155,6 +156,41 @@ var defaultTexts = struct {
 	"5.5.1": "Invalid command",
 }}
 
+// Response type for Stringer interface
+type Response struct {
+	EnhancedCode string
+	BasicCode    int
+	Class        int
+	// Comment is optional
+	Comment string
+}
+
+// Custom returns a custom Response Stringer
+func (r *Response) String() string {
+	e := buildEnhancedResponseFromDefaultStatus(r.Class, r.EnhancedCode)
+	basicCode := r.BasicCode
+	comment := r.Comment
+	if len(comment) == 0 && r.BasicCode == 0 {
+		comment = defaultTexts.m[r.EnhancedCode]
+		if len(comment) == 0 {
+			switch r.Class {
+			case 2:
+				comment = "OK"
+			case 4:
+				comment = "Temporary failure."
+			case 5:
+				comment = "Permanent failure."
+			}
+		}
+	}
+	if r.BasicCode == 0 {
+		basicCode = getBasicStatusCode(e)
+	}
+
+	return fmt.Sprintf("%d %s %s", basicCode, e, comment)
+}
+
+/*
 // CustomString builds an enhanced status code string using your custom string and basic code
 func CustomString(enhancedCode string, basicCode, class int, comment string) string {
 	e := buildEnhancedResponseFromDefaultStatus(class, enhancedCode)
@@ -179,7 +215,7 @@ func String(enhancedCode string, class int) string {
 	}
 	return CustomString(enhancedCode, basicCode, class, comment)
 }
-
+*/
 func getBasicStatusCode(enhancedStatusCode string) int {
 	if val, ok := codeMap.m[enhancedStatusCode]; ok {
 		return val

+ 18 - 7
response/enhanced_test.go

@@ -1,6 +1,8 @@
 package response
 
-import "testing"
+import (
+	"testing"
+)
 
 func TestClass(t *testing.T) {
 	if ClassPermanentFailure != 5 {
@@ -31,15 +33,24 @@ func TestGetBasicStatusCode(t *testing.T) {
 // TestString for the String function
 func TestCustomString(t *testing.T) {
 	// Basic testing
-	a := CustomString(OtherStatus, 200, ClassSuccess, "Test")
-	if a != "200 2.0.0 Test" {
-		t.Errorf("CustomString failed. String \"%s\" not expected.", a)
+	resp := &Response{
+		EnhancedCode: OtherStatus,
+		BasicCode:    200,
+		Class:        ClassSuccess,
+		Comment:      "Test",
+	}
+
+	if resp.String() != "200 2.0.0 Test" {
+		t.Errorf("CustomString failed. String \"%s\" not expected.", resp)
 	}
 
 	// Default String
-	b := String(OtherStatus, ClassSuccess)
-	if b != "200 2.0.0 OK" {
-		t.Errorf("String failed. String \"%s\" not expected.", b)
+	resp2 := &Response{
+		EnhancedCode: OtherStatus,
+		Class:        ClassSuccess,
+	}
+	if resp2.String() != "200 2.0.0 OK" {
+		t.Errorf("String failed. String \"%s\" not expected.", resp2)
 	}
 }
 

+ 126 - 21
server.go

@@ -308,7 +308,13 @@ func (server *server) handleClient(client *client) {
 				log.WithError(err).Warnf("Timeout: %s", client.RemoteAddress)
 				return
 			} else if err == LineLimitExceeded {
-				client.responseAdd(response.CustomString(response.InvalidCommand, 554, response.ClassPermanentFailure, "Line too long."))
+				resp := &response.Response{
+					EnhancedCode: response.InvalidCommand,
+					BasicCode:    554,
+					Class:        response.ClassPermanentFailure,
+					Comment:      "Line too long.",
+				}
+				client.responseAdd(resp.String())
 				client.kill()
 				break
 			} else if err != nil {
@@ -327,7 +333,6 @@ func (server *server) handleClient(client *client) {
 				cmdLen = CommandVerbMaxLength
 			}
 			cmd := strings.ToUpper(input[:cmdLen])
-
 			switch {
 			case strings.Index(cmd, "HELO") == 0:
 				client.Helo = strings.Trim(input[4:], " ")
@@ -344,7 +349,13 @@ func (server *server) handleClient(client *client) {
 
 			case strings.Index(cmd, "MAIL FROM:") == 0:
 				if client.isInTransaction() {
-					client.responseAdd(response.CustomString(response.InvalidCommand, 503, response.ClassPermanentFailure, "Error: nested MAIL command"))
+					resp := &response.Response{
+						EnhancedCode: response.InvalidCommand,
+						BasicCode:    503,
+						Class:        response.ClassPermanentFailure,
+						Comment:      "Error: nested MAIL command",
+					}
+					client.responseAdd(resp.String())
 					break
 				}
 				// Fix for issue #53 - MAIL FROM may only be <> if it is a bounce
@@ -359,12 +370,22 @@ func (server *server) handleClient(client *client) {
 					client.responseAdd(err.Error())
 				} else {
 					client.MailFrom = from
-					client.responseAdd(response.CustomString(response.OtherAddressStatus, 250, response.ClassSuccess, "OK"))
+					resp := &response.Response{
+						EnhancedCode: response.OtherAddressStatus,
+						Class:        response.ClassSuccess,
+					}
+					client.responseAdd(resp.String())
 				}
 
 			case strings.Index(cmd, "RCPT TO:") == 0:
 				if len(client.RcptTo) > RFC2821LimitRecipients {
-					client.responseAdd(response.CustomString(response.TooManyRecipients, 452, response.ClassTransientFailure, "Too many recipients"))
+					resp := &response.Response{
+						EnhancedCode: response.TooManyRecipients,
+						BasicCode:    452,
+						Class:        response.ClassTransientFailure,
+						Comment:      "Too many recipients",
+					}
+					client.responseAdd(resp.String())
 					break
 				}
 				to, err := extractEmail(input[8:])
@@ -372,50 +393,110 @@ func (server *server) handleClient(client *client) {
 					client.responseAdd(err.Error())
 				} else {
 					if !server.allowsHost(to.Host) {
-						client.responseAdd(response.CustomString(response.BadDestinationMailboxAddress, 454, response.ClassTransientFailure, "Error: Relay access denied: "+to.Host))
+						resp := &response.Response{
+							EnhancedCode: response.BadDestinationMailboxAddress,
+							BasicCode:    454,
+							Class:        response.ClassTransientFailure,
+							Comment:      "Error: Relay access denied: " + to.Host,
+						}
+						client.responseAdd(resp.String())
 					} else {
 						client.RcptTo = append(client.RcptTo, *to)
-						client.responseAdd(response.String(response.DestinationMailboxAddressValid, response.ClassSuccess))
+						resp := &response.Response{
+							EnhancedCode: response.DestinationMailboxAddressValid,
+							Class:        response.ClassSuccess,
+						}
+						client.responseAdd(resp.String())
 					}
 				}
 
 			case strings.Index(cmd, "RSET") == 0:
 				client.resetTransaction()
-				client.responseAdd(response.CustomString(response.OtherAddressStatus, 250, response.ClassSuccess, "OK"))
+				resp := &response.Response{
+					EnhancedCode: response.OtherAddressStatus,
+					Class:        response.ClassSuccess,
+				}
+				client.responseAdd(resp.String())
 
 			case strings.Index(cmd, "VRFY") == 0:
-				client.responseAdd(response.CustomString(response.OtherOrUndefinedProtocolStatus, 252, response.ClassSuccess, "Cannot verify user"))
+				resp := &response.Response{
+					EnhancedCode: response.OtherOrUndefinedProtocolStatus,
+					BasicCode:    252,
+					Class:        response.ClassSuccess,
+					Comment:      "Cannot verify user",
+				}
+				client.responseAdd(resp.String())
 
 			case strings.Index(cmd, "NOOP") == 0:
-				client.responseAdd(response.String(response.DestinationMailboxAddressValid, response.ClassSuccess))
+				resp := &response.Response{
+					EnhancedCode: response.OtherStatus,
+					Class:        response.ClassSuccess,
+				}
+				client.responseAdd(resp.String())
 
 			case strings.Index(cmd, "QUIT") == 0:
-				client.responseAdd(response.CustomString(response.OtherStatus, 221, response.ClassSuccess, "Bye"))
+				resp := &response.Response{
+					EnhancedCode: response.OtherStatus,
+					BasicCode:    221,
+					Class:        response.ClassSuccess,
+					Comment:      "Bye",
+				}
+				client.responseAdd(resp.String())
 				client.kill()
 
 			case strings.Index(cmd, "DATA") == 0:
 				if client.MailFrom.IsEmpty() {
-					client.responseAdd(response.CustomString(response.InvalidCommand, 503, response.ClassPermanentFailure, "Error: No sender"))
+					resp := &response.Response{
+						EnhancedCode: response.InvalidCommand,
+						BasicCode:    503,
+						Class:        response.ClassPermanentFailure,
+						Comment:      "Error: No sender",
+					}
+					client.responseAdd(resp.String())
 					break
 				}
 				if len(client.RcptTo) == 0 {
-					client.responseAdd(response.CustomString(response.InvalidCommand, 503, response.ClassPermanentFailure, "Error: No recipients"))
+					resp := &response.Response{
+						EnhancedCode: response.InvalidCommand,
+						BasicCode:    503,
+						Class:        response.ClassPermanentFailure,
+						Comment:      "Error: No recipients",
+					}
+					client.responseAdd(resp.String())
 					break
 				}
 				client.responseAdd("354 Enter message, ending with '.' on a line by itself")
 				client.state = ClientData
 
 			case sc.StartTLSOn && strings.Index(cmd, "STARTTLS") == 0:
-				client.responseAdd(response.CustomString(response.OtherStatus, 220, response.ClassSuccess, "Ready to start TLS"))
+				resp := &response.Response{
+					EnhancedCode: response.OtherStatus,
+					BasicCode:    220,
+					Class:        response.ClassSuccess,
+					Comment:      "Ready to start TLS",
+				}
+				client.responseAdd(resp.String())
 				client.state = ClientStartTLS
 			default:
-
-				client.responseAdd(response.CustomString(response.SyntaxError, 500, response.ClassPermanentFailure, "Unrecognized command: "+cmd))
+				resp := &response.Response{
+					EnhancedCode: response.InvalidCommand,
+					BasicCode:    554,
+					Class:        response.ClassPermanentFailure,
+					Comment:      fmt.Sprintf("Unrecognized command"),
+				}
+				client.responseAdd(resp.String())
 				client.errors++
 				if client.errors > MaxUnrecognizedCommands {
-					client.responseAdd(response.CustomString(response.InvalidCommand, 554, response.ClassPermanentFailure, "Too many unrecognized commands"))
+					resp := &response.Response{
+						EnhancedCode: response.InvalidCommand,
+						BasicCode:    554,
+						Class:        response.ClassPermanentFailure,
+						Comment:      "Too many unrecognized commands",
+					}
+					client.responseAdd(resp.String())
 					client.kill()
 				}
+
 			}
 
 		case ClientData:
@@ -430,14 +511,32 @@ func (server *server) handleClient(client *client) {
 			}
 			if err != nil {
 				if err == LineLimitExceeded {
-					client.responseAdd(response.CustomString(response.SyntaxError, 550, response.ClassPermanentFailure, "Error: "+LineLimitExceeded.Error()))
+					resp := &response.Response{
+						EnhancedCode: response.SyntaxError,
+						BasicCode:    550,
+						Class:        response.ClassPermanentFailure,
+						Comment:      "Error: " + LineLimitExceeded.Error(),
+					}
+					client.responseAdd(resp.String())
 					client.kill()
 				} else if err == MessageSizeExceeded {
-					client.responseAdd(response.CustomString(response.SyntaxError, 550, response.ClassPermanentFailure, "Error: "+MessageSizeExceeded.Error()))
+					resp := &response.Response{
+						EnhancedCode: response.SyntaxError,
+						BasicCode:    550,
+						Class:        response.ClassPermanentFailure,
+						Comment:      "Error: " + MessageSizeExceeded.Error(),
+					}
+					client.responseAdd(resp.String())
 					client.kill()
 				} else {
+					resp := &response.Response{
+						EnhancedCode: response.OtherOrUndefinedMailSystemStatus,
+						BasicCode:    451,
+						Class:        response.ClassTransientFailure,
+						Comment:      "Error: " + err.Error(),
+					}
+					client.responseAdd(resp.String())
 					client.kill()
-					client.responseAdd(response.CustomString(response.OtherOrUndefinedMailSystemStatus, 451, response.ClassTransientFailure, "Error: "+err.Error()))
 				}
 				log.WithError(err).Warn("Error reading data")
 				break
@@ -466,7 +565,13 @@ func (server *server) handleClient(client *client) {
 			client.state = ClientCmd
 		case ClientShutdown:
 			// shutdown state
-			client.responseAdd(response.CustomString(response.OtherOrUndefinedMailSystemStatus, 421, response.ClassTransientFailure, "Server is shutting down. Please try again later. Sayonara!"))
+			resp := &response.Response{
+				EnhancedCode: response.OtherOrUndefinedMailSystemStatus,
+				BasicCode:    421,
+				Class:        response.ClassTransientFailure,
+				Comment:      "Server is shutting down. Please try again later. Sayonara!",
+			}
+			client.responseAdd(resp.String())
 			client.kill()
 		}
 

+ 77 - 0
tests/guerrilla_test.go

@@ -728,6 +728,83 @@ func TestMailFromCmd(t *testing.T) {
 	logIn.Reset(&logBuffer)
 }
 
+// Test several different inputs to MAIL FROM command
+func TestHeloEhlo(t *testing.T) {
+	if initErr != nil {
+		t.Error(initErr)
+		t.FailNow()
+	}
+	if startErrors := app.Start(); startErrors == nil {
+		conn, bufin, err := Connect(config.Servers[0], 20)
+		hostname := config.Servers[0].Hostname
+		if err != nil {
+			// handle error
+			t.Error(err.Error(), config.Servers[0].ListenInterface)
+			t.FailNow()
+		} else {
+			// Test HELO
+			response, err := Command(conn, bufin, "HELO localtester")
+			if err != nil {
+				t.Error("command failed", err.Error())
+			}
+			expected := fmt.Sprintf("250 %s Hello", hostname)
+			if strings.Index(response, expected) != 0 {
+				t.Error("Server did not respond with", expected, ", it said:"+response)
+			}
+			// Reset
+			response, err = Command(conn, bufin, "RSET")
+			if err != nil {
+				t.Error("command failed", err.Error())
+			}
+			expected = "250 2.1.0 OK"
+			if strings.Index(response, expected) != 0 {
+				t.Error("Server did not respond with", expected, ", it said:"+response)
+			}
+			// Test EHLO
+			// This is tricky as it is a multiline response
+			var fullresp string
+			response, err = Command(conn, bufin, "EHLO localtester")
+			fullresp = fullresp + response
+			if err != nil {
+				t.Error("command failed", err.Error())
+			}
+			for err == nil {
+				response, err = bufin.ReadString('\n')
+				fullresp = fullresp + response
+				if strings.HasPrefix(response, "250 ") { // Last response has a whitespace and no "-"
+					break // bail
+				}
+			}
+
+			expected = fmt.Sprintf("250-%s Hello\r\n250-SIZE 100017\r\n250-PIPELINING\r\n250-STARTTLS\r\n250-ENHANCEDSTATUSCODES\r\n250 HELP\r\n", hostname)
+			if fullresp != expected {
+				t.Error("Server did not respond with [" + expected + "], it said [" + fullresp + "]")
+			}
+			// be kind, QUIT. And we are sure that bufin does not contain fragments from the EHLO command.
+			response, err = Command(conn, bufin, "QUIT")
+			if err != nil {
+				t.Error("command failed", err.Error())
+			}
+			expected = "221 2.0.0 Bye"
+			if strings.Index(response, expected) != 0 {
+				t.Error("Server did not respond with", expected, ", it said:"+response)
+			}
+		}
+		conn.Close()
+		app.Shutdown()
+	} else {
+		if startErrors := app.Start(); startErrors != nil {
+			t.Error(startErrors)
+			app.Shutdown()
+			t.FailNow()
+		}
+	}
+	logOut.Flush()
+	// don't forget to reset
+	logBuffer.Reset()
+	logIn.Reset(&logBuffer)
+}
+
 // It should error when MAIL FROM was given twice
 func TestNestedMailCmd(t *testing.T) {
 	if initErr != nil {

+ 22 - 4
util.go

@@ -5,6 +5,8 @@ import (
 	"regexp"
 	"strings"
 
+	"fmt"
+
 	"github.com/flashmob/go-guerrilla/envelope"
 	"github.com/flashmob/go-guerrilla/response"
 )
@@ -15,7 +17,13 @@ func extractEmail(str string) (*envelope.EmailAddress, error) {
 	email := &envelope.EmailAddress{}
 	var err error
 	if len(str) > RFC2821LimitPath {
-		return email, errors.New(response.CustomString(response.InvalidCommandArguments, 550, response.ClassPermanentFailure, "Path too long"))
+		resp := &response.Response{
+			EnhancedCode: response.InvalidCommandArguments,
+			BasicCode:    550,
+			Class:        response.ClassPermanentFailure,
+			Comment:      "Path too long",
+		}
+		return email, errors.New(resp.String())
 	}
 	if matched := extractEmailRegex.FindStringSubmatch(str); len(matched) > 2 {
 		email.User = matched[1]
@@ -25,12 +33,22 @@ func extractEmail(str string) (*envelope.EmailAddress, error) {
 		email.Host = validHost(res[1])
 	}
 	err = nil
+	resp := &response.Response{
+		EnhancedCode: response.InvalidCommandArguments,
+		BasicCode:    501,
+		Class:        response.ClassPermanentFailure,
+	}
 	if email.User == "" || email.Host == "" {
-		err = errors.New(response.CustomString(response.InvalidCommandArguments, 501, response.ClassPermanentFailure, "Invalid address"))
+		resp.Comment = "Invalid address"
+		err = fmt.Errorf("%s", resp)
 	} else if len(email.User) > RFC2832LimitLocalPart {
-		err = errors.New(response.CustomString(response.InvalidCommandArguments, 550, response.ClassPermanentFailure, "Local part too long, cannot exceed 64 characters"))
+		resp.BasicCode = 550
+		resp.Comment = "Local part too long, cannot exceed 64 characters"
+		err = fmt.Errorf("%s", resp)
 	} else if len(email.Host) > RFC2821LimitDomain {
-		err = errors.New(response.CustomString(response.InvalidCommandArguments, 501, response.ClassPermanentFailure, "Domain cannot exceed 255 characters"))
+		resp.BasicCode = 550
+		resp.Comment = "Domain cannot exceed 255 characters"
+		err = fmt.Errorf("%s", resp)
 	}
 	return email, err
 }