瀏覽代碼

Merge branch 'master' into dashboard

flashmob 6 年之前
父節點
當前提交
e1c505be41

+ 1 - 1
README.md

@@ -267,7 +267,7 @@ Using Nginx as a proxy
 ======================
 
 For such purposes as load balancing, terminating TLS early,
- or supporting SSL versions not supported by Go (highly not recommenced if you
+ or supporting SSL versions not supported by Go (highly not recommended if you
  want to use older SSL versions), 
  it is possible to [use NGINX as a proxy](https://github.com/flashmob/go-guerrilla/wiki/Using-Nginx-as-a-proxy).
 

+ 2 - 4
api.go

@@ -29,8 +29,6 @@ type deferredSub struct {
 	fn    interface{}
 }
 
-const defaultInterface = "127.0.0.1:2525"
-
 // AddProcessor adds a processor constructor to the backend.
 // name is the identifier to be used in the config. See backends docs for more info.
 func (d *Daemon) AddProcessor(name string, pc backends.ProcessorConstructor) {
@@ -64,7 +62,7 @@ func (d *Daemon) Start() (err error) {
 			return err
 		}
 		for i := range d.subs {
-			d.Subscribe(d.subs[i].topic, d.subs[i].fn)
+			_ = d.Subscribe(d.subs[i].topic, d.subs[i].fn)
 
 		}
 		d.subs = make([]deferredSub, 0)
@@ -164,10 +162,10 @@ func (d *Daemon) ReopenLogs() error {
 // Subscribe for subscribing to config change events
 func (d *Daemon) Subscribe(topic Event, fn interface{}) error {
 	if d.g == nil {
+		// defer the subscription until the daemon is started
 		d.subs = append(d.subs, deferredSub{topic, fn})
 		return nil
 	}
-
 	return d.g.Subscribe(topic, fn)
 }
 

+ 187 - 34
api_test.go

@@ -2,10 +2,12 @@ package guerrilla
 
 import (
 	"bufio"
+	"errors"
 	"fmt"
 	"github.com/flashmob/go-guerrilla/backends"
 	"github.com/flashmob/go-guerrilla/log"
 	"github.com/flashmob/go-guerrilla/mail"
+	"github.com/flashmob/go-guerrilla/response"
 	"io/ioutil"
 	"net"
 	"os"
@@ -210,7 +212,9 @@ func TestSMTPLoadFile(t *testing.T) {
 			return
 		}
 
-		d.ReloadConfigFile("goguerrilla.conf.api")
+		if err = d.ReloadConfigFile("goguerrilla.conf.api"); err != nil {
+			t.Error(err)
+		}
 
 		if d.Config.LogFile != "./tests/testlog2" {
 			t.Error("d.Config.LogFile != \"./tests/testlog\"")
@@ -225,7 +229,9 @@ func TestSMTPLoadFile(t *testing.T) {
 }
 
 func TestReopenLog(t *testing.T) {
-	os.Truncate("test/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 	cfg := &AppConfig{LogFile: "tests/testlog"}
 	sc := ServerConfig{
 		ListenInterface: "127.0.0.1:2526",
@@ -238,7 +244,9 @@ func TestReopenLog(t *testing.T) {
 	if err != nil {
 		t.Error("start error", err)
 	} else {
-		d.ReopenLogs()
+		if err = d.ReopenLogs(); err != nil {
+			t.Error(err)
+		}
 		time.Sleep(time.Second * 2)
 
 		d.Shutdown()
@@ -259,7 +267,9 @@ func TestReopenLog(t *testing.T) {
 
 func TestSetConfig(t *testing.T) {
 
-	os.Truncate("test/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 	cfg := AppConfig{LogFile: "tests/testlog"}
 	sc := ServerConfig{
 		ListenInterface: "127.0.0.1:2526",
@@ -303,7 +313,9 @@ func TestSetConfig(t *testing.T) {
 
 func TestSetConfigError(t *testing.T) {
 
-	os.Truncate("tests/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 	cfg := AppConfig{LogFile: "tests/testlog"}
 	sc := ServerConfig{
 		ListenInterface: "127.0.0.1:2526",
@@ -349,7 +361,7 @@ var funkyLogger = func() backends.Decorator {
 		return backends.ProcessWith(
 			func(e *mail.Envelope, task backends.SelectTask) (backends.Result, error) {
 				if task == backends.TaskValidateRcpt {
-					// validate the last recipient appended to e.Rcpt
+					// log the last recipient appended to e.Rcpt
 					backends.Log().Infof(
 						"another funky recipient [%s]",
 						e.RcptTo[len(e.RcptTo)-1])
@@ -367,7 +379,9 @@ var funkyLogger = func() backends.Decorator {
 
 // How about a custom processor?
 func TestSetAddProcessor(t *testing.T) {
-	os.Truncate("tests/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 	cfg := &AppConfig{
 		LogFile:      "tests/testlog",
 		AllowedHosts: []string{"grr.la"},
@@ -379,9 +393,13 @@ func TestSetAddProcessor(t *testing.T) {
 	d := Daemon{Config: cfg}
 	d.AddProcessor("FunkyLogger", funkyLogger)
 
-	d.Start()
+	if err := d.Start(); err != nil {
+		t.Error(err)
+	}
 	// lets have a talk with the server
-	talkToServer("127.0.0.1:2525")
+	if err := talkToServer("127.0.0.1:2525"); err != nil {
+		t.Error(err)
+	}
 
 	d.Shutdown()
 
@@ -408,38 +426,87 @@ func TestSetAddProcessor(t *testing.T) {
 
 }
 
-func talkToServer(address string) {
+func talkToServer(address string) (err error) {
 
 	conn, err := net.Dial("tcp", address)
 	if err != nil {
-
 		return
 	}
 	in := bufio.NewReader(conn)
 	str, err := in.ReadString('\n')
-	fmt.Fprint(conn, "HELO maildiranasaurustester\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "HELO maildiranasaurustester\r\n")
+	if err != nil {
+		return err
+	}
 	str, err = in.ReadString('\n')
-	fmt.Fprint(conn, "MAIL FROM:<[email protected]>r\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "MAIL FROM:<[email protected]>r\r\n")
+	if err != nil {
+		return err
+	}
 	str, err = in.ReadString('\n')
-	fmt.Fprint(conn, "RCPT TO:[email protected]\r\n")
+	if err != nil {
+		return err
+	}
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "RCPT TO:<[email protected]>\r\n")
+	if err != nil {
+		return err
+	}
 	str, err = in.ReadString('\n')
-	fmt.Fprint(conn, "DATA\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "DATA\r\n")
+	if err != nil {
+		return err
+	}
 	str, err = in.ReadString('\n')
-	fmt.Fprint(conn, "Subject: Test subject\r\n")
-	fmt.Fprint(conn, "\r\n")
-	fmt.Fprint(conn, "A an email body\r\n")
-	fmt.Fprint(conn, ".\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "Subject: Test subject\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "A an email body\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, ".\r\n")
+	if err != nil {
+		return err
+	}
 	str, err = in.ReadString('\n')
+	if err != nil {
+		return err
+	}
 	_ = str
+	return nil
 }
 
 // Test hot config reload
 // Here we forgot to add FunkyLogger so backend will fail to init
 
 func TestReloadConfig(t *testing.T) {
-	os.Truncate("tests/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 	d := Daemon{}
-	d.Start()
+	if err := d.Start(); err != nil {
+		t.Error(err)
+	}
 	defer d.Shutdown()
 	cfg := AppConfig{
 		LogFile:      "tests/testlog",
@@ -450,16 +517,22 @@ func TestReloadConfig(t *testing.T) {
 		},
 	}
 	// Look mom, reloading the config without shutting down!
-	d.ReloadConfig(cfg)
+	if err := d.ReloadConfig(cfg); err != nil {
+		t.Error(err)
+	}
 
 }
 
 func TestPubSubAPI(t *testing.T) {
 
-	os.Truncate("tests/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 
 	d := Daemon{Config: &AppConfig{LogFile: "tests/testlog"}}
-	d.Start()
+	if err := d.Start(); err != nil {
+		t.Error(err)
+	}
 	defer d.Shutdown()
 	// new config
 	cfg := AppConfig{
@@ -480,14 +553,22 @@ func TestPubSubAPI(t *testing.T) {
 		}
 		d.Logger.Info("number", i)
 	}
-	d.Subscribe(EventConfigPidFile, pidEvHandler)
+	if err := d.Subscribe(EventConfigPidFile, pidEvHandler); err != nil {
+		t.Error(err)
+	}
 
-	d.ReloadConfig(cfg)
+	if err := d.ReloadConfig(cfg); err != nil {
+		t.Error(err)
+	}
 
-	d.Unsubscribe(EventConfigPidFile, pidEvHandler)
+	if err := d.Unsubscribe(EventConfigPidFile, pidEvHandler); err != nil {
+		t.Error(err)
+	}
 	cfg.PidFile = "tests/pidfile2.pid"
 	d.Publish(EventConfigPidFile, &cfg)
-	d.ReloadConfig(cfg)
+	if err := d.ReloadConfig(cfg); err != nil {
+		t.Error(err)
+	}
 
 	b, err := ioutil.ReadFile("tests/testlog")
 	if err != nil {
@@ -502,7 +583,9 @@ func TestPubSubAPI(t *testing.T) {
 }
 
 func TestAPILog(t *testing.T) {
-	os.Truncate("tests/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 	d := Daemon{}
 	l := d.Log()
 	l.Info("logtest1") // to stderr
@@ -538,7 +621,9 @@ func TestSkipAllowsHost(t *testing.T) {
 	defer d.Shutdown()
 	// setting the allowed hosts to a single entry with a dot will let any host through
 	d.Config = &AppConfig{AllowedHosts: []string{"."}, LogFile: "off"}
-	d.Start()
+	if err := d.Start(); err != nil {
+		t.Error(err)
+	}
 
 	conn, err := net.Dial("tcp", d.Config.Servers[0].ListenInterface)
 	if err != nil {
@@ -546,13 +631,81 @@ func TestSkipAllowsHost(t *testing.T) {
 		return
 	}
 	in := bufio.NewReader(conn)
-	fmt.Fprint(conn, "HELO test\r\n")
-	fmt.Fprint(conn, "RCPT TO: [email protected]\r\n")
-	in.ReadString('\n')
-	in.ReadString('\n')
+	if _, err := fmt.Fprint(conn, "HELO test\r\n"); err != nil {
+		t.Error(err)
+	}
+	if _, err := fmt.Fprint(conn, "RCPT TO:<[email protected]>\r\n"); err != nil {
+		t.Error(err)
+	}
+
+	if _, err := in.ReadString('\n'); err != nil {
+		t.Error(err)
+	}
+	if _, err := in.ReadString('\n'); err != nil {
+		t.Error(err)
+	}
 	str, _ := in.ReadString('\n')
 	if strings.Index(str, "250") != 0 {
 		t.Error("expected 250 reply, got:", str)
 	}
 
 }
+
+var customBackend2 = func() backends.Decorator {
+
+	return func(p backends.Processor) backends.Processor {
+		return backends.ProcessWith(
+			func(e *mail.Envelope, task backends.SelectTask) (backends.Result, error) {
+				if task == backends.TaskValidateRcpt {
+					return p.Process(e, task)
+				} else if task == backends.TaskSaveMail {
+					backends.Log().Info("Another funky email!")
+					err := errors.New("system shock")
+					return backends.NewResult(response.Canned.FailReadErrorDataCmd, response.SP, err), err
+				}
+				return p.Process(e, task)
+			})
+	}
+}
+
+// Test a custom backend response
+func TestCustomBackendResult(t *testing.T) {
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
+	cfg := &AppConfig{
+		LogFile:      "tests/testlog",
+		AllowedHosts: []string{"grr.la"},
+		BackendConfig: backends.BackendConfig{
+			"save_process":     "HeadersParser|Debugger|Custom",
+			"validate_process": "Custom",
+		},
+	}
+	d := Daemon{Config: cfg}
+	d.AddProcessor("Custom", customBackend2)
+
+	if err := d.Start(); err != nil {
+		t.Error(err)
+	}
+	// lets have a talk with the server
+	if err := talkToServer("127.0.0.1:2525"); err != nil {
+		t.Error(err)
+	}
+
+	d.Shutdown()
+
+	b, err := ioutil.ReadFile("tests/testlog")
+	if err != nil {
+		t.Error("could not read logfile")
+		return
+	}
+	// lets check for fingerprints
+	if strings.Index(string(b), "451 4.3.0 Error") < 0 {
+		t.Error("did not log: 451 4.3.0 Error")
+	}
+
+	if strings.Index(string(b), "system shock") < 0 {
+		t.Error("did not log: system shock")
+	}
+
+}

+ 35 - 19
backends/backend.go

@@ -1,6 +1,7 @@
 package backends
 
 import (
+	"bytes"
 	"fmt"
 	"github.com/flashmob/go-guerrilla/log"
 	"github.com/flashmob/go-guerrilla/mail"
@@ -54,6 +55,7 @@ type BaseConfig interface{}
 type notifyMsg struct {
 	err      error
 	queuedID string
+	result   Result
 }
 
 // Result represents a response to an SMTP client after receiving DATA.
@@ -66,16 +68,19 @@ type Result interface {
 }
 
 // Internal implementation of BackendResult for use by backend implementations.
-type result string
+type result struct {
+	// we're going to use a bytes.Buffer for building a string
+	bytes.Buffer
+}
 
-func (br result) String() string {
-	return string(br)
+func (r *result) String() string {
+	return r.Buffer.String()
 }
 
 // Parses the SMTP code from the first 3 characters of the SMTP message.
 // Returns 554 if code cannot be parsed.
-func (br result) Code() int {
-	trimmed := strings.TrimSpace(string(br))
+func (r *result) Code() int {
+	trimmed := strings.TrimSpace(r.String())
 	if len(trimmed) < 3 {
 		return 554
 	}
@@ -86,8 +91,19 @@ func (br result) Code() int {
 	return code
 }
 
-func NewResult(message string) Result {
-	return result(message)
+func NewResult(r ...interface{}) Result {
+	buf := new(result)
+	for _, item := range r {
+		switch v := item.(type) {
+		case error:
+			_, _ = buf.WriteString(v.Error())
+		case fmt.Stringer:
+			_, _ = buf.WriteString(v.String())
+		case string:
+			_, _ = buf.WriteString(v)
+		}
+	}
+	return buf
 }
 
 type processorInitializer interface {
@@ -240,13 +256,13 @@ func (s *service) ExtractConfig(configData BackendConfig, configType BaseConfig)
 	for i := 0; i < v.NumField(); i++ {
 		f := v.Field(i)
 		// read the tags of the config struct
-		field_name := t.Field(i).Tag.Get("json")
+		fieldName := t.Field(i).Tag.Get("json")
 		omitempty := false
-		if len(field_name) > 0 {
+		if len(fieldName) > 0 {
 			// parse the tag to
 			// get the field name from struct tag
-			split := strings.Split(field_name, ",")
-			field_name = split[0]
+			split := strings.Split(fieldName, ",")
+			fieldName = split[0]
 			if len(split) > 1 {
 				if split[1] == "omitempty" {
 					omitempty = true
@@ -255,30 +271,30 @@ func (s *service) ExtractConfig(configData BackendConfig, configType BaseConfig)
 		} else {
 			// could have no tag
 			// so use the reflected field name
-			field_name = typeOfT.Field(i).Name
+			fieldName = typeOfT.Field(i).Name
 		}
 		if f.Type().Name() == "int" {
 			// in json, there is no int, only floats...
-			if intVal, converted := configData[field_name].(float64); converted {
+			if intVal, converted := configData[fieldName].(float64); converted {
 				v.Field(i).SetInt(int64(intVal))
-			} else if intVal, converted := configData[field_name].(int); converted {
+			} else if intVal, converted := configData[fieldName].(int); converted {
 				v.Field(i).SetInt(int64(intVal))
 			} else if !omitempty {
-				return configType, convertError("property missing/invalid: '" + field_name + "' of expected type: " + f.Type().Name())
+				return configType, convertError("property missing/invalid: '" + fieldName + "' of expected type: " + f.Type().Name())
 			}
 		}
 		if f.Type().Name() == "string" {
-			if stringVal, converted := configData[field_name].(string); converted {
+			if stringVal, converted := configData[fieldName].(string); converted {
 				v.Field(i).SetString(stringVal)
 			} else if !omitempty {
-				return configType, convertError("missing/invalid: '" + field_name + "' of type: " + f.Type().Name())
+				return configType, convertError("missing/invalid: '" + fieldName + "' of type: " + f.Type().Name())
 			}
 		}
 		if f.Type().Name() == "bool" {
-			if boolVal, converted := configData[field_name].(bool); converted {
+			if boolVal, converted := configData[fieldName].(bool); converted {
 				v.Field(i).SetBool(boolVal)
 			} else if !omitempty {
-				return configType, convertError("missing/invalid: '" + field_name + "' of type: " + f.Type().Name())
+				return configType, convertError("missing/invalid: '" + fieldName + "' of type: " + f.Type().Name())
 			}
 		}
 	}

+ 30 - 24
backends/gateway.go

@@ -129,7 +129,7 @@ func (w *workerMsg) reset(e *mail.Envelope, task SelectTask) {
 // Process distributes an envelope to one of the backend workers with a TaskSaveMail task
 func (gw *BackendGateway) Process(e *mail.Envelope) Result {
 	if gw.State != BackendStateRunning {
-		return NewResult(response.Canned.FailBackendNotRunning + gw.State.String())
+		return NewResult(response.Canned.FailBackendNotRunning, response.SP, gw.State)
 	}
 	// borrow a workerMsg from the pool
 	workerMsg := workerMsgPool.Get().(*workerMsg)
@@ -140,11 +140,32 @@ func (gw *BackendGateway) Process(e *mail.Envelope) Result {
 	// or timeout
 	select {
 	case status := <-workerMsg.notifyMe:
-		workerMsgPool.Put(workerMsg) // can be recycled since we used the notifyMe channel
+		// email saving transaction completed
+		if status.result == BackendResultOK && status.queuedID != "" {
+			return NewResult(response.Canned.SuccessMessageQueued, response.SP, status.queuedID)
+		}
+
+		// A custom result, there was probably an error, if so, log it
+		if status.result != nil {
+			if status.err != nil {
+				Log().Error(status.err)
+			}
+			return status.result
+		}
+
+		// if there was no result, but there's an error, then make a new result from the error
 		if status.err != nil {
-			return NewResult(response.Canned.FailBackendTransaction + status.err.Error())
+			if _, err := strconv.Atoi(status.err.Error()[:3]); err != nil {
+				return NewResult(response.Canned.FailBackendTransaction, response.SP, status.err)
+			}
+			return NewResult(status.err)
 		}
-		return NewResult(response.Canned.SuccessMessageQueued + status.queuedID)
+
+		// both result & error are nil (should not happen)
+		err := errors.New("no response from backend - processor did not return a result or an error")
+		Log().Error(err)
+		return NewResult(response.Canned.FailBackendTransaction, response.SP, err)
+
 	case <-time.After(gw.saveTimeout()):
 		Log().Error("Backend has timed out while saving email")
 		e.Lock() // lock the envelope - it's still processing here, we don't want the server to recycle it
@@ -435,27 +456,12 @@ func (gw *BackendGateway) workDispatcher(
 			return
 		case msg = <-workIn:
 			state = dispatcherStateWorking // recovers from panic if in this state
+			result, err := save.Process(msg.e, msg.task)
+			state = dispatcherStateNotify
 			if msg.task == TaskSaveMail {
-				// process the email here
-				result, _ := save.Process(msg.e, TaskSaveMail)
-				state = dispatcherStateNotify
-				if result.Code() < 300 {
-					// if all good, let the gateway know that it was saved
-					msg.notifyMe <- &notifyMsg{nil, msg.e.QueuedId}
-				} else {
-					// notify the gateway about the error
-					msg.notifyMe <- &notifyMsg{err: errors.New(result.String())}
-				}
-			} else if msg.task == TaskValidateRcpt {
-				_, err := validate.Process(msg.e, TaskValidateRcpt)
-				state = dispatcherStateNotify
-				if err != nil {
-					// validation failed
-					msg.notifyMe <- &notifyMsg{err: err}
-				} else {
-					// all good.
-					msg.notifyMe <- &notifyMsg{err: nil}
-				}
+				msg.notifyMe <- &notifyMsg{err: err, result: result, queuedID: msg.e.QueuedId}
+			} else {
+				msg.notifyMe <- &notifyMsg{err: err, result: result}
 			}
 		}
 		state = dispatcherStateIdle

+ 3 - 3
backends/p_compressor.go

@@ -77,9 +77,9 @@ func (c *compressor) String() string {
 	var r *bytes.Reader
 	w, _ := zlib.NewWriterLevel(b, zlib.BestSpeed)
 	r = bytes.NewReader(c.extraHeaders)
-	io.Copy(w, r)
-	io.Copy(w, c.data)
-	w.Close()
+	_, _ = io.Copy(w, r)
+	_, _ = io.Copy(w, c.data)
+	_ = w.Close()
 	return b.String()
 }
 

+ 44 - 20
backends/p_guerrilla_db_redis.go

@@ -13,7 +13,6 @@ import (
 	"time"
 
 	"github.com/flashmob/go-guerrilla/mail"
-	"github.com/garyburd/redigo/redis"
 )
 
 // ----------------------------------------------------------------------------------
@@ -85,7 +84,7 @@ func (g *GuerrillaDBAndRedisBackend) getNumberOfWorkers() int {
 
 type redisClient struct {
 	isConnected bool
-	conn        redis.Conn
+	conn        RedisConn
 	time        int
 }
 
@@ -131,9 +130,9 @@ func (c *compressedData) String() string {
 	var r *bytes.Reader
 	w, _ := zlib.NewWriterLevel(b, zlib.BestSpeed)
 	r = bytes.NewReader(c.extraHeaders)
-	io.Copy(w, r)
-	io.Copy(w, c.data)
-	w.Close()
+	_, _ = io.Copy(w, r)
+	_, _ = io.Copy(w, c.data)
+	_ = w.Close()
 	return b.String()
 }
 
@@ -151,9 +150,25 @@ func (g *GuerrillaDBAndRedisBackend) prepareInsertQuery(rows int, db *sql.DB) *s
 	if g.cache[rows-1] != nil {
 		return g.cache[rows-1]
 	}
-	sqlstr := "INSERT INTO " + g.config.Table + " "
-	sqlstr += "(`date`, `to`, `from`, `subject`, `body`, `charset`, `mail`, `spam_score`, `hash`, `content_type`, `recipient`, `has_attach`, `ip_addr`, `return_path`, `is_tls`)"
-	sqlstr += " values "
+	sqlstr := "INSERT INTO " + g.config.Table + "" +
+		"(" +
+		"`date`, " +
+		"`to`, " +
+		"`from`, " +
+		"`subject`, " +
+		"`body`, " +
+		"`charset`, " +
+		"`mail`, " +
+		"`spam_score`, " +
+		"`hash`, " +
+		"`content_type`, " +
+		"`recipient`, " +
+		"`has_attach`, " +
+		"`ip_addr`, " +
+		"`return_path`, " +
+		"`is_tls`" +
+		")" +
+		" values "
 	values := "(NOW(), ?, ?, ?, ? , 'UTF-8' , ?, 0, ?, '', ?, 0, ?, ?, ?)"
 	// add more rows
 	comma := ""
@@ -328,7 +343,7 @@ func (g *GuerrillaDBAndRedisBackend) sqlConnect() (*sql.DB, error) {
 
 func (c *redisClient) redisConnection(redisInterface string) (err error) {
 	if c.isConnected == false {
-		c.conn, err = redis.Dial("tcp", redisInterface)
+		c.conn, err = RedisDialer("tcp", redisInterface)
 		if err != nil {
 			// handle error
 			return err
@@ -347,12 +362,12 @@ func GuerrillaDbRedis() Decorator {
 	g := GuerrillaDBAndRedisBackend{}
 	redisClient := &redisClient{}
 
-	var db *sql.DB
-	var to, body string
-
-	var redisErr error
-
-	var feeders []feedChan
+	var (
+		db       *sql.DB
+		to, body string
+		redisErr error
+		feeders  []feedChan
+	)
 
 	g.batcherStoppers = make([]chan bool, 0)
 
@@ -388,11 +403,17 @@ func GuerrillaDbRedis() Decorator {
 	}))
 
 	Svc.AddShutdowner(ShutdownWith(func() error {
-		db.Close()
-		Log().Infof("closed sql")
+		if err := db.Close(); err != nil {
+			Log().WithError(err).Error("close mysql failed")
+		} else {
+			Log().Infof("closed mysql")
+		}
 		if redisClient.conn != nil {
-			Log().Infof("closed redis")
-			redisClient.conn.Close()
+			if err := redisClient.conn.Close(); err != nil {
+				Log().WithError(err).Error("close redis failed")
+			} else {
+				Log().Infof("closed redis")
+			}
 		}
 		// send a close signal to all query batchers to exit.
 		for i := range g.batcherStoppers {
@@ -414,12 +435,15 @@ func GuerrillaDbRedis() Decorator {
 				e.Helo = trimToLimit(e.Helo, 255)
 				e.RcptTo[0].Host = trimToLimit(e.RcptTo[0].Host, 255)
 				ts := fmt.Sprintf("%d", time.Now().UnixNano())
-				e.ParseHeaders()
+				if err := e.ParseHeaders(); err != nil {
+					Log().WithError(err).Error("failed to parse headers")
+				}
 				hash := MD5Hex(
 					to,
 					e.MailFrom.String(),
 					e.Subject,
 					ts)
+				e.QueuedId = hash
 				// Add extra headers
 				var addHead string
 				addHead += "Delivered-To: " + to + "\r\n"

+ 6 - 2
backends/p_guerrilla_db_redis_test.go

@@ -19,12 +19,16 @@ func TestCompressedData(t *testing.T) {
 	cd.set([]byte(sbj), &b)
 
 	// compress
-	fmt.Fprint(&out, cd)
+	if _, err := fmt.Fprint(&out, cd); err != nil {
+		t.Error(err)
+	}
 
 	// decompress
 	var result bytes.Buffer
 	zReader, _ := zlib.NewReader(bytes.NewReader(out.Bytes()))
-	io.Copy(&result, zReader)
+	if _, err := io.Copy(&result, zReader); err != nil {
+		t.Error(err)
+	}
 	expect := sbj + str
 	if delta := strings.Compare(expect, result.String()); delta != 0 {
 		t.Error(delta, "compression did match, expected", expect, "but got", result.String())

+ 4 - 4
backends/p_hasher.go

@@ -38,13 +38,13 @@ func Hasher() Decorator {
 				// base hash, use subject from and timestamp-nano
 				h := md5.New()
 				ts := fmt.Sprintf("%d", time.Now().UnixNano())
-				io.Copy(h, strings.NewReader(e.MailFrom.String()))
-				io.Copy(h, strings.NewReader(e.Subject))
-				io.Copy(h, strings.NewReader(ts))
+				_, _ = io.Copy(h, strings.NewReader(e.MailFrom.String()))
+				_, _ = io.Copy(h, strings.NewReader(e.Subject))
+				_, _ = io.Copy(h, strings.NewReader(ts))
 				// using the base hash, calculate a unique hash for each recipient
 				for i := range e.RcptTo {
 					h2 := h
-					io.Copy(h2, strings.NewReader(e.RcptTo[i].String()))
+					_, _ = io.Copy(h2, strings.NewReader(e.RcptTo[i].String()))
 					sum := h2.Sum([]byte{})
 					e.Hashes = append(e.Hashes, fmt.Sprintf("%x", sum))
 				}

+ 3 - 1
backends/p_headers_parser.go

@@ -25,7 +25,9 @@ func HeadersParser() Decorator {
 	return func(p Processor) Processor {
 		return ProcessWith(func(e *mail.Envelope, task SelectTask) (Result, error) {
 			if task == TaskSaveMail {
-				e.ParseHeaders()
+				if err := e.ParseHeaders(); err != nil {
+					Log().WithError(err).Error("parse headers error")
+				}
 				// next processor
 				return p.Process(e, task)
 			} else {

+ 4 - 6
backends/p_redis.go

@@ -5,8 +5,6 @@ import (
 
 	"github.com/flashmob/go-guerrilla/mail"
 	"github.com/flashmob/go-guerrilla/response"
-
-	"github.com/garyburd/redigo/redis"
 )
 
 // ----------------------------------------------------------------------------------
@@ -39,12 +37,12 @@ type RedisProcessorConfig struct {
 
 type RedisProcessor struct {
 	isConnected bool
-	conn        redis.Conn
+	conn        RedisConn
 }
 
 func (r *RedisProcessor) redisConnection(redisInterface string) (err error) {
 	if r.isConnected == false {
-		r.conn, err = redis.Dial("tcp", redisInterface)
+		r.conn, err = RedisDialer("tcp", redisInterface)
 		if err != nil {
 			// handle error
 			return err
@@ -69,7 +67,7 @@ func Redis() Decorator {
 		}
 		config = bcfg.(*RedisProcessorConfig)
 		if redisErr := redisClient.redisConnection(config.RedisInterface); redisErr != nil {
-			err := fmt.Errorf("Redis cannot connect, check your settings: %s", redisErr)
+			err := fmt.Errorf("redis cannot connect, check your settings: %s", redisErr)
 			return err
 		}
 		return nil
@@ -113,7 +111,7 @@ func Redis() Decorator {
 					}
 					e.Values["redis"] = "redis" // the next processor will know to look in redis for the message data
 				} else {
-					Log().Error("Redis needs a Hash() process before it")
+					Log().Error("Redis needs a Hasher() process before it")
 					result := NewResult(response.Canned.FailBackendTransaction)
 					return result, StorageError
 				}

+ 62 - 0
backends/p_redis_test.go

@@ -0,0 +1,62 @@
+package backends
+
+import (
+	"github.com/flashmob/go-guerrilla/log"
+	"github.com/flashmob/go-guerrilla/mail"
+	"io/ioutil"
+	"os"
+	"strings"
+	"testing"
+)
+
+func TestRedisGeneric(t *testing.T) {
+
+	e := mail.NewEnvelope("127.0.0.1", 1)
+	e.RcptTo = append(e.RcptTo, mail.Address{User: "test", Host: "grr.la"})
+
+	l, _ := log.GetLogger("./test_redis.log", "debug")
+	g, err := New(BackendConfig{
+		"save_process":         "Hasher|Redis",
+		"redis_interface":      "127.0.0.1:6379",
+		"redis_expire_seconds": 7200,
+	}, l)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	err = g.Start()
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	defer func() {
+		err := g.Shutdown()
+		if err != nil {
+			t.Error(err)
+		}
+	}()
+	if gateway, ok := g.(*BackendGateway); ok {
+		r := gateway.Process(e)
+		if strings.Index(r.String(), "250 2.0.0 OK") == -1 {
+			t.Error("redis processor didn't result with expected result, it said", r)
+		}
+	}
+	// check the log
+	if _, err := os.Stat("./test_redis.log"); err != nil {
+		t.Error(err)
+		return
+	}
+	if b, err := ioutil.ReadFile("./test_redis.log"); err != nil {
+		t.Error(err)
+		return
+	} else {
+		if strings.Index(string(b), "SETEX") == -1 {
+			t.Error("Log did not contain SETEX, the log was: ", string(b))
+		}
+	}
+
+	if err := os.Remove("./test_redis.log"); err != nil {
+		t.Error(err)
+	}
+
+}

+ 0 - 1
backends/p_sql.go

@@ -181,7 +181,6 @@ func SQL() Decorator {
 
 				hash := ""
 				if len(e.Hashes) > 0 {
-					// if saved in redis, hash will be the redis key
 					hash = e.Hashes[0]
 					e.QueuedId = e.Hashes[0]
 				}

+ 3 - 1
backends/p_sql_test.go

@@ -73,7 +73,9 @@ func findRows(hash string) ([]string, error) {
 	if err != nil {
 		return nil, err
 	}
-	defer db.Close()
+	defer func() {
+		_ = db.Close()
+	}()
 
 	stmt := fmt.Sprintf(`SELECT hash FROM %s WHERE hash = ?`, *mailTableFlag)
 	rows, err := db.Query(stmt, hash)

+ 45 - 0
backends/redis_generic.go

@@ -0,0 +1,45 @@
+package backends
+
+import (
+	"net"
+	"time"
+)
+
+func init() {
+	RedisDialer = func(network, address string, options ...RedisDialOption) (RedisConn, error) {
+		return new(RedisMockConn), nil
+	}
+}
+
+// RedisConn interface provides a generic way to access Redis via drivers
+type RedisConn interface {
+	Close() error
+	Do(commandName string, args ...interface{}) (reply interface{}, err error)
+}
+
+type RedisMockConn struct{}
+
+func (m *RedisMockConn) Close() error {
+	return nil
+}
+
+func (m *RedisMockConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
+	Log().Info("redis mock driver command: ", commandName)
+	return nil, nil
+}
+
+type dialOptions struct {
+	readTimeout  time.Duration
+	writeTimeout time.Duration
+	dial         func(network, addr string) (net.Conn, error)
+	db           int
+	password     string
+}
+
+type RedisDialOption struct {
+	f func(*dialOptions)
+}
+
+type redisDial func(network, address string, options ...RedisDialOption) (RedisConn, error)
+
+var RedisDialer redisDial

+ 10 - 0
backends/storage/redigo/driver.go

@@ -0,0 +1,10 @@
+package redigo_driver
+
+import "github.com/flashmob/go-guerrilla/backends"
+import redigo "github.com/gomodule/redigo/redis"
+
+func init() {
+	backends.RedisDialer = func(network, address string, options ...backends.RedisDialOption) (backends.RedisConn, error) {
+		return redigo.Dial(network, address)
+	}
+}

+ 3 - 3
backends/util.go

@@ -41,7 +41,7 @@ func MD5Hex(stringArguments ...string) string {
 	var r *strings.Reader
 	for i := 0; i < len(stringArguments); i++ {
 		r = strings.NewReader(stringArguments[i])
-		io.Copy(h, r)
+		_, _ = io.Copy(h, r)
 	}
 	sum := h.Sum([]byte{})
 	return fmt.Sprintf("%x", sum)
@@ -54,8 +54,8 @@ func Compress(stringArguments ...string) string {
 	w, _ := zlib.NewWriterLevel(&b, zlib.BestSpeed)
 	for i := 0; i < len(stringArguments); i++ {
 		r = strings.NewReader(stringArguments[i])
-		io.Copy(w, r)
+		_, _ = io.Copy(w, r)
 	}
-	w.Close()
+	_ = w.Close()
 	return b.String()
 }

+ 75 - 39
client.go

@@ -4,14 +4,16 @@ import (
 	"bufio"
 	"bytes"
 	"crypto/tls"
+	"errors"
 	"fmt"
+	"github.com/flashmob/go-guerrilla/log"
+	"github.com/flashmob/go-guerrilla/mail"
+	"github.com/flashmob/go-guerrilla/mail/rfc5321"
+	"github.com/flashmob/go-guerrilla/response"
 	"net"
 	"net/textproto"
 	"sync"
 	"time"
-
-	"github.com/flashmob/go-guerrilla/log"
-	"github.com/flashmob/go-guerrilla/mail"
 )
 
 // ClientState indicates which part of the SMTP transaction a given client is in.
@@ -39,8 +41,9 @@ type client struct {
 	errors       int
 	state        ClientState
 	messagesSent int
-	// Response to be written to the client
+	// Response to be written to the client (for debugging)
 	response   bytes.Buffer
+	bufErr     error
 	conn       net.Conn
 	bufin      *smtpBufferedReader
 	bufout     *bufio.Writer
@@ -49,6 +52,7 @@ type client struct {
 	// guards access to conn
 	connGuard sync.Mutex
 	log       log.Logger
+	parser    rfc5321.Parser
 }
 
 // NewClient allocates a new client.
@@ -70,39 +74,38 @@ func NewClient(conn net.Conn, clientID uint64, logger log.Logger, envelope *mail
 	return c
 }
 
-// setResponse adds a response to be written on the next turn
+// sendResponse adds a response to be written on the next turn
+// the response gets buffered
 func (c *client) sendResponse(r ...interface{}) {
 	c.bufout.Reset(c.conn)
 	if c.log.IsDebug() {
-		// us additional buffer so that we can log the response in debug mode only
+		// an additional buffer so that we can log the response in debug mode only
 		c.response.Reset()
 	}
+	var out string
+	if c.bufErr != nil {
+		c.bufErr = nil
+	}
 	for _, item := range r {
 		switch v := item.(type) {
-		case string:
-			if _, err := c.bufout.WriteString(v); err != nil {
-				c.log.WithError(err).Error("could not write to c.bufout")
-			}
-			if c.log.IsDebug() {
-				c.response.WriteString(v)
-			}
 		case error:
-			if _, err := c.bufout.WriteString(v.Error()); err != nil {
-				c.log.WithError(err).Error("could not write to c.bufout")
-			}
-			if c.log.IsDebug() {
-				c.response.WriteString(v.Error())
-			}
+			out = v.Error()
 		case fmt.Stringer:
-			if _, err := c.bufout.WriteString(v.String()); err != nil {
-				c.log.WithError(err).Error("could not write to c.bufout")
-			}
-			if c.log.IsDebug() {
-				c.response.WriteString(v.String())
-			}
+			out = v.String()
+		case string:
+			out = v
+		}
+		if _, c.bufErr = c.bufout.WriteString(out); c.bufErr != nil {
+			c.log.WithError(c.bufErr).Error("could not write to c.bufout")
+		}
+		if c.log.IsDebug() {
+			c.response.WriteString(out)
+		}
+		if c.bufErr != nil {
+			return
 		}
 	}
-	c.bufout.WriteString("\r\n")
+	_, c.bufErr = c.bufout.WriteString("\r\n")
 	if c.log.IsDebug() {
 		c.response.WriteString("\r\n")
 	}
@@ -112,7 +115,7 @@ func (c *client) sendResponse(r ...interface{}) {
 // Transaction ends on:
 // -HELO/EHLO/REST command
 // -End of DATA command
-// TLS handhsake
+// TLS handshake
 func (c *client) resetTransaction() {
 	c.Envelope.ResetTransaction()
 }
@@ -121,8 +124,7 @@ func (c *client) resetTransaction() {
 // A transaction starts after a MAIL command gets issued by the client.
 // Call resetTransaction to end the transaction
 func (c *client) isInTransaction() bool {
-	isMailFromEmpty := c.MailFrom == (mail.Address{})
-	if isMailFromEmpty {
+	if len(c.MailFrom.User) == 0 && !c.MailFrom.NullPath {
 		return false
 	}
 	return true
@@ -139,19 +141,20 @@ func (c *client) isAlive() bool {
 }
 
 // setTimeout adjust the timeout on the connection, goroutine safe
-func (c *client) setTimeout(t time.Duration) {
+func (c *client) setTimeout(t time.Duration) (err error) {
 	defer c.connGuard.Unlock()
 	c.connGuard.Lock()
 	if c.conn != nil {
-		c.conn.SetDeadline(time.Now().Add(t * time.Second))
+		err = c.conn.SetDeadline(time.Now().Add(t * time.Second))
 	}
+	return
 }
 
 // closeConn closes a client connection, , goroutine safe
 func (c *client) closeConn() {
 	defer c.connGuard.Unlock()
 	c.connGuard.Lock()
-	c.conn.Close()
+	_ = c.conn.Close()
 	c.conn = nil
 }
 
@@ -177,20 +180,20 @@ func (c *client) getID() uint64 {
 }
 
 // UpgradeToTLS upgrades a client connection to TLS
-func (client *client) upgradeToTLS(tlsConfig *tls.Config) error {
+func (c *client) upgradeToTLS(tlsConfig *tls.Config) error {
 	var tlsConn *tls.Conn
-	// wrap client.conn in a new TLS server side connection
-	tlsConn = tls.Server(client.conn, tlsConfig)
+	// wrap c.conn in a new TLS server side connection
+	tlsConn = tls.Server(c.conn, tlsConfig)
 	// Call handshake here to get any handshake error before reading starts
 	err := tlsConn.Handshake()
 	if err != nil {
 		return err
 	}
 	// convert tlsConn to net.Conn
-	client.conn = net.Conn(tlsConn)
-	client.bufout.Reset(client.conn)
-	client.bufin.Reset(client.conn)
-	client.TLS = true
+	c.conn = net.Conn(tlsConn)
+	c.bufout.Reset(c.conn)
+	c.bufin.Reset(c.conn)
+	c.TLS = true
 	return err
 }
 
@@ -202,3 +205,36 @@ func getRemoteAddr(conn net.Conn) string {
 		return conn.RemoteAddr().Network()
 	}
 }
+
+type pathParser func([]byte) error
+
+func (c *client) parsePath(in []byte, p pathParser) (mail.Address, error) {
+	address := mail.Address{}
+	var err error
+	if len(in) > rfc5321.LimitPath {
+		return address, errors.New(response.Canned.FailPathTooLong.String())
+	}
+	if err = p(in); err != nil {
+		return address, errors.New(response.Canned.FailInvalidAddress.String())
+	} else if c.parser.NullPath {
+		// bounce has empty from address
+		address = mail.Address{}
+	} else if len(c.parser.LocalPart) > rfc5321.LimitLocalPart {
+		err = errors.New(response.Canned.FailLocalPartTooLong.String())
+	} else if len(c.parser.Domain) > rfc5321.LimitDomain {
+		err = errors.New(response.Canned.FailDomainTooLong.String())
+	} else {
+		address = mail.Address{
+			User:       c.parser.LocalPart,
+			Host:       c.parser.Domain,
+			ADL:        c.parser.ADL,
+			PathParams: c.parser.PathParams,
+			NullPath:   c.parser.NullPath,
+		}
+	}
+	return address, err
+}
+
+func (s *server) rcptTo(in []byte) (address mail.Address, err error) {
+	return address, err
+}

+ 18 - 33
cmd/guerrillad/serve.go

@@ -3,17 +3,20 @@ package main
 import (
 	"fmt"
 	"os"
-	"os/exec"
 	"os/signal"
-	"strconv"
-	"strings"
 	"syscall"
 	"time"
 
 	"github.com/flashmob/go-guerrilla"
 	"github.com/flashmob/go-guerrilla/log"
+
+	// enable the Redis redigo driver
+	_ "github.com/flashmob/go-guerrilla/backends/storage/redigo"
+
+	// Choose iconv or mail/encoding package which uses golang.org/x/net/html/charset
 	//_ "github.com/flashmob/go-guerrilla/mail/iconv"
 	_ "github.com/flashmob/go-guerrilla/mail/encoding"
+
 	"github.com/spf13/cobra"
 
 	_ "github.com/go-sql-driver/mysql"
@@ -66,17 +69,20 @@ func sigHandler() {
 		syscall.SIGINT,
 		syscall.SIGKILL,
 		syscall.SIGUSR1,
+		os.Kill,
 	)
 	for sig := range signalChannel {
 		if sig == syscall.SIGHUP {
 			if ac, err := readConfig(configPath, pidFile); err == nil {
-				d.ReloadConfig(*ac)
+				_ = d.ReloadConfig(*ac)
 			} else {
 				mainlog.WithError(err).Error("Could not reload config")
 			}
 		} else if sig == syscall.SIGUSR1 {
-			d.ReopenLogs()
-		} else if sig == syscall.SIGTERM || sig == syscall.SIGQUIT || sig == syscall.SIGINT {
+			if err := d.ReopenLogs(); err != nil {
+				mainlog.WithError(err).Error("reopening logs failed")
+			}
+		} else if sig == syscall.SIGTERM || sig == syscall.SIGQUIT || sig == syscall.SIGINT || sig == os.Kill {
 			mainlog.Infof("Shutdown signal caught")
 			go func() {
 				select {
@@ -99,24 +105,16 @@ func sigHandler() {
 func serve(cmd *cobra.Command, args []string) {
 	logVersion()
 	d = guerrilla.Daemon{Logger: mainlog}
-	ac, err := readConfig(configPath, pidFile)
+	c, err := readConfig(configPath, pidFile)
 	if err != nil {
 		mainlog.WithError(err).Fatal("Error while reading config")
 	}
-	d.SetConfig(*ac)
+	_ = d.SetConfig(*c)
 
 	// Check that max clients is not greater than system open file limit.
-	fileLimit := getFileLimit()
-
-	if fileLimit > 0 {
-		maxClients := 0
-		for _, s := range ac.Servers {
-			maxClients += s.MaxClients
-		}
-		if maxClients > fileLimit {
-			mainlog.Fatalf("Combined max clients for all servers (%d) is greater than open file limit (%d). "+
-				"Please increase your open file limit or decrease max clients.", maxClients, fileLimit)
-		}
+	if ok, maxClients, fileLimit := guerrilla.CheckFileLimit(c); !ok {
+		mainlog.Fatalf("Combined max clients for all servers (%d) is greater than open file limit (%d). "+
+			"Please increase your open file limit or decrease max clients.", maxClients, fileLimit)
 	}
 
 	err = d.Start()
@@ -136,7 +134,7 @@ func readConfig(path string, pidFile string) (*guerrilla.AppConfig, error) {
 	// command line flags can override config values
 	appConfig, err := d.LoadConfig(path)
 	if err != nil {
-		return &appConfig, fmt.Errorf("Could not read config file: %s", err.Error())
+		return &appConfig, fmt.Errorf("could not read config file: %s", err.Error())
 	}
 	// override config pidFile with with flag from the command line
 	if len(pidFile) > 0 {
@@ -149,16 +147,3 @@ func readConfig(path string, pidFile string) (*guerrilla.AppConfig, error) {
 	}
 	return &appConfig, nil
 }
-
-func getFileLimit() int {
-	cmd := exec.Command("ulimit", "-n")
-	out, err := cmd.Output()
-	if err != nil {
-		return -1
-	}
-	limit, err := strconv.Atoi(strings.TrimSpace(string(out)))
-	if err != nil {
-		return -1
-	}
-	return limit
-}

File diff suppressed because it is too large
+ 435 - 198
cmd/guerrillad/serve_test.go


+ 1 - 1
cmd/guerrillad/version.go

@@ -3,7 +3,7 @@ package main
 import (
 	"github.com/spf13/cobra"
 
-	guerrilla "github.com/flashmob/go-guerrilla"
+	"github.com/flashmob/go-guerrilla"
 )
 
 var versionCmd = &cobra.Command{

+ 33 - 17
config.go

@@ -8,6 +8,7 @@ import (
 	"os"
 	"reflect"
 	"strings"
+	"time"
 
 	"github.com/flashmob/go-guerrilla/backends"
 	"github.com/flashmob/go-guerrilla/dashboard"
@@ -52,8 +53,8 @@ type ServerConfig struct {
 	// Listen interface specified in <ip>:<port> - defaults to 127.0.0.1:2525
 	ListenInterface string `json:"listen_interface"`
 
-	// MaxClients controls how many maxiumum clients we can handle at once.
-	// Defaults to 100
+	// MaxClients controls how many maximum clients we can handle at once.
+	// Defaults to defaultMaxClients
 	MaxClients int `json:"max_clients"`
 	// LogFile is where the logs go. Use path to file, or "stderr", "stdout" or "off".
 	// defaults to AppConfig.Log file setting
@@ -91,12 +92,12 @@ type ServerTLSConfig struct {
 	// Use Go's default if empty
 	ClientAuthType string `json:"client_auth_type,omitempty"`
 	// controls whether the server selects the
-	// client's most preferred ciphersuite
+	// client's most preferred cipher suite
 	PreferServerCipherSuites bool `json:"prefer_server_cipher_suites,omitempty"`
 
 	// The following used to watch certificate changes so that the TLS can be reloaded
-	_privateKeyFile_mtime int64
-	_publicKeyFile_mtime  int64
+	_privateKeyFileMtime int64
+	_publicKeyFileMtime  int64
 }
 
 // https://golang.org/pkg/crypto/tls/#pkg-constants
@@ -152,6 +153,11 @@ var TLSClientAuthTypes = map[string]tls.ClientAuthType{
 	"RequireAndVerifyClientCert": tls.RequireAndVerifyClientCert,
 }
 
+const defaultMaxClients = 100
+const defaultTimeout = 30
+const defaultInterface = "127.0.0.1:2525"
+const defaultMaxSize = int64(10 << 20) // 10 Mebibytes
+
 // Unmarshalls json data into AppConfig struct and any other initialization of the struct
 // also does validation, returns error if validation failed or something went wrong
 func (c *AppConfig) Load(jsonBytes []byte) error {
@@ -175,7 +181,9 @@ func (c *AppConfig) Load(jsonBytes []byte) error {
 
 	// read the timestamps for the ssl keys, to determine if they need to be reloaded
 	for i := 0; i < len(c.Servers); i++ {
-		c.Servers[i].loadTlsKeyTimestamps()
+		if err := c.Servers[i].loadTlsKeyTimestamps(); err != nil {
+			return err
+		}
 	}
 	return nil
 }
@@ -222,8 +230,8 @@ func (c *AppConfig) EmitChangeEvents(oldConfig *AppConfig, app Guerrilla) {
 
 	}
 	// remove any servers that don't exist anymore
-	for _, oldserver := range oldServers {
-		app.Publish(EventConfigServerRemove, oldserver)
+	for _, oldServer := range oldServers {
+		app.Publish(EventConfigServerRemove, oldServer)
 	}
 }
 
@@ -279,9 +287,9 @@ func (c *AppConfig) setDefaults() error {
 		sc.ListenInterface = defaultInterface
 		sc.IsEnabled = true
 		sc.Hostname = h
-		sc.MaxClients = 100
-		sc.Timeout = 30
-		sc.MaxSize = 10 << 20 // 10 Mebibytes
+		sc.MaxClients = defaultMaxClients
+		sc.Timeout = defaultTimeout
+		sc.MaxSize = defaultMaxSize
 		c.Servers = append(c.Servers, sc)
 	} else {
 		// make sure each server has defaults correctly configured
@@ -290,13 +298,13 @@ func (c *AppConfig) setDefaults() error {
 				c.Servers[i].Hostname = h
 			}
 			if c.Servers[i].MaxClients == 0 {
-				c.Servers[i].MaxClients = 100
+				c.Servers[i].MaxClients = defaultMaxClients
 			}
 			if c.Servers[i].Timeout == 0 {
-				c.Servers[i].Timeout = 20
+				c.Servers[i].Timeout = defaultTimeout
 			}
 			if c.Servers[i].MaxSize == 0 {
-				c.Servers[i].MaxSize = 10 << 20 // 10 Mebibytes
+				c.Servers[i].MaxSize = defaultMaxSize // 10 Mebibytes
 			}
 			if c.Servers[i].ListenInterface == "" {
 				return errors.New(fmt.Sprintf("Listen interface not specified for server at index %d", i))
@@ -410,13 +418,21 @@ func (sc *ServerConfig) loadTlsKeyTimestamps() error {
 				iface,
 				err.Error()))
 	}
+	if sc.TLS.PrivateKeyFile == "" {
+		sc.TLS._privateKeyFileMtime = time.Now().Unix()
+		return nil
+	}
+	if sc.TLS.PublicKeyFile == "" {
+		sc.TLS._publicKeyFileMtime = time.Now().Unix()
+		return nil
+	}
 	if info, err := os.Stat(sc.TLS.PrivateKeyFile); err == nil {
-		sc.TLS._privateKeyFile_mtime = info.ModTime().Unix()
+		sc.TLS._privateKeyFileMtime = info.ModTime().Unix()
 	} else {
 		return statErr(sc.ListenInterface, err)
 	}
 	if info, err := os.Stat(sc.TLS.PublicKeyFile); err == nil {
-		sc.TLS._publicKeyFile_mtime = info.ModTime().Unix()
+		sc.TLS._publicKeyFileMtime = info.ModTime().Unix()
 	} else {
 		return statErr(sc.ListenInterface, err)
 	}
@@ -449,7 +465,7 @@ func (sc *ServerConfig) Validate() error {
 // Gets the timestamp of the TLS certificates. Returns a unix time of when they were last modified
 // when the config was read. We use this info to determine if TLS needs to be re-loaded.
 func (stc *ServerTLSConfig) getTlsKeyTimestamps() (int64, int64) {
-	return stc._privateKeyFile_mtime, stc._publicKeyFile_mtime
+	return stc._privateKeyFileMtime, stc._publicKeyFileMtime
 }
 
 // Returns value changes between struct a & struct b.

+ 47 - 18
config_test.go

@@ -11,10 +11,6 @@ import (
 	"time"
 )
 
-func init() {
-	testcert.GenerateCert("mail2.guerrillamail.com", "", 365*24*time.Hour, false, 2048, "P256", "./tests/")
-}
-
 // a configuration file with a dummy backend
 
 //
@@ -114,8 +110,8 @@ var configJsonB = `
             "listen_interface":"127.0.0.1:2526",
             "max_clients": 3,
 			"tls" : {
- 				"private_key_file":"./config_test.go",
-            	"public_key_file":"./config_test.go",
+ 				"private_key_file":"./tests/mail2.guerrillamail.com.key.pem",
+            	"public_key_file": "./tests/mail2.guerrillamail.com.cert.pem",
 				"start_tls_on":false,
             	"tls_always_on":true
 			}
@@ -161,7 +157,7 @@ var configJsonB = `
 			"tls" : {
 				"private_key_file":"config_test.go",
             	"public_key_file":"config_test.go",
-				"start_tls_on":true,
+				"start_tls_on":false,
             	"tls_always_on":false
 			}
         }
@@ -170,6 +166,18 @@ var configJsonB = `
 `
 
 func TestConfigLoad(t *testing.T) {
+	if err := testcert.GenerateCert("mail2.guerrillamail.com", "", 365*24*time.Hour, false, 2048, "P256", "./tests/"); err != nil {
+		t.Error(err)
+	}
+	defer func() {
+		if err := deleteIfExists("../tests/mail2.guerrillamail.com.cert.pem"); err != nil {
+			t.Error(err)
+		}
+		if err := deleteIfExists("../tests/mail2.guerrillamail.com.key.pem"); err != nil {
+			t.Error(err)
+		}
+	}()
+
 	ac := &AppConfig{}
 	if err := ac.Load([]byte(configJsonA)); err != nil {
 		t.Error("Cannot load config |", err)
@@ -181,8 +189,8 @@ func TestConfigLoad(t *testing.T) {
 		t.SkipNow()
 	}
 	// did we got the timestamps?
-	if ac.Servers[0].TLS._privateKeyFile_mtime <= 0 {
-		t.Error("failed to read timestamp for _privateKeyFile_mtime, got", ac.Servers[0].TLS._privateKeyFile_mtime)
+	if ac.Servers[0].TLS._privateKeyFileMtime <= 0 {
+		t.Error("failed to read timestamp for _privateKeyFileMtime, got", ac.Servers[0].TLS._privateKeyFileMtime)
 	}
 }
 
@@ -205,9 +213,22 @@ func TestSampleConfig(t *testing.T) {
 
 // make sure that we get all the config change events
 func TestConfigChangeEvents(t *testing.T) {
+	if err := testcert.GenerateCert("mail2.guerrillamail.com", "", 365*24*time.Hour, false, 2048, "P256", "./tests/"); err != nil {
+		t.Error(err)
+	}
+	defer func() {
+		if err := deleteIfExists("../tests/mail2.guerrillamail.com.cert.pem"); err != nil {
+			t.Error(err)
+		}
+		if err := deleteIfExists("../tests/mail2.guerrillamail.com.key.pem"); err != nil {
+			t.Error(err)
+		}
+	}()
 
 	oldconf := &AppConfig{}
-	oldconf.Load([]byte(configJsonA))
+	if err := oldconf.Load([]byte(configJsonA)); err != nil {
+		t.Error(err)
+	}
 	logger, _ := log.GetLogger(oldconf.LogFile, oldconf.LogLevel)
 	bcfg := backends.BackendConfig{"log_received_mails": true}
 	backend, err := backends.New(bcfg, logger)
@@ -221,10 +242,16 @@ func TestConfigChangeEvents(t *testing.T) {
 	// simulate timestamp change
 
 	time.Sleep(time.Second + time.Millisecond*500)
-	os.Chtimes(oldconf.Servers[1].TLS.PrivateKeyFile, time.Now(), time.Now())
-	os.Chtimes(oldconf.Servers[1].TLS.PublicKeyFile, time.Now(), time.Now())
+	if err := os.Chtimes(oldconf.Servers[1].TLS.PrivateKeyFile, time.Now(), time.Now()); err != nil {
+		t.Error(err)
+	}
+	if err := os.Chtimes(oldconf.Servers[1].TLS.PublicKeyFile, time.Now(), time.Now()); err != nil {
+		t.Error(err)
+	}
 	newconf := &AppConfig{}
-	newconf.Load([]byte(configJsonB))
+	if err := newconf.Load([]byte(configJsonB)); err != nil {
+		t.Error(err)
+	}
 	newconf.Servers[0].LogFile = log.OutputOff.String() // test for log file change
 	newconf.LogLevel = log.InfoLevel.String()
 	newconf.LogFile = "off"
@@ -253,14 +280,14 @@ func TestConfigChangeEvents(t *testing.T) {
 				f := func(c *AppConfig) {
 					expectedEvents[e] = true
 				}
-				app.Subscribe(event, f)
+				_ = app.Subscribe(event, f)
 				toUnsubscribe[event] = f
 			} else {
 				// must be a server config change then
 				f := func(c *ServerConfig) {
 					expectedEvents[e] = true
 				}
-				app.Subscribe(event, f)
+				_ = app.Subscribe(event, f)
 				toUnsubscribeSrv[event] = f
 			}
 
@@ -271,10 +298,10 @@ func TestConfigChangeEvents(t *testing.T) {
 	newconf.EmitChangeEvents(oldconf, app)
 	// unsubscribe
 	for unevent, unfun := range toUnsubscribe {
-		app.Unsubscribe(unevent, unfun)
+		_ = app.Unsubscribe(unevent, unfun)
 	}
 	for unevent, unfun := range toUnsubscribeSrv {
-		app.Unsubscribe(unevent, unfun)
+		_ = app.Unsubscribe(unevent, unfun)
 	}
 	for event, val := range expectedEvents {
 		if val == false {
@@ -285,5 +312,7 @@ func TestConfigChangeEvents(t *testing.T) {
 	}
 
 	// don't forget to reset
-	os.Truncate(oldconf.LogFile, 0)
+	if err := os.Truncate(oldconf.LogFile, 0); err != nil {
+		t.Error(err)
+	}
 }

+ 14 - 14
glide.lock

@@ -1,18 +1,14 @@
-hash: d845af9d7a26647c61c850a305d94006a0528611a1ae81eccea766c432e7aac0
-updated: 2018-05-31T15:05:08.573435967+10:00
+hash: 3e14b9859f4843eb8d20af4208ed1a1ea2a7b69cdca427ea59a243914e54242c
+updated: 2019-01-31T01:06:42.940723143+11:00
 imports:
 - name: github.com/asaskevich/EventBus
   version: 68a521d7cbbb7a859c2608b06342f384b3bd5f5a
-- name: github.com/garyburd/redigo
-  version: 8873b2f1995f59d4bcdd2b0dc9858e2cb9bf0c13
-  subpackages:
-  - internal
-  - redis
 - name: github.com/go-sql-driver/mysql
-  version: a0583e0143b1624142adab07e0e97fe106d99561
+  version: 72cd26f257d44c1114970e19afddcd812016007e
 - name: github.com/gomodule/redigo
-  version: 8873b2f1995f59d4bcdd2b0dc9858e2cb9bf0c13
+  version: 9c11da706d9b7902c6da69c592f75637793fe121
   subpackages:
+  - internal
   - redis
 - name: github.com/gorilla/context
   version: 08b5f424b9271eedf6f9f0ce86cb9396ed337a42
@@ -23,21 +19,21 @@ imports:
 - name: github.com/inconshreveable/mousetrap
   version: 76626ae9c91c4f2a10f34cad8ce83ea42c93bb75
 - name: github.com/rakyll/statik
-  version: fd36b3595eb2ec8da4b8153b107f7ea08504899d
+  version: 1355192d24db2566a83c3914e187e2a7e7679832
   subpackages:
   - fs
 - name: github.com/sirupsen/logrus
-  version: c155da19408a8799da419ed3eeb0cb5db0ad5dbc
+  version: 3e01752db0189b9157070a0e1668a620f9a85da2
 - name: github.com/spf13/cobra
   version: b62566898a99f2db9c68ed0026aa0a052e59678d
 - name: github.com/spf13/pflag
   version: 25f8b5b07aece3207895bf19f7ab517eb3b22a40
 - name: golang.org/x/crypto
-  version: ab813273cd59e1333f7ae7bff5d027d4aadf528c
+  version: b01c7a72566457eb1420261cdafef86638fc3861
   subpackages:
   - ssh/terminal
 - name: golang.org/x/net
-  version: 1e491301e022f8f977054da4c2d852decd59571f
+  version: d26f9f9a57f3fab6a695bec0d84433c2c50f8bbf
   subpackages:
   - html
   - html/atom
@@ -48,7 +44,7 @@ imports:
   - unix
   - windows
 - name: golang.org/x/text
-  version: 5c1cf69b5978e5a34c5f9ba09a83e56acc4b7877
+  version: e6919f6577db79269a6443b9dc46d18f2238fb5d
   subpackages:
   - encoding
   - encoding/charmap
@@ -67,6 +63,10 @@ imports:
   - language
   - runes
   - transform
+- name: google.golang.org/appengine
+  version: e9657d882bb81064595ca3b56cbe2546bbabf7b1
+  subpackages:
+  - cloudsql
 - name: gopkg.in/iconv.v1
   version: 16a760eb7e186ae0e3aedda00d4a1daa4d0701d8
 testImports: []

+ 1 - 1
glide.yaml

@@ -3,7 +3,7 @@ import:
 - package: github.com/sirupsen/logrus
   version: ~1.0.4
 - package: github.com/gomodule/redigo
-  version: ~1.0.0
+  version: ~2.0.0
   subpackages:
   - redis
 - package: github.com/spf13/cobra

+ 113 - 52
guerrilla.go

@@ -13,12 +13,12 @@ import (
 )
 
 const (
-	// server has just been created
-	GuerrillaStateNew = iota
-	// Server has been started and is running
-	GuerrillaStateStarted
-	// Server has just been stopped
-	GuerrillaStateStopped
+	// all configured servers were just been created
+	daemonStateNew = iota
+	// ... been started and running
+	daemonStateStarted
+	// ... been stopped
+	daemonStateStopped
 )
 
 type Errors []error
@@ -64,6 +64,9 @@ type backendStore struct {
 	atomic.Value
 }
 
+type daemonEvent func(c *AppConfig)
+type serverEvent func(sc *ServerConfig)
+
 // Get loads the log.logger in an atomic operation. Returns a stderr logger if not able to load
 func (ls *logStore) mainlog() log.Logger {
 	if v, ok := ls.Load().(log.Logger); ok {
@@ -73,7 +76,7 @@ func (ls *logStore) mainlog() log.Logger {
 	return l
 }
 
-// storeMainlog stores the log value in an atomic operation
+// setMainlog stores the log value in an atomic operation
 func (ls *logStore) setMainlog(log log.Logger) {
 	ls.Store(log)
 }
@@ -98,17 +101,18 @@ func New(ac *AppConfig, b backends.Backend, l log.Logger) (Guerrilla, error) {
 			}
 		}
 	}
+	// Write the process id (pid) to a file
+	// we should still be able to continue even if we can't write the pid, error will be logged by writePid()
+	_ = g.writePid()
 
-	g.state = GuerrillaStateNew
+	g.state = daemonStateNew
 	err := g.makeServers()
 
 	// start backend for processing email
 	err = g.backend().Start()
-
 	if err != nil {
 		return g, err
 	}
-	g.writePid()
 
 	// subscribe for any events that may come in while running
 	g.subscribeEvents()
@@ -142,7 +146,7 @@ func (g *guerrilla) makeServers() error {
 		}
 	}
 	if len(g.servers) == 0 {
-		errs = append(errs, errors.New("There are no servers that can start, please check your config"))
+		errs = append(errs, errors.New("there are no servers that can start, please check your config"))
 	}
 	if len(errs) == 0 {
 		return nil
@@ -197,13 +201,13 @@ func (g *guerrilla) mapServers(callback func(*server)) map[string]*server {
 // subscribeEvents subscribes event handlers for configuration change events
 func (g *guerrilla) subscribeEvents() {
 
+	events := map[Event]interface{}{}
 	// main config changed
-	g.Subscribe(EventConfigNewConfig, func(c *AppConfig) {
+	events[EventConfigNewConfig] = daemonEvent(func(c *AppConfig) {
 		g.setConfig(c)
 	})
-
 	// allowed_hosts changed, set for all servers
-	g.Subscribe(EventConfigAllowedHosts, func(c *AppConfig) {
+	events[EventConfigAllowedHosts] = daemonEvent(func(c *AppConfig) {
 		g.mapServers(func(server *server) {
 			server.setAllowedHosts(c.AllowedHosts)
 		})
@@ -211,7 +215,7 @@ func (g *guerrilla) subscribeEvents() {
 	})
 
 	// the main log file changed
-	g.Subscribe(EventConfigLogFile, func(c *AppConfig) {
+	events[EventConfigLogFile] = daemonEvent(func(c *AppConfig) {
 		var err error
 		var l log.Logger
 		if l, err = log.GetLogger(c.LogFile, c.LogLevel); err == nil {
@@ -231,13 +235,17 @@ func (g *guerrilla) subscribeEvents() {
 	})
 
 	// re-open the main log file (file not changed)
-	g.Subscribe(EventConfigLogReopen, func(c *AppConfig) {
-		g.mainlog().Reopen()
+	events[EventConfigLogReopen] = daemonEvent(func(c *AppConfig) {
+		err := g.mainlog().Reopen()
+		if err != nil {
+			g.mainlog().WithError(err).Errorf("main log file [%s] failed to re-open", c.LogFile)
+			return
+		}
 		g.mainlog().Infof("re-opened main log file [%s]", c.LogFile)
 	})
 
 	// when log level changes, apply to mainlog and server logs
-	g.Subscribe(EventConfigLogLevel, func(c *AppConfig) {
+	events[EventConfigLogLevel] = daemonEvent(func(c *AppConfig) {
 		l, err := log.GetLogger(g.mainlog().GetLogDest(), c.LogLevel)
 		if err == nil {
 			if c.Dashboard.Enabled {
@@ -252,18 +260,18 @@ func (g *guerrilla) subscribeEvents() {
 	})
 
 	// write out our pid whenever the file name changes in the config
-	g.Subscribe(EventConfigPidFile, func(ac *AppConfig) {
-		g.writePid()
+	events[EventConfigPidFile] = daemonEvent(func(ac *AppConfig) {
+		_ = g.writePid()
 	})
 
 	// server config was updated
-	g.Subscribe(EventConfigServerConfig, func(sc *ServerConfig) {
+	events[EventConfigServerConfig] = serverEvent(func(sc *ServerConfig) {
 		g.setServerConfig(sc)
 		g.mainlog().Infof("server %s config change event, a new config has been saved", sc.ListenInterface)
 	})
 
 	// add a new server to the config & start
-	g.Subscribe(EventConfigServerNew, func(sc *ServerConfig) {
+	events[EventConfigServerNew] = serverEvent(func(sc *ServerConfig) {
 		g.mainlog().Debugf("event fired [%s] %s", EventConfigServerNew, sc.ListenInterface)
 		if _, err := g.findServer(sc.ListenInterface); err != nil {
 			// not found, lets add it
@@ -273,7 +281,7 @@ func (g *guerrilla) subscribeEvents() {
 				return
 			}
 			g.mainlog().Infof("New server added [%s]", sc.ListenInterface)
-			if g.state == GuerrillaStateStarted {
+			if g.state == daemonStateStarted {
 				err := g.Start()
 				if err != nil {
 					g.mainlog().WithError(err).Info("Event server_change:new_server returned errors when starting")
@@ -283,8 +291,9 @@ func (g *guerrilla) subscribeEvents() {
 			g.mainlog().Debugf("new event, but server already fund")
 		}
 	})
+
 	// start a server that already exists in the config and has been enabled
-	g.Subscribe(EventConfigServerStart, func(sc *ServerConfig) {
+	events[EventConfigServerStart] = serverEvent(func(sc *ServerConfig) {
 		if server, err := g.findServer(sc.ListenInterface); err == nil {
 			if server.state == ServerStateStopped || server.state == ServerStateNew {
 				g.mainlog().Infof("Starting server [%s]", server.listenInterface)
@@ -295,8 +304,9 @@ func (g *guerrilla) subscribeEvents() {
 			}
 		}
 	})
+
 	// stop running a server
-	g.Subscribe(EventConfigServerStop, func(sc *ServerConfig) {
+	events[EventConfigServerStop] = serverEvent(func(sc *ServerConfig) {
 		if server, err := g.findServer(sc.ListenInterface); err == nil {
 			if server.state == ServerStateRunning {
 				server.Shutdown()
@@ -304,8 +314,9 @@ func (g *guerrilla) subscribeEvents() {
 			}
 		}
 	})
+
 	// server was removed from config
-	g.Subscribe(EventConfigServerRemove, func(sc *ServerConfig) {
+	events[EventConfigServerRemove] = serverEvent(func(sc *ServerConfig) {
 		if server, err := g.findServer(sc.ListenInterface); err == nil {
 			server.Shutdown()
 			g.removeServer(sc.ListenInterface)
@@ -314,7 +325,7 @@ func (g *guerrilla) subscribeEvents() {
 	})
 
 	// TLS changes
-	g.Subscribe(EventConfigServerTLSConfig, func(sc *ServerConfig) {
+	events[EventConfigServerTLSConfig] = serverEvent(func(sc *ServerConfig) {
 		if server, err := g.findServer(sc.ListenInterface); err == nil {
 			if err := server.configureSSL(); err == nil {
 				g.mainlog().Infof("Server [%s] new TLS configuration loaded", sc.ListenInterface)
@@ -324,19 +335,19 @@ func (g *guerrilla) subscribeEvents() {
 		}
 	})
 	// when server's timeout change.
-	g.Subscribe(EventConfigServerTimeout, func(sc *ServerConfig) {
+	events[EventConfigServerTimeout] = serverEvent(func(sc *ServerConfig) {
 		g.mapServers(func(server *server) {
 			server.setTimeout(sc.Timeout)
 		})
 	})
 	// when server's max clients change.
-	g.Subscribe(EventConfigServerMaxClients, func(sc *ServerConfig) {
+	events[EventConfigServerMaxClients] = serverEvent(func(sc *ServerConfig) {
 		g.mapServers(func(server *server) {
 			// TODO resize the pool somehow
 		})
 	})
 	// when a server's log file changes
-	g.Subscribe(EventConfigServerLogFile, func(sc *ServerConfig) {
+	events[EventConfigServerLogFile] = serverEvent(func(sc *ServerConfig) {
 		if server, err := g.findServer(sc.ListenInterface); err == nil {
 			var err error
 			var l log.Logger
@@ -360,14 +371,17 @@ func (g *guerrilla) subscribeEvents() {
 		}
 	})
 	// when the daemon caught a sighup, event for individual server
-	g.Subscribe(EventConfigServerLogReopen, func(sc *ServerConfig) {
+	events[EventConfigServerLogReopen] = serverEvent(func(sc *ServerConfig) {
 		if server, err := g.findServer(sc.ListenInterface); err == nil {
-			server.log().Reopen()
+			if err = server.log().Reopen(); err != nil {
+				g.mainlog().WithError(err).Errorf("server [%s] log file [%s] failed to re-open", sc.ListenInterface, sc.LogFile)
+				return
+			}
 			g.mainlog().Infof("Server [%s] re-opened log file [%s]", sc.ListenInterface, sc.LogFile)
 		}
 	})
 	// when the backend changes
-	g.Subscribe(EventConfigBackendConfig, func(appConfig *AppConfig) {
+	events[EventConfigBackendConfig] = daemonEvent(func(appConfig *AppConfig) {
 		logger, _ := log.GetLogger(appConfig.LogFile, appConfig.LogLevel)
 		// shutdown the backend first.
 		var err error
@@ -398,6 +412,19 @@ func (g *guerrilla) subscribeEvents() {
 			g.storeBackend(newBackend)
 		}
 	})
+	var err error
+	for topic, fn := range events {
+		switch f := fn.(type) {
+		case daemonEvent:
+			err = g.Subscribe(topic, f)
+		case serverEvent:
+			err = g.Subscribe(topic, f)
+		}
+		if err != nil {
+			g.mainlog().WithError(err).Errorf("failed to subscribe on topic [%s]", topic)
+			break
+		}
+	}
 
 }
 
@@ -420,16 +447,20 @@ func (g *guerrilla) Start() error {
 	var startErrors Errors
 	g.guard.Lock()
 	defer func() {
-		g.state = GuerrillaStateStarted
+		g.state = daemonStateStarted
 		g.guard.Unlock()
 	}()
 	if len(g.servers) == 0 {
-		return append(startErrors, errors.New("No servers to start, please check the config"))
+		return append(startErrors, errors.New("no servers to start, please check the config"))
 	}
-	if g.state == GuerrillaStateStopped {
+	if g.state == daemonStateStopped {
 		// when a backend is shutdown, we need to re-initialize before it can be started again
-		g.backend().Reinitialize()
-		g.backend().Start()
+		if err := g.backend().Reinitialize(); err != nil {
+			startErrors = append(startErrors, err)
+		}
+		if err := g.backend().Start(); err != nil {
+			startErrors = append(startErrors, err)
+		}
 	}
 	// channel for reading errors
 	errs := make(chan error, len(g.servers))
@@ -485,7 +516,7 @@ func (g *guerrilla) Shutdown() {
 
 	g.guard.Lock()
 	defer func() {
-		g.state = GuerrillaStateStopped
+		g.state = daemonStateStopped
 		defer g.guard.Unlock()
 	}()
 	if err := g.backend().Shutdown(); err != nil {
@@ -503,22 +534,52 @@ func (g *guerrilla) SetLogger(l log.Logger) {
 
 // writePid writes the pid (process id) to the file specified in the config.
 // Won't write anything if no file specified
-func (g *guerrilla) writePid() error {
-	if len(g.Config.PidFile) > 0 {
-		if f, err := os.Create(g.Config.PidFile); err == nil {
-			defer f.Close()
-			pid := os.Getpid()
-			if _, err := f.WriteString(fmt.Sprintf("%d", pid)); err == nil {
-				f.Sync()
-				g.mainlog().Infof("pid_file (%s) written with pid:%v", g.Config.PidFile, pid)
-			} else {
-				g.mainlog().WithError(err).Errorf("Error while writing pidFile (%s)", g.Config.PidFile)
-				return err
+func (g *guerrilla) writePid() (err error) {
+	var f *os.File
+	defer func() {
+		if f != nil {
+			if closeErr := f.Close(); closeErr != nil {
+				err = closeErr
 			}
-		} else {
-			g.mainlog().WithError(err).Errorf("Error while creating pidFile (%s)", g.Config.PidFile)
+		}
+		if err != nil {
+			g.mainlog().WithError(err).Errorf("error while writing pidFile (%s)", g.Config.PidFile)
+		}
+	}()
+	if len(g.Config.PidFile) > 0 {
+		if f, err = os.Create(g.Config.PidFile); err != nil {
+			return err
+		}
+		pid := os.Getpid()
+		if _, err := f.WriteString(fmt.Sprintf("%d", pid)); err != nil {
 			return err
 		}
+		if err = f.Sync(); err != nil {
+			return err
+		}
+		g.mainlog().Infof("pid_file (%s) written with pid:%v", g.Config.PidFile, pid)
 	}
 	return nil
 }
+
+// CheckFileLimit checks the number of files we can open (works on OS'es that support the ulimit command)
+func CheckFileLimit(c *AppConfig) (bool, int, uint64) {
+	fileLimit, err := getFileLimit()
+	maxClients := 0
+	if err != nil {
+		// since we can't get the limit, return true to indicate the check passed
+		return true, maxClients, fileLimit
+	}
+	if c.Servers == nil {
+		// no servers have been configured, assuming default
+		maxClients = defaultMaxClients
+	} else {
+		for _, s := range c.Servers {
+			maxClients += s.MaxClients
+		}
+	}
+	if uint64(maxClients) > fileLimit {
+		return false, maxClients, fileLimit
+	}
+	return true, maxClients, fileLimit
+}

+ 16 - 0
guerrilla_notunix.go

@@ -0,0 +1,16 @@
+// +build !darwin
+// +build !dragonfly
+// +build !freebsd
+// +build !linux
+// +build !netbsd
+// +build !openbsd
+
+package guerrilla
+
+import "errors"
+
+// getFileLimit checks how many files we can open
+// Don't know how to get that info (yet?), so returns false information & error
+func getFileLimit() (uint64, error) {
+	return 1000000, errors.New("syscall.RLIMIT_NOFILE not supported on your OS/platform")
+}

+ 15 - 0
guerrilla_unix.go

@@ -0,0 +1,15 @@
+// +build darwin dragonfly freebsd linux netbsd openbsd
+
+package guerrilla
+
+import "syscall"
+
+// getFileLimit checks how many files we can open
+func getFileLimit() (uint64, error) {
+	var rLimit syscall.Rlimit
+	err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit)
+	if err != nil {
+		return 0, err
+	}
+	return rLimit.Max, nil
+}

+ 1 - 1
log/hook.go

@@ -153,7 +153,7 @@ func (hook *LogrusHook) Fire(entry *log.Entry) error {
 				return err
 			}
 			if hook.fd != nil {
-				hook.fd.Sync()
+				err = hook.fd.Sync()
 			}
 		}
 		return err

+ 1 - 2
log/log.go

@@ -221,8 +221,7 @@ func (l *HookedLogger) GetLogDest() string {
 
 // WithConn extends logrus to be able to log with a net.Conn
 func (l *HookedLogger) WithConn(conn net.Conn) *log.Entry {
-	var addr string = "unknown"
-
+	var addr = "unknown"
 	if conn != nil {
 		addr = conn.RemoteAddr().String()
 	}

+ 15 - 10
mail/envelope.go

@@ -27,12 +27,20 @@ func init() {
 	Dec = mime.WordDecoder{}
 }
 
-const maxHeaderChunk = 1 + (3 << 10) // 3KB
+const maxHeaderChunk = 1 + (4 << 10) // 4KB
 
 // Address encodes an email address of the form `<user@host>`
 type Address struct {
+	// User is local part
 	User string
+	// Host is the domain
 	Host string
+	// ADL is at-domain list if matched
+	ADL []string
+	// PathParams contains any ESTMP parameters that were matched
+	PathParams [][]string
+	// NullPath is true if <> was received
+	NullPath bool
 }
 
 func (ep *Address) String() string {
@@ -114,18 +122,15 @@ func (e *Envelope) ParseHeaders() error {
 	if e.Header != nil {
 		return errors.New("headers already parsed")
 	}
-	buf := bytes.NewBuffer(e.Data.Bytes())
+	buf := e.Data.Bytes()
 	// find where the header ends, assuming that over 30 kb would be max
-	max := maxHeaderChunk
-	if buf.Len() < max {
-		max = buf.Len()
+	if len(buf) > maxHeaderChunk {
+		buf = buf[:maxHeaderChunk]
 	}
-	// read in the chunk which we'll scan for the header
-	chunk := make([]byte, max)
-	buf.Read(chunk)
-	headerEnd := strings.Index(string(chunk), "\n\n") // the first two new-lines chars are the End Of Header
+
+	headerEnd := bytes.Index(buf, []byte{'\n', '\n'}) // the first two new-lines chars are the End Of Header
 	if headerEnd > -1 {
-		header := chunk[0:headerEnd]
+		header := buf[0:headerEnd]
 		headerReader := textproto.NewReader(bufio.NewReader(bytes.NewBuffer(header)))
 		e.Header, err = headerReader.ReadMIMEHeader()
 		if err != nil {

+ 6 - 2
mail/envelope_test.go

@@ -1,6 +1,7 @@
 package mail
 
 import (
+	"io"
 	"io/ioutil"
 	"strings"
 	"testing"
@@ -61,9 +62,12 @@ func TestEnvelope(t *testing.T) {
 
 	data, _ := ioutil.ReadAll(r)
 	if len(data) != e.Len() {
-		t.Error("e.Len() is inccorrect, it shown ", e.Len(), " but we wanted ", len(data))
+		t.Error("e.Len() is incorrect, it shown ", e.Len(), " but we wanted ", len(data))
+	}
+	if err := e.ParseHeaders(); err != nil && err != io.EOF {
+		t.Error("cannot parse headers:", err)
+		return
 	}
-	e.ParseHeaders()
 	if e.Subject != "Test" {
 		t.Error("Subject expecting: Test, got:", e.Subject)
 	}

+ 595 - 0
mail/rfc5321/parse.go

@@ -0,0 +1,595 @@
+package rfc5321
+
+// Parse RFC5321 productions, no regex
+
+import (
+	"bytes"
+	"errors"
+	"net"
+	"strconv"
+)
+
+const (
+	// The maximum total length of a reverse-path or forward-path is 256
+	LimitPath = 256
+	// The maximum total length of a user name or other local-part is 64
+	// however, here we double it, since a few major services don't respect that and go over
+	LimitLocalPart = 64 * 2
+	// //The maximum total length of a domain name or number is 255
+	LimitDomain = 255
+	// The minimum total number of recipients that must be buffered is 100
+	LimitRecipients = 100
+)
+
+// Parse Email Addresses according to https://tools.ietf.org/html/rfc5321
+type Parser struct {
+	NullPath  bool
+	LocalPart string
+	Domain    string
+
+	ADL        []string
+	PathParams [][]string
+
+	pos int
+	ch  byte
+
+	buf    []byte
+	accept bytes.Buffer
+}
+
+func NewParser(buf []byte) *Parser {
+	s := new(Parser)
+	s.buf = buf
+	s.pos = -1
+	return s
+}
+
+func (s *Parser) Reset() {
+	s.buf = s.buf[:0]
+	if s.pos != -1 {
+		s.pos = -1
+		s.ADL = nil
+		s.PathParams = nil
+		s.NullPath = false
+		s.LocalPart = ""
+		s.Domain = ""
+		s.accept.Reset()
+	}
+}
+
+func (s *Parser) set(input []byte) {
+	s.Reset()
+	s.buf = input
+}
+
+func (s *Parser) next() byte {
+	s.pos++
+	if s.pos < len(s.buf) {
+		s.ch = s.buf[s.pos]
+		return s.ch
+	}
+	return 0
+}
+
+func (s *Parser) peek() byte {
+	if s.pos+1 < len(s.buf) {
+		return s.buf[s.pos+1]
+	}
+	return 0
+}
+
+func (s *Parser) reversePath() (err error) {
+	if s.peek() == ' ' {
+		s.next() // tolerate a space at the front
+	}
+	if i := bytes.Index(s.buf[s.pos+1:], []byte{'<', '>'}); i == 0 {
+		s.NullPath = true
+		return nil
+	}
+	if err = s.path(); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (s *Parser) forwardPath() (err error) {
+	if s.peek() == ' ' {
+		s.next() // tolerate a space at the front
+	}
+	if i := bytes.Index(bytes.ToLower(s.buf[s.pos+1:]), []byte(postmasterPath)); i == 0 {
+		s.LocalPart = postmasterLocalPart
+		return nil
+	}
+	if err = s.path(); err != nil {
+		return err
+	}
+	return nil
+}
+
+//MailFrom accepts the following syntax: Reverse-path [SP Mail-parameters] CRLF
+func (s *Parser) MailFrom(input []byte) (err error) {
+	s.set(input)
+	if err := s.reversePath(); err != nil {
+		return err
+	}
+	s.next()
+	if p := s.next(); p == ' ' {
+		// parse Rcpt-parameters
+		// The optional <mail-parameters> are associated with negotiated SMTP
+		//  service extensions
+		if tup, err := s.parameters(); err != nil {
+			return errors.New("param parse error")
+		} else if len(tup) > 0 {
+			s.PathParams = tup
+		}
+	}
+	return nil
+}
+
+const postmasterPath = "<postmaster>"
+const postmasterLocalPart = "Postmaster"
+
+//RcptTo accepts the following syntax: ( "<Postmaster@" Domain ">" / "<Postmaster>" /
+//                  Forward-path ) [SP Rcpt-parameters] CRLF
+func (s *Parser) RcptTo(input []byte) (err error) {
+	s.set(input)
+	if err := s.forwardPath(); err != nil {
+		return err
+	}
+	s.next()
+	if p := s.next(); p == ' ' {
+		// parse Rcpt-parameters
+		if tup, err := s.parameters(); err != nil {
+			return errors.New("param parse error")
+		} else if len(tup) > 0 {
+			s.PathParams = tup
+		}
+	}
+	return nil
+}
+
+// esmtp-param *(SP esmtp-param)
+func (s *Parser) parameters() ([][]string, error) {
+	params := make([][]string, 0)
+	for {
+		if result, err := s.param(); err != nil {
+			return params, err
+		} else {
+			params = append(params, result)
+		}
+		if p := s.next(); p != ' ' {
+			return params, nil
+		}
+	}
+}
+
+func isESMTPValue(c byte) bool {
+	if ('!' <= c && c <= '<') ||
+		('>' <= c && c <= '~') {
+		return true
+	}
+	return false
+}
+
+// esmtp-param    = esmtp-keyword ["=" esmtp-value]
+// esmtp-keyword  = (ALPHA / DIGIT) *(ALPHA / DIGIT / "-")
+// esmtp-value    = 1*(%d33-60 / %d62-126)
+func (s *Parser) param() (result []string, err error) {
+	state := 0
+	var key, value string
+	defer func() {
+		result = append(result, key, value)
+		s.accept.Reset()
+	}()
+	for c := s.next(); ; c = s.next() {
+		switch state {
+		case 0:
+			// first char must be let-dig
+			if !isLetDig(c) {
+				return result, errors.New("parse error")
+			}
+			// accept
+			s.accept.WriteByte(c)
+			state = 1
+		case 1:
+			// *(ALPHA / DIGIT / "-")
+			if !isLetDig(c) {
+				if c == '=' {
+					key = s.accept.String()
+					s.accept.Reset()
+					state = 2
+					continue
+				} else if c == '-' {
+					// cannot have - at the end of a keyword
+					if p := s.peek(); !isLetDig(p) && p != '-' {
+						return result, errors.New("parse error")
+					}
+					s.accept.WriteByte(c)
+					continue
+
+				}
+				key = s.accept.String()
+				return result, nil
+			}
+			s.accept.WriteByte(c)
+		case 2:
+			// start of value, must match at least 1
+			if !isESMTPValue(c) {
+				return result, errors.New("parse error")
+			}
+			s.accept.WriteByte(c)
+			if !isESMTPValue(s.peek()) {
+				value = s.accept.String()
+				return result, nil
+			}
+			state = 3
+		case 3:
+			// 1*(%d33-60 / %d62-126)
+			s.accept.WriteByte(c)
+			if !isESMTPValue(s.peek()) {
+				value = s.accept.String()
+				return result, nil
+			}
+		}
+	}
+}
+
+// "<" [ A-d-l ":" ] Mailbox ">"
+func (s *Parser) path() (err error) {
+	if s.next() == '<' && s.peek() == '@' {
+		if err = s.adl(); err == nil {
+			s.next()
+			if s.ch != ':' {
+				return errors.New("syntax error")
+			}
+		}
+	}
+	if err = s.mailbox(); err != nil {
+		return err
+	}
+	if p := s.peek(); p != '>' {
+		return errors.New("missing closing >")
+	}
+	return nil
+}
+
+// At-domain *( "," At-domain )
+func (s *Parser) adl() error {
+	for {
+		if err := s.atDomain(); err != nil {
+			return err
+		}
+		s.ADL = append(s.ADL, s.accept.String())
+		s.accept.Reset()
+		if s.peek() != ',' {
+			break
+		}
+		s.next()
+	}
+	return nil
+}
+
+// At-domain = "@" Domain
+func (s *Parser) atDomain() error {
+	if s.next() == '@' {
+		s.accept.WriteByte('@')
+		return s.domain()
+	}
+	return errors.New("syntax error")
+}
+
+// sub-domain *("." sub-domain)
+func (s *Parser) domain() error {
+	for {
+		if err := s.subdomain(); err != nil {
+			return err
+		}
+		if p := s.peek(); p != '.' {
+			if p != ':' && p != ',' && p != '>' && p != 0 {
+				return errors.New("domain parse error")
+			}
+
+			break
+		}
+		s.accept.WriteByte(s.next())
+	}
+	return nil
+}
+
+// Let-dig [Ldh-str]
+func (s *Parser) subdomain() error {
+	state := 0
+	for c := s.next(); ; c = s.next() {
+		switch state {
+		case 0:
+			p := s.peek()
+			if isLetDig(c) {
+				s.accept.WriteByte(c)
+				if !isLetDig(p) && p != '-' {
+					return nil
+				}
+				state = 1
+				continue
+			}
+			return errors.New("parse err")
+		case 1:
+			p := s.peek()
+			if isLetDig(c) || c == '-' {
+				s.accept.WriteByte(c)
+			}
+			if !isLetDig(p) && p != '-' {
+				if c == '-' {
+					return errors.New("parse err")
+				}
+				return nil
+			}
+		}
+	}
+}
+
+// Local-part "@" ( Domain / address-literal )
+func (s *Parser) mailbox() error {
+	defer func() {
+		if s.accept.Len() > 0 {
+			s.Domain = s.accept.String()
+			s.accept.Reset()
+		}
+	}()
+	err := s.localPart()
+	if err != nil {
+		return err
+	}
+	if s.ch != '@' {
+		return errors.New("@ expected as part of mailbox")
+	}
+	if p := s.peek(); p == '[' {
+		return s.addressLiteral()
+	} else {
+		return s.domain()
+	}
+}
+
+// "[" ( IPv4-address-literal /
+//                    IPv6-address-literal /
+//                    General-address-literal ) "]"
+func (s *Parser) addressLiteral() error {
+	ch := s.next()
+	if ch == '[' {
+		p := s.peek()
+		var err error
+		if p == 'I' || p == 'i' {
+			for i := 0; i < 5; i++ {
+				s.next() // IPv6:
+			}
+			err = s.ipv6AddressLiteral()
+		} else if p >= 48 && p <= 57 {
+			err = s.ipv4AddressLiteral()
+		}
+		if err != nil {
+			return err
+		}
+		if s.ch != ']' {
+			return errors.New("] expected for address literal")
+		}
+		return nil
+	}
+	return nil
+}
+
+// Snum 3("."  Snum)
+func (s *Parser) ipv4AddressLiteral() error {
+	for i := 0; i < 4; i++ {
+		if err := s.snum(); err != nil {
+			return err
+		}
+		if s.ch != '.' {
+			break
+		}
+		s.accept.WriteByte(s.ch)
+	}
+	return nil
+}
+
+// 1*3DIGIT
+// representing a decimal integer
+// value accept the range 0 through 255
+func (s *Parser) snum() error {
+	state := 0
+	var num bytes.Buffer
+	for i := 4; i > 0; i-- {
+		c := s.next()
+		if state == 0 {
+			if !(c >= 48 && c <= 57) {
+				return errors.New("parse error")
+			} else {
+				num.WriteByte(s.ch)
+				s.accept.WriteByte(s.ch)
+				state = 1
+				continue
+			}
+		}
+		if state == 1 {
+			if !(c >= 48 && c <= 57) {
+				if v, err := strconv.Atoi(num.String()); err != nil {
+					return err
+				} else if v >= 0 && v <= 255 {
+					return nil
+				} else {
+					return errors.New("invalid ipv4")
+				}
+			} else {
+				num.WriteByte(s.ch)
+				s.accept.WriteByte(s.ch)
+			}
+		}
+	}
+	return errors.New("too many digits")
+}
+
+//IPv6:" IPv6-addr
+func (s *Parser) ipv6AddressLiteral() error {
+	var ip bytes.Buffer
+	for c := s.next(); ; c = s.next() {
+		if !(c >= 48 && c <= 57) &&
+			!(c >= 65 && c <= 70) &&
+			!(c >= 97 && c <= 102) &&
+			c != ':' && c != '.' {
+			ipstr := ip.String()
+			if v := net.ParseIP(ipstr); v != nil {
+				s.accept.WriteString(ipstr)
+				return nil
+			}
+			return errors.New("invalid ipv6")
+		} else {
+			ip.WriteByte(c)
+		}
+	}
+}
+
+// Dot-string / Quoted-string
+func (s *Parser) localPart() error {
+	defer func() {
+		if s.accept.Len() > 0 {
+			s.LocalPart = s.accept.String()
+			s.accept.Reset()
+		}
+	}()
+	p := s.peek()
+	if p == '"' {
+		return s.quotedString()
+	} else {
+		return s.dotString()
+	}
+}
+
+// DQUOTE *QcontentSMTP DQUOTE
+func (s *Parser) quotedString() error {
+	if s.next() == '"' {
+		if err := s.QcontentSMTP(); err != nil {
+			return err
+		}
+		if s.ch != '"' {
+			return errors.New("quoted string not closed")
+		} else {
+			// accept the "
+			s.next()
+		}
+	}
+	return nil
+}
+
+// qtextSMTP / quoted-pairSMTP
+// quoted-pairSMTP = %d92 %d32-126
+// qtextSMTP = %d32-33 / %d35-91 / %d93-126
+func (s *Parser) QcontentSMTP() error {
+	state := 0
+	for {
+		ch := s.next()
+		switch state {
+		case 0:
+			if ch == '\\' {
+				state = 1
+				s.accept.WriteByte(ch)
+				continue
+			} else if ch == 32 || ch == 33 ||
+				(ch >= 35 && ch <= 91) ||
+				(ch >= 93 && ch <= 126) {
+				s.accept.WriteByte(ch)
+				continue
+			}
+			return nil
+		case 1:
+			// escaped character state
+			if ch >= 32 && ch <= 126 {
+				s.accept.WriteByte(ch)
+				state = 0
+				continue
+			} else {
+				return errors.New("non-printable character found")
+			}
+		}
+	}
+}
+
+//Dot-string     = Atom *("."  Atom)
+func (s *Parser) dotString() error {
+	for {
+		if err := s.atom(); err != nil {
+			return err
+		}
+		if s.ch != '.' {
+			break
+		}
+		s.accept.WriteByte(s.ch)
+	}
+	return nil
+}
+
+// 1*atext
+func (s *Parser) atom() error {
+	state := 0
+	for {
+		if state == 0 {
+			if !s.isAtext(s.next()) {
+				return errors.New("parse error")
+			} else {
+				s.accept.WriteByte(s.ch)
+				state = 1
+				continue
+			}
+		}
+		if state == 1 {
+			if !s.isAtext(s.next()) {
+				return nil
+			} else {
+				s.accept.WriteByte(s.ch)
+			}
+		}
+	}
+}
+
+/*
+
+Dot-string     = Atom *("."  Atom)
+
+Atom           = 1*atext
+
+atext           =       ALPHA / DIGIT / ; Any character except controls,
+                        "!" / "#" /     ;  SP, and specials.
+                        "$" / "%" /     ;  Used for atoms
+                        "&" / "'" /
+                        "*" / "+" /
+                        "-" / "/" /
+                        "=" / "?" /
+                        "^" / "_" /
+                        "`" / "{" /
+                        "|" / "}" /
+                        "~"
+
+*/
+
+func (s *Parser) isAtext(c byte) bool {
+	if ('0' <= c && c <= '9') ||
+		('A' <= c && c <= 'z') ||
+		c == '!' || c == '#' ||
+		c == '$' || c == '%' ||
+		c == '&' || c == '\'' ||
+		c == '*' || c == '+' ||
+		c == '-' || c == '/' ||
+		c == '=' || c == '?' ||
+		c == '^' || c == '_' ||
+		c == '`' || c == '{' ||
+		c == '|' || c == '}' ||
+		c == '~' {
+		return true
+	}
+	return false
+}
+
+func isLetDig(c byte) bool {
+	if ('0' <= c && c <= '9') ||
+		('A' <= c && c <= 'z') {
+		return true
+	}
+	return false
+}

+ 583 - 0
mail/rfc5321/parse_test.go

@@ -0,0 +1,583 @@
+package rfc5321
+
+import (
+	"strings"
+	"testing"
+)
+
+func TestParseParam(t *testing.T) {
+	s := NewParser([]byte("SIZE=2000"))
+	params, err := s.param()
+	if strings.Compare(params[0], "SIZE") != 0 {
+		t.Error("SIZE ecpected")
+	}
+	if strings.Compare(params[1], "2000") != 0 {
+		t.Error("2000 ecpected")
+	}
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	s = NewParser([]byte("SI--ZE=2000 BODY=8BITMIME"))
+	tup, err := s.parameters()
+	if strings.Compare(tup[0][0], "SI--ZE") != 0 {
+		t.Error("SI--ZE ecpected")
+	}
+	if strings.Compare(tup[0][1], "2000") != 0 {
+		t.Error("2000 ecpected")
+	}
+	if strings.Compare(tup[1][0], "BODY") != 0 {
+		t.Error("BODY expected", err)
+	}
+	if strings.Compare(tup[1][1], "8BITMIME") != 0 {
+		t.Error("8BITMIME expected", err)
+	}
+
+	s = NewParser([]byte("SI--ZE-=2000 BODY=8BITMIME")) // illegal - after ZE
+	tup, err = s.parameters()
+	if err == nil {
+		t.Error("error was expected ")
+	}
+}
+
+func TestParseRcptTo(t *testing.T) {
+	var s Parser
+	err := s.RcptTo([]byte("<Postmaster>"))
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	err = s.RcptTo([]byte("<[email protected]>"))
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+	if s.LocalPart != "Postmaster" {
+		t.Error("s.LocalPart should be: Postmaster")
+	}
+
+	err = s.RcptTo([]byte("<[email protected]> NOTIFY=SUCCESS,FAILURE"))
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	//
+}
+
+func TestParseForwardPath(t *testing.T) {
+	s := NewParser([]byte("<@a,@b:user@[227.0.0.1>")) // missing ]
+	err := s.forwardPath()
+	if err == nil {
+		t.Error("error expected ", err)
+	}
+
+	s = NewParser([]byte("<@a,@b:user@[527.0.0.1>")) // ip out of range
+	err = s.forwardPath()
+	if err == nil {
+		t.Error("error expected ", err)
+	}
+
+	// with a 'size' estmp param
+	s = NewParser([]byte("<[email protected]> NOTIFY=FAILURE ORCPT=rfc822;[email protected]"))
+	err = s.forwardPath()
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	// tolerate a space at the front
+	s = NewParser([]byte(" <[email protected]>"))
+	err = s.forwardPath()
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	// tolerate a space at the front, invalid
+	s = NewParser([]byte(" <"))
+	err = s.forwardPath()
+	if err == nil {
+		t.Error("error expected ", err)
+	}
+
+	// tolerate a space at the front, invalid
+	s = NewParser([]byte(" "))
+	err = s.forwardPath()
+	if err == nil {
+		t.Error("error expected ", err)
+	}
+
+	// empty
+	s = NewParser([]byte(""))
+	err = s.forwardPath()
+	if err == nil {
+		t.Error("error expected ", err)
+	}
+
+}
+
+func TestParseReversePath(t *testing.T) {
+
+	s := NewParser([]byte("<@a,@b:user@d>"))
+	err := s.reversePath()
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	s = NewParser([]byte("<@a,@b:user@d> param=some-value")) // includes a mail parameter
+	err = s.reversePath()
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	s = NewParser([]byte("<@a,@b:user@[227.0.0.1]>"))
+	err = s.reversePath()
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	s = NewParser([]byte("<>"))
+	err = s.reversePath()
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	s = NewParser([]byte(""))
+	err = s.reversePath()
+	if err == nil {
+		t.Error("error  expected ", err)
+	}
+
+	s = NewParser([]byte("[email protected]"))
+	err = s.reversePath()
+	if err == nil {
+		t.Error("error expected ", err)
+	}
+
+	s = NewParser([]byte("<@ghg;$7@65"))
+	err = s.reversePath()
+	if err == nil {
+		t.Error("error  expected ", err)
+	}
+
+	// tolerate a space at the front
+	s = NewParser([]byte(" <>"))
+	err = s.reversePath()
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	// tolerate a space at the front, invalid
+	s = NewParser([]byte(" <"))
+	err = s.reversePath()
+	if err == nil {
+		t.Error("error expected ", err)
+	}
+
+	// tolerate a space at the front, invalid
+	s = NewParser([]byte(" "))
+	err = s.reversePath()
+	if err == nil {
+		t.Error("error expected ", err)
+	}
+
+	// empty
+	s = NewParser([]byte(" "))
+	err = s.reversePath()
+	if err == nil {
+		t.Error("error expected ", err)
+	}
+}
+
+func TestParseIpv6Address(t *testing.T) {
+	s := NewParser([]byte("2001:0000:3238:DFE1:0063:0000:0000:FEFB"))
+	err := s.ipv6AddressLiteral()
+	if s.accept.String() != "2001:0000:3238:DFE1:0063:0000:0000:FEFB" {
+		t.Error("expected 2001:0000:3238:DFE1:0063:0000:0000:FEFB, got:", s.accept.String())
+	}
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+	s = NewParser([]byte("2001:3238:DFE1:6323:FEFB:2536:1.2.3.2"))
+	err = s.ipv6AddressLiteral()
+	if s.accept.String() != "2001:3238:DFE1:6323:FEFB:2536:1.2.3.2" {
+		t.Error("expected 2001:3238:DFE1:6323:FEFB:2536:1.2.3.2, got:", s.accept.String())
+	}
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	s = NewParser([]byte("2001:0000:3238:DFE1:63:0000:0000:FEFB"))
+	err = s.ipv6AddressLiteral()
+	if s.accept.String() != "2001:0000:3238:DFE1:63:0000:0000:FEFB" {
+		t.Error("expected 2001:0000:3238:DFE1:63:0000:0000:FEFB, got:", s.accept.String())
+	}
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	s = NewParser([]byte("2001:0000:3238:DFE1:63::FEFB"))
+	err = s.ipv6AddressLiteral()
+	if s.accept.String() != "2001:0000:3238:DFE1:63::FEFB" {
+		t.Error("expected 2001:0000:3238:DFE1:63::FEFB, got:", s.accept.String())
+	}
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	s = NewParser([]byte("2001:0:3238:DFE1:63::FEFB"))
+	err = s.ipv6AddressLiteral()
+	if s.accept.String() != "2001:0:3238:DFE1:63::FEFB" {
+		t.Error("expected 2001:0:3238:DFE1:63::FEFB, got:", s.accept.String())
+	}
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	s = NewParser([]byte("g001:0:3238:DFE1:63::FEFB"))
+	err = s.ipv6AddressLiteral()
+	if s.accept.String() != "" {
+		t.Error("expected \"\", got:", s.accept.String())
+	}
+	if err == nil {
+		t.Error("error expected ", err)
+	}
+
+	s = NewParser([]byte("g001:0:3238:DFE1::63::FEFB"))
+	err = s.ipv6AddressLiteral()
+	if s.accept.String() != "" {
+		t.Error("expected \"\", got:", s.accept.String())
+	}
+	if err == nil {
+		t.Error("error expected ", err)
+	}
+}
+
+func TestParseIpv4Address(t *testing.T) {
+	s := NewParser([]byte("0.0.0.255"))
+	err := s.ipv4AddressLiteral()
+	if s.accept.String() != "0.0.0.255" {
+		t.Error("expected 0.0.0.255, got:", s.accept.String())
+	}
+	if err != nil {
+		t.Error("error not expected ", err)
+	}
+
+	s = NewParser([]byte("0.0.0.256"))
+	err = s.ipv4AddressLiteral()
+	if s.accept.String() != "0.0.0.256" {
+		t.Error("expected 0.0.0.256, got:", s.accept.String())
+	}
+	if err == nil {
+		t.Error("error expected ", err)
+	}
+
+}
+
+func TestParseMailBoxBad(t *testing.T) {
+
+	// must be quoted
+	s := NewParser([]byte("Abc\\@[email protected]"))
+	err := s.mailbox()
+
+	if err == nil {
+		t.Error("error expected")
+	}
+
+	// must be quoted
+	s = NewParser([]byte("Fred\\ [email protected]"))
+	err = s.mailbox()
+
+	if err == nil {
+		t.Error("error expected")
+	}
+}
+
+func TestParseMailbox(t *testing.T) {
+
+	s := NewParser([]byte("jsmith@[IPv6:2001:db8::1]"))
+	err := s.mailbox()
+	if s.Domain != "2001:db8::1" {
+		t.Error("expected domain:2001:db8::1, got:", s.Domain)
+	}
+	if err != nil {
+		t.Error("error not expected ")
+	}
+
+	s = NewParser([]byte("\"qu\\{oted\"@test.com"))
+	err = s.mailbox()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+
+	s = NewParser([]byte("\"qu\\{oted\"@[127.0.0.1]"))
+	err = s.mailbox()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+
+	s = NewParser([]byte("jsmith@[IPv6:2001:db8::1]"))
+	err = s.mailbox()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+
+	s = NewParser([]byte("Joe.\\[email protected]"))
+	err = s.mailbox()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+	s = NewParser([]byte("\"Abc@def\"@example.com"))
+	err = s.mailbox()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+	s = NewParser([]byte("\"Fred Bloggs\"@example.com"))
+	err = s.mailbox()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+	s = NewParser([]byte("customer/[email protected]"))
+	err = s.mailbox()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+	s = NewParser([]byte("[email protected]"))
+	err = s.mailbox()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+	s = NewParser([]byte("!def!xyz%[email protected]"))
+	err = s.mailbox()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+	s = NewParser([]byte("[email protected]"))
+	err = s.mailbox()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+
+}
+
+func TestParseLocalPart(t *testing.T) {
+	s := NewParser([]byte("\"qu\\{oted\""))
+	err := s.localPart()
+	if s.LocalPart != "qu\\{oted" {
+		t.Error("expected qu\\{oted, got:", s.LocalPart)
+	}
+	if err != nil {
+		t.Error("error not expected ")
+	}
+	s = NewParser([]byte("dot.string"))
+	err = s.localPart()
+	if s.LocalPart != "dot.string" {
+		t.Error("expected dot.string, got:", s.LocalPart)
+	}
+	if err != nil {
+		t.Error("error not expected ")
+	}
+	s = NewParser([]byte("dot.st!ring"))
+	err = s.localPart()
+	if s.LocalPart != "dot.st!ring" {
+		t.Error("expected dot.st!ring, got:", s.LocalPart)
+	}
+	if err != nil {
+		t.Error("error not expected ")
+	}
+	s = NewParser([]byte("dot..st!ring")) // fail
+	err = s.localPart()
+
+	if err == nil {
+		t.Error("error expected ")
+	}
+}
+
+func TestParseQuotedString(t *testing.T) {
+	s := NewParser([]byte("\"qu\\ oted\""))
+	err := s.quotedString()
+	if s.accept.String() != "qu\\ oted" {
+		t.Error("Expected qu\\ oted, got:", s.accept.String())
+	}
+	if err != nil {
+		t.Error("error not expected ")
+	}
+
+	s = NewParser([]byte("\"@\""))
+	err = s.quotedString()
+	if s.accept.String() != "@" {
+		t.Error("Expected @, got:", s.accept.String())
+	}
+	if err != nil {
+		t.Error("error not expected ")
+	}
+}
+
+func TestParseDotString(t *testing.T) {
+
+	s := NewParser([]byte("Joe..\\\\Blow"))
+	err := s.dotString()
+	if err == nil {
+		t.Error("error expected ")
+	}
+
+	s = NewParser([]byte("Joe.\\\\Blow"))
+	err = s.dotString()
+	if s.accept.String() != "Joe.\\\\Blow" {
+		t.Error("Expected Joe.\\\\Blow, got:", s.accept.String())
+	}
+	if err != nil {
+		t.Error("error not expected ")
+	}
+}
+
+func TestParsePath(t *testing.T) {
+	s := NewParser([]byte("<foo>")) // requires @
+	err := s.path()
+	if err == nil {
+		t.Error("error expected ")
+	}
+	s = NewParser([]byte("<@example.com,@test.com:[email protected]>"))
+	err = s.path()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+	s = NewParser([]byte("<@example.com>")) // no mailbox
+	err = s.path()
+	if err == nil {
+		t.Error("error expected ")
+	}
+
+	s = NewParser([]byte("<[email protected]	1")) // no closing >
+	err = s.path()
+	if err == nil {
+		t.Error("error expected ")
+	}
+}
+
+func TestParseADL(t *testing.T) {
+	s := NewParser([]byte("@example.com,@test.com"))
+	err := s.adl()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+}
+
+func TestParseAtDomain(t *testing.T) {
+	s := NewParser([]byte("@example.com"))
+	err := s.atDomain()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+}
+
+func TestParseDomain(t *testing.T) {
+
+	s := NewParser([]byte("a"))
+	err := s.domain()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+
+	s = NewParser([]byte("a.com.gov"))
+	err = s.domain()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+
+	s = NewParser([]byte("wrong-.com"))
+	err = s.domain()
+	if err == nil {
+		t.Error("error was expected ")
+	}
+	s = NewParser([]byte("wrong."))
+	err = s.domain()
+	if err == nil {
+		t.Error("error was expected ")
+	}
+}
+
+func TestParseSubDomain(t *testing.T) {
+
+	s := NewParser([]byte("a"))
+	err := s.subdomain()
+	if err != nil {
+		t.Error("error not expected ")
+	}
+	s = NewParser([]byte("-a"))
+	err = s.subdomain()
+	if err == nil {
+		t.Error("error was expected ")
+	}
+	s = NewParser([]byte("a--"))
+	err = s.subdomain()
+	if err == nil {
+		t.Error("error was expected ")
+	}
+	s = NewParser([]byte("a--"))
+	err = s.subdomain()
+	if err == nil {
+		t.Error("error was expected ")
+	}
+	s = NewParser([]byte("a--b"))
+	err = s.subdomain()
+	if err != nil {
+		t.Error("error was not expected ")
+	}
+
+	// although a---b looks like an illegal subdomain, it is rfc5321 grammar spec
+	s = NewParser([]byte("a---b"))
+	err = s.subdomain()
+	if err != nil {
+		t.Error("error was not expected ")
+	}
+
+	s = NewParser([]byte("abc"))
+	err = s.subdomain()
+	if err != nil {
+		t.Error("error was not expected ")
+	}
+
+	s = NewParser([]byte("a-b-c"))
+	err = s.subdomain()
+	if err != nil {
+		t.Error("error was not expected ")
+	}
+
+}
+func TestParse(t *testing.T) {
+
+	s := NewParser([]byte("<"))
+	err := s.reversePath()
+	if err == nil {
+		t.Error("< expected parse error")
+	}
+
+	// the @ needs to be quoted
+	s = NewParser([]byte("<@[email protected]>"))
+	err = s.reversePath()
+	if err == nil {
+		t.Error("expected parse error", err)
+	}
+
+	s = NewParser([]byte("<\"@m.conm\"@test.com>"))
+	err = s.reversePath()
+	if err != nil {
+		t.Error("not expected parse error", err)
+	}
+
+	s = NewParser([]byte("<[email protected]>"))
+	err = s.reversePath()
+	if err != nil {
+		t.Error("not expected parse error")
+	}
+
+	s = NewParser([]byte("<@test:[email protected]>"))
+	err = s.reversePath()
+	if err != nil {
+		t.Error("not expected parse error")
+	}
+	s = NewParser([]byte("<@test,@test2:[email protected]>"))
+	err = s.reversePath()
+	if err != nil {
+		t.Error("not expected parse error")
+	}
+
+}

+ 3 - 3
mocks/conn_mock.go

@@ -31,11 +31,11 @@ type End struct {
 	Writer *io.PipeWriter
 }
 
-func (c End) Close() error {
-	if err := c.Writer.Close(); err != nil {
+func (e End) Close() error {
+	if err := e.Writer.Close(); err != nil {
 		return err
 	}
-	if err := c.Reader.Close(); err != nil {
+	if err := e.Reader.Close(); err != nil {
 		return err
 	}
 	return nil

+ 2 - 2
models.go

@@ -7,8 +7,8 @@ import (
 )
 
 var (
-	LineLimitExceeded   = errors.New("Maximum line length exceeded")
-	MessageSizeExceeded = errors.New("Maximum message size exceeded")
+	LineLimitExceeded   = errors.New("maximum line length exceeded")
+	MessageSizeExceeded = errors.New("maximum message size exceeded")
 )
 
 // we need to adjust the limit, so we embed io.LimitedReader

+ 9 - 4
pool.go

@@ -17,11 +17,12 @@ var (
 // a struct can be pooled if it has the following interface
 type Poolable interface {
 	// ability to set read/write timeout
-	setTimeout(t time.Duration)
+	setTimeout(t time.Duration) error
 	// set a new connection and client id
 	init(c net.Conn, clientID uint64, ep *mail.Pool)
 	// get a unique id
 	getID() uint64
+	kill()
 }
 
 // Pool holds Clients.
@@ -82,9 +83,11 @@ func (p *Pool) ShutdownState() {
 	p.isShuttingDownFlg.Store(true) // no more borrowing
 	p.ShutdownChan <- 1             // release any waiting p.sem
 
-	// set a low timeout
+	// set a low timeout (let the clients finish whatever the're doing)
 	p.activeClients.mapAll(func(p Poolable) {
-		p.setTimeout(time.Duration(int64(aVeryLowTimeout)))
+		if err := p.setTimeout(time.Duration(int64(aVeryLowTimeout))); err != nil {
+			p.kill()
+		}
 	})
 
 }
@@ -111,7 +114,9 @@ func (p *Pool) IsShuttingDown() bool {
 // set a timeout for all lent clients
 func (p *Pool) SetTimeout(duration time.Duration) {
 	p.activeClients.mapAll(func(p Poolable) {
-		p.setTimeout(duration)
+		if err := p.setTimeout(duration); err != nil {
+			p.kill()
+		}
 	})
 }
 

+ 109 - 93
response/enhanced.go

@@ -22,6 +22,9 @@ const (
 	ClassPermanentFailure = 5
 )
 
+// space char
+const SP = " "
+
 // class is a type for ClassSuccess, ClassTransientFailure and ClassPermanentFailure constants
 type class int
 
@@ -118,39 +121,39 @@ var (
 type Responses struct {
 
 	// The 500's
-	FailLineTooLong              string
-	FailNestedMailCmd            string
-	FailNoSenderDataCmd          string
-	FailNoRecipientsDataCmd      string
-	FailUnrecognizedCmd          string
-	FailMaxUnrecognizedCmd       string
-	FailReadLimitExceededDataCmd string
-	FailMessageSizeExceeded      string
-	FailReadErrorDataCmd         string
-	FailPathTooLong              string
-	FailInvalidAddress           string
-	FailLocalPartTooLong         string
-	FailDomainTooLong            string
-	FailBackendNotRunning        string
-	FailBackendTransaction       string
-	FailBackendTimeout           string
-	FailRcptCmd                  string
+	FailLineTooLong              *Response
+	FailNestedMailCmd            *Response
+	FailNoSenderDataCmd          *Response
+	FailNoRecipientsDataCmd      *Response
+	FailUnrecognizedCmd          *Response
+	FailMaxUnrecognizedCmd       *Response
+	FailReadLimitExceededDataCmd *Response
+	FailMessageSizeExceeded      *Response
+	FailReadErrorDataCmd         *Response
+	FailPathTooLong              *Response
+	FailInvalidAddress           *Response
+	FailLocalPartTooLong         *Response
+	FailDomainTooLong            *Response
+	FailBackendNotRunning        *Response
+	FailBackendTransaction       *Response
+	FailBackendTimeout           *Response
+	FailRcptCmd                  *Response
 
 	// The 400's
-	ErrorTooManyRecipients string
-	ErrorRelayDenied       string
-	ErrorShutdown          string
+	ErrorTooManyRecipients *Response
+	ErrorRelayDenied       *Response
+	ErrorShutdown          *Response
 
 	// The 200's
-	SuccessMailCmd       string
-	SuccessRcptCmd       string
-	SuccessResetCmd      string
-	SuccessVerifyCmd     string
-	SuccessNoopCmd       string
-	SuccessQuitCmd       string
-	SuccessDataCmd       string
-	SuccessStartTLSCmd   string
-	SuccessMessageQueued string
+	SuccessMailCmd       *Response
+	SuccessRcptCmd       *Response
+	SuccessResetCmd      *Response
+	SuccessVerifyCmd     *Response
+	SuccessNoopCmd       *Response
+	SuccessQuitCmd       *Response
+	SuccessDataCmd       *Response
+	SuccessStartTLSCmd   *Response
+	SuccessMessageQueued *Response
 }
 
 // Called automatically during package load to build up the Responses struct
@@ -158,191 +161,195 @@ func init() {
 
 	Canned = Responses{}
 
-	Canned.FailLineTooLong = (&Response{
+	Canned.FailLineTooLong = &Response{
 		EnhancedCode: InvalidCommand,
 		BasicCode:    554,
 		Class:        ClassPermanentFailure,
 		Comment:      "Line too long.",
-	}).String()
+	}
 
-	Canned.FailNestedMailCmd = (&Response{
+	Canned.FailNestedMailCmd = &Response{
 		EnhancedCode: InvalidCommand,
 		BasicCode:    503,
 		Class:        ClassPermanentFailure,
 		Comment:      "Error: nested MAIL command",
-	}).String()
+	}
 
-	Canned.SuccessMailCmd = (&Response{
+	Canned.SuccessMailCmd = &Response{
 		EnhancedCode: OtherAddressStatus,
 		Class:        ClassSuccess,
-	}).String()
+	}
 
-	Canned.SuccessRcptCmd = (&Response{
+	Canned.SuccessRcptCmd = &Response{
 		EnhancedCode: DestinationMailboxAddressValid,
 		Class:        ClassSuccess,
-	}).String()
+	}
 
 	Canned.SuccessResetCmd = Canned.SuccessMailCmd
-	Canned.SuccessNoopCmd = (&Response{
+
+	Canned.SuccessNoopCmd = &Response{
 		EnhancedCode: OtherStatus,
 		Class:        ClassSuccess,
-	}).String()
+	}
 
-	Canned.SuccessVerifyCmd = (&Response{
+	Canned.SuccessVerifyCmd = &Response{
 		EnhancedCode: OtherOrUndefinedProtocolStatus,
 		BasicCode:    252,
 		Class:        ClassSuccess,
 		Comment:      "Cannot verify user",
-	}).String()
+	}
 
-	Canned.ErrorTooManyRecipients = (&Response{
+	Canned.ErrorTooManyRecipients = &Response{
 		EnhancedCode: TooManyRecipients,
 		BasicCode:    452,
 		Class:        ClassTransientFailure,
 		Comment:      "Too many recipients",
-	}).String()
+	}
 
-	Canned.ErrorRelayDenied = (&Response{
+	Canned.ErrorRelayDenied = &Response{
 		EnhancedCode: BadDestinationMailboxAddress,
 		BasicCode:    454,
 		Class:        ClassTransientFailure,
-		Comment:      "Error: Relay access denied: ",
-	}).String()
+		Comment:      "Error: Relay access denied:",
+	}
 
-	Canned.SuccessQuitCmd = (&Response{
+	Canned.SuccessQuitCmd = &Response{
 		EnhancedCode: OtherStatus,
 		BasicCode:    221,
 		Class:        ClassSuccess,
 		Comment:      "Bye",
-	}).String()
+	}
 
-	Canned.FailNoSenderDataCmd = (&Response{
+	Canned.FailNoSenderDataCmd = &Response{
 		EnhancedCode: InvalidCommand,
 		BasicCode:    503,
 		Class:        ClassPermanentFailure,
 		Comment:      "Error: No sender",
-	}).String()
+	}
 
-	Canned.FailNoRecipientsDataCmd = (&Response{
+	Canned.FailNoRecipientsDataCmd = &Response{
 		EnhancedCode: InvalidCommand,
 		BasicCode:    503,
 		Class:        ClassPermanentFailure,
 		Comment:      "Error: No recipients",
-	}).String()
+	}
 
-	Canned.SuccessDataCmd = "354 Enter message, ending with '.' on a line by itself"
+	Canned.SuccessDataCmd = &Response{
+		BasicCode: 354,
+		Comment:   "354 Enter message, ending with '.' on a line by itself",
+	}
 
-	Canned.SuccessStartTLSCmd = (&Response{
+	Canned.SuccessStartTLSCmd = &Response{
 		EnhancedCode: OtherStatus,
 		BasicCode:    220,
 		Class:        ClassSuccess,
 		Comment:      "Ready to start TLS",
-	}).String()
+	}
 
-	Canned.FailUnrecognizedCmd = (&Response{
+	Canned.FailUnrecognizedCmd = &Response{
 		EnhancedCode: InvalidCommand,
 		BasicCode:    554,
 		Class:        ClassPermanentFailure,
 		Comment:      "Unrecognized command",
-	}).String()
+	}
 
-	Canned.FailMaxUnrecognizedCmd = (&Response{
+	Canned.FailMaxUnrecognizedCmd = &Response{
 		EnhancedCode: InvalidCommand,
 		BasicCode:    554,
 		Class:        ClassPermanentFailure,
 		Comment:      "Too many unrecognized commands",
-	}).String()
+	}
 
-	Canned.ErrorShutdown = (&Response{
+	Canned.ErrorShutdown = &Response{
 		EnhancedCode: OtherOrUndefinedMailSystemStatus,
 		BasicCode:    421,
 		Class:        ClassTransientFailure,
 		Comment:      "Server is shutting down. Please try again later. Sayonara!",
-	}).String()
+	}
 
-	Canned.FailReadLimitExceededDataCmd = (&Response{
+	Canned.FailReadLimitExceededDataCmd = &Response{
 		EnhancedCode: SyntaxError,
 		BasicCode:    550,
 		Class:        ClassPermanentFailure,
-		Comment:      "Error: ",
-	}).String()
+		Comment:      "Error:",
+	}
 
-	Canned.FailMessageSizeExceeded = (&Response{
+	Canned.FailMessageSizeExceeded = &Response{
 		EnhancedCode: SyntaxError,
 		BasicCode:    550,
 		Class:        ClassPermanentFailure,
-		Comment:      "Error: ",
-	}).String()
+		Comment:      "Error:",
+	}
 
-	Canned.FailReadErrorDataCmd = (&Response{
+	Canned.FailReadErrorDataCmd = &Response{
 		EnhancedCode: OtherOrUndefinedMailSystemStatus,
 		BasicCode:    451,
 		Class:        ClassTransientFailure,
-		Comment:      "Error: ",
-	}).String()
+		Comment:      "Error:",
+	}
 
-	Canned.FailPathTooLong = (&Response{
+	Canned.FailPathTooLong = &Response{
 		EnhancedCode: InvalidCommandArguments,
 		BasicCode:    550,
 		Class:        ClassPermanentFailure,
 		Comment:      "Path too long",
-	}).String()
+	}
 
-	Canned.FailInvalidAddress = (&Response{
+	Canned.FailInvalidAddress = &Response{
 		EnhancedCode: InvalidCommandArguments,
 		BasicCode:    501,
 		Class:        ClassPermanentFailure,
 		Comment:      "Invalid address",
-	}).String()
+	}
 
-	Canned.FailLocalPartTooLong = (&Response{
+	Canned.FailLocalPartTooLong = &Response{
 		EnhancedCode: InvalidCommandArguments,
 		BasicCode:    550,
 		Class:        ClassPermanentFailure,
 		Comment:      "Local part too long, cannot exceed 64 characters",
-	}).String()
+	}
 
-	Canned.FailDomainTooLong = (&Response{
+	Canned.FailDomainTooLong = &Response{
 		EnhancedCode: InvalidCommandArguments,
 		BasicCode:    550,
 		Class:        ClassPermanentFailure,
 		Comment:      "Domain cannot exceed 255 characters",
-	}).String()
+	}
 
-	Canned.FailBackendNotRunning = (&Response{
+	Canned.FailBackendNotRunning = &Response{
 		EnhancedCode: OtherOrUndefinedProtocolStatus,
 		BasicCode:    554,
 		Class:        ClassPermanentFailure,
-		Comment:      "Transaction failed - backend not running ",
-	}).String()
+		Comment:      "Transaction failed - backend not running",
+	}
 
-	Canned.FailBackendTransaction = (&Response{
+	Canned.FailBackendTransaction = &Response{
 		EnhancedCode: OtherOrUndefinedProtocolStatus,
 		BasicCode:    554,
 		Class:        ClassPermanentFailure,
-		Comment:      "Error: ",
-	}).String()
+		Comment:      "Error:",
+	}
 
-	Canned.SuccessMessageQueued = (&Response{
+	Canned.SuccessMessageQueued = &Response{
 		EnhancedCode: OtherStatus,
 		BasicCode:    250,
 		Class:        ClassSuccess,
-		Comment:      "OK : queued as ",
-	}).String()
+		Comment:      "OK: queued as",
+	}
 
-	Canned.FailBackendTimeout = (&Response{
+	Canned.FailBackendTimeout = &Response{
 		EnhancedCode: OtherOrUndefinedProtocolStatus,
 		BasicCode:    554,
 		Class:        ClassPermanentFailure,
 		Comment:      "Error: transaction timeout",
-	}).String()
+	}
 
-	Canned.FailRcptCmd = (&Response{
+	Canned.FailRcptCmd = &Response{
 		EnhancedCode: BadDestinationMailboxAddress,
 		BasicCode:    550,
 		Class:        ClassPermanentFailure,
 		Comment:      "User unknown in local recipient table",
-	}).String()
+	}
 
 }
 
@@ -409,6 +416,7 @@ type Response struct {
 	Class        class
 	// Comment is optional
 	Comment string
+	cached  string
 }
 
 // it looks like this ".5.4"
@@ -428,6 +436,14 @@ func (e EnhancedStatusCode) String() string {
 // String returns a custom Response as a string
 func (r *Response) String() string {
 
+	if r.cached != "" {
+		return r.cached
+	}
+	if r.EnhancedCode == "" {
+		r.cached = r.Comment
+		return r.Comment
+	}
+
 	basicCode := r.BasicCode
 	comment := r.Comment
 	if len(comment) == 0 && r.BasicCode == 0 {
@@ -447,8 +463,8 @@ func (r *Response) String() string {
 	if r.BasicCode == 0 {
 		basicCode = getBasicStatusCode(e)
 	}
-
-	return fmt.Sprintf("%d %s %s", basicCode, e.String(), comment)
+	r.cached = fmt.Sprintf("%d %s %s", basicCode, e.String(), comment)
+	return r.cached
 }
 
 // getBasicStatusCode gets the basic status code from codeMap, or fallback code if not mapped

+ 0 - 12
response/enhanced_test.go

@@ -4,18 +4,6 @@ import (
 	"testing"
 )
 
-func TestClass(t *testing.T) {
-	if ClassPermanentFailure != 5 {
-		t.Error("ClassPermanentFailure is not 5")
-	}
-	if ClassTransientFailure != 4 {
-		t.Error("ClassTransientFailure is not 4")
-	}
-	if ClassSuccess != 2 {
-		t.Error("ClassSuccess is not 2")
-	}
-}
-
 func TestGetBasicStatusCode(t *testing.T) {
 	// Known status code
 	a := getBasicStatusCode(EnhancedStatusCode{2, OtherOrUndefinedProtocolStatus})

+ 210 - 184
server.go

@@ -1,22 +1,25 @@
 package guerrilla
 
 import (
+	"bytes"
 	"crypto/rand"
 	"crypto/tls"
+	"crypto/x509"
 	"fmt"
 	"io"
+	"io/ioutil"
 	"net"
+	"path/filepath"
 	"strings"
 	"sync"
 	"sync/atomic"
 	"time"
 
-	"crypto/x509"
 	"github.com/flashmob/go-guerrilla/backends"
 	"github.com/flashmob/go-guerrilla/log"
 	"github.com/flashmob/go-guerrilla/mail"
+	"github.com/flashmob/go-guerrilla/mail/rfc5321"
 	"github.com/flashmob/go-guerrilla/response"
-	"io/ioutil"
 )
 
 const (
@@ -24,14 +27,6 @@ const (
 	CommandLineMaxLength = 1024
 	// Number of allowed unrecognized commands before we terminate the connection
 	MaxUnrecognizedCommands = 5
-	// The maximum total length of a reverse-path or forward-path is 256
-	RFC2821LimitPath = 256
-	// The maximum total length of a user name or other local-part is 64
-	RFC2832LimitLocalPart = 64
-	//The maximum total length of a domain name or number is 255
-	RFC2821LimitDomain = 255
-	// The minimum total number of recipients that must be buffered is 100
-	RFC2821LimitRecipients = 100
 )
 
 const (
@@ -54,7 +49,7 @@ type server struct {
 	clientPool      *Pool
 	wg              sync.WaitGroup // for waiting to shutdown
 	listener        net.Listener
-	closedListener  chan (bool)
+	closedListener  chan bool
 	hosts           allowedHosts // stores map[string]bool for faster lookup
 	state           int
 	// If log changed after a config reload, newLogStore stores the value here until it's safe to change it
@@ -66,14 +61,36 @@ type server struct {
 
 type allowedHosts struct {
 	table      map[string]bool // host lookup table
+	wildcards  []string        // host wildcard list (* is used as a wildcard)
 	sync.Mutex                 // guard access to the map
 }
 
+type command []byte
+
+var (
+	cmdHELO     command = []byte("HELO")
+	cmdEHLO     command = []byte("EHLO")
+	cmdHELP     command = []byte("HELP")
+	cmdXCLIENT  command = []byte("XCLIENT")
+	cmdMAIL     command = []byte("MAIL FROM:")
+	cmdRCPT     command = []byte("RCPT TO:")
+	cmdRSET     command = []byte("RSET")
+	cmdVRFY     command = []byte("VRFY")
+	cmdNOOP     command = []byte("NOOP")
+	cmdQUIT     command = []byte("QUIT")
+	cmdDATA     command = []byte("DATA")
+	cmdSTARTTLS command = []byte("STARTTLS")
+)
+
+func (c command) match(in []byte) bool {
+	return bytes.Index(in, []byte(c)) == 0
+}
+
 // Creates and returns a new ready-to-run Server from a configuration
 func newServer(sc *ServerConfig, b backends.Backend, l log.Logger) (*server, error) {
 	server := &server{
 		clientPool:      NewPool(sc.MaxClients),
-		closedListener:  make(chan (bool), 1),
+		closedListener:  make(chan bool, 1),
 		listenInterface: sc.ListenInterface,
 		state:           ServerStateNew,
 		envelopePool:    mail.NewPool(sc.MaxClients),
@@ -173,167 +190,176 @@ func (s *server) backend() backends.Backend {
 }
 
 // Set the timeout for the server and all clients
-func (server *server) setTimeout(seconds int) {
+func (s *server) setTimeout(seconds int) {
 	duration := time.Duration(int64(seconds))
-	server.clientPool.SetTimeout(duration)
-	server.timeout.Store(duration)
+	s.clientPool.SetTimeout(duration)
+	s.timeout.Store(duration)
 }
 
 // goroutine safe config store
-func (server *server) setConfig(sc *ServerConfig) {
-	server.configStore.Store(*sc)
+func (s *server) setConfig(sc *ServerConfig) {
+	s.configStore.Store(*sc)
 }
 
 // goroutine safe
-func (server *server) isEnabled() bool {
-	sc := server.configStore.Load().(ServerConfig)
+func (s *server) isEnabled() bool {
+	sc := s.configStore.Load().(ServerConfig)
 	return sc.IsEnabled
 }
 
 // Set the allowed hosts for the server
-func (server *server) setAllowedHosts(allowedHosts []string) {
-	server.hosts.Lock()
-	defer server.hosts.Unlock()
-	server.hosts.table = make(map[string]bool, len(allowedHosts))
+func (s *server) setAllowedHosts(allowedHosts []string) {
+	s.hosts.Lock()
+	defer s.hosts.Unlock()
+	s.hosts.table = make(map[string]bool, len(allowedHosts))
+	s.hosts.wildcards = nil
 	for _, h := range allowedHosts {
-		server.hosts.table[strings.ToLower(h)] = true
+		if strings.Index(h, "*") != -1 {
+			s.hosts.wildcards = append(s.hosts.wildcards, strings.ToLower(h))
+		} else {
+			s.hosts.table[strings.ToLower(h)] = true
+		}
 	}
 }
 
 // Begin accepting SMTP clients. Will block unless there is an error or server.Shutdown() is called
-func (server *server) Start(startWG *sync.WaitGroup) error {
+func (s *server) Start(startWG *sync.WaitGroup) error {
 	var clientID uint64
 	clientID = 0
 
-	listener, err := net.Listen("tcp", server.listenInterface)
-	server.listener = listener
+	listener, err := net.Listen("tcp", s.listenInterface)
+	s.listener = listener
 	if err != nil {
 		startWG.Done() // don't wait for me
-		server.state = ServerStateStartError
-		return fmt.Errorf("[%s] Cannot listen on port: %s ", server.listenInterface, err.Error())
+		s.state = ServerStateStartError
+		return fmt.Errorf("[%s] Cannot listen on port: %s ", s.listenInterface, err.Error())
 	}
 
-	server.log().Infof("Listening on TCP %s", server.listenInterface)
-	server.state = ServerStateRunning
+	s.log().Infof("Listening on TCP %s", s.listenInterface)
+	s.state = ServerStateRunning
 	startWG.Done() // start successful, don't wait for me
 
 	for {
-		server.log().Debugf("[%s] Waiting for a new client. Next Client ID: %d", server.listenInterface, clientID+1)
+		s.log().Debugf("[%s] Waiting for a new client. Next Client ID: %d", s.listenInterface, clientID+1)
 		conn, err := listener.Accept()
 		clientID++
 		if err != nil {
 			if e, ok := err.(net.Error); ok && !e.Temporary() {
-				server.log().Infof("Server [%s] has stopped accepting new clients", server.listenInterface)
+				s.log().Infof("Server [%s] has stopped accepting new clients", s.listenInterface)
 				// the listener has been closed, wait for clients to exit
-				server.log().Infof("shutting down pool [%s]", server.listenInterface)
-				server.clientPool.ShutdownState()
-				server.clientPool.ShutdownWait()
-				server.state = ServerStateStopped
-				server.closedListener <- true
+				s.log().Infof("shutting down pool [%s]", s.listenInterface)
+				s.clientPool.ShutdownState()
+				s.clientPool.ShutdownWait()
+				s.state = ServerStateStopped
+				s.closedListener <- true
 				return nil
 			}
-			server.mainlog().WithError(err).Info("Temporary error accepting client")
+			s.mainlog().WithError(err).Info("Temporary error accepting client")
 			continue
 		}
-		go func(p Poolable, borrow_err error) {
+		go func(p Poolable, borrowErr error) {
 			c := p.(*client)
-			if borrow_err == nil {
-				server.handleClient(c)
-				server.envelopePool.Return(c.Envelope)
-				server.clientPool.Return(c)
+			if borrowErr == nil {
+				s.handleClient(c)
+				s.envelopePool.Return(c.Envelope)
+				s.clientPool.Return(c)
 			} else {
-				server.log().WithError(borrow_err).Info("couldn't borrow a new client")
+				s.log().WithError(borrowErr).Info("couldn't borrow a new client")
 				// we could not get a client, so close the connection.
-				conn.Close()
+				_ = conn.Close()
 
 			}
 			// intentionally placed Borrow in args so that it's called in the
 			// same main goroutine.
-		}(server.clientPool.Borrow(conn, clientID, server.log(), server.envelopePool))
+		}(s.clientPool.Borrow(conn, clientID, s.log(), s.envelopePool))
 
 	}
 }
 
-func (server *server) Shutdown() {
-	if server.listener != nil {
+func (s *server) Shutdown() {
+	if s.listener != nil {
 		// This will cause Start function to return, by causing an error on listener.Accept
-		server.listener.Close()
+		_ = s.listener.Close()
 		// wait for the listener to listener.Accept
-		<-server.closedListener
+		<-s.closedListener
 		// At this point Start will exit and close down the pool
 	} else {
-		server.clientPool.ShutdownState()
+		s.clientPool.ShutdownState()
 		// listener already closed, wait for clients to exit
-		server.clientPool.ShutdownWait()
-		server.state = ServerStateStopped
+		s.clientPool.ShutdownWait()
+		s.state = ServerStateStopped
 	}
 }
 
-func (server *server) GetActiveClientsCount() int {
-	return server.clientPool.GetActiveClientsCount()
+func (s *server) GetActiveClientsCount() int {
+	return s.clientPool.GetActiveClientsCount()
 }
 
 // Verifies that the host is a valid recipient.
 // host checking turned off if there is a single entry and it's a dot.
-func (server *server) allowsHost(host string) bool {
-	server.hosts.Lock()
-	defer server.hosts.Unlock()
-	if len(server.hosts.table) == 1 {
-		if _, ok := server.hosts.table["."]; ok {
+func (s *server) allowsHost(host string) bool {
+	s.hosts.Lock()
+	defer s.hosts.Unlock()
+	// if hosts contains a single dot, further processing is skipped
+	if len(s.hosts.table) == 1 {
+		if _, ok := s.hosts.table["."]; ok {
 			return true
 		}
 	}
-	if _, ok := server.hosts.table[strings.ToLower(host)]; ok {
+	if _, ok := s.hosts.table[strings.ToLower(host)]; ok {
 		return true
 	}
+	// check the wildcards
+	for _, w := range s.hosts.wildcards {
+		if matched, err := filepath.Match(w, strings.ToLower(host)); matched && err == nil {
+			return true
+		}
+	}
 	return false
 }
 
-// Reads from the client until a terminating sequence is encountered,
+const commandSuffix = "\r\n"
+
+// Reads from the client until a \n terminator is encountered,
 // or until a timeout occurs.
-func (server *server) readCommand(client *client, maxSize int64) (string, error) {
-	var input, reply string
+func (s *server) readCommand(client *client) ([]byte, error) {
+	//var input string
 	var err error
+	var bs []byte
 	// In command state, stop reading at line breaks
-	suffix := "\r\n"
-	for {
-		client.setTimeout(server.timeout.Load().(time.Duration))
-		reply, err = client.bufin.ReadString('\n')
-		input = input + reply
-		if err != nil {
-			break
-		}
-		if strings.HasSuffix(input, suffix) {
-			// discard the suffix and stop reading
-			input = input[0 : len(input)-len(suffix)]
-			break
-		}
+	bs, err = client.bufin.ReadSlice('\n')
+	if err != nil {
+		return bs, err
+	} else if bytes.HasSuffix(bs, []byte(commandSuffix)) {
+		return bs[:len(bs)-2], err
 	}
-	return input, err
+	return bs[:len(bs)-1], err
 }
 
 // flushResponse a response to the client. Flushes the client.bufout buffer to the connection
-func (server *server) flushResponse(client *client) error {
-	client.setTimeout(server.timeout.Load().(time.Duration))
+func (s *server) flushResponse(client *client) error {
+	if err := client.setTimeout(s.timeout.Load().(time.Duration)); err != nil {
+		return err
+	}
 	return client.bufout.Flush()
 }
 
-func (server *server) isShuttingDown() bool {
-	return server.clientPool.IsShuttingDown()
+func (s *server) isShuttingDown() bool {
+	return s.clientPool.IsShuttingDown()
 }
 
 // Handles an entire client SMTP exchange
-func (server *server) handleClient(client *client) {
+func (s *server) handleClient(client *client) {
 	defer func() {
-		server.log().WithFields(map[string]interface{}{
+		s.log().WithFields(map[string]interface{}{
 			"event": "disconnect",
 			"id":    client.ID,
 		}).Info("Disconnect client")
 		client.closeConn()
 	}()
 
-	sc := server.configStore.Load().(ServerConfig)
-	server.log().WithFields(map[string]interface{}{
+	sc := s.configStore.Load().(ServerConfig)
+	s.log().WithFields(map[string]interface{}{
 		"event": "connect",
 		"id":    client.ID,
 	}).Infof("Handle client [%s]", client.RemoteIP)
@@ -341,7 +367,7 @@ func (server *server) handleClient(client *client) {
 	// Initial greeting
 	greeting := fmt.Sprintf("220 %s SMTP Guerrilla(%s) #%d (%d) %s",
 		sc.Hostname, Version, client.ID,
-		server.clientPool.GetActiveClientsCount(), time.Now().Format(time.RFC3339))
+		s.clientPool.GetActiveClientsCount(), time.Now().Format(time.RFC3339))
 
 	helo := fmt.Sprintf("250 %s Hello", sc.Hostname)
 	// ehlo is a multi-line reply and need additional \r\n at the end
@@ -357,13 +383,13 @@ func (server *server) handleClient(client *client) {
 	help := "250 HELP"
 
 	if sc.TLS.AlwaysOn {
-		tlsConfig, ok := server.tlsConfigStore.Load().(*tls.Config)
+		tlsConfig, ok := s.tlsConfigStore.Load().(*tls.Config)
 		if !ok {
-			server.mainlog().Error("Failed to load *tls.Config")
+			s.mainlog().Error("Failed to load *tls.Config")
 		} else if err := client.upgradeToTLS(tlsConfig); err == nil {
 			advertiseTLS = ""
 		} else {
-			server.log().WithError(err).Warnf("[%s] Failed TLS handshake", client.RemoteIP)
+			s.log().WithError(err).Warnf("[%s] Failed TLS handshake", client.RemoteIP)
 			// server requires TLS, but can't handshake
 			client.kill()
 		}
@@ -372,7 +398,7 @@ func (server *server) handleClient(client *client) {
 		// STARTTLS turned off, don't advertise it
 		advertiseTLS = ""
 	}
-
+	r := response.Canned
 	for client.isAlive() {
 		switch client.state {
 		case ClientGreeting:
@@ -380,42 +406,41 @@ func (server *server) handleClient(client *client) {
 			client.state = ClientCmd
 		case ClientCmd:
 			client.bufin.setLimit(CommandLineMaxLength)
-			input, err := server.readCommand(client, sc.MaxSize)
-			server.log().Debugf("Client sent: %s", input)
+			input, err := s.readCommand(client)
+			s.log().Debugf("Client sent: %s", input)
 			if err == io.EOF {
-				server.log().WithError(err).Warnf("Client closed the connection: %s", client.RemoteIP)
+				s.log().WithError(err).Warnf("Client closed the connection: %s", client.RemoteIP)
 				return
 			} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
-				server.log().WithError(err).Warnf("Timeout: %s", client.RemoteIP)
+				s.log().WithError(err).Warnf("Timeout: %s", client.RemoteIP)
 				return
 			} else if err == LineLimitExceeded {
-				client.sendResponse(response.Canned.FailLineTooLong)
+				client.sendResponse(r.FailLineTooLong)
 				client.kill()
 				break
 			} else if err != nil {
-				server.log().WithError(err).Warnf("Read error: %s", client.RemoteIP)
+				s.log().WithError(err).Warnf("Read error: %s", client.RemoteIP)
 				client.kill()
 				break
 			}
-			if server.isShuttingDown() {
+			if s.isShuttingDown() {
 				client.state = ClientShutdown
 				continue
 			}
 
-			input = strings.Trim(input, " \r\n")
 			cmdLen := len(input)
 			if cmdLen > CommandVerbMaxLength {
 				cmdLen = CommandVerbMaxLength
 			}
-			cmd := strings.ToUpper(input[:cmdLen])
+			cmd := bytes.ToUpper(input[:cmdLen])
 			switch {
-			case strings.Index(cmd, "HELO") == 0:
-				client.Helo = strings.Trim(input[4:], " ")
+			case cmdHELO.match(cmd):
+				client.Helo = string(bytes.Trim(input[4:], " "))
 				client.resetTransaction()
 				client.sendResponse(helo)
 
-			case strings.Index(cmd, "EHLO") == 0:
-				client.Helo = strings.Trim(input[4:], " ")
+			case cmdEHLO.match(cmd):
+				client.Helo = string(bytes.Trim(input[4:], " "))
 				client.resetTransaction()
 				client.sendResponse(ehlo,
 					messageSize,
@@ -424,113 +449,109 @@ func (server *server) handleClient(client *client) {
 					advertiseEnhancedStatusCodes,
 					help)
 
-			case strings.Index(cmd, "HELP") == 0:
+			case cmdHELP.match(cmd):
 				quote := response.GetQuote()
-				client.sendResponse("214-OK\r\n" + quote)
+				client.sendResponse("214-OK\r\n", quote)
 
-			case sc.XClientOn && strings.Index(cmd, "XCLIENT ") == 0:
-				if toks := strings.Split(input[8:], " "); len(toks) > 0 {
+			case sc.XClientOn && cmdXCLIENT.match(cmd):
+				if toks := bytes.Split(input[8:], []byte{' '}); len(toks) > 0 {
 					for i := range toks {
-						if vals := strings.Split(toks[i], "="); len(vals) == 2 {
-							if vals[1] == "[UNAVAILABLE]" {
+						if vals := bytes.Split(toks[i], []byte{'='}); len(vals) == 2 {
+							if bytes.Compare(vals[1], []byte("[UNAVAILABLE]")) == 0 {
 								// skip
 								continue
 							}
-							if vals[0] == "ADDR" {
-								client.RemoteIP = vals[1]
+							if bytes.Compare(vals[0], []byte("ADDR")) == 0 {
+								client.RemoteIP = string(vals[1])
 							}
-							if vals[0] == "HELO" {
-								client.Helo = vals[1]
+							if bytes.Compare(vals[0], []byte("HELO")) == 0 {
+								client.Helo = string(vals[1])
 							}
 						}
 					}
 				}
-				client.sendResponse(response.Canned.SuccessMailCmd)
-			case strings.Index(cmd, "MAIL FROM:") == 0:
+				client.sendResponse(r.SuccessMailCmd)
+			case cmdMAIL.match(cmd):
 				if client.isInTransaction() {
-					client.sendResponse(response.Canned.FailNestedMailCmd)
+					client.sendResponse(r.FailNestedMailCmd)
 					break
 				}
-				addr := input[10:]
-				if !(strings.Index(addr, "<>") == 0) &&
-					!(strings.Index(addr, " <>") == 0) {
-					// Not Bounce, extract mail.
-					if from, err := extractEmail(addr); err != nil {
-						client.sendResponse(err)
-						break
-					} else {
-						client.MailFrom = from
-						server.log().WithFields(map[string]interface{}{
-							"event":   "mailfrom",
-							"helo":    client.Helo,
-							"domain":  from.Host,
-							"address": getRemoteAddr(client.conn),
-							"id":      client.ID,
-						}).Info("Mail from")
-					}
-
-				} else {
+				client.MailFrom, err = client.parsePath([]byte(input[10:]), client.parser.MailFrom)
+				if err != nil {
+					s.log().WithError(err).Error("MAIL parse error", "["+string(input[10:])+"]")
+					client.sendResponse(err)
+					break
+				} else if client.parser.NullPath {
 					// bounce has empty from address
 					client.MailFrom = mail.Address{}
+				} else {
+					s.log().WithFields(map[string]interface{}{
+						"event":   "mailfrom",
+						"helo":    client.Helo,
+						"domain":  client.MailFrom.Host,
+						"address": getRemoteAddr(client.conn),
+						"id":      client.ID,
+					}).Info("Mail from")
 				}
-				client.sendResponse(response.Canned.SuccessMailCmd)
+				client.sendResponse(r.SuccessMailCmd)
 
-			case strings.Index(cmd, "RCPT TO:") == 0:
-				if len(client.RcptTo) > RFC2821LimitRecipients {
-					client.sendResponse(response.Canned.ErrorTooManyRecipients)
+			case cmdRCPT.match(cmd):
+				if len(client.RcptTo) > rfc5321.LimitRecipients {
+					client.sendResponse(r.ErrorTooManyRecipients)
 					break
 				}
-				to, err := extractEmail(input[8:])
+				to, err := client.parsePath([]byte(input[8:]), client.parser.RcptTo)
 				if err != nil {
+					s.log().WithError(err).Error("RCPT parse error", "["+string(input[8:])+"]")
 					client.sendResponse(err.Error())
+					break
+				}
+				if !s.allowsHost(to.Host) {
+					client.sendResponse(r.ErrorRelayDenied, " ", to.Host)
 				} else {
-					if !server.allowsHost(to.Host) {
-						client.sendResponse(response.Canned.ErrorRelayDenied, to.Host)
+					client.PushRcpt(to)
+					rcptError := s.backend().ValidateRcpt(client.Envelope)
+					if rcptError != nil {
+						client.PopRcpt()
+						client.sendResponse(r.FailRcptCmd, " ", rcptError.Error())
 					} else {
-						client.PushRcpt(to)
-						rcptError := server.backend().ValidateRcpt(client.Envelope)
-						if rcptError != nil {
-							client.PopRcpt()
-							client.sendResponse(response.Canned.FailRcptCmd + " " + rcptError.Error())
-						} else {
-							client.sendResponse(response.Canned.SuccessRcptCmd)
-						}
+						client.sendResponse(r.SuccessRcptCmd)
 					}
 				}
 
-			case strings.Index(cmd, "RSET") == 0:
+			case cmdRSET.match(cmd):
 				client.resetTransaction()
-				client.sendResponse(response.Canned.SuccessResetCmd)
+				client.sendResponse(r.SuccessResetCmd)
 
-			case strings.Index(cmd, "VRFY") == 0:
-				client.sendResponse(response.Canned.SuccessVerifyCmd)
+			case cmdVRFY.match(cmd):
+				client.sendResponse(r.SuccessVerifyCmd)
 
-			case strings.Index(cmd, "NOOP") == 0:
-				client.sendResponse(response.Canned.SuccessNoopCmd)
+			case cmdNOOP.match(cmd):
+				client.sendResponse(r.SuccessNoopCmd)
 
-			case strings.Index(cmd, "QUIT") == 0:
-				client.sendResponse(response.Canned.SuccessQuitCmd)
+			case cmdQUIT.match(cmd):
+				client.sendResponse(r.SuccessQuitCmd)
 				client.kill()
 
-			case strings.Index(cmd, "DATA") == 0:
+			case cmdDATA.match(cmd):
 				if len(client.RcptTo) == 0 {
-					client.sendResponse(response.Canned.FailNoRecipientsDataCmd)
+					client.sendResponse(r.FailNoRecipientsDataCmd)
 					break
 				}
-				client.sendResponse(response.Canned.SuccessDataCmd)
+				client.sendResponse(r.SuccessDataCmd)
 				client.state = ClientData
 
-			case sc.TLS.StartTLSOn && strings.Index(cmd, "STARTTLS") == 0:
+			case sc.TLS.StartTLSOn && cmdSTARTTLS.match(cmd):
 
-				client.sendResponse(response.Canned.SuccessStartTLSCmd)
+				client.sendResponse(r.SuccessStartTLSCmd)
 				client.state = ClientStartTLS
 			default:
 				client.errors++
 				if client.errors >= MaxUnrecognizedCommands {
-					client.sendResponse(response.Canned.FailMaxUnrecognizedCmd)
+					client.sendResponse(r.FailMaxUnrecognizedCmd)
 					client.kill()
 				} else {
-					client.sendResponse(response.Canned.FailUnrecognizedCmd)
+					client.sendResponse(r.FailUnrecognizedCmd)
 				}
 			}
 
@@ -542,50 +563,50 @@ func (server *server) handleClient(client *client) {
 
 			n, err := client.Data.ReadFrom(client.smtpReader.DotReader())
 			if n > sc.MaxSize {
-				err = fmt.Errorf("Maximum DATA size exceeded (%d)", sc.MaxSize)
+				err = fmt.Errorf("maximum DATA size exceeded (%d)", sc.MaxSize)
 			}
 			if err != nil {
 				if err == LineLimitExceeded {
-					client.sendResponse(response.Canned.FailReadLimitExceededDataCmd, LineLimitExceeded.Error())
+					client.sendResponse(r.FailReadLimitExceededDataCmd, " ", LineLimitExceeded.Error())
 					client.kill()
 				} else if err == MessageSizeExceeded {
-					client.sendResponse(response.Canned.FailMessageSizeExceeded, MessageSizeExceeded.Error())
+					client.sendResponse(r.FailMessageSizeExceeded, " ", MessageSizeExceeded.Error())
 					client.kill()
 				} else {
-					client.sendResponse(response.Canned.FailReadErrorDataCmd, err.Error())
+					client.sendResponse(r.FailReadErrorDataCmd, " ", err.Error())
 					client.kill()
 				}
-				server.log().WithError(err).Warn("Error reading data")
+				s.log().WithError(err).Warn("Error reading data")
 				client.resetTransaction()
 				break
 			}
 
-			res := server.backend().Process(client.Envelope)
+			res := s.backend().Process(client.Envelope)
 			if res.Code() < 300 {
 				client.messagesSent++
-				server.log().WithFields(map[string]interface{}{
+				s.log().WithFields(map[string]interface{}{
 					"helo":          client.Helo,
 					"remoteAddress": getRemoteAddr(client.conn),
 					"success":       true,
 				}).Info("Received message")
 			}
-			client.sendResponse(res.String())
+			client.sendResponse(res)
 			client.state = ClientCmd
-			if server.isShuttingDown() {
+			if s.isShuttingDown() {
 				client.state = ClientShutdown
 			}
 			client.resetTransaction()
 
 		case ClientStartTLS:
 			if !client.TLS && sc.TLS.StartTLSOn {
-				tlsConfig, ok := server.tlsConfigStore.Load().(*tls.Config)
+				tlsConfig, ok := s.tlsConfigStore.Load().(*tls.Config)
 				if !ok {
-					server.mainlog().Error("Failed to load *tls.Config")
+					s.mainlog().Error("Failed to load *tls.Config")
 				} else if err := client.upgradeToTLS(tlsConfig); err == nil {
 					advertiseTLS = ""
 					client.resetTransaction()
 				} else {
-					server.log().WithError(err).Warnf("[%s] Failed TLS handshake", client.RemoteIP)
+					s.log().WithError(err).Warnf("[%s] Failed TLS handshake", client.RemoteIP)
 					// Don't disconnect, let the client decide if it wants to continue
 				}
 			}
@@ -593,17 +614,22 @@ func (server *server) handleClient(client *client) {
 			client.state = ClientCmd
 		case ClientShutdown:
 			// shutdown state
-			client.sendResponse(response.Canned.ErrorShutdown)
+			client.sendResponse(r.ErrorShutdown)
 			client.kill()
 		}
 
+		if client.bufErr != nil {
+			s.log().WithError(client.bufErr).Debug("client could not buffer a response")
+			return
+		}
+		// flush the response buffer
 		if client.bufout.Buffered() > 0 {
-			if server.log().IsDebug() {
-				server.log().Debugf("Writing response to client: \n%s", client.response.String())
+			if s.log().IsDebug() {
+				s.log().Debugf("Writing response to client: \n%s", client.response.String())
 			}
-			err := server.flushResponse(client)
+			err := s.flushResponse(client)
 			if err != nil {
-				server.log().WithError(err).Debug("Error writing response")
+				s.log().WithError(err).Debug("error writing response")
 				return
 			}
 		}

+ 212 - 45
server_test.go

@@ -1,6 +1,7 @@
 package guerrilla
 
 import (
+	"os"
 	"testing"
 
 	"bufio"
@@ -16,7 +17,6 @@ import (
 	"github.com/flashmob/go-guerrilla/mocks"
 	"io/ioutil"
 	"net"
-	"os"
 )
 
 // getMockServerConfig gets a mock ServerConfig struct used for creating a new server
@@ -139,8 +139,62 @@ jDGZARZqGyrPeXi+RNe1cMvZCxAFy7gqEtWFLWWrp0gYNPvxkHhhQBrUcF+8T/Nf
 ug8tR8eSL1vGleONtFRBUVG7NbtjhBf9FhvPZcSRR10od/vWHku9E01i4xg=
 -----END CERTIFICATE-----`
 
+func truncateIfExists(filename string) error {
+	if _, err := os.Stat(filename); !os.IsNotExist(err) {
+		return os.Truncate(filename, 0)
+	}
+	return nil
+}
+func deleteIfExists(filename string) error {
+	if _, err := os.Stat(filename); !os.IsNotExist(err) {
+		return os.Remove(filename)
+	}
+	return nil
+}
+
+func cleanTestArtifacts(t *testing.T) {
+	if err := deleteIfExists("rootca.test.pem"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("client.test.key"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("client.test.pem"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/mail.guerrillamail.com.key.pem"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/mail.guerrillamail.com.cert.pem"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/different-go-guerrilla.pid"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/go-guerrilla.pid"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/go-guerrilla2.pid"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/pidfile.pid"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/pidfile2.pid"); err != nil {
+		t.Error(err)
+	}
+
+	if err := truncateIfExists("./tests/testlog"); err != nil {
+		t.Error(err)
+	}
+	if err := truncateIfExists("./tests/testlog2"); err != nil {
+		t.Error(err)
+	}
+}
+
 func TestTLSConfig(t *testing.T) {
 
+	defer cleanTestArtifacts(t)
 	if err := ioutil.WriteFile("rootca.test.pem", []byte(rootCAPK), 0644); err != nil {
 		t.Fatal("couldn't create rootca.test.pem file.", err)
 		return
@@ -167,7 +221,9 @@ func TestTLSConfig(t *testing.T) {
 			Protocols:      []string{"tls1.0", "tls1.2"},
 		},
 	})
-	s.configureSSL()
+	if err := s.configureSSL(); err != nil {
+		t.Error(err)
+	}
 
 	c := s.tlsConfigStore.Load().(*tls.Config)
 
@@ -203,15 +259,12 @@ func TestTLSConfig(t *testing.T) {
 		t.Error("PreferServerCipherSuites should be false")
 	}
 
-	os.Remove("rootca.test.pem")
-	os.Remove("client.test.key")
-	os.Remove("client.test.pem")
-
 }
 
 func TestHandleClient(t *testing.T) {
 	var mainlog log.Logger
 	var logOpenError error
+	defer cleanTestArtifacts(t)
 	sc := getMockServerConfig()
 	mainlog, logOpenError = log.GetLogger(sc.LogFile, "debug")
 	if logOpenError != nil {
@@ -231,10 +284,14 @@ func TestHandleClient(t *testing.T) {
 	line, _ := r.ReadLine()
 	//	fmt.Println(line)
 	w := textproto.NewWriter(bufio.NewWriter(conn.Client))
-	w.PrintfLine("HELO test.test.com")
+	if err := w.PrintfLine("HELO test.test.com"); err != nil {
+		t.Error(err)
+	}
 	line, _ = r.ReadLine()
 	//fmt.Println(line)
-	w.PrintfLine("QUIT")
+	if err := w.PrintfLine("QUIT"); err != nil {
+		t.Error(err)
+	}
 	line, _ = r.ReadLine()
 	//fmt.Println("line is:", line)
 	expected := "221 2.0.0 Bye"
@@ -247,6 +304,7 @@ func TestHandleClient(t *testing.T) {
 func TestXClient(t *testing.T) {
 	var mainlog log.Logger
 	var logOpenError error
+	defer cleanTestArtifacts(t)
 	sc := getMockServerConfig()
 	sc.XClientOn = true
 	mainlog, logOpenError = log.GetLogger(sc.LogFile, "debug")
@@ -267,10 +325,14 @@ func TestXClient(t *testing.T) {
 	line, _ := r.ReadLine()
 	//	fmt.Println(line)
 	w := textproto.NewWriter(bufio.NewWriter(conn.Client))
-	w.PrintfLine("HELO test.test.com")
+	if err := w.PrintfLine("HELO test.test.com"); err != nil {
+		t.Error(err)
+	}
 	line, _ = r.ReadLine()
 	//fmt.Println(line)
-	w.PrintfLine("XCLIENT ADDR=212.96.64.216 NAME=[UNAVAILABLE]")
+	if err := w.PrintfLine("XCLIENT ADDR=212.96.64.216 NAME=[UNAVAILABLE]"); err != nil {
+		t.Error(err)
+	}
 	line, _ = r.ReadLine()
 
 	if client.RemoteIP != "212.96.64.216" {
@@ -282,7 +344,9 @@ func TestXClient(t *testing.T) {
 	}
 
 	// try malformed input
-	w.PrintfLine("XCLIENT c")
+	if err := w.PrintfLine("XCLIENT c"); err != nil {
+		t.Error(err)
+	}
 	line, _ = r.ReadLine()
 
 	expected = "250 2.1.0 OK"
@@ -290,7 +354,9 @@ func TestXClient(t *testing.T) {
 		t.Error("expected", expected, "but got:", line)
 	}
 
-	w.PrintfLine("QUIT")
+	if err := w.PrintfLine("QUIT"); err != nil {
+		t.Error(err)
+	}
 	line, _ = r.ReadLine()
 	wg.Wait() // wait for handleClient to exit
 }
@@ -299,7 +365,7 @@ func TestXClient(t *testing.T) {
 // The transaction should wait until finished, and then test to see if we can do
 // a second transaction
 func TestGatewayTimeout(t *testing.T) {
-
+	defer cleanTestArtifacts(t)
 	bcfg := backends.BackendConfig{
 		"save_workers_size":   1,
 		"save_process":        "HeadersParser|Debugger",
@@ -330,24 +396,51 @@ func TestGatewayTimeout(t *testing.T) {
 		}
 		in := bufio.NewReader(conn)
 		str, err := in.ReadString('\n')
-		fmt.Fprint(conn, "HELO host\r\n")
+		if err != nil {
+			t.Error(err)
+		}
+		if _, err := fmt.Fprint(conn, "HELO host\r\n"); err != nil {
+			t.Error(err)
+		}
 		str, err = in.ReadString('\n')
 		// perform 2 transactions
 		// both should panic.
 		for i := 0; i < 2; i++ {
-			fmt.Fprint(conn, "MAIL FROM:<[email protected]>r\r\n")
-			str, err = in.ReadString('\n')
-			fmt.Fprint(conn, "RCPT TO:[email protected]\r\n")
-			str, err = in.ReadString('\n')
-			fmt.Fprint(conn, "DATA\r\n")
-			str, err = in.ReadString('\n')
-			fmt.Fprint(conn, "Subject: Test subject\r\n")
-			fmt.Fprint(conn, "\r\n")
-			fmt.Fprint(conn, "A an email body\r\n")
-			fmt.Fprint(conn, ".\r\n")
+			if _, err := fmt.Fprint(conn, "MAIL FROM:<[email protected]>r\r\n"); err != nil {
+				t.Error(err)
+			}
+			if str, err = in.ReadString('\n'); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "RCPT TO:<[email protected]>\r\n"); err != nil {
+				t.Error(err)
+			}
+			if str, err = in.ReadString('\n'); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "DATA\r\n"); err != nil {
+				t.Error(err)
+			}
+			if str, err = in.ReadString('\n'); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "Subject: Test subject\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "A an email body\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, ".\r\n"); err != nil {
+				t.Error(err)
+			}
 			str, err = in.ReadString('\n')
 			expect := "transaction timeout"
-			if strings.Index(str, expect) == -1 {
+			if err != nil {
+				t.Error(err)
+			} else if strings.Index(str, expect) == -1 {
 				t.Error("Expected the reply to have'", expect, "'but got", str)
 			}
 		}
@@ -359,6 +452,7 @@ func TestGatewayTimeout(t *testing.T) {
 
 // The processor will panic and gateway should recover from it
 func TestGatewayPanic(t *testing.T) {
+	defer cleanTestArtifacts(t)
 	bcfg := backends.BackendConfig{
 		"save_workers_size":   1,
 		"save_process":        "HeadersParser|Debugger",
@@ -388,35 +482,108 @@ func TestGatewayPanic(t *testing.T) {
 			return
 		}
 		in := bufio.NewReader(conn)
-		str, err := in.ReadString('\n')
-		fmt.Fprint(conn, "HELO host\r\n")
-		str, err = in.ReadString('\n')
+		if _, err := in.ReadString('\n'); err != nil {
+			t.Error(err)
+		}
+		if _, err := fmt.Fprint(conn, "HELO host\r\n"); err != nil {
+			t.Error(err)
+		}
+		if _, err = in.ReadString('\n'); err != nil {
+			t.Error(err)
+		}
 		// perform 2 transactions
 		// both should timeout. The reason why 2 is because we want to make
 		// sure that the client waits until processing finishes, and the
 		// timeout event is captured.
 		for i := 0; i < 2; i++ {
-			fmt.Fprint(conn, "MAIL FROM:<[email protected]>r\r\n")
-			str, err = in.ReadString('\n')
-			fmt.Fprint(conn, "RCPT TO:[email protected]\r\n")
-			str, err = in.ReadString('\n')
-			fmt.Fprint(conn, "DATA\r\n")
-			str, err = in.ReadString('\n')
-			fmt.Fprint(conn, "Subject: Test subject\r\n")
-			fmt.Fprint(conn, "\r\n")
-			fmt.Fprint(conn, "A an email body\r\n")
-			fmt.Fprint(conn, ".\r\n")
-			str, err = in.ReadString('\n')
-			expect := "storage failed"
-			if strings.Index(str, expect) == -1 {
-				t.Error("Expected the reply to have'", expect, "'but got", str)
+			if _, err := fmt.Fprint(conn, "MAIL FROM:<[email protected]>r\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err = in.ReadString('\n'); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "RCPT TO:<[email protected]>\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err = in.ReadString('\n'); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "DATA\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err = in.ReadString('\n'); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "Subject: Test subject\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "A an email body\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, ".\r\n"); err != nil {
+				t.Error(err)
+			}
+			if str, err := in.ReadString('\n'); err != nil {
+				t.Error(err)
+			} else {
+				expect := "storage failed"
+				if strings.Index(str, expect) == -1 {
+					t.Error("Expected the reply to have'", expect, "'but got", str)
+				}
 			}
 		}
-		_ = str
 		d.Shutdown()
 	}
 
 }
 
-// TODO
-// - test github issue #44 and #42
+func TestAllowsHosts(t *testing.T) {
+	defer cleanTestArtifacts(t)
+	s := server{}
+	allowedHosts := []string{
+		"spam4.me",
+		"grr.la",
+		"newhost.com",
+		"example.*",
+		"*.test",
+		"wild*.card",
+		"multiple*wild*cards.*",
+	}
+	s.setAllowedHosts(allowedHosts)
+
+	testTable := map[string]bool{
+		"spam4.me":                true,
+		"dont.match":              false,
+		"example.com":             true,
+		"another.example.com":     false,
+		"anything.test":           true,
+		"wild.card":               true,
+		"wild.card.com":           false,
+		"multipleXwildXcards.com": true,
+	}
+
+	for host, allows := range testTable {
+		if res := s.allowsHost(host); res != allows {
+			t.Error(host, ": expected", allows, "but got", res)
+		}
+	}
+
+	// only wildcard - should match anything
+	s.setAllowedHosts([]string{"*"})
+	if !s.allowsHost("match.me") {
+		t.Error("match.me: expected true but got false")
+	}
+
+	// turns off
+	s.setAllowedHosts([]string{"."})
+	if !s.allowsHost("match.me") {
+		t.Error("match.me: expected true but got false")
+	}
+
+	// no wilcards
+	s.setAllowedHosts([]string{"grr.la", "example.com"})
+
+}

+ 3 - 1
tests/client.go

@@ -32,7 +32,9 @@ func Connect(serverConfig guerrilla.ServerConfig, deadline time.Duration) (net.C
 	bufin = bufio.NewReader(conn)
 
 	// should be ample time to complete the test
-	conn.SetDeadline(time.Now().Add(time.Duration(time.Second * deadline)))
+	if err = conn.SetDeadline(time.Now().Add(time.Duration(time.Second * deadline))); err != nil {
+		return conn, bufin, err
+	}
 	// read greeting, ignore it
 	_, err = bufin.ReadString('\n')
 	return conn, bufin, err

+ 124 - 88
tests/guerrilla_test.go

@@ -16,6 +16,7 @@ package test
 
 import (
 	"encoding/json"
+	"github.com/flashmob/go-guerrilla/mail/rfc5321"
 	"testing"
 
 	"time"
@@ -61,10 +62,13 @@ func init() {
 	if err := json.Unmarshal([]byte(configJson), config); err != nil {
 		initErr = errors.New("Could not Unmarshal config," + err.Error())
 	} else {
-		setupCerts(config)
 		logger, _ = log.GetLogger(config.LogFile, "debug")
+		initErr = setupCerts(config)
+		if err != nil {
+			return
+		}
 		backend, _ := getBackend(config.BackendConfig, logger)
-		app, _ = guerrilla.New(&config.AppConfig, backend, logger)
+		app, initErr = guerrilla.New(&config.AppConfig, backend, logger)
 	}
 
 }
@@ -90,8 +94,8 @@ var configJson = `
             "max_clients": 2,
             "log_file" : "",
 			"tls" : {
-				"private_key_file":"/vagrant/projects/htdocs/guerrilla/config/ssl/guerrillamail.com.key",
-            	"public_key_file":"/vagrant/projects/htdocs/guerrilla/config/ssl/guerrillamail.com.crt",
+				"private_key_file":"/this/will/be/ignored/guerrillamail.com.key.pem",
+            	"public_key_file":"/this/will/be/ignored//guerrillamail.com.crt",
 				"start_tls_on":true,
             	"tls_always_on":false
 			}
@@ -106,8 +110,8 @@ var configJson = `
             "max_clients":1,
             "log_file" : "",
 			"tls" : {
-				"private_key_file":"/vagrant/projects/htdocs/guerrilla/config/ssl/guerrillamail.com.key",
-            	"public_key_file":"/vagrant/projects/htdocs/guerrilla/config/ssl/guerrillamail.com.crt",
+				"private_key_file":"/this/will/be/ignored/guerrillamail.com.key.pem",
+            	"public_key_file":"/this/will/be/ignored/guerrillamail.com.crt",
 				"start_tls_on":false,
             	"tls_always_on":true
 			}
@@ -125,12 +129,56 @@ func getBackend(backendConfig map[string]interface{}, l log.Logger) (backends.Ba
 	return b, err
 }
 
-func setupCerts(c *TestConfig) {
+func setupCerts(c *TestConfig) error {
 	for i := range c.Servers {
-		testcert.GenerateCert(c.Servers[i].Hostname, "", 365*24*time.Hour, false, 2048, "P256", "./")
+		err := testcert.GenerateCert(c.Servers[i].Hostname, "", 365*24*time.Hour, false, 2048, "P256", "./")
+		if err != nil {
+			return err
+		}
 		c.Servers[i].TLS.PrivateKeyFile = c.Servers[i].Hostname + ".key.pem"
 		c.Servers[i].TLS.PublicKeyFile = c.Servers[i].Hostname + ".cert.pem"
 	}
+	return nil
+}
+func truncateIfExists(filename string) error {
+	if _, err := os.Stat(filename); !os.IsNotExist(err) {
+		return os.Truncate(filename, 0)
+	}
+	return nil
+}
+func deleteIfExists(filename string) error {
+	if _, err := os.Stat(filename); !os.IsNotExist(err) {
+		return os.Remove(filename)
+	}
+	return nil
+}
+func cleanTestArtifacts(t *testing.T) {
+
+	if err := truncateIfExists("./testlog"); err != nil {
+		t.Error("could not clean tests/testlog:", err)
+	}
+
+	letters := []byte{'A', 'B', 'C', 'D', 'E'}
+	for _, l := range letters {
+		if err := deleteIfExists("configJson" + string(l) + ".json"); err != nil {
+			t.Error("could not delete configJson"+string(l)+".json:", err)
+		}
+	}
+
+	if err := deleteIfExists("./go-guerrilla.pid"); err != nil {
+		t.Error("could not delete ./guerrilla", err)
+	}
+	if err := deleteIfExists("./go-guerrilla2.pid"); err != nil {
+		t.Error("could not delete ./go-guerrilla2.pid", err)
+	}
+
+	if err := deleteIfExists("./mail.guerrillamail.com.cert.pem"); err != nil {
+		t.Error("could not delete ./mail.guerrillamail.com.cert.pem", err)
+	}
+	if err := deleteIfExists("./mail.guerrillamail.com.key.pem"); err != nil {
+		t.Error("could not delete ./mail.guerrillamail.com.key.pem", err)
+	}
+
 }
 
 // Testing start and stop of server
@@ -139,6 +187,7 @@ func TestStart(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors != nil {
 		t.Error(startErrors)
 		t.FailNow()
@@ -183,18 +232,16 @@ func TestStart(t *testing.T) {
 		}
 
 	}
-	// don't forget to reset
 
-	os.Truncate("./testlog", 0)
 }
 
 // Simple smoke-test to see if the server can listen & issues a greeting on connect
 func TestGreeting(t *testing.T) {
-	//log.SetOutput(os.Stdout)
 	if initErr != nil {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		// 1. plaintext connection
 		conn, err := net.Dial("tcp", config.Servers[0].ListenInterface)
@@ -202,9 +249,10 @@ func TestGreeting(t *testing.T) {
 			// handle error
 			t.Error("Cannot dial server", config.Servers[0].ListenInterface)
 		}
-		conn.SetReadDeadline(time.Now().Add(time.Duration(time.Millisecond * 500)))
+		if err := conn.SetReadDeadline(time.Now().Add(time.Duration(time.Millisecond * 500))); err != nil {
+			t.Error(err)
+		}
 		greeting, err := bufio.NewReader(conn).ReadString('\n')
-		//fmt.Println(greeting)
 		if err != nil {
 			t.Error(err)
 			t.FailNow()
@@ -214,7 +262,7 @@ func TestGreeting(t *testing.T) {
 				t.Error("Server[1] did not have the expected greeting prefix", expected)
 			}
 		}
-		conn.Close()
+		_ = conn.Close()
 
 		// 2. tls connection
 		//	roots, err := x509.SystemCertPool()
@@ -228,9 +276,10 @@ func TestGreeting(t *testing.T) {
 			t.Error(err, "Cannot dial server (TLS)", config.Servers[1].ListenInterface)
 			t.FailNow()
 		}
-		conn.SetReadDeadline(time.Now().Add(time.Duration(time.Millisecond * 500)))
+		if err := conn.SetReadDeadline(time.Now().Add(time.Duration(time.Millisecond * 500))); err != nil {
+			t.Error(err)
+		}
 		greeting, err = bufio.NewReader(conn).ReadString('\n')
-		//fmt.Println(greeting)
 		if err != nil {
 			t.Error(err)
 			t.FailNow()
@@ -240,7 +289,7 @@ func TestGreeting(t *testing.T) {
 				t.Error("Server[2] (TLS) did not have the expected greeting prefix", expected)
 			}
 		}
-		conn.Close()
+		_ = conn.Close()
 
 	} else {
 		fmt.Println("Nope", startErrors)
@@ -252,13 +301,10 @@ func TestGreeting(t *testing.T) {
 	app.Shutdown()
 	if read, err := ioutil.ReadFile("./testlog"); err == nil {
 		logOutput := string(read)
-		//fmt.Println(logOutput)
 		if i := strings.Index(logOutput, "Handle client [127.0.0.1"); i < 0 {
 			t.Error("Server did not handle any clients")
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 
 }
 
@@ -271,6 +317,7 @@ func TestShutDown(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -299,7 +346,7 @@ func TestShutDown(t *testing.T) {
 			time.Sleep(time.Millisecond * 250) // let server to close
 		}
 
-		conn.Close()
+		_ = conn.Close()
 
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -316,8 +363,6 @@ func TestShutDown(t *testing.T) {
 			t.Error("Server did not handle any clients")
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 
 }
 
@@ -327,6 +372,7 @@ func TestRFC2821LimitRecipients(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -341,12 +387,12 @@ func TestRFC2821LimitRecipients(t *testing.T) {
 
 			for i := 0; i < 101; i++ {
 				//fmt.Println(fmt.Sprintf("RCPT TO:test%[email protected]", i))
-				if _, err := Command(conn, bufin, fmt.Sprintf("RCPT TO:test%[email protected]", i)); err != nil {
+				if _, err := Command(conn, bufin, fmt.Sprintf("RCPT TO:<test%[email protected]>", i)); err != nil {
 					t.Error("RCPT TO", err.Error())
 					break
 				}
 			}
-			response, err := Command(conn, bufin, "RCPT TO:[email protected]")
+			response, err := Command(conn, bufin, "RCPT TO:<[email protected]>")
 			if err != nil {
 				t.Error("rcpt command failed", err.Error())
 			}
@@ -356,7 +402,7 @@ func TestRFC2821LimitRecipients(t *testing.T) {
 			}
 		}
 
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 
 	} else {
@@ -367,8 +413,6 @@ func TestRFC2821LimitRecipients(t *testing.T) {
 		}
 	}
 
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }
 
 // RCPT TO & MAIL FROM with 64 chars in local part, it should fail at 65
@@ -377,7 +421,7 @@ func TestRFC2832LimitLocalPart(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
-
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -390,7 +434,7 @@ func TestRFC2832LimitLocalPart(t *testing.T) {
 				t.Error("Hello command failed", err.Error())
 			}
 			// repeat > 64 characters in local part
-			response, err := Command(conn, bufin, fmt.Sprintf("RCPT TO:%[email protected]", strings.Repeat("a", 65)))
+			response, err := Command(conn, bufin, fmt.Sprintf("RCPT TO:<%[email protected]>", strings.Repeat("a", rfc5321.LimitLocalPart+1)))
 			if err != nil {
 				t.Error("rcpt command failed", err.Error())
 			}
@@ -400,7 +444,7 @@ func TestRFC2832LimitLocalPart(t *testing.T) {
 			}
 			// what about if it's exactly 64?
 			// repeat > 64 characters in local part
-			response, err = Command(conn, bufin, fmt.Sprintf("RCPT TO:%[email protected]", strings.Repeat("a", 64)))
+			response, err = Command(conn, bufin, fmt.Sprintf("RCPT TO:<%[email protected]>", strings.Repeat("a", rfc5321.LimitLocalPart-1)))
 			if err != nil {
 				t.Error("rcpt command failed", err.Error())
 			}
@@ -410,7 +454,7 @@ func TestRFC2832LimitLocalPart(t *testing.T) {
 			}
 		}
 
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 
 	} else {
@@ -421,8 +465,6 @@ func TestRFC2832LimitLocalPart(t *testing.T) {
 		}
 	}
 
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }
 
 //RFC2821LimitPath fail if path > 256 but different error if below
@@ -444,7 +486,7 @@ func TestRFC2821LimitPath(t *testing.T) {
 				t.Error("Hello command failed", err.Error())
 			}
 			// repeat > 256 characters in local part
-			response, err := Command(conn, bufin, fmt.Sprintf("RCPT TO:%[email protected]", strings.Repeat("a", 257-7)))
+			response, err := Command(conn, bufin, fmt.Sprintf("RCPT TO:<%[email protected]>", strings.Repeat("a", 257-7)))
 			if err != nil {
 				t.Error("rcpt command failed", err.Error())
 			}
@@ -454,7 +496,7 @@ func TestRFC2821LimitPath(t *testing.T) {
 			}
 			// what about if it's exactly 256?
 			response, err = Command(conn, bufin,
-				fmt.Sprintf("RCPT TO:%s@%s.la", strings.Repeat("a", 64), strings.Repeat("b", 257-5-64)))
+				fmt.Sprintf("RCPT TO:<%s@%s.la>", strings.Repeat("a", 64), strings.Repeat("b", 186)))
 			if err != nil {
 				t.Error("rcpt command failed", err.Error())
 			}
@@ -463,7 +505,7 @@ func TestRFC2821LimitPath(t *testing.T) {
 				t.Error("Server did not respond with", expected, ", it said:"+response)
 			}
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -472,8 +514,6 @@ func TestRFC2821LimitPath(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }
 
 // RFC2821LimitDomain 501 Domain cannot exceed 255 characters
@@ -482,6 +522,7 @@ func TestRFC2821LimitDomain(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -494,7 +535,7 @@ func TestRFC2821LimitDomain(t *testing.T) {
 				t.Error("Hello command failed", err.Error())
 			}
 			// repeat > 64 characters in local part
-			response, err := Command(conn, bufin, fmt.Sprintf("RCPT TO:a@%s.l", strings.Repeat("a", 255-2)))
+			response, err := Command(conn, bufin, fmt.Sprintf("RCPT TO:<a@%s.l>", strings.Repeat("a", 255-2)))
 			if err != nil {
 				t.Error("command failed", err.Error())
 			}
@@ -504,7 +545,7 @@ func TestRFC2821LimitDomain(t *testing.T) {
 			}
 			// what about if it's exactly 255?
 			response, err = Command(conn, bufin,
-				fmt.Sprintf("RCPT TO:a@%s.la", strings.Repeat("b", 255-4)))
+				fmt.Sprintf("RCPT TO:<a@%s.la>", strings.Repeat("b", 255-6)))
 			if err != nil {
 				t.Error("command failed", err.Error())
 			}
@@ -513,7 +554,7 @@ func TestRFC2821LimitDomain(t *testing.T) {
 				t.Error("Server did not respond with", expected, ", it said:"+response)
 			}
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -522,8 +563,7 @@ func TestRFC2821LimitDomain(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
+
 }
 
 // Test several different inputs to MAIL FROM command
@@ -532,6 +572,7 @@ func TestMailFromCmd(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -544,7 +585,7 @@ func TestMailFromCmd(t *testing.T) {
 				t.Error("Hello command failed", err.Error())
 			}
 			// Basic valid address
-			response, err := Command(conn, bufin, "MAIL FROM:[email protected]")
+			response, err := Command(conn, bufin, "MAIL FROM:<[email protected]>")
 			if err != nil {
 				t.Error("command failed", err.Error())
 			}
@@ -703,15 +744,17 @@ func TestMailFromCmd(t *testing.T) {
 				t.Error("Server did not respond with", expected, ", it said:"+response)
 			}
 
-			// SMTPUTF8 not implemented for now, currently still accepted
-			response, err = Command(conn, bufin, "MAIL FROM:<anö[email protected]>")
-			if err != nil {
-				t.Error("command failed", err.Error())
-			}
-			expected = "250 2.1.0 OK"
-			if strings.Index(response, expected) != 0 {
-				t.Error("Server did not respond with", expected, ", it said:"+response)
-			}
+			/*
+				// todo SMTPUTF8 not implemented for now,
+				response, err = Command(conn, bufin, "MAIL FROM:<anö[email protected]>")
+				if err != nil {
+					t.Error("command failed", err.Error())
+				}
+				expected = "250 2.1.0 OK"
+				if strings.Index(response, expected) != 0 {
+					t.Error("Server did not respond with", expected, ", it said:"+response)
+				}
+			*/
 
 			// Reset
 			response, err = Command(conn, bufin, "RSET")
@@ -764,7 +807,7 @@ func TestMailFromCmd(t *testing.T) {
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -782,6 +825,7 @@ func TestHeloEhlo(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		hostname := config.Servers[0].Hostname
@@ -838,7 +882,7 @@ func TestHeloEhlo(t *testing.T) {
 				t.Error("Server did not respond with", expected, ", it said:"+response)
 			}
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -856,6 +900,7 @@ func TestNestedMailCmd(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -868,11 +913,11 @@ func TestNestedMailCmd(t *testing.T) {
 				t.Error("Hello command failed", err.Error())
 			}
 			// repeat > 64 characters in local part
-			response, err := Command(conn, bufin, "MAIL FROM:[email protected]")
+			response, err := Command(conn, bufin, "MAIL FROM:<[email protected]>")
 			if err != nil {
 				t.Error("command failed", err.Error())
 			}
-			response, err = Command(conn, bufin, "MAIL FROM:[email protected]")
+			response, err = Command(conn, bufin, "MAIL FROM:<[email protected]>")
 			if err != nil {
 				t.Error("command failed", err.Error())
 			}
@@ -884,7 +929,7 @@ func TestNestedMailCmd(t *testing.T) {
 			if _, err := Command(conn, bufin, "HELO localtester"); err != nil {
 				t.Error("Hello command failed", err.Error())
 			}
-			response, err = Command(conn, bufin, "MAIL FROM:[email protected]")
+			response, err = Command(conn, bufin, "MAIL FROM:<[email protected]>")
 			if err != nil {
 				t.Error("command failed", err.Error())
 			}
@@ -902,7 +947,7 @@ func TestNestedMailCmd(t *testing.T) {
 				t.Error("Server did not respond with", expected, ", it said:"+response)
 			}
 
-			response, err = Command(conn, bufin, "MAIL FROM:[email protected]")
+			response, err = Command(conn, bufin, "MAIL FROM:<[email protected]>")
 			if err != nil {
 				t.Error("command failed", err.Error())
 			}
@@ -912,7 +957,7 @@ func TestNestedMailCmd(t *testing.T) {
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -921,8 +966,6 @@ func TestNestedMailCmd(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }
 
 // It should error on a very long command line, exceeding CommandLineMaxLength 1024
@@ -931,7 +974,7 @@ func TestCommandLineMaxLength(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
-
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -955,7 +998,7 @@ func TestCommandLineMaxLength(t *testing.T) {
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -964,8 +1007,7 @@ func TestCommandLineMaxLength(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
+
 }
 
 // It should error on a very long message, exceeding servers config value
@@ -974,7 +1016,7 @@ func TestDataMaxLength(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
-
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -992,7 +1034,7 @@ func TestDataMaxLength(t *testing.T) {
 				t.Error("command failed", err.Error())
 			}
 			//fmt.Println(response)
-			response, err = Command(conn, bufin, "RCPT TO:[email protected]")
+			response, err = Command(conn, bufin, "RCPT TO:<[email protected]>")
 			if err != nil {
 				t.Error("command failed", err.Error())
 			}
@@ -1009,13 +1051,13 @@ func TestDataMaxLength(t *testing.T) {
 					strings.Repeat("n", int(config.Servers[0].MaxSize-20))))
 
 			//expected := "500 Line too long"
-			expected := "451 4.3.0 Error: Maximum DATA size exceeded"
+			expected := "451 4.3.0 Error: maximum DATA size exceeded"
 			if strings.Index(response, expected) != 0 {
-				t.Error("Server did not respond with", expected, ", it said:"+response, err)
+				t.Error("Server did not respond with", expected, ", it said:"+response)
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -1024,8 +1066,7 @@ func TestDataMaxLength(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
+
 }
 
 func TestDataCommand(t *testing.T) {
@@ -1033,7 +1074,7 @@ func TestDataCommand(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
-
+	defer cleanTestArtifacts(t)
 	testHeader :=
 		"Subject: =?Shift_JIS?B?W4NYg06DRYNGg0GBRYNHg2qDYoNOg1ggg0GDSoNFg5ODZ12DQYNKg0WDk4Nn?=\r\n" +
 			"\t=?Shift_JIS?B?k2+YXoqul7mCzIKokm2C54K5?=\r\n"
@@ -1079,12 +1120,12 @@ func TestDataCommand(t *testing.T) {
 				t.Error("Hello command failed", err.Error())
 			}
 
-			response, err := Command(conn, bufin, "MAIL FROM:[email protected]")
+			response, err := Command(conn, bufin, "MAIL FROM:<[email protected]>")
 			if err != nil {
 				t.Error("command failed", err.Error())
 			}
 			//fmt.Println(response)
-			response, err = Command(conn, bufin, "RCPT TO:[email protected]")
+			response, err = Command(conn, bufin, "RCPT TO:<[email protected]>")
 			if err != nil {
 				t.Error("command failed", err.Error())
 			}
@@ -1105,13 +1146,13 @@ func TestDataCommand(t *testing.T) {
 				bufin,
 				email+"\r\n.\r\n")
 			//expected := "500 Line too long"
-			expected := "250 2.0.0 OK : queued as "
+			expected := "250 2.0.0 OK: queued as "
 			if strings.Index(response, expected) != 0 {
 				t.Error("Server did not respond with", expected, ", it said:"+response, err)
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -1120,8 +1161,6 @@ func TestDataCommand(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }
 
 // Fuzzer crashed the server by submitting "DATA\r\n" as the first command
@@ -1130,7 +1169,7 @@ func TestFuzz86f25b86b09897aed8f6c2aa5b5ee1557358a6de(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
-
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -1149,7 +1188,7 @@ func TestFuzz86f25b86b09897aed8f6c2aa5b5ee1557358a6de(t *testing.T) {
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -1158,8 +1197,6 @@ func TestFuzz86f25b86b09897aed8f6c2aa5b5ee1557358a6de(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }
 
 // Appears to hang the fuzz test, but not server.
@@ -1168,6 +1205,7 @@ func TestFuzz21c56f89989d19c3bbbd81b288b2dae9e6dd2150(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	str := "X_\r\nMAIL FROM:<u\xfd\xfdrU" +
 		"\x10c22695140\xfd727235530" +
 		" Walter Sobchak\x1a\tDon" +
@@ -1243,7 +1281,7 @@ func TestFuzz21c56f89989d19c3bbbd81b288b2dae9e6dd2150(t *testing.T) {
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -1252,6 +1290,4 @@ func TestFuzz21c56f89989d19c3bbbd81b288b2dae9e6dd2150(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }

+ 35 - 17
tests/testcert/generate_cert.go

@@ -13,6 +13,7 @@ import (
 	"crypto/x509/pkix"
 	"encoding/pem"
 
+	"errors"
 	"fmt"
 	"log"
 	"math/big"
@@ -44,32 +45,30 @@ func publicKey(priv interface{}) interface{} {
 	}
 }
 
-func pemBlockForKey(priv interface{}) *pem.Block {
+func pemBlockForKey(priv interface{}) (*pem.Block, error) {
 	switch k := priv.(type) {
 	case *rsa.PrivateKey:
-		return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}
+		return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}, nil
 	case *ecdsa.PrivateKey:
 		b, err := x509.MarshalECPrivateKey(k)
 		if err != nil {
-			fmt.Fprintf(os.Stderr, "Unable to marshal ECDSA private key: %v", err)
-			os.Exit(2)
+			err = errors.New(fmt.Sprintf("Unable to marshal ECDSA private key: %v", err))
 		}
-		return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}
+		return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}, err
 	default:
-		return nil
+		return nil, errors.New("not a private key")
 	}
 }
 
 // validFrom - Creation date formatted as Jan 1 15:04:05 2011 or ""
 
-func GenerateCert(host string, validFrom string, validFor time.Duration, isCA bool, rsaBits int, ecdsaCurve string, dirPrefix string) {
+func GenerateCert(host string, validFrom string, validFor time.Duration, isCA bool, rsaBits int, ecdsaCurve string, dirPrefix string) (err error) {
 
 	if len(host) == 0 {
 		log.Fatalf("Missing required --host parameter")
 	}
 
 	var priv interface{}
-	var err error
 	switch ecdsaCurve {
 	case "":
 		priv, err = rsa.GenerateKey(rand.Reader, rsaBits)
@@ -82,11 +81,11 @@ func GenerateCert(host string, validFrom string, validFor time.Duration, isCA bo
 	case "P521":
 		priv, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
 	default:
-		fmt.Fprintf(os.Stderr, "Unrecognized elliptic curve: %q", ecdsaCurve)
-		os.Exit(1)
+		err = errors.New(fmt.Sprintf("Unrecognized elliptic curve: %q", ecdsaCurve))
 	}
 	if err != nil {
 		log.Fatalf("failed to generate private key: %s", err)
+		return
 	}
 
 	var notBefore time.Time
@@ -95,8 +94,8 @@ func GenerateCert(host string, validFrom string, validFor time.Duration, isCA bo
 	} else {
 		notBefore, err = time.Parse("Jan 2 15:04:05 2006", validFrom)
 		if err != nil {
-			fmt.Fprintf(os.Stderr, "Failed to parse creation date: %s\n", err)
-			os.Exit(1)
+			err = errors.New(fmt.Sprintf("Failed to parse creation date: %s\n", err))
+			return
 		}
 	}
 
@@ -144,16 +143,35 @@ func GenerateCert(host string, validFrom string, validFor time.Duration, isCA bo
 	if err != nil {
 		log.Fatalf("failed to open cert.pem for writing: %s", err)
 	}
-	pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
-	certOut.Close()
+	err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+	if err != nil {
+		return
+	}
+	if err = certOut.Sync(); err != nil {
+		return
+	}
+	if err = certOut.Close(); err != nil {
+		return
+	}
 
 	keyOut, err := os.OpenFile(dirPrefix+host+".key.pem", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
 	if err != nil {
 		log.Print("failed to open key.pem for writing:", err)
 		return
 	}
-	pem.Encode(keyOut, pemBlockForKey(priv))
-	keyOut.Sync()
-	keyOut.Close()
+	var block *pem.Block
+	if block, err = pemBlockForKey(priv); err != nil {
+		return err
+	}
+	if err = pem.Encode(keyOut, block); err != nil {
+		return err
+	}
+	if err = keyOut.Sync(); err != nil {
+		return err
+	}
+	if err = keyOut.Close(); err != nil {
+		return err
+	}
+	return
 
 }

+ 0 - 46
util.go

@@ -1,46 +0,0 @@
-package guerrilla
-
-import (
-	"errors"
-	"regexp"
-	"strings"
-
-	"github.com/flashmob/go-guerrilla/mail"
-	"github.com/flashmob/go-guerrilla/response"
-)
-
-var extractEmailRegex, _ = regexp.Compile(`<(.+?)@(.+?)>`) // go home regex, you're drunk!
-
-func extractEmail(str string) (mail.Address, error) {
-	email := mail.Address{}
-	var err error
-	if len(str) > RFC2821LimitPath {
-		return email, errors.New(response.Canned.FailPathTooLong)
-	}
-	if matched := extractEmailRegex.FindStringSubmatch(str); len(matched) > 2 {
-		email.User = matched[1]
-		email.Host = validHost(matched[2])
-	} else if res := strings.Split(str, "@"); len(res) > 1 {
-		email.User = strings.TrimSpace(res[0])
-		email.Host = validHost(res[1])
-	}
-	err = nil
-	if email.User == "" || email.Host == "" {
-		err = errors.New(response.Canned.FailInvalidAddress)
-	} else if len(email.User) > RFC2832LimitLocalPart {
-		err = errors.New(response.Canned.FailLocalPartTooLong)
-	} else if len(email.Host) > RFC2821LimitDomain {
-		err = errors.New(response.Canned.FailDomainTooLong)
-	}
-	return email, err
-}
-
-var validhostRegex, _ = regexp.Compile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
-
-func validHost(host string) string {
-	host = strings.Trim(host, " ")
-	if validhostRegex.MatchString(host) {
-		return host
-	}
-	return ""
-}

Some files were not shown because too many files changed in this diff