ソースを参照

Code style fixes (#131)

* Codebase: fix code style issues, error checking where possible, fix #96 
* Tests: fix race conditions in tests, which caused random failures in the past, fix #96 
* Cross-platform compatibility: Open-file-limit test, capture os.kill signal for maximal compatibility
* Header parsing: increase max header size to 4kb
* Header parsing: avoid extra copy when parsing headers
Flashmob 6 年 前
コミット
3493a860b0

+ 2 - 4
api.go

@@ -29,8 +29,6 @@ type deferredSub struct {
 	fn    interface{}
 }
 
-const defaultInterface = "127.0.0.1:2525"
-
 // AddProcessor adds a processor constructor to the backend.
 // name is the identifier to be used in the config. See backends docs for more info.
 func (d *Daemon) AddProcessor(name string, pc backends.ProcessorConstructor) {
@@ -64,7 +62,7 @@ func (d *Daemon) Start() (err error) {
 			return err
 		}
 		for i := range d.subs {
-			d.Subscribe(d.subs[i].topic, d.subs[i].fn)
+			_ = d.Subscribe(d.subs[i].topic, d.subs[i].fn)
 
 		}
 		d.subs = make([]deferredSub, 0)
@@ -164,10 +162,10 @@ func (d *Daemon) ReopenLogs() error {
 // Subscribe for subscribing to config change events
 func (d *Daemon) Subscribe(topic Event, fn interface{}) error {
 	if d.g == nil {
+		// defer the subscription until the daemon is started
 		d.subs = append(d.subs, deferredSub{topic, fn})
 		return nil
 	}
-
 	return d.g.Subscribe(topic, fn)
 }
 

+ 131 - 35
api_test.go

@@ -212,7 +212,9 @@ func TestSMTPLoadFile(t *testing.T) {
 			return
 		}
 
-		d.ReloadConfigFile("goguerrilla.conf.api")
+		if err = d.ReloadConfigFile("goguerrilla.conf.api"); err != nil {
+			t.Error(err)
+		}
 
 		if d.Config.LogFile != "./tests/testlog2" {
 			t.Error("d.Config.LogFile != \"./tests/testlog\"")
@@ -227,7 +229,9 @@ func TestSMTPLoadFile(t *testing.T) {
 }
 
 func TestReopenLog(t *testing.T) {
-	os.Truncate("test/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 	cfg := &AppConfig{LogFile: "tests/testlog"}
 	sc := ServerConfig{
 		ListenInterface: "127.0.0.1:2526",
@@ -240,7 +244,9 @@ func TestReopenLog(t *testing.T) {
 	if err != nil {
 		t.Error("start error", err)
 	} else {
-		d.ReopenLogs()
+		if err = d.ReopenLogs(); err != nil {
+			t.Error(err)
+		}
 		time.Sleep(time.Second * 2)
 
 		d.Shutdown()
@@ -261,7 +267,9 @@ func TestReopenLog(t *testing.T) {
 
 func TestSetConfig(t *testing.T) {
 
-	os.Truncate("test/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 	cfg := AppConfig{LogFile: "tests/testlog"}
 	sc := ServerConfig{
 		ListenInterface: "127.0.0.1:2526",
@@ -305,7 +313,9 @@ func TestSetConfig(t *testing.T) {
 
 func TestSetConfigError(t *testing.T) {
 
-	os.Truncate("tests/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 	cfg := AppConfig{LogFile: "tests/testlog"}
 	sc := ServerConfig{
 		ListenInterface: "127.0.0.1:2526",
@@ -369,7 +379,9 @@ var funkyLogger = func() backends.Decorator {
 
 // How about a custom processor?
 func TestSetAddProcessor(t *testing.T) {
-	os.Truncate("tests/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 	cfg := &AppConfig{
 		LogFile:      "tests/testlog",
 		AllowedHosts: []string{"grr.la"},
@@ -381,9 +393,13 @@ func TestSetAddProcessor(t *testing.T) {
 	d := Daemon{Config: cfg}
 	d.AddProcessor("FunkyLogger", funkyLogger)
 
-	d.Start()
+	if err := d.Start(); err != nil {
+		t.Error(err)
+	}
 	// lets have a talk with the server
-	talkToServer("127.0.0.1:2525")
+	if err := talkToServer("127.0.0.1:2525"); err != nil {
+		t.Error(err)
+	}
 
 	d.Shutdown()
 
@@ -410,38 +426,87 @@ func TestSetAddProcessor(t *testing.T) {
 
 }
 
-func talkToServer(address string) {
+func talkToServer(address string) (err error) {
 
 	conn, err := net.Dial("tcp", address)
 	if err != nil {
-
 		return
 	}
 	in := bufio.NewReader(conn)
 	str, err := in.ReadString('\n')
-	fmt.Fprint(conn, "HELO maildiranasaurustester\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "HELO maildiranasaurustester\r\n")
+	if err != nil {
+		return err
+	}
 	str, err = in.ReadString('\n')
-	fmt.Fprint(conn, "MAIL FROM:<[email protected]>r\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "MAIL FROM:<[email protected]>r\r\n")
+	if err != nil {
+		return err
+	}
 	str, err = in.ReadString('\n')
-	fmt.Fprint(conn, "RCPT TO:<[email protected]>\r\n")
+	if err != nil {
+		return err
+	}
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "RCPT TO:<[email protected]>\r\n")
+	if err != nil {
+		return err
+	}
 	str, err = in.ReadString('\n')
-	fmt.Fprint(conn, "DATA\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "DATA\r\n")
+	if err != nil {
+		return err
+	}
 	str, err = in.ReadString('\n')
-	fmt.Fprint(conn, "Subject: Test subject\r\n")
-	fmt.Fprint(conn, "\r\n")
-	fmt.Fprint(conn, "A an email body\r\n")
-	fmt.Fprint(conn, ".\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "Subject: Test subject\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, "A an email body\r\n")
+	if err != nil {
+		return err
+	}
+	_, err = fmt.Fprint(conn, ".\r\n")
+	if err != nil {
+		return err
+	}
 	str, err = in.ReadString('\n')
+	if err != nil {
+		return err
+	}
 	_ = str
+	return nil
 }
 
 // Test hot config reload
 // Here we forgot to add FunkyLogger so backend will fail to init
 
 func TestReloadConfig(t *testing.T) {
-	os.Truncate("tests/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 	d := Daemon{}
-	d.Start()
+	if err := d.Start(); err != nil {
+		t.Error(err)
+	}
 	defer d.Shutdown()
 	cfg := AppConfig{
 		LogFile:      "tests/testlog",
@@ -452,16 +517,22 @@ func TestReloadConfig(t *testing.T) {
 		},
 	}
 	// Look mom, reloading the config without shutting down!
-	d.ReloadConfig(cfg)
+	if err := d.ReloadConfig(cfg); err != nil {
+		t.Error(err)
+	}
 
 }
 
 func TestPubSubAPI(t *testing.T) {
 
-	os.Truncate("tests/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 
 	d := Daemon{Config: &AppConfig{LogFile: "tests/testlog"}}
-	d.Start()
+	if err := d.Start(); err != nil {
+		t.Error(err)
+	}
 	defer d.Shutdown()
 	// new config
 	cfg := AppConfig{
@@ -482,14 +553,22 @@ func TestPubSubAPI(t *testing.T) {
 		}
 		d.Logger.Info("number", i)
 	}
-	d.Subscribe(EventConfigPidFile, pidEvHandler)
+	if err := d.Subscribe(EventConfigPidFile, pidEvHandler); err != nil {
+		t.Error(err)
+	}
 
-	d.ReloadConfig(cfg)
+	if err := d.ReloadConfig(cfg); err != nil {
+		t.Error(err)
+	}
 
-	d.Unsubscribe(EventConfigPidFile, pidEvHandler)
+	if err := d.Unsubscribe(EventConfigPidFile, pidEvHandler); err != nil {
+		t.Error(err)
+	}
 	cfg.PidFile = "tests/pidfile2.pid"
 	d.Publish(EventConfigPidFile, &cfg)
-	d.ReloadConfig(cfg)
+	if err := d.ReloadConfig(cfg); err != nil {
+		t.Error(err)
+	}
 
 	b, err := ioutil.ReadFile("tests/testlog")
 	if err != nil {
@@ -504,7 +583,9 @@ func TestPubSubAPI(t *testing.T) {
 }
 
 func TestAPILog(t *testing.T) {
-	os.Truncate("tests/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 	d := Daemon{}
 	l := d.Log()
 	l.Info("logtest1") // to stderr
@@ -540,7 +621,9 @@ func TestSkipAllowsHost(t *testing.T) {
 	defer d.Shutdown()
 	// setting the allowed hosts to a single entry with a dot will let any host through
 	d.Config = &AppConfig{AllowedHosts: []string{"."}, LogFile: "off"}
-	d.Start()
+	if err := d.Start(); err != nil {
+		t.Error(err)
+	}
 
 	conn, err := net.Dial("tcp", d.Config.Servers[0].ListenInterface)
 	if err != nil {
@@ -548,10 +631,19 @@ func TestSkipAllowsHost(t *testing.T) {
 		return
 	}
 	in := bufio.NewReader(conn)
-	fmt.Fprint(conn, "HELO test\r\n")
-	fmt.Fprint(conn, "RCPT TO:<[email protected]>\r\n")
-	in.ReadString('\n')
-	in.ReadString('\n')
+	if _, err := fmt.Fprint(conn, "HELO test\r\n"); err != nil {
+		t.Error(err)
+	}
+	if _, err := fmt.Fprint(conn, "RCPT TO:<[email protected]>\r\n"); err != nil {
+		t.Error(err)
+	}
+
+	if _, err := in.ReadString('\n'); err != nil {
+		t.Error(err)
+	}
+	if _, err := in.ReadString('\n'); err != nil {
+		t.Error(err)
+	}
 	str, _ := in.ReadString('\n')
 	if strings.Index(str, "250") != 0 {
 		t.Error("expected 250 reply, got:", str)
@@ -578,7 +670,9 @@ var customBackend2 = func() backends.Decorator {
 
 // Test a custom backend response
 func TestCustomBackendResult(t *testing.T) {
-	os.Truncate("tests/testlog", 0)
+	if err := os.Truncate("tests/testlog", 0); err != nil {
+		t.Error(err)
+	}
 	cfg := &AppConfig{
 		LogFile:      "tests/testlog",
 		AllowedHosts: []string{"grr.la"},
@@ -594,7 +688,9 @@ func TestCustomBackendResult(t *testing.T) {
 		t.Error(err)
 	}
 	// lets have a talk with the server
-	talkToServer("127.0.0.1:2525")
+	if err := talkToServer("127.0.0.1:2525"); err != nil {
+		t.Error(err)
+	}
 
 	d.Shutdown()
 

+ 16 - 15
backends/backend.go

@@ -69,6 +69,7 @@ type Result interface {
 
 // Internal implementation of BackendResult for use by backend implementations.
 type result struct {
+	// we're going to use a bytes.Buffer for building a string
 	bytes.Buffer
 }
 
@@ -95,11 +96,11 @@ func NewResult(r ...interface{}) Result {
 	for _, item := range r {
 		switch v := item.(type) {
 		case error:
-			buf.WriteString(v.Error())
+			_, _ = buf.WriteString(v.Error())
 		case fmt.Stringer:
-			buf.WriteString(v.String())
+			_, _ = buf.WriteString(v.String())
 		case string:
-			buf.WriteString(v)
+			_, _ = buf.WriteString(v)
 		}
 	}
 	return buf
@@ -255,13 +256,13 @@ func (s *service) ExtractConfig(configData BackendConfig, configType BaseConfig)
 	for i := 0; i < v.NumField(); i++ {
 		f := v.Field(i)
 		// read the tags of the config struct
-		field_name := t.Field(i).Tag.Get("json")
+		fieldName := t.Field(i).Tag.Get("json")
 		omitempty := false
-		if len(field_name) > 0 {
+		if len(fieldName) > 0 {
 			// parse the tag to
 			// get the field name from struct tag
-			split := strings.Split(field_name, ",")
-			field_name = split[0]
+			split := strings.Split(fieldName, ",")
+			fieldName = split[0]
 			if len(split) > 1 {
 				if split[1] == "omitempty" {
 					omitempty = true
@@ -270,30 +271,30 @@ func (s *service) ExtractConfig(configData BackendConfig, configType BaseConfig)
 		} else {
 			// could have no tag
 			// so use the reflected field name
-			field_name = typeOfT.Field(i).Name
+			fieldName = typeOfT.Field(i).Name
 		}
 		if f.Type().Name() == "int" {
 			// in json, there is no int, only floats...
-			if intVal, converted := configData[field_name].(float64); converted {
+			if intVal, converted := configData[fieldName].(float64); converted {
 				v.Field(i).SetInt(int64(intVal))
-			} else if intVal, converted := configData[field_name].(int); converted {
+			} else if intVal, converted := configData[fieldName].(int); converted {
 				v.Field(i).SetInt(int64(intVal))
 			} else if !omitempty {
-				return configType, convertError("property missing/invalid: '" + field_name + "' of expected type: " + f.Type().Name())
+				return configType, convertError("property missing/invalid: '" + fieldName + "' of expected type: " + f.Type().Name())
 			}
 		}
 		if f.Type().Name() == "string" {
-			if stringVal, converted := configData[field_name].(string); converted {
+			if stringVal, converted := configData[fieldName].(string); converted {
 				v.Field(i).SetString(stringVal)
 			} else if !omitempty {
-				return configType, convertError("missing/invalid: '" + field_name + "' of type: " + f.Type().Name())
+				return configType, convertError("missing/invalid: '" + fieldName + "' of type: " + f.Type().Name())
 			}
 		}
 		if f.Type().Name() == "bool" {
-			if boolVal, converted := configData[field_name].(bool); converted {
+			if boolVal, converted := configData[fieldName].(bool); converted {
 				v.Field(i).SetBool(boolVal)
 			} else if !omitempty {
-				return configType, convertError("missing/invalid: '" + field_name + "' of type: " + f.Type().Name())
+				return configType, convertError("missing/invalid: '" + fieldName + "' of type: " + f.Type().Name())
 			}
 		}
 	}

+ 3 - 3
backends/p_compressor.go

@@ -77,9 +77,9 @@ func (c *compressor) String() string {
 	var r *bytes.Reader
 	w, _ := zlib.NewWriterLevel(b, zlib.BestSpeed)
 	r = bytes.NewReader(c.extraHeaders)
-	io.Copy(w, r)
-	io.Copy(w, c.data)
-	w.Close()
+	_, _ = io.Copy(w, r)
+	_, _ = io.Copy(w, c.data)
+	_ = w.Close()
 	return b.String()
 }
 

+ 41 - 17
backends/p_guerrilla_db_redis.go

@@ -130,9 +130,9 @@ func (c *compressedData) String() string {
 	var r *bytes.Reader
 	w, _ := zlib.NewWriterLevel(b, zlib.BestSpeed)
 	r = bytes.NewReader(c.extraHeaders)
-	io.Copy(w, r)
-	io.Copy(w, c.data)
-	w.Close()
+	_, _ = io.Copy(w, r)
+	_, _ = io.Copy(w, c.data)
+	_ = w.Close()
 	return b.String()
 }
 
@@ -150,9 +150,25 @@ func (g *GuerrillaDBAndRedisBackend) prepareInsertQuery(rows int, db *sql.DB) *s
 	if g.cache[rows-1] != nil {
 		return g.cache[rows-1]
 	}
-	sqlstr := "INSERT INTO " + g.config.Table + " "
-	sqlstr += "(`date`, `to`, `from`, `subject`, `body`, `charset`, `mail`, `spam_score`, `hash`, `content_type`, `recipient`, `has_attach`, `ip_addr`, `return_path`, `is_tls`)"
-	sqlstr += " values "
+	sqlstr := "INSERT INTO %s " + g.config.Table + "" +
+		"(" +
+		"`date`, " +
+		"`to`, " +
+		"`from`, " +
+		"`subject`, " +
+		"`body`, " +
+		"`charset`, " +
+		"`mail`, " +
+		"`spam_score`, " +
+		"`hash`, " +
+		"`content_type`, " +
+		"`recipient`, " +
+		"`has_attach`, " +
+		"`ip_addr`, " +
+		"`return_path`, " +
+		"`is_tls`" +
+		")" +
+		" values "
 	values := "(NOW(), ?, ?, ?, ? , 'UTF-8' , ?, 0, ?, '', ?, 0, ?, ?, ?)"
 	// add more rows
 	comma := ""
@@ -346,12 +362,12 @@ func GuerrillaDbRedis() Decorator {
 	g := GuerrillaDBAndRedisBackend{}
 	redisClient := &redisClient{}
 
-	var db *sql.DB
-	var to, body string
-
-	var redisErr error
-
-	var feeders []feedChan
+	var (
+		db       *sql.DB
+		to, body string
+		redisErr error
+		feeders  []feedChan
+	)
 
 	g.batcherStoppers = make([]chan bool, 0)
 
@@ -387,11 +403,17 @@ func GuerrillaDbRedis() Decorator {
 	}))
 
 	Svc.AddShutdowner(ShutdownWith(func() error {
-		db.Close()
-		Log().Infof("closed sql")
+		if err := db.Close(); err != nil {
+			Log().WithError(err).Error("close mysql failed")
+		} else {
+			Log().Infof("closed mysql")
+		}
 		if redisClient.conn != nil {
-			Log().Infof("closed redis")
-			redisClient.conn.Close()
+			if err := redisClient.conn.Close(); err != nil {
+				Log().WithError(err).Error("close redis failed")
+			} else {
+				Log().Infof("closed redis")
+			}
 		}
 		// send a close signal to all query batchers to exit.
 		for i := range g.batcherStoppers {
@@ -413,7 +435,9 @@ func GuerrillaDbRedis() Decorator {
 				e.Helo = trimToLimit(e.Helo, 255)
 				e.RcptTo[0].Host = trimToLimit(e.RcptTo[0].Host, 255)
 				ts := fmt.Sprintf("%d", time.Now().UnixNano())
-				e.ParseHeaders()
+				if err := e.ParseHeaders(); err != nil {
+					Log().WithError(err).Error("failed to parse headers")
+				}
 				hash := MD5Hex(
 					to,
 					e.MailFrom.String(),

+ 6 - 2
backends/p_guerrilla_db_redis_test.go

@@ -19,12 +19,16 @@ func TestCompressedData(t *testing.T) {
 	cd.set([]byte(sbj), &b)
 
 	// compress
-	fmt.Fprint(&out, cd)
+	if _, err := fmt.Fprint(&out, cd); err != nil {
+		t.Error(err)
+	}
 
 	// decompress
 	var result bytes.Buffer
 	zReader, _ := zlib.NewReader(bytes.NewReader(out.Bytes()))
-	io.Copy(&result, zReader)
+	if _, err := io.Copy(&result, zReader); err != nil {
+		t.Error(err)
+	}
 	expect := sbj + str
 	if delta := strings.Compare(expect, result.String()); delta != 0 {
 		t.Error(delta, "compression did match, expected", expect, "but got", result.String())

+ 4 - 4
backends/p_hasher.go

@@ -38,13 +38,13 @@ func Hasher() Decorator {
 				// base hash, use subject from and timestamp-nano
 				h := md5.New()
 				ts := fmt.Sprintf("%d", time.Now().UnixNano())
-				io.Copy(h, strings.NewReader(e.MailFrom.String()))
-				io.Copy(h, strings.NewReader(e.Subject))
-				io.Copy(h, strings.NewReader(ts))
+				_, _ = io.Copy(h, strings.NewReader(e.MailFrom.String()))
+				_, _ = io.Copy(h, strings.NewReader(e.Subject))
+				_, _ = io.Copy(h, strings.NewReader(ts))
 				// using the base hash, calculate a unique hash for each recipient
 				for i := range e.RcptTo {
 					h2 := h
-					io.Copy(h2, strings.NewReader(e.RcptTo[i].String()))
+					_, _ = io.Copy(h2, strings.NewReader(e.RcptTo[i].String()))
 					sum := h2.Sum([]byte{})
 					e.Hashes = append(e.Hashes, fmt.Sprintf("%x", sum))
 				}

+ 3 - 1
backends/p_headers_parser.go

@@ -25,7 +25,9 @@ func HeadersParser() Decorator {
 	return func(p Processor) Processor {
 		return ProcessWith(func(e *mail.Envelope, task SelectTask) (Result, error) {
 			if task == TaskSaveMail {
-				e.ParseHeaders()
+				if err := e.ParseHeaders(); err != nil {
+					Log().WithError(err).Error("parse headers error")
+				}
 				// next processor
 				return p.Process(e, task)
 			} else {

+ 1 - 1
backends/p_redis.go

@@ -67,7 +67,7 @@ func Redis() Decorator {
 		}
 		config = bcfg.(*RedisProcessorConfig)
 		if redisErr := redisClient.redisConnection(config.RedisInterface); redisErr != nil {
-			err := fmt.Errorf("Redis cannot connect, check your settings: %s", redisErr)
+			err := fmt.Errorf("redis cannot connect, check your settings: %s", redisErr)
 			return err
 		}
 		return nil

+ 9 - 2
backends/p_redis_test.go

@@ -29,7 +29,12 @@ func TestRedisGeneric(t *testing.T) {
 		t.Error(err)
 		return
 	}
-	defer g.Shutdown()
+	defer func() {
+		err := g.Shutdown()
+		if err != nil {
+			t.Error(err)
+		}
+	}()
 	if gateway, ok := g.(*BackendGateway); ok {
 		r := gateway.Process(e)
 		if strings.Index(r.String(), "250 2.0.0 OK") == -1 {
@@ -50,6 +55,8 @@ func TestRedisGeneric(t *testing.T) {
 		}
 	}
 
-	os.Remove("./test_redis.log")
+	if err := os.Remove("./test_redis.log"); err != nil {
+		t.Error(err)
+	}
 
 }

+ 3 - 1
backends/p_sql_test.go

@@ -73,7 +73,9 @@ func findRows(hash string) ([]string, error) {
 	if err != nil {
 		return nil, err
 	}
-	defer db.Close()
+	defer func() {
+		_ = db.Close()
+	}()
 
 	stmt := fmt.Sprintf(`SELECT hash FROM %s WHERE hash = ?`, *mailTableFlag)
 	rows, err := db.Query(stmt, hash)

+ 3 - 3
backends/util.go

@@ -41,7 +41,7 @@ func MD5Hex(stringArguments ...string) string {
 	var r *strings.Reader
 	for i := 0; i < len(stringArguments); i++ {
 		r = strings.NewReader(stringArguments[i])
-		io.Copy(h, r)
+		_, _ = io.Copy(h, r)
 	}
 	sum := h.Sum([]byte{})
 	return fmt.Sprintf("%x", sum)
@@ -54,8 +54,8 @@ func Compress(stringArguments ...string) string {
 	w, _ := zlib.NewWriterLevel(&b, zlib.BestSpeed)
 	for i := 0; i < len(stringArguments); i++ {
 		r = strings.NewReader(stringArguments[i])
-		io.Copy(w, r)
+		_, _ = io.Copy(w, r)
 	}
-	w.Close()
+	_ = w.Close()
 	return b.String()
 }

+ 5 - 4
client.go

@@ -115,7 +115,7 @@ func (c *client) sendResponse(r ...interface{}) {
 // Transaction ends on:
 // -HELO/EHLO/REST command
 // -End of DATA command
-// TLS handhsake
+// TLS handshake
 func (c *client) resetTransaction() {
 	c.Envelope.ResetTransaction()
 }
@@ -141,19 +141,20 @@ func (c *client) isAlive() bool {
 }
 
 // setTimeout adjust the timeout on the connection, goroutine safe
-func (c *client) setTimeout(t time.Duration) {
+func (c *client) setTimeout(t time.Duration) (err error) {
 	defer c.connGuard.Unlock()
 	c.connGuard.Lock()
 	if c.conn != nil {
-		c.conn.SetDeadline(time.Now().Add(t * time.Second))
+		err = c.conn.SetDeadline(time.Now().Add(t * time.Second))
 	}
+	return
 }
 
 // closeConn closes a client connection, , goroutine safe
 func (c *client) closeConn() {
 	defer c.connGuard.Unlock()
 	c.connGuard.Lock()
-	c.conn.Close()
+	_ = c.conn.Close()
 	c.conn = nil
 }
 

+ 11 - 32
cmd/guerrillad/serve.go

@@ -3,10 +3,7 @@ package main
 import (
 	"fmt"
 	"os"
-	"os/exec"
 	"os/signal"
-	"strconv"
-	"strings"
 	"syscall"
 	"time"
 
@@ -72,17 +69,20 @@ func sigHandler() {
 		syscall.SIGINT,
 		syscall.SIGKILL,
 		syscall.SIGUSR1,
+		os.Kill,
 	)
 	for sig := range signalChannel {
 		if sig == syscall.SIGHUP {
 			if ac, err := readConfig(configPath, pidFile); err == nil {
-				d.ReloadConfig(*ac)
+				_ = d.ReloadConfig(*ac)
 			} else {
 				mainlog.WithError(err).Error("Could not reload config")
 			}
 		} else if sig == syscall.SIGUSR1 {
-			d.ReopenLogs()
-		} else if sig == syscall.SIGTERM || sig == syscall.SIGQUIT || sig == syscall.SIGINT {
+			if err := d.ReopenLogs(); err != nil {
+				mainlog.WithError(err).Error("reopening logs failed")
+			}
+		} else if sig == syscall.SIGTERM || sig == syscall.SIGQUIT || sig == syscall.SIGINT || sig == os.Kill {
 			mainlog.Infof("Shutdown signal caught")
 			go func() {
 				select {
@@ -105,24 +105,16 @@ func sigHandler() {
 func serve(cmd *cobra.Command, args []string) {
 	logVersion()
 	d = guerrilla.Daemon{Logger: mainlog}
-	ac, err := readConfig(configPath, pidFile)
+	c, err := readConfig(configPath, pidFile)
 	if err != nil {
 		mainlog.WithError(err).Fatal("Error while reading config")
 	}
-	d.SetConfig(*ac)
+	_ = d.SetConfig(*c)
 
 	// Check that max clients is not greater than system open file limit.
-	fileLimit := getFileLimit()
-
-	if fileLimit > 0 {
-		maxClients := 0
-		for _, s := range ac.Servers {
-			maxClients += s.MaxClients
-		}
-		if maxClients > fileLimit {
-			mainlog.Fatalf("Combined max clients for all servers (%d) is greater than open file limit (%d). "+
-				"Please increase your open file limit or decrease max clients.", maxClients, fileLimit)
-		}
+	if ok, maxClients, fileLimit := guerrilla.CheckFileLimit(c); !ok {
+		mainlog.Fatalf("Combined max clients for all servers (%d) is greater than open file limit (%d). "+
+			"Please increase your open file limit or decrease max clients.", maxClients, fileLimit)
 	}
 
 	err = d.Start()
@@ -155,16 +147,3 @@ func readConfig(path string, pidFile string) (*guerrilla.AppConfig, error) {
 	}
 	return &appConfig, nil
 }
-
-func getFileLimit() int {
-	cmd := exec.Command("ulimit", "-n")
-	out, err := cmd.Output()
-	if err != nil {
-		return -1
-	}
-	limit, err := strconv.Atoi(strings.TrimSpace(string(out)))
-	if err != nil {
-		return -1
-	}
-	return limit
-}

ファイルの差分が大きいため隠しています
+ 471 - 220
cmd/guerrillad/serve_test.go


+ 1 - 1
cmd/guerrillad/version.go

@@ -3,7 +3,7 @@ package main
 import (
 	"github.com/spf13/cobra"
 
-	guerrilla "github.com/flashmob/go-guerrilla"
+	"github.com/flashmob/go-guerrilla"
 )
 
 var versionCmd = &cobra.Command{

+ 33 - 17
config.go

@@ -10,6 +10,7 @@ import (
 	"os"
 	"reflect"
 	"strings"
+	"time"
 )
 
 // AppConfig is the holder of the configuration of the app
@@ -48,8 +49,8 @@ type ServerConfig struct {
 	// Listen interface specified in <ip>:<port> - defaults to 127.0.0.1:2525
 	ListenInterface string `json:"listen_interface"`
 
-	// MaxClients controls how many maxiumum clients we can handle at once.
-	// Defaults to 100
+	// MaxClients controls how many maximum clients we can handle at once.
+	// Defaults to defaultMaxClients
 	MaxClients int `json:"max_clients"`
 	// LogFile is where the logs go. Use path to file, or "stderr", "stdout" or "off".
 	// defaults to AppConfig.Log file setting
@@ -87,12 +88,12 @@ type ServerTLSConfig struct {
 	// Use Go's default if empty
 	ClientAuthType string `json:"client_auth_type,omitempty"`
 	// controls whether the server selects the
-	// client's most preferred ciphersuite
+	// client's most preferred cipher suite
 	PreferServerCipherSuites bool `json:"prefer_server_cipher_suites,omitempty"`
 
 	// The following used to watch certificate changes so that the TLS can be reloaded
-	_privateKeyFile_mtime int64
-	_publicKeyFile_mtime  int64
+	_privateKeyFileMtime int64
+	_publicKeyFileMtime  int64
 }
 
 // https://golang.org/pkg/crypto/tls/#pkg-constants
@@ -148,6 +149,11 @@ var TLSClientAuthTypes = map[string]tls.ClientAuthType{
 	"RequireAndVerifyClientCert": tls.RequireAndVerifyClientCert,
 }
 
+const defaultMaxClients = 100
+const defaultTimeout = 30
+const defaultInterface = "127.0.0.1:2525"
+const defaultMaxSize = int64(10 << 20) // 10 Mebibytes
+
 // Unmarshalls json data into AppConfig struct and any other initialization of the struct
 // also does validation, returns error if validation failed or something went wrong
 func (c *AppConfig) Load(jsonBytes []byte) error {
@@ -171,7 +177,9 @@ func (c *AppConfig) Load(jsonBytes []byte) error {
 
 	// read the timestamps for the ssl keys, to determine if they need to be reloaded
 	for i := 0; i < len(c.Servers); i++ {
-		c.Servers[i].loadTlsKeyTimestamps()
+		if err := c.Servers[i].loadTlsKeyTimestamps(); err != nil {
+			return err
+		}
 	}
 	return nil
 }
@@ -218,8 +226,8 @@ func (c *AppConfig) EmitChangeEvents(oldConfig *AppConfig, app Guerrilla) {
 
 	}
 	// remove any servers that don't exist anymore
-	for _, oldserver := range oldServers {
-		app.Publish(EventConfigServerRemove, oldserver)
+	for _, oldServer := range oldServers {
+		app.Publish(EventConfigServerRemove, oldServer)
 	}
 }
 
@@ -275,9 +283,9 @@ func (c *AppConfig) setDefaults() error {
 		sc.ListenInterface = defaultInterface
 		sc.IsEnabled = true
 		sc.Hostname = h
-		sc.MaxClients = 100
-		sc.Timeout = 30
-		sc.MaxSize = 10 << 20 // 10 Mebibytes
+		sc.MaxClients = defaultMaxClients
+		sc.Timeout = defaultTimeout
+		sc.MaxSize = defaultMaxSize
 		c.Servers = append(c.Servers, sc)
 	} else {
 		// make sure each server has defaults correctly configured
@@ -286,13 +294,13 @@ func (c *AppConfig) setDefaults() error {
 				c.Servers[i].Hostname = h
 			}
 			if c.Servers[i].MaxClients == 0 {
-				c.Servers[i].MaxClients = 100
+				c.Servers[i].MaxClients = defaultMaxClients
 			}
 			if c.Servers[i].Timeout == 0 {
-				c.Servers[i].Timeout = 20
+				c.Servers[i].Timeout = defaultTimeout
 			}
 			if c.Servers[i].MaxSize == 0 {
-				c.Servers[i].MaxSize = 10 << 20 // 10 Mebibytes
+				c.Servers[i].MaxSize = defaultMaxSize // 10 Mebibytes
 			}
 			if c.Servers[i].ListenInterface == "" {
 				return errors.New(fmt.Sprintf("Listen interface not specified for server at index %d", i))
@@ -406,13 +414,21 @@ func (sc *ServerConfig) loadTlsKeyTimestamps() error {
 				iface,
 				err.Error()))
 	}
+	if sc.TLS.PrivateKeyFile == "" {
+		sc.TLS._privateKeyFileMtime = time.Now().Unix()
+		return nil
+	}
+	if sc.TLS.PublicKeyFile == "" {
+		sc.TLS._publicKeyFileMtime = time.Now().Unix()
+		return nil
+	}
 	if info, err := os.Stat(sc.TLS.PrivateKeyFile); err == nil {
-		sc.TLS._privateKeyFile_mtime = info.ModTime().Unix()
+		sc.TLS._privateKeyFileMtime = info.ModTime().Unix()
 	} else {
 		return statErr(sc.ListenInterface, err)
 	}
 	if info, err := os.Stat(sc.TLS.PublicKeyFile); err == nil {
-		sc.TLS._publicKeyFile_mtime = info.ModTime().Unix()
+		sc.TLS._publicKeyFileMtime = info.ModTime().Unix()
 	} else {
 		return statErr(sc.ListenInterface, err)
 	}
@@ -445,7 +461,7 @@ func (sc *ServerConfig) Validate() error {
 // Gets the timestamp of the TLS certificates. Returns a unix time of when they were last modified
 // when the config was read. We use this info to determine if TLS needs to be re-loaded.
 func (stc *ServerTLSConfig) getTlsKeyTimestamps() (int64, int64) {
-	return stc._privateKeyFile_mtime, stc._publicKeyFile_mtime
+	return stc._privateKeyFileMtime, stc._publicKeyFileMtime
 }
 
 // Returns value changes between struct a & struct b.

+ 47 - 18
config_test.go

@@ -11,10 +11,6 @@ import (
 	"time"
 )
 
-func init() {
-	testcert.GenerateCert("mail2.guerrillamail.com", "", 365*24*time.Hour, false, 2048, "P256", "./tests/")
-}
-
 // a configuration file with a dummy backend
 
 //
@@ -114,8 +110,8 @@ var configJsonB = `
             "listen_interface":"127.0.0.1:2526",
             "max_clients": 3,
 			"tls" : {
- 				"private_key_file":"./config_test.go",
-            	"public_key_file":"./config_test.go",
+ 				"private_key_file":"./tests/mail2.guerrillamail.com.key.pem",
+            	"public_key_file": "./tests/mail2.guerrillamail.com.cert.pem",
 				"start_tls_on":false,
             	"tls_always_on":true
 			}
@@ -161,7 +157,7 @@ var configJsonB = `
 			"tls" : {
 				"private_key_file":"config_test.go",
             	"public_key_file":"config_test.go",
-				"start_tls_on":true,
+				"start_tls_on":false,
             	"tls_always_on":false
 			}
         }
@@ -170,6 +166,18 @@ var configJsonB = `
 `
 
 func TestConfigLoad(t *testing.T) {
+	if err := testcert.GenerateCert("mail2.guerrillamail.com", "", 365*24*time.Hour, false, 2048, "P256", "./tests/"); err != nil {
+		t.Error(err)
+	}
+	defer func() {
+		if err := deleteIfExists("../tests/mail2.guerrillamail.com.cert.pem"); err != nil {
+			t.Error(err)
+		}
+		if err := deleteIfExists("../tests/mail2.guerrillamail.com.key.pem"); err != nil {
+			t.Error(err)
+		}
+	}()
+
 	ac := &AppConfig{}
 	if err := ac.Load([]byte(configJsonA)); err != nil {
 		t.Error("Cannot load config |", err)
@@ -181,8 +189,8 @@ func TestConfigLoad(t *testing.T) {
 		t.SkipNow()
 	}
 	// did we got the timestamps?
-	if ac.Servers[0].TLS._privateKeyFile_mtime <= 0 {
-		t.Error("failed to read timestamp for _privateKeyFile_mtime, got", ac.Servers[0].TLS._privateKeyFile_mtime)
+	if ac.Servers[0].TLS._privateKeyFileMtime <= 0 {
+		t.Error("failed to read timestamp for _privateKeyFileMtime, got", ac.Servers[0].TLS._privateKeyFileMtime)
 	}
 }
 
@@ -205,9 +213,22 @@ func TestSampleConfig(t *testing.T) {
 
 // make sure that we get all the config change events
 func TestConfigChangeEvents(t *testing.T) {
+	if err := testcert.GenerateCert("mail2.guerrillamail.com", "", 365*24*time.Hour, false, 2048, "P256", "./tests/"); err != nil {
+		t.Error(err)
+	}
+	defer func() {
+		if err := deleteIfExists("../tests/mail2.guerrillamail.com.cert.pem"); err != nil {
+			t.Error(err)
+		}
+		if err := deleteIfExists("../tests/mail2.guerrillamail.com.key.pem"); err != nil {
+			t.Error(err)
+		}
+	}()
 
 	oldconf := &AppConfig{}
-	oldconf.Load([]byte(configJsonA))
+	if err := oldconf.Load([]byte(configJsonA)); err != nil {
+		t.Error(err)
+	}
 	logger, _ := log.GetLogger(oldconf.LogFile, oldconf.LogLevel)
 	bcfg := backends.BackendConfig{"log_received_mails": true}
 	backend, err := backends.New(bcfg, logger)
@@ -221,10 +242,16 @@ func TestConfigChangeEvents(t *testing.T) {
 	// simulate timestamp change
 
 	time.Sleep(time.Second + time.Millisecond*500)
-	os.Chtimes(oldconf.Servers[1].TLS.PrivateKeyFile, time.Now(), time.Now())
-	os.Chtimes(oldconf.Servers[1].TLS.PublicKeyFile, time.Now(), time.Now())
+	if err := os.Chtimes(oldconf.Servers[1].TLS.PrivateKeyFile, time.Now(), time.Now()); err != nil {
+		t.Error(err)
+	}
+	if err := os.Chtimes(oldconf.Servers[1].TLS.PublicKeyFile, time.Now(), time.Now()); err != nil {
+		t.Error(err)
+	}
 	newconf := &AppConfig{}
-	newconf.Load([]byte(configJsonB))
+	if err := newconf.Load([]byte(configJsonB)); err != nil {
+		t.Error(err)
+	}
 	newconf.Servers[0].LogFile = log.OutputOff.String() // test for log file change
 	newconf.LogLevel = log.InfoLevel.String()
 	newconf.LogFile = "off"
@@ -253,14 +280,14 @@ func TestConfigChangeEvents(t *testing.T) {
 				f := func(c *AppConfig) {
 					expectedEvents[e] = true
 				}
-				app.Subscribe(event, f)
+				_ = app.Subscribe(event, f)
 				toUnsubscribe[event] = f
 			} else {
 				// must be a server config change then
 				f := func(c *ServerConfig) {
 					expectedEvents[e] = true
 				}
-				app.Subscribe(event, f)
+				_ = app.Subscribe(event, f)
 				toUnsubscribeSrv[event] = f
 			}
 
@@ -271,10 +298,10 @@ func TestConfigChangeEvents(t *testing.T) {
 	newconf.EmitChangeEvents(oldconf, app)
 	// unsubscribe
 	for unevent, unfun := range toUnsubscribe {
-		app.Unsubscribe(unevent, unfun)
+		_ = app.Unsubscribe(unevent, unfun)
 	}
 	for unevent, unfun := range toUnsubscribeSrv {
-		app.Unsubscribe(unevent, unfun)
+		_ = app.Unsubscribe(unevent, unfun)
 	}
 	for event, val := range expectedEvents {
 		if val == false {
@@ -285,5 +312,7 @@ func TestConfigChangeEvents(t *testing.T) {
 	}
 
 	// don't forget to reset
-	os.Truncate(oldconf.LogFile, 0)
+	if err := os.Truncate(oldconf.LogFile, 0); err != nil {
+		t.Error(err)
+	}
 }

+ 116 - 54
guerrilla.go

@@ -3,20 +3,21 @@ package guerrilla
 import (
 	"errors"
 	"fmt"
-	"github.com/flashmob/go-guerrilla/backends"
-	"github.com/flashmob/go-guerrilla/log"
 	"os"
 	"sync"
 	"sync/atomic"
+
+	"github.com/flashmob/go-guerrilla/backends"
+	"github.com/flashmob/go-guerrilla/log"
 )
 
 const (
-	// server has just been created
-	GuerrillaStateNew = iota
-	// Server has been started and is running
-	GuerrillaStateStarted
-	// Server has just been stopped
-	GuerrillaStateStopped
+	// all configured servers were just been created
+	daemonStateNew = iota
+	// ... been started and running
+	daemonStateStarted
+	// ... been stopped
+	daemonStateStopped
 )
 
 type Errors []error
@@ -62,6 +63,9 @@ type backendStore struct {
 	atomic.Value
 }
 
+type daemonEvent func(c *AppConfig)
+type serverEvent func(sc *ServerConfig)
+
 // Get loads the log.logger in an atomic operation. Returns a stderr logger if not able to load
 func (ls *logStore) mainlog() log.Logger {
 	if v, ok := ls.Load().(log.Logger); ok {
@@ -71,7 +75,7 @@ func (ls *logStore) mainlog() log.Logger {
 	return l
 }
 
-// storeMainlog stores the log value in an atomic operation
+// setMainlog stores the log value in an atomic operation
 func (ls *logStore) setMainlog(log log.Logger) {
 	ls.Store(log)
 }
@@ -92,17 +96,18 @@ func New(ac *AppConfig, b backends.Backend, l log.Logger) (Guerrilla, error) {
 			}
 		}
 	}
+	// Write the process id (pid) to a file
+	// we should still be able to continue even if we can't write the pid, error will be logged by writePid()
+	_ = g.writePid()
 
-	g.state = GuerrillaStateNew
+	g.state = daemonStateNew
 	err := g.makeServers()
 
 	// start backend for processing email
 	err = g.backend().Start()
-
 	if err != nil {
 		return g, err
 	}
-	g.writePid()
 
 	// subscribe for any events that may come in while running
 	g.subscribeEvents()
@@ -136,7 +141,7 @@ func (g *guerrilla) makeServers() error {
 		}
 	}
 	if len(g.servers) == 0 {
-		errs = append(errs, errors.New("There are no servers that can start, please check your config"))
+		errs = append(errs, errors.New("there are no servers that can start, please check your config"))
 	}
 	if len(errs) == 0 {
 		return nil
@@ -191,13 +196,13 @@ func (g *guerrilla) mapServers(callback func(*server)) map[string]*server {
 // subscribeEvents subscribes event handlers for configuration change events
 func (g *guerrilla) subscribeEvents() {
 
+	events := map[Event]interface{}{}
 	// main config changed
-	g.Subscribe(EventConfigNewConfig, func(c *AppConfig) {
+	events[EventConfigNewConfig] = daemonEvent(func(c *AppConfig) {
 		g.setConfig(c)
 	})
-
 	// allowed_hosts changed, set for all servers
-	g.Subscribe(EventConfigAllowedHosts, func(c *AppConfig) {
+	events[EventConfigAllowedHosts] = daemonEvent(func(c *AppConfig) {
 		g.mapServers(func(server *server) {
 			server.setAllowedHosts(c.AllowedHosts)
 		})
@@ -205,7 +210,7 @@ func (g *guerrilla) subscribeEvents() {
 	})
 
 	// the main log file changed
-	g.Subscribe(EventConfigLogFile, func(c *AppConfig) {
+	events[EventConfigLogFile] = daemonEvent(func(c *AppConfig) {
 		var err error
 		var l log.Logger
 		if l, err = log.GetLogger(c.LogFile, c.LogLevel); err == nil {
@@ -222,13 +227,17 @@ func (g *guerrilla) subscribeEvents() {
 	})
 
 	// re-open the main log file (file not changed)
-	g.Subscribe(EventConfigLogReopen, func(c *AppConfig) {
-		g.mainlog().Reopen()
+	events[EventConfigLogReopen] = daemonEvent(func(c *AppConfig) {
+		err := g.mainlog().Reopen()
+		if err != nil {
+			g.mainlog().WithError(err).Errorf("main log file [%s] failed to re-open", c.LogFile)
+			return
+		}
 		g.mainlog().Infof("re-opened main log file [%s]", c.LogFile)
 	})
 
 	// when log level changes, apply to mainlog and server logs
-	g.Subscribe(EventConfigLogLevel, func(c *AppConfig) {
+	events[EventConfigLogLevel] = daemonEvent(func(c *AppConfig) {
 		l, err := log.GetLogger(g.mainlog().GetLogDest(), c.LogLevel)
 		if err == nil {
 			g.logStore.Store(l)
@@ -240,18 +249,18 @@ func (g *guerrilla) subscribeEvents() {
 	})
 
 	// write out our pid whenever the file name changes in the config
-	g.Subscribe(EventConfigPidFile, func(ac *AppConfig) {
-		g.writePid()
+	events[EventConfigPidFile] = daemonEvent(func(ac *AppConfig) {
+		_ = g.writePid()
 	})
 
 	// server config was updated
-	g.Subscribe(EventConfigServerConfig, func(sc *ServerConfig) {
+	events[EventConfigServerConfig] = serverEvent(func(sc *ServerConfig) {
 		g.setServerConfig(sc)
 		g.mainlog().Infof("server %s config change event, a new config has been saved", sc.ListenInterface)
 	})
 
 	// add a new server to the config & start
-	g.Subscribe(EventConfigServerNew, func(sc *ServerConfig) {
+	events[EventConfigServerNew] = serverEvent(func(sc *ServerConfig) {
 		g.mainlog().Debugf("event fired [%s] %s", EventConfigServerNew, sc.ListenInterface)
 		if _, err := g.findServer(sc.ListenInterface); err != nil {
 			// not found, lets add it
@@ -261,7 +270,7 @@ func (g *guerrilla) subscribeEvents() {
 				return
 			}
 			g.mainlog().Infof("New server added [%s]", sc.ListenInterface)
-			if g.state == GuerrillaStateStarted {
+			if g.state == daemonStateStarted {
 				err := g.Start()
 				if err != nil {
 					g.mainlog().WithError(err).Info("Event server_change:new_server returned errors when starting")
@@ -271,8 +280,9 @@ func (g *guerrilla) subscribeEvents() {
 			g.mainlog().Debugf("new event, but server already fund")
 		}
 	})
+
 	// start a server that already exists in the config and has been enabled
-	g.Subscribe(EventConfigServerStart, func(sc *ServerConfig) {
+	events[EventConfigServerStart] = serverEvent(func(sc *ServerConfig) {
 		if server, err := g.findServer(sc.ListenInterface); err == nil {
 			if server.state == ServerStateStopped || server.state == ServerStateNew {
 				g.mainlog().Infof("Starting server [%s]", server.listenInterface)
@@ -283,8 +293,9 @@ func (g *guerrilla) subscribeEvents() {
 			}
 		}
 	})
+
 	// stop running a server
-	g.Subscribe(EventConfigServerStop, func(sc *ServerConfig) {
+	events[EventConfigServerStop] = serverEvent(func(sc *ServerConfig) {
 		if server, err := g.findServer(sc.ListenInterface); err == nil {
 			if server.state == ServerStateRunning {
 				server.Shutdown()
@@ -292,8 +303,9 @@ func (g *guerrilla) subscribeEvents() {
 			}
 		}
 	})
+
 	// server was removed from config
-	g.Subscribe(EventConfigServerRemove, func(sc *ServerConfig) {
+	events[EventConfigServerRemove] = serverEvent(func(sc *ServerConfig) {
 		if server, err := g.findServer(sc.ListenInterface); err == nil {
 			server.Shutdown()
 			g.removeServer(sc.ListenInterface)
@@ -302,7 +314,7 @@ func (g *guerrilla) subscribeEvents() {
 	})
 
 	// TLS changes
-	g.Subscribe(EventConfigServerTLSConfig, func(sc *ServerConfig) {
+	events[EventConfigServerTLSConfig] = serverEvent(func(sc *ServerConfig) {
 		if server, err := g.findServer(sc.ListenInterface); err == nil {
 			if err := server.configureSSL(); err == nil {
 				g.mainlog().Infof("Server [%s] new TLS configuration loaded", sc.ListenInterface)
@@ -312,19 +324,19 @@ func (g *guerrilla) subscribeEvents() {
 		}
 	})
 	// when server's timeout change.
-	g.Subscribe(EventConfigServerTimeout, func(sc *ServerConfig) {
+	events[EventConfigServerTimeout] = serverEvent(func(sc *ServerConfig) {
 		g.mapServers(func(server *server) {
 			server.setTimeout(sc.Timeout)
 		})
 	})
 	// when server's max clients change.
-	g.Subscribe(EventConfigServerMaxClients, func(sc *ServerConfig) {
+	events[EventConfigServerMaxClients] = serverEvent(func(sc *ServerConfig) {
 		g.mapServers(func(server *server) {
 			// TODO resize the pool somehow
 		})
 	})
 	// when a server's log file changes
-	g.Subscribe(EventConfigServerLogFile, func(sc *ServerConfig) {
+	events[EventConfigServerLogFile] = serverEvent(func(sc *ServerConfig) {
 		if server, err := g.findServer(sc.ListenInterface); err == nil {
 			var err error
 			var l log.Logger
@@ -348,14 +360,17 @@ func (g *guerrilla) subscribeEvents() {
 		}
 	})
 	// when the daemon caught a sighup, event for individual server
-	g.Subscribe(EventConfigServerLogReopen, func(sc *ServerConfig) {
+	events[EventConfigServerLogReopen] = serverEvent(func(sc *ServerConfig) {
 		if server, err := g.findServer(sc.ListenInterface); err == nil {
-			server.log().Reopen()
+			if err = server.log().Reopen(); err != nil {
+				g.mainlog().WithError(err).Errorf("server [%s] log file [%s] failed to re-open", sc.ListenInterface, sc.LogFile)
+				return
+			}
 			g.mainlog().Infof("Server [%s] re-opened log file [%s]", sc.ListenInterface, sc.LogFile)
 		}
 	})
 	// when the backend changes
-	g.Subscribe(EventConfigBackendConfig, func(appConfig *AppConfig) {
+	events[EventConfigBackendConfig] = daemonEvent(func(appConfig *AppConfig) {
 		logger, _ := log.GetLogger(appConfig.LogFile, appConfig.LogLevel)
 		// shutdown the backend first.
 		var err error
@@ -386,6 +401,19 @@ func (g *guerrilla) subscribeEvents() {
 			g.storeBackend(newBackend)
 		}
 	})
+	var err error
+	for topic, fn := range events {
+		switch f := fn.(type) {
+		case daemonEvent:
+			err = g.Subscribe(topic, f)
+		case serverEvent:
+			err = g.Subscribe(topic, f)
+		}
+		if err != nil {
+			g.mainlog().WithError(err).Errorf("failed to subscribe on topic [%s]", topic)
+			break
+		}
+	}
 
 }
 
@@ -408,16 +436,20 @@ func (g *guerrilla) Start() error {
 	var startErrors Errors
 	g.guard.Lock()
 	defer func() {
-		g.state = GuerrillaStateStarted
+		g.state = daemonStateStarted
 		g.guard.Unlock()
 	}()
 	if len(g.servers) == 0 {
-		return append(startErrors, errors.New("No servers to start, please check the config"))
+		return append(startErrors, errors.New("no servers to start, please check the config"))
 	}
-	if g.state == GuerrillaStateStopped {
+	if g.state == daemonStateStopped {
 		// when a backend is shutdown, we need to re-initialize before it can be started again
-		g.backend().Reinitialize()
-		g.backend().Start()
+		if err := g.backend().Reinitialize(); err != nil {
+			startErrors = append(startErrors, err)
+		}
+		if err := g.backend().Start(); err != nil {
+			startErrors = append(startErrors, err)
+		}
 	}
 	// channel for reading errors
 	errs := make(chan error, len(g.servers))
@@ -469,7 +501,7 @@ func (g *guerrilla) Shutdown() {
 
 	g.guard.Lock()
 	defer func() {
-		g.state = GuerrillaStateStopped
+		g.state = daemonStateStopped
 		defer g.guard.Unlock()
 	}()
 	if err := g.backend().Shutdown(); err != nil {
@@ -487,22 +519,52 @@ func (g *guerrilla) SetLogger(l log.Logger) {
 
 // writePid writes the pid (process id) to the file specified in the config.
 // Won't write anything if no file specified
-func (g *guerrilla) writePid() error {
-	if len(g.Config.PidFile) > 0 {
-		if f, err := os.Create(g.Config.PidFile); err == nil {
-			defer f.Close()
-			pid := os.Getpid()
-			if _, err := f.WriteString(fmt.Sprintf("%d", pid)); err == nil {
-				f.Sync()
-				g.mainlog().Infof("pid_file (%s) written with pid:%v", g.Config.PidFile, pid)
-			} else {
-				g.mainlog().WithError(err).Errorf("Error while writing pidFile (%s)", g.Config.PidFile)
-				return err
+func (g *guerrilla) writePid() (err error) {
+	var f *os.File
+	defer func() {
+		if f != nil {
+			if closeErr := f.Close(); closeErr != nil {
+				err = closeErr
 			}
-		} else {
-			g.mainlog().WithError(err).Errorf("Error while creating pidFile (%s)", g.Config.PidFile)
+		}
+		if err != nil {
+			g.mainlog().WithError(err).Errorf("error while writing pidFile (%s)", g.Config.PidFile)
+		}
+	}()
+	if len(g.Config.PidFile) > 0 {
+		if f, err = os.Create(g.Config.PidFile); err != nil {
+			return err
+		}
+		pid := os.Getpid()
+		if _, err := f.WriteString(fmt.Sprintf("%d", pid)); err != nil {
+			return err
+		}
+		if err = f.Sync(); err != nil {
 			return err
 		}
+		g.mainlog().Infof("pid_file (%s) written with pid:%v", g.Config.PidFile, pid)
 	}
 	return nil
 }
+
+// CheckFileLimit checks the number of files we can open (works on OS'es that support the ulimit command)
+func CheckFileLimit(c *AppConfig) (bool, int, uint64) {
+	fileLimit, err := getFileLimit()
+	maxClients := 0
+	if err != nil {
+		// since we can't get the limit, return true to indicate the check passed
+		return true, maxClients, fileLimit
+	}
+	if c.Servers == nil {
+		// no servers have been configured, assuming default
+		maxClients = defaultMaxClients
+	} else {
+		for _, s := range c.Servers {
+			maxClients += s.MaxClients
+		}
+	}
+	if uint64(maxClients) > fileLimit {
+		return false, maxClients, fileLimit
+	}
+	return true, maxClients, fileLimit
+}

+ 16 - 0
guerrilla_notunix.go

@@ -0,0 +1,16 @@
+// +build !darwin
+// +build !dragonfly
+// +build !freebsd
+// +build !linux
+// +build !netbsd
+// +build !openbsd
+
+package guerrilla
+
+import "errors"
+
+// getFileLimit checks how many files we can open
+// Don't know how to get that info (yet?), so returns false information & error
+func getFileLimit() (uint64, error) {
+	return 1000000, errors.New("syscall.RLIMIT_NOFILE not supported on your OS/platform")
+}

+ 15 - 0
guerrilla_unix.go

@@ -0,0 +1,15 @@
+// +build darwin dragonfly freebsd linux netbsd openbsd
+
+package guerrilla
+
+import "syscall"
+
+// getFileLimit checks how many files we can open
+func getFileLimit() (uint64, error) {
+	var rLimit syscall.Rlimit
+	err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit)
+	if err != nil {
+		return 0, err
+	}
+	return rLimit.Max, nil
+}

+ 1 - 1
log/hook.go

@@ -153,7 +153,7 @@ func (hook *LogrusHook) Fire(entry *log.Entry) error {
 				return err
 			}
 			if hook.fd != nil {
-				hook.fd.Sync()
+				err = hook.fd.Sync()
 			}
 		}
 		return err

+ 1 - 2
log/log.go

@@ -212,8 +212,7 @@ func (l *HookedLogger) GetLogDest() string {
 
 // WithConn extends logrus to be able to log with a net.Conn
 func (l *HookedLogger) WithConn(conn net.Conn) *log.Entry {
-	var addr string = "unknown"
-
+	var addr = "unknown"
 	if conn != nil {
 		addr = conn.RemoteAddr().String()
 	}

+ 7 - 10
mail/envelope.go

@@ -27,7 +27,7 @@ func init() {
 	Dec = mime.WordDecoder{}
 }
 
-const maxHeaderChunk = 1 + (3 << 10) // 3KB
+const maxHeaderChunk = 1 + (4 << 10) // 4KB
 
 // Address encodes an email address of the form `<user@host>`
 type Address struct {
@@ -122,18 +122,15 @@ func (e *Envelope) ParseHeaders() error {
 	if e.Header != nil {
 		return errors.New("headers already parsed")
 	}
-	buf := bytes.NewBuffer(e.Data.Bytes())
+	buf := e.Data.Bytes()
 	// find where the header ends, assuming that over 30 kb would be max
-	max := maxHeaderChunk
-	if buf.Len() < max {
-		max = buf.Len()
+	if len(buf) > maxHeaderChunk {
+		buf = buf[:maxHeaderChunk]
 	}
-	// read in the chunk which we'll scan for the header
-	chunk := make([]byte, max)
-	buf.Read(chunk)
-	headerEnd := strings.Index(string(chunk), "\n\n") // the first two new-lines chars are the End Of Header
+
+	headerEnd := bytes.Index(buf, []byte{'\n', '\n'}) // the first two new-lines chars are the End Of Header
 	if headerEnd > -1 {
-		header := chunk[0:headerEnd]
+		header := buf[0:headerEnd]
 		headerReader := textproto.NewReader(bufio.NewReader(bytes.NewBuffer(header)))
 		e.Header, err = headerReader.ReadMIMEHeader()
 		if err != nil {

+ 6 - 2
mail/envelope_test.go

@@ -1,6 +1,7 @@
 package mail
 
 import (
+	"io"
 	"io/ioutil"
 	"strings"
 	"testing"
@@ -61,9 +62,12 @@ func TestEnvelope(t *testing.T) {
 
 	data, _ := ioutil.ReadAll(r)
 	if len(data) != e.Len() {
-		t.Error("e.Len() is inccorrect, it shown ", e.Len(), " but we wanted ", len(data))
+		t.Error("e.Len() is incorrect, it shown ", e.Len(), " but we wanted ", len(data))
+	}
+	if err := e.ParseHeaders(); err != nil && err != io.EOF {
+		t.Error("cannot parse headers:", err)
+		return
 	}
-	e.ParseHeaders()
 	if e.Subject != "Test" {
 		t.Error("Subject expecting: Test, got:", e.Subject)
 	}

+ 6 - 11
mocks/client.go

@@ -3,7 +3,6 @@ package mocks
 import (
 	"fmt"
 	"net/smtp"
-	"time"
 )
 
 const (
@@ -16,21 +15,15 @@ func lastWords(message string, err error) {
 	// panic(err)
 }
 
-// Sends a single SMTP message, for testing.
-func main() {
-	for i := 0; i < 100; i++ {
-		go sendMail(i)
-	}
-	time.Sleep(time.Minute / 10)
-}
-
 func sendMail(i int) {
 	fmt.Printf("Sending %d mail\n", i)
 	c, err := smtp.Dial(URL)
 	if err != nil {
 		lastWords("Dial ", err)
 	}
-	defer c.Close()
+	defer func() {
+		_ = c.Close()
+	}()
 
 	from := "[email protected]"
 	to := "[email protected]"
@@ -47,7 +40,9 @@ func sendMail(i int) {
 	if err != nil {
 		lastWords("Data ", err)
 	}
-	defer wr.Close()
+	defer func() {
+		_ = wr.Close()
+	}()
 
 	msg := fmt.Sprint("Subject: something\n")
 	msg += "From: " + from + "\n"

+ 3 - 3
mocks/conn_mock.go

@@ -31,11 +31,11 @@ type End struct {
 	Writer *io.PipeWriter
 }
 
-func (c End) Close() error {
-	if err := c.Writer.Close(); err != nil {
+func (e End) Close() error {
+	if err := e.Writer.Close(); err != nil {
 		return err
 	}
-	if err := c.Reader.Close(); err != nil {
+	if err := e.Reader.Close(); err != nil {
 		return err
 	}
 	return nil

+ 2 - 2
models.go

@@ -7,8 +7,8 @@ import (
 )
 
 var (
-	LineLimitExceeded   = errors.New("Maximum line length exceeded")
-	MessageSizeExceeded = errors.New("Maximum message size exceeded")
+	LineLimitExceeded   = errors.New("maximum line length exceeded")
+	MessageSizeExceeded = errors.New("maximum message size exceeded")
 )
 
 // we need to adjust the limit, so we embed io.LimitedReader

+ 9 - 4
pool.go

@@ -17,11 +17,12 @@ var (
 // a struct can be pooled if it has the following interface
 type Poolable interface {
 	// ability to set read/write timeout
-	setTimeout(t time.Duration)
+	setTimeout(t time.Duration) error
 	// set a new connection and client id
 	init(c net.Conn, clientID uint64, ep *mail.Pool)
 	// get a unique id
 	getID() uint64
+	kill()
 }
 
 // Pool holds Clients.
@@ -82,9 +83,11 @@ func (p *Pool) ShutdownState() {
 	p.isShuttingDownFlg.Store(true) // no more borrowing
 	p.ShutdownChan <- 1             // release any waiting p.sem
 
-	// set a low timeout
+	// set a low timeout (let the clients finish whatever the're doing)
 	p.activeClients.mapAll(func(p Poolable) {
-		p.setTimeout(time.Duration(int64(aVeryLowTimeout)))
+		if err := p.setTimeout(time.Duration(int64(aVeryLowTimeout))); err != nil {
+			p.kill()
+		}
 	})
 
 }
@@ -111,7 +114,9 @@ func (p *Pool) IsShuttingDown() bool {
 // set a timeout for all lent clients
 func (p *Pool) SetTimeout(duration time.Duration) {
 	p.activeClients.mapAll(func(p Poolable) {
-		p.setTimeout(duration)
+		if err := p.setTimeout(duration); err != nil {
+			p.kill()
+		}
 	})
 }
 

+ 0 - 12
response/enhanced_test.go

@@ -4,18 +4,6 @@ import (
 	"testing"
 )
 
-func TestClass(t *testing.T) {
-	if ClassPermanentFailure != 5 {
-		t.Error("ClassPermanentFailure is not 5")
-	}
-	if ClassTransientFailure != 4 {
-		t.Error("ClassTransientFailure is not 4")
-	}
-	if ClassSuccess != 2 {
-		t.Error("ClassSuccess is not 2")
-	}
-}
-
 func TestGetBasicStatusCode(t *testing.T) {
 	// Known status code
 	a := getBasicStatusCode(EnhancedStatusCode{2, OtherOrUndefinedProtocolStatus})

+ 113 - 111
server.go

@@ -90,7 +90,7 @@ func (c command) match(in []byte) bool {
 func newServer(sc *ServerConfig, b backends.Backend, l log.Logger) (*server, error) {
 	server := &server{
 		clientPool:      NewPool(sc.MaxClients),
-		closedListener:  make(chan (bool), 1),
+		closedListener:  make(chan bool, 1),
 		listenInterface: sc.ListenInterface,
 		state:           ServerStateNew,
 		envelopePool:    mail.NewPool(sc.MaxClients),
@@ -190,127 +190,127 @@ func (s *server) backend() backends.Backend {
 }
 
 // Set the timeout for the server and all clients
-func (server *server) setTimeout(seconds int) {
+func (s *server) setTimeout(seconds int) {
 	duration := time.Duration(int64(seconds))
-	server.clientPool.SetTimeout(duration)
-	server.timeout.Store(duration)
+	s.clientPool.SetTimeout(duration)
+	s.timeout.Store(duration)
 }
 
 // goroutine safe config store
-func (server *server) setConfig(sc *ServerConfig) {
-	server.configStore.Store(*sc)
+func (s *server) setConfig(sc *ServerConfig) {
+	s.configStore.Store(*sc)
 }
 
 // goroutine safe
-func (server *server) isEnabled() bool {
-	sc := server.configStore.Load().(ServerConfig)
+func (s *server) isEnabled() bool {
+	sc := s.configStore.Load().(ServerConfig)
 	return sc.IsEnabled
 }
 
 // Set the allowed hosts for the server
-func (server *server) setAllowedHosts(allowedHosts []string) {
-	server.hosts.Lock()
-	defer server.hosts.Unlock()
-	server.hosts.table = make(map[string]bool, len(allowedHosts))
-	server.hosts.wildcards = nil
+func (s *server) setAllowedHosts(allowedHosts []string) {
+	s.hosts.Lock()
+	defer s.hosts.Unlock()
+	s.hosts.table = make(map[string]bool, len(allowedHosts))
+	s.hosts.wildcards = nil
 	for _, h := range allowedHosts {
 		if strings.Index(h, "*") != -1 {
-			server.hosts.wildcards = append(server.hosts.wildcards, strings.ToLower(h))
+			s.hosts.wildcards = append(s.hosts.wildcards, strings.ToLower(h))
 		} else {
-			server.hosts.table[strings.ToLower(h)] = true
+			s.hosts.table[strings.ToLower(h)] = true
 		}
 	}
 }
 
 // Begin accepting SMTP clients. Will block unless there is an error or server.Shutdown() is called
-func (server *server) Start(startWG *sync.WaitGroup) error {
+func (s *server) Start(startWG *sync.WaitGroup) error {
 	var clientID uint64
 	clientID = 0
 
-	listener, err := net.Listen("tcp", server.listenInterface)
-	server.listener = listener
+	listener, err := net.Listen("tcp", s.listenInterface)
+	s.listener = listener
 	if err != nil {
 		startWG.Done() // don't wait for me
-		server.state = ServerStateStartError
-		return fmt.Errorf("[%s] Cannot listen on port: %s ", server.listenInterface, err.Error())
+		s.state = ServerStateStartError
+		return fmt.Errorf("[%s] Cannot listen on port: %s ", s.listenInterface, err.Error())
 	}
 
-	server.log().Infof("Listening on TCP %s", server.listenInterface)
-	server.state = ServerStateRunning
+	s.log().Infof("Listening on TCP %s", s.listenInterface)
+	s.state = ServerStateRunning
 	startWG.Done() // start successful, don't wait for me
 
 	for {
-		server.log().Debugf("[%s] Waiting for a new client. Next Client ID: %d", server.listenInterface, clientID+1)
+		s.log().Debugf("[%s] Waiting for a new client. Next Client ID: %d", s.listenInterface, clientID+1)
 		conn, err := listener.Accept()
 		clientID++
 		if err != nil {
 			if e, ok := err.(net.Error); ok && !e.Temporary() {
-				server.log().Infof("Server [%s] has stopped accepting new clients", server.listenInterface)
+				s.log().Infof("Server [%s] has stopped accepting new clients", s.listenInterface)
 				// the listener has been closed, wait for clients to exit
-				server.log().Infof("shutting down pool [%s]", server.listenInterface)
-				server.clientPool.ShutdownState()
-				server.clientPool.ShutdownWait()
-				server.state = ServerStateStopped
-				server.closedListener <- true
+				s.log().Infof("shutting down pool [%s]", s.listenInterface)
+				s.clientPool.ShutdownState()
+				s.clientPool.ShutdownWait()
+				s.state = ServerStateStopped
+				s.closedListener <- true
 				return nil
 			}
-			server.mainlog().WithError(err).Info("Temporary error accepting client")
+			s.mainlog().WithError(err).Info("Temporary error accepting client")
 			continue
 		}
-		go func(p Poolable, borrow_err error) {
+		go func(p Poolable, borrowErr error) {
 			c := p.(*client)
-			if borrow_err == nil {
-				server.handleClient(c)
-				server.envelopePool.Return(c.Envelope)
-				server.clientPool.Return(c)
+			if borrowErr == nil {
+				s.handleClient(c)
+				s.envelopePool.Return(c.Envelope)
+				s.clientPool.Return(c)
 			} else {
-				server.log().WithError(borrow_err).Info("couldn't borrow a new client")
+				s.log().WithError(borrowErr).Info("couldn't borrow a new client")
 				// we could not get a client, so close the connection.
-				conn.Close()
+				_ = conn.Close()
 
 			}
 			// intentionally placed Borrow in args so that it's called in the
 			// same main goroutine.
-		}(server.clientPool.Borrow(conn, clientID, server.log(), server.envelopePool))
+		}(s.clientPool.Borrow(conn, clientID, s.log(), s.envelopePool))
 
 	}
 }
 
-func (server *server) Shutdown() {
-	if server.listener != nil {
+func (s *server) Shutdown() {
+	if s.listener != nil {
 		// This will cause Start function to return, by causing an error on listener.Accept
-		server.listener.Close()
+		_ = s.listener.Close()
 		// wait for the listener to listener.Accept
-		<-server.closedListener
+		<-s.closedListener
 		// At this point Start will exit and close down the pool
 	} else {
-		server.clientPool.ShutdownState()
+		s.clientPool.ShutdownState()
 		// listener already closed, wait for clients to exit
-		server.clientPool.ShutdownWait()
-		server.state = ServerStateStopped
+		s.clientPool.ShutdownWait()
+		s.state = ServerStateStopped
 	}
 }
 
-func (server *server) GetActiveClientsCount() int {
-	return server.clientPool.GetActiveClientsCount()
+func (s *server) GetActiveClientsCount() int {
+	return s.clientPool.GetActiveClientsCount()
 }
 
 // Verifies that the host is a valid recipient.
 // host checking turned off if there is a single entry and it's a dot.
-func (server *server) allowsHost(host string) bool {
-	server.hosts.Lock()
-	defer server.hosts.Unlock()
+func (s *server) allowsHost(host string) bool {
+	s.hosts.Lock()
+	defer s.hosts.Unlock()
 	// if hosts contains a single dot, further processing is skipped
-	if len(server.hosts.table) == 1 {
-		if _, ok := server.hosts.table["."]; ok {
+	if len(s.hosts.table) == 1 {
+		if _, ok := s.hosts.table["."]; ok {
 			return true
 		}
 	}
-	if _, ok := server.hosts.table[strings.ToLower(host)]; ok {
+	if _, ok := s.hosts.table[strings.ToLower(host)]; ok {
 		return true
 	}
-	// check the willdcards
-	for _, w := range server.hosts.wildcards {
+	// check the wildcards
+	for _, w := range s.hosts.wildcards {
 		if matched, err := filepath.Match(w, strings.ToLower(host)); matched && err == nil {
 			return true
 		}
@@ -322,7 +322,7 @@ const commandSuffix = "\r\n"
 
 // Reads from the client until a \n terminator is encountered,
 // or until a timeout occurs.
-func (server *server) readCommand(client *client) ([]byte, error) {
+func (s *server) readCommand(client *client) ([]byte, error) {
 	//var input string
 	var err error
 	var bs []byte
@@ -337,25 +337,27 @@ func (server *server) readCommand(client *client) ([]byte, error) {
 }
 
 // flushResponse a response to the client. Flushes the client.bufout buffer to the connection
-func (server *server) flushResponse(client *client) error {
-	client.setTimeout(server.timeout.Load().(time.Duration))
+func (s *server) flushResponse(client *client) error {
+	if err := client.setTimeout(s.timeout.Load().(time.Duration)); err != nil {
+		return err
+	}
 	return client.bufout.Flush()
 }
 
-func (server *server) isShuttingDown() bool {
-	return server.clientPool.IsShuttingDown()
+func (s *server) isShuttingDown() bool {
+	return s.clientPool.IsShuttingDown()
 }
 
 // Handles an entire client SMTP exchange
-func (server *server) handleClient(client *client) {
+func (s *server) handleClient(client *client) {
 	defer client.closeConn()
-	sc := server.configStore.Load().(ServerConfig)
-	server.log().Infof("Handle client [%s], id: %d", client.RemoteIP, client.ID)
+	sc := s.configStore.Load().(ServerConfig)
+	s.log().Infof("Handle client [%s], id: %d", client.RemoteIP, client.ID)
 
 	// Initial greeting
 	greeting := fmt.Sprintf("220 %s SMTP Guerrilla(%s) #%d (%d) %s",
 		sc.Hostname, Version, client.ID,
-		server.clientPool.GetActiveClientsCount(), time.Now().Format(time.RFC3339))
+		s.clientPool.GetActiveClientsCount(), time.Now().Format(time.RFC3339))
 
 	helo := fmt.Sprintf("250 %s Hello", sc.Hostname)
 	// ehlo is a multi-line reply and need additional \r\n at the end
@@ -371,13 +373,13 @@ func (server *server) handleClient(client *client) {
 	help := "250 HELP"
 
 	if sc.TLS.AlwaysOn {
-		tlsConfig, ok := server.tlsConfigStore.Load().(*tls.Config)
+		tlsConfig, ok := s.tlsConfigStore.Load().(*tls.Config)
 		if !ok {
-			server.mainlog().Error("Failed to load *tls.Config")
+			s.mainlog().Error("Failed to load *tls.Config")
 		} else if err := client.upgradeToTLS(tlsConfig); err == nil {
 			advertiseTLS = ""
 		} else {
-			server.log().WithError(err).Warnf("[%s] Failed TLS handshake", client.RemoteIP)
+			s.log().WithError(err).Warnf("[%s] Failed TLS handshake", client.RemoteIP)
 			// server requires TLS, but can't handshake
 			client.kill()
 		}
@@ -386,7 +388,7 @@ func (server *server) handleClient(client *client) {
 		// STARTTLS turned off, don't advertise it
 		advertiseTLS = ""
 	}
-
+	r := response.Canned
 	for client.isAlive() {
 		switch client.state {
 		case ClientGreeting:
@@ -394,24 +396,24 @@ func (server *server) handleClient(client *client) {
 			client.state = ClientCmd
 		case ClientCmd:
 			client.bufin.setLimit(CommandLineMaxLength)
-			input, err := server.readCommand(client)
-			server.log().Debugf("Client sent: %s", input)
+			input, err := s.readCommand(client)
+			s.log().Debugf("Client sent: %s", input)
 			if err == io.EOF {
-				server.log().WithError(err).Warnf("Client closed the connection: %s", client.RemoteIP)
+				s.log().WithError(err).Warnf("Client closed the connection: %s", client.RemoteIP)
 				return
 			} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
-				server.log().WithError(err).Warnf("Timeout: %s", client.RemoteIP)
+				s.log().WithError(err).Warnf("Timeout: %s", client.RemoteIP)
 				return
 			} else if err == LineLimitExceeded {
-				client.sendResponse(response.Canned.FailLineTooLong)
+				client.sendResponse(r.FailLineTooLong)
 				client.kill()
 				break
 			} else if err != nil {
-				server.log().WithError(err).Warnf("Read error: %s", client.RemoteIP)
+				s.log().WithError(err).Warnf("Read error: %s", client.RemoteIP)
 				client.kill()
 				break
 			}
-			if server.isShuttingDown() {
+			if s.isShuttingDown() {
 				client.state = ClientShutdown
 				continue
 			}
@@ -458,80 +460,80 @@ func (server *server) handleClient(client *client) {
 						}
 					}
 				}
-				client.sendResponse(response.Canned.SuccessMailCmd)
+				client.sendResponse(r.SuccessMailCmd)
 			case cmdMAIL.match(cmd):
 				if client.isInTransaction() {
-					client.sendResponse(response.Canned.FailNestedMailCmd)
+					client.sendResponse(r.FailNestedMailCmd)
 					break
 				}
 				client.MailFrom, err = client.parsePath([]byte(input[10:]), client.parser.MailFrom)
 				if err != nil {
-					server.log().WithError(err).Error("MAIL parse error", "["+string(input[10:])+"]")
+					s.log().WithError(err).Error("MAIL parse error", "["+string(input[10:])+"]")
 					client.sendResponse(err)
 					break
 				} else if client.parser.NullPath {
 					// bounce has empty from address
 					client.MailFrom = mail.Address{}
 				}
-				client.sendResponse(response.Canned.SuccessMailCmd)
+				client.sendResponse(r.SuccessMailCmd)
 
 			case cmdRCPT.match(cmd):
 				if len(client.RcptTo) > rfc5321.LimitRecipients {
-					client.sendResponse(response.Canned.ErrorTooManyRecipients)
+					client.sendResponse(r.ErrorTooManyRecipients)
 					break
 				}
 				to, err := client.parsePath([]byte(input[8:]), client.parser.RcptTo)
 				if err != nil {
-					server.log().WithError(err).Error("RCPT parse error", "["+string(input[8:])+"]")
+					s.log().WithError(err).Error("RCPT parse error", "["+string(input[8:])+"]")
 					client.sendResponse(err.Error())
 					break
 				}
-				if !server.allowsHost(to.Host) {
-					client.sendResponse(response.Canned.ErrorRelayDenied, " ", to.Host)
+				if !s.allowsHost(to.Host) {
+					client.sendResponse(r.ErrorRelayDenied, " ", to.Host)
 				} else {
 					client.PushRcpt(to)
-					rcptError := server.backend().ValidateRcpt(client.Envelope)
+					rcptError := s.backend().ValidateRcpt(client.Envelope)
 					if rcptError != nil {
 						client.PopRcpt()
-						client.sendResponse(response.Canned.FailRcptCmd, " ", rcptError.Error())
+						client.sendResponse(r.FailRcptCmd, " ", rcptError.Error())
 					} else {
-						client.sendResponse(response.Canned.SuccessRcptCmd)
+						client.sendResponse(r.SuccessRcptCmd)
 					}
 				}
 
 			case cmdRSET.match(cmd):
 				client.resetTransaction()
-				client.sendResponse(response.Canned.SuccessResetCmd)
+				client.sendResponse(r.SuccessResetCmd)
 
 			case cmdVRFY.match(cmd):
-				client.sendResponse(response.Canned.SuccessVerifyCmd)
+				client.sendResponse(r.SuccessVerifyCmd)
 
 			case cmdNOOP.match(cmd):
-				client.sendResponse(response.Canned.SuccessNoopCmd)
+				client.sendResponse(r.SuccessNoopCmd)
 
 			case cmdQUIT.match(cmd):
-				client.sendResponse(response.Canned.SuccessQuitCmd)
+				client.sendResponse(r.SuccessQuitCmd)
 				client.kill()
 
 			case cmdDATA.match(cmd):
 				if len(client.RcptTo) == 0 {
-					client.sendResponse(response.Canned.FailNoRecipientsDataCmd)
+					client.sendResponse(r.FailNoRecipientsDataCmd)
 					break
 				}
-				client.sendResponse(response.Canned.SuccessDataCmd)
+				client.sendResponse(r.SuccessDataCmd)
 				client.state = ClientData
 
 			case sc.TLS.StartTLSOn && cmdSTARTTLS.match(cmd):
 
-				client.sendResponse(response.Canned.SuccessStartTLSCmd)
+				client.sendResponse(r.SuccessStartTLSCmd)
 				client.state = ClientStartTLS
 			default:
 				client.errors++
 				if client.errors >= MaxUnrecognizedCommands {
-					client.sendResponse(response.Canned.FailMaxUnrecognizedCmd)
+					client.sendResponse(r.FailMaxUnrecognizedCmd)
 					client.kill()
 				} else {
-					client.sendResponse(response.Canned.FailUnrecognizedCmd)
+					client.sendResponse(r.FailUnrecognizedCmd)
 				}
 			}
 
@@ -543,45 +545,45 @@ func (server *server) handleClient(client *client) {
 
 			n, err := client.Data.ReadFrom(client.smtpReader.DotReader())
 			if n > sc.MaxSize {
-				err = fmt.Errorf("Maximum DATA size exceeded (%d)", sc.MaxSize)
+				err = fmt.Errorf("maximum DATA size exceeded (%d)", sc.MaxSize)
 			}
 			if err != nil {
 				if err == LineLimitExceeded {
-					client.sendResponse(response.Canned.FailReadLimitExceededDataCmd, " ", LineLimitExceeded.Error())
+					client.sendResponse(r.FailReadLimitExceededDataCmd, " ", LineLimitExceeded.Error())
 					client.kill()
 				} else if err == MessageSizeExceeded {
-					client.sendResponse(response.Canned.FailMessageSizeExceeded, " ", MessageSizeExceeded.Error())
+					client.sendResponse(r.FailMessageSizeExceeded, " ", MessageSizeExceeded.Error())
 					client.kill()
 				} else {
-					client.sendResponse(response.Canned.FailReadErrorDataCmd, " ", err.Error())
+					client.sendResponse(r.FailReadErrorDataCmd, " ", err.Error())
 					client.kill()
 				}
-				server.log().WithError(err).Warn("Error reading data")
+				s.log().WithError(err).Warn("Error reading data")
 				client.resetTransaction()
 				break
 			}
 
-			res := server.backend().Process(client.Envelope)
+			res := s.backend().Process(client.Envelope)
 			if res.Code() < 300 {
 				client.messagesSent++
 			}
 			client.sendResponse(res)
 			client.state = ClientCmd
-			if server.isShuttingDown() {
+			if s.isShuttingDown() {
 				client.state = ClientShutdown
 			}
 			client.resetTransaction()
 
 		case ClientStartTLS:
 			if !client.TLS && sc.TLS.StartTLSOn {
-				tlsConfig, ok := server.tlsConfigStore.Load().(*tls.Config)
+				tlsConfig, ok := s.tlsConfigStore.Load().(*tls.Config)
 				if !ok {
-					server.mainlog().Error("Failed to load *tls.Config")
+					s.mainlog().Error("Failed to load *tls.Config")
 				} else if err := client.upgradeToTLS(tlsConfig); err == nil {
 					advertiseTLS = ""
 					client.resetTransaction()
 				} else {
-					server.log().WithError(err).Warnf("[%s] Failed TLS handshake", client.RemoteIP)
+					s.log().WithError(err).Warnf("[%s] Failed TLS handshake", client.RemoteIP)
 					// Don't disconnect, let the client decide if it wants to continue
 				}
 			}
@@ -589,22 +591,22 @@ func (server *server) handleClient(client *client) {
 			client.state = ClientCmd
 		case ClientShutdown:
 			// shutdown state
-			client.sendResponse(response.Canned.ErrorShutdown)
+			client.sendResponse(r.ErrorShutdown)
 			client.kill()
 		}
 
 		if client.bufErr != nil {
-			server.log().WithError(client.bufErr).Debug("client could not buffer a response")
+			s.log().WithError(client.bufErr).Debug("client could not buffer a response")
 			return
 		}
 		// flush the response buffer
 		if client.bufout.Buffered() > 0 {
-			if server.log().IsDebug() {
-				server.log().Debugf("Writing response to client: \n%s", client.response.String())
+			if s.log().IsDebug() {
+				s.log().Debugf("Writing response to client: \n%s", client.response.String())
 			}
-			err := server.flushResponse(client)
+			err := s.flushResponse(client)
 			if err != nil {
-				server.log().WithError(err).Debug("error writing response")
+				s.log().WithError(err).Debug("error writing response")
 				return
 			}
 		}

+ 166 - 46
server_test.go

@@ -1,6 +1,7 @@
 package guerrilla
 
 import (
+	"os"
 	"testing"
 
 	"bufio"
@@ -16,7 +17,6 @@ import (
 	"github.com/flashmob/go-guerrilla/mocks"
 	"io/ioutil"
 	"net"
-	"os"
 )
 
 // getMockServerConfig gets a mock ServerConfig struct used for creating a new server
@@ -139,8 +139,62 @@ jDGZARZqGyrPeXi+RNe1cMvZCxAFy7gqEtWFLWWrp0gYNPvxkHhhQBrUcF+8T/Nf
 ug8tR8eSL1vGleONtFRBUVG7NbtjhBf9FhvPZcSRR10od/vWHku9E01i4xg=
 -----END CERTIFICATE-----`
 
+func truncateIfExists(filename string) error {
+	if _, err := os.Stat(filename); !os.IsNotExist(err) {
+		return os.Truncate(filename, 0)
+	}
+	return nil
+}
+func deleteIfExists(filename string) error {
+	if _, err := os.Stat(filename); !os.IsNotExist(err) {
+		return os.Remove(filename)
+	}
+	return nil
+}
+
+func cleanTestArtifacts(t *testing.T) {
+	if err := deleteIfExists("rootca.test.pem"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("client.test.key"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("client.test.pem"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/mail.guerrillamail.com.key.pem"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/mail.guerrillamail.com.cert.pem"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/different-go-guerrilla.pid"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/go-guerrilla.pid"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/go-guerrilla2.pid"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/pidfile.pid"); err != nil {
+		t.Error(err)
+	}
+	if err := deleteIfExists("./tests/pidfile2.pid"); err != nil {
+		t.Error(err)
+	}
+
+	if err := truncateIfExists("./tests/testlog"); err != nil {
+		t.Error(err)
+	}
+	if err := truncateIfExists("./tests/testlog2"); err != nil {
+		t.Error(err)
+	}
+}
+
 func TestTLSConfig(t *testing.T) {
 
+	defer cleanTestArtifacts(t)
 	if err := ioutil.WriteFile("rootca.test.pem", []byte(rootCAPK), 0644); err != nil {
 		t.Fatal("couldn't create rootca.test.pem file.", err)
 		return
@@ -167,7 +221,9 @@ func TestTLSConfig(t *testing.T) {
 			Protocols:      []string{"tls1.0", "tls1.2"},
 		},
 	})
-	s.configureSSL()
+	if err := s.configureSSL(); err != nil {
+		t.Error(err)
+	}
 
 	c := s.tlsConfigStore.Load().(*tls.Config)
 
@@ -203,15 +259,12 @@ func TestTLSConfig(t *testing.T) {
 		t.Error("PreferServerCipherSuites should be false")
 	}
 
-	os.Remove("rootca.test.pem")
-	os.Remove("client.test.key")
-	os.Remove("client.test.pem")
-
 }
 
 func TestHandleClient(t *testing.T) {
 	var mainlog log.Logger
 	var logOpenError error
+	defer cleanTestArtifacts(t)
 	sc := getMockServerConfig()
 	mainlog, logOpenError = log.GetLogger(sc.LogFile, "debug")
 	if logOpenError != nil {
@@ -231,10 +284,14 @@ func TestHandleClient(t *testing.T) {
 	line, _ := r.ReadLine()
 	//	fmt.Println(line)
 	w := textproto.NewWriter(bufio.NewWriter(conn.Client))
-	w.PrintfLine("HELO test.test.com")
+	if err := w.PrintfLine("HELO test.test.com"); err != nil {
+		t.Error(err)
+	}
 	line, _ = r.ReadLine()
 	//fmt.Println(line)
-	w.PrintfLine("QUIT")
+	if err := w.PrintfLine("QUIT"); err != nil {
+		t.Error(err)
+	}
 	line, _ = r.ReadLine()
 	//fmt.Println("line is:", line)
 	expected := "221 2.0.0 Bye"
@@ -247,6 +304,7 @@ func TestHandleClient(t *testing.T) {
 func TestXClient(t *testing.T) {
 	var mainlog log.Logger
 	var logOpenError error
+	defer cleanTestArtifacts(t)
 	sc := getMockServerConfig()
 	sc.XClientOn = true
 	mainlog, logOpenError = log.GetLogger(sc.LogFile, "debug")
@@ -267,10 +325,14 @@ func TestXClient(t *testing.T) {
 	line, _ := r.ReadLine()
 	//	fmt.Println(line)
 	w := textproto.NewWriter(bufio.NewWriter(conn.Client))
-	w.PrintfLine("HELO test.test.com")
+	if err := w.PrintfLine("HELO test.test.com"); err != nil {
+		t.Error(err)
+	}
 	line, _ = r.ReadLine()
 	//fmt.Println(line)
-	w.PrintfLine("XCLIENT ADDR=212.96.64.216 NAME=[UNAVAILABLE]")
+	if err := w.PrintfLine("XCLIENT ADDR=212.96.64.216 NAME=[UNAVAILABLE]"); err != nil {
+		t.Error(err)
+	}
 	line, _ = r.ReadLine()
 
 	if client.RemoteIP != "212.96.64.216" {
@@ -282,7 +344,9 @@ func TestXClient(t *testing.T) {
 	}
 
 	// try malformed input
-	w.PrintfLine("XCLIENT c")
+	if err := w.PrintfLine("XCLIENT c"); err != nil {
+		t.Error(err)
+	}
 	line, _ = r.ReadLine()
 
 	expected = "250 2.1.0 OK"
@@ -290,7 +354,9 @@ func TestXClient(t *testing.T) {
 		t.Error("expected", expected, "but got:", line)
 	}
 
-	w.PrintfLine("QUIT")
+	if err := w.PrintfLine("QUIT"); err != nil {
+		t.Error(err)
+	}
 	line, _ = r.ReadLine()
 	wg.Wait() // wait for handleClient to exit
 }
@@ -299,7 +365,7 @@ func TestXClient(t *testing.T) {
 // The transaction should wait until finished, and then test to see if we can do
 // a second transaction
 func TestGatewayTimeout(t *testing.T) {
-
+	defer cleanTestArtifacts(t)
 	bcfg := backends.BackendConfig{
 		"save_workers_size":   1,
 		"save_process":        "HeadersParser|Debugger",
@@ -330,24 +396,51 @@ func TestGatewayTimeout(t *testing.T) {
 		}
 		in := bufio.NewReader(conn)
 		str, err := in.ReadString('\n')
-		fmt.Fprint(conn, "HELO host\r\n")
+		if err != nil {
+			t.Error(err)
+		}
+		if _, err := fmt.Fprint(conn, "HELO host\r\n"); err != nil {
+			t.Error(err)
+		}
 		str, err = in.ReadString('\n')
 		// perform 2 transactions
 		// both should panic.
 		for i := 0; i < 2; i++ {
-			fmt.Fprint(conn, "MAIL FROM:<[email protected]>r\r\n")
-			str, err = in.ReadString('\n')
-			fmt.Fprint(conn, "RCPT TO:<[email protected]>\r\n")
-			str, err = in.ReadString('\n')
-			fmt.Fprint(conn, "DATA\r\n")
-			str, err = in.ReadString('\n')
-			fmt.Fprint(conn, "Subject: Test subject\r\n")
-			fmt.Fprint(conn, "\r\n")
-			fmt.Fprint(conn, "A an email body\r\n")
-			fmt.Fprint(conn, ".\r\n")
+			if _, err := fmt.Fprint(conn, "MAIL FROM:<[email protected]>r\r\n"); err != nil {
+				t.Error(err)
+			}
+			if str, err = in.ReadString('\n'); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "RCPT TO:<[email protected]>\r\n"); err != nil {
+				t.Error(err)
+			}
+			if str, err = in.ReadString('\n'); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "DATA\r\n"); err != nil {
+				t.Error(err)
+			}
+			if str, err = in.ReadString('\n'); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "Subject: Test subject\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "A an email body\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, ".\r\n"); err != nil {
+				t.Error(err)
+			}
 			str, err = in.ReadString('\n')
 			expect := "transaction timeout"
-			if strings.Index(str, expect) == -1 {
+			if err != nil {
+				t.Error(err)
+			} else if strings.Index(str, expect) == -1 {
 				t.Error("Expected the reply to have'", expect, "'but got", str)
 			}
 		}
@@ -359,6 +452,7 @@ func TestGatewayTimeout(t *testing.T) {
 
 // The processor will panic and gateway should recover from it
 func TestGatewayPanic(t *testing.T) {
+	defer cleanTestArtifacts(t)
 	bcfg := backends.BackendConfig{
 		"save_workers_size":   1,
 		"save_process":        "HeadersParser|Debugger",
@@ -388,37 +482,66 @@ func TestGatewayPanic(t *testing.T) {
 			return
 		}
 		in := bufio.NewReader(conn)
-		str, err := in.ReadString('\n')
-		fmt.Fprint(conn, "HELO host\r\n")
-		str, err = in.ReadString('\n')
+		if _, err := in.ReadString('\n'); err != nil {
+			t.Error(err)
+		}
+		if _, err := fmt.Fprint(conn, "HELO host\r\n"); err != nil {
+			t.Error(err)
+		}
+		if _, err = in.ReadString('\n'); err != nil {
+			t.Error(err)
+		}
 		// perform 2 transactions
 		// both should timeout. The reason why 2 is because we want to make
 		// sure that the client waits until processing finishes, and the
 		// timeout event is captured.
 		for i := 0; i < 2; i++ {
-			fmt.Fprint(conn, "MAIL FROM:<[email protected]>r\r\n")
-			str, err = in.ReadString('\n')
-			fmt.Fprint(conn, "RCPT TO:<[email protected]>\r\n")
-			str, err = in.ReadString('\n')
-			fmt.Fprint(conn, "DATA\r\n")
-			str, err = in.ReadString('\n')
-			fmt.Fprint(conn, "Subject: Test subject\r\n")
-			fmt.Fprint(conn, "\r\n")
-			fmt.Fprint(conn, "A an email body\r\n")
-			fmt.Fprint(conn, ".\r\n")
-			str, err = in.ReadString('\n')
-			expect := "storage failed"
-			if strings.Index(str, expect) == -1 {
-				t.Error("Expected the reply to have'", expect, "'but got", str)
+			if _, err := fmt.Fprint(conn, "MAIL FROM:<[email protected]>r\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err = in.ReadString('\n'); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "RCPT TO:<[email protected]>\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err = in.ReadString('\n'); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "DATA\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err = in.ReadString('\n'); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "Subject: Test subject\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, "A an email body\r\n"); err != nil {
+				t.Error(err)
+			}
+			if _, err := fmt.Fprint(conn, ".\r\n"); err != nil {
+				t.Error(err)
+			}
+			if str, err := in.ReadString('\n'); err != nil {
+				t.Error(err)
+			} else {
+				expect := "storage failed"
+				if strings.Index(str, expect) == -1 {
+					t.Error("Expected the reply to have'", expect, "'but got", str)
+				}
 			}
 		}
-		_ = str
 		d.Shutdown()
 	}
 
 }
 
 func TestAllowsHosts(t *testing.T) {
+	defer cleanTestArtifacts(t)
 	s := server{}
 	allowedHosts := []string{
 		"spam4.me",
@@ -464,6 +587,3 @@ func TestAllowsHosts(t *testing.T) {
 	s.setAllowedHosts([]string{"grr.la", "example.com"})
 
 }
-
-// TODO
-// - test github issue #44 and #42

+ 3 - 1
tests/client.go

@@ -32,7 +32,9 @@ func Connect(serverConfig guerrilla.ServerConfig, deadline time.Duration) (net.C
 	bufin = bufio.NewReader(conn)
 
 	// should be ample time to complete the test
-	conn.SetDeadline(time.Now().Add(time.Duration(time.Second * deadline)))
+	if err = conn.SetDeadline(time.Now().Add(time.Duration(time.Second * deadline))); err != nil {
+		return conn, bufin, err
+	}
 	// read greeting, ignore it
 	_, err = bufin.ReadString('\n')
 	return conn, bufin, err

+ 94 - 61
tests/guerrilla_test.go

@@ -62,10 +62,13 @@ func init() {
 	if err := json.Unmarshal([]byte(configJson), config); err != nil {
 		initErr = errors.New("Could not Unmarshal config," + err.Error())
 	} else {
-		setupCerts(config)
 		logger, _ = log.GetLogger(config.LogFile, "debug")
+		initErr = setupCerts(config)
+		if err != nil {
+			return
+		}
 		backend, _ := getBackend(config.BackendConfig, logger)
-		app, _ = guerrilla.New(&config.AppConfig, backend, logger)
+		app, initErr = guerrilla.New(&config.AppConfig, backend, logger)
 	}
 
 }
@@ -91,8 +94,8 @@ var configJson = `
             "max_clients": 2,
             "log_file" : "",
 			"tls" : {
-				"private_key_file":"/vagrant/projects/htdocs/guerrilla/config/ssl/guerrillamail.com.key",
-            	"public_key_file":"/vagrant/projects/htdocs/guerrilla/config/ssl/guerrillamail.com.crt",
+				"private_key_file":"/this/will/be/ignored/guerrillamail.com.key.pem",
+            	"public_key_file":"/this/will/be/ignored//guerrillamail.com.crt",
 				"start_tls_on":true,
             	"tls_always_on":false
 			}
@@ -107,8 +110,8 @@ var configJson = `
             "max_clients":1,
             "log_file" : "",
 			"tls" : {
-				"private_key_file":"/vagrant/projects/htdocs/guerrilla/config/ssl/guerrillamail.com.key",
-            	"public_key_file":"/vagrant/projects/htdocs/guerrilla/config/ssl/guerrillamail.com.crt",
+				"private_key_file":"/this/will/be/ignored/guerrillamail.com.key.pem",
+            	"public_key_file":"/this/will/be/ignored/guerrillamail.com.crt",
 				"start_tls_on":false,
             	"tls_always_on":true
 			}
@@ -126,12 +129,56 @@ func getBackend(backendConfig map[string]interface{}, l log.Logger) (backends.Ba
 	return b, err
 }
 
-func setupCerts(c *TestConfig) {
+func setupCerts(c *TestConfig) error {
 	for i := range c.Servers {
-		testcert.GenerateCert(c.Servers[i].Hostname, "", 365*24*time.Hour, false, 2048, "P256", "./")
+		err := testcert.GenerateCert(c.Servers[i].Hostname, "", 365*24*time.Hour, false, 2048, "P256", "./")
+		if err != nil {
+			return err
+		}
 		c.Servers[i].TLS.PrivateKeyFile = c.Servers[i].Hostname + ".key.pem"
 		c.Servers[i].TLS.PublicKeyFile = c.Servers[i].Hostname + ".cert.pem"
 	}
+	return nil
+}
+func truncateIfExists(filename string) error {
+	if _, err := os.Stat(filename); !os.IsNotExist(err) {
+		return os.Truncate(filename, 0)
+	}
+	return nil
+}
+func deleteIfExists(filename string) error {
+	if _, err := os.Stat(filename); !os.IsNotExist(err) {
+		return os.Remove(filename)
+	}
+	return nil
+}
+func cleanTestArtifacts(t *testing.T) {
+
+	if err := truncateIfExists("./testlog"); err != nil {
+		t.Error("could not clean tests/testlog:", err)
+	}
+
+	letters := []byte{'A', 'B', 'C', 'D', 'E'}
+	for _, l := range letters {
+		if err := deleteIfExists("configJson" + string(l) + ".json"); err != nil {
+			t.Error("could not delete configJson"+string(l)+".json:", err)
+		}
+	}
+
+	if err := deleteIfExists("./go-guerrilla.pid"); err != nil {
+		t.Error("could not delete ./guerrilla", err)
+	}
+	if err := deleteIfExists("./go-guerrilla2.pid"); err != nil {
+		t.Error("could not delete ./go-guerrilla2.pid", err)
+	}
+
+	if err := deleteIfExists("./mail.guerrillamail.com.cert.pem"); err != nil {
+		t.Error("could not delete ./mail.guerrillamail.com.cert.pem", err)
+	}
+	if err := deleteIfExists("./mail.guerrillamail.com.key.pem"); err != nil {
+		t.Error("could not delete ./mail.guerrillamail.com.key.pem", err)
+	}
+
 }
 
 // Testing start and stop of server
@@ -140,6 +187,7 @@ func TestStart(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors != nil {
 		t.Error(startErrors)
 		t.FailNow()
@@ -184,18 +232,16 @@ func TestStart(t *testing.T) {
 		}
 
 	}
-	// don't forget to reset
 
-	os.Truncate("./testlog", 0)
 }
 
 // Simple smoke-test to see if the server can listen & issues a greeting on connect
 func TestGreeting(t *testing.T) {
-	//log.SetOutput(os.Stdout)
 	if initErr != nil {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		// 1. plaintext connection
 		conn, err := net.Dial("tcp", config.Servers[0].ListenInterface)
@@ -203,9 +249,10 @@ func TestGreeting(t *testing.T) {
 			// handle error
 			t.Error("Cannot dial server", config.Servers[0].ListenInterface)
 		}
-		conn.SetReadDeadline(time.Now().Add(time.Duration(time.Millisecond * 500)))
+		if err := conn.SetReadDeadline(time.Now().Add(time.Duration(time.Millisecond * 500))); err != nil {
+			t.Error(err)
+		}
 		greeting, err := bufio.NewReader(conn).ReadString('\n')
-		//fmt.Println(greeting)
 		if err != nil {
 			t.Error(err)
 			t.FailNow()
@@ -215,7 +262,7 @@ func TestGreeting(t *testing.T) {
 				t.Error("Server[1] did not have the expected greeting prefix", expected)
 			}
 		}
-		conn.Close()
+		_ = conn.Close()
 
 		// 2. tls connection
 		//	roots, err := x509.SystemCertPool()
@@ -229,9 +276,10 @@ func TestGreeting(t *testing.T) {
 			t.Error(err, "Cannot dial server (TLS)", config.Servers[1].ListenInterface)
 			t.FailNow()
 		}
-		conn.SetReadDeadline(time.Now().Add(time.Duration(time.Millisecond * 500)))
+		if err := conn.SetReadDeadline(time.Now().Add(time.Duration(time.Millisecond * 500))); err != nil {
+			t.Error(err)
+		}
 		greeting, err = bufio.NewReader(conn).ReadString('\n')
-		//fmt.Println(greeting)
 		if err != nil {
 			t.Error(err)
 			t.FailNow()
@@ -241,7 +289,7 @@ func TestGreeting(t *testing.T) {
 				t.Error("Server[2] (TLS) did not have the expected greeting prefix", expected)
 			}
 		}
-		conn.Close()
+		_ = conn.Close()
 
 	} else {
 		fmt.Println("Nope", startErrors)
@@ -253,13 +301,10 @@ func TestGreeting(t *testing.T) {
 	app.Shutdown()
 	if read, err := ioutil.ReadFile("./testlog"); err == nil {
 		logOutput := string(read)
-		//fmt.Println(logOutput)
 		if i := strings.Index(logOutput, "Handle client [127.0.0.1"); i < 0 {
 			t.Error("Server did not handle any clients")
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 
 }
 
@@ -272,6 +317,7 @@ func TestShutDown(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -300,7 +346,7 @@ func TestShutDown(t *testing.T) {
 			time.Sleep(time.Millisecond * 250) // let server to close
 		}
 
-		conn.Close()
+		_ = conn.Close()
 
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -317,8 +363,6 @@ func TestShutDown(t *testing.T) {
 			t.Error("Server did not handle any clients")
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 
 }
 
@@ -328,6 +372,7 @@ func TestRFC2821LimitRecipients(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -357,7 +402,7 @@ func TestRFC2821LimitRecipients(t *testing.T) {
 			}
 		}
 
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 
 	} else {
@@ -368,8 +413,6 @@ func TestRFC2821LimitRecipients(t *testing.T) {
 		}
 	}
 
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }
 
 // RCPT TO & MAIL FROM with 64 chars in local part, it should fail at 65
@@ -378,7 +421,7 @@ func TestRFC2832LimitLocalPart(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
-
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -411,7 +454,7 @@ func TestRFC2832LimitLocalPart(t *testing.T) {
 			}
 		}
 
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 
 	} else {
@@ -422,8 +465,6 @@ func TestRFC2832LimitLocalPart(t *testing.T) {
 		}
 	}
 
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }
 
 //RFC2821LimitPath fail if path > 256 but different error if below
@@ -464,7 +505,7 @@ func TestRFC2821LimitPath(t *testing.T) {
 				t.Error("Server did not respond with", expected, ", it said:"+response)
 			}
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -473,8 +514,6 @@ func TestRFC2821LimitPath(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }
 
 // RFC2821LimitDomain 501 Domain cannot exceed 255 characters
@@ -483,6 +522,7 @@ func TestRFC2821LimitDomain(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -514,7 +554,7 @@ func TestRFC2821LimitDomain(t *testing.T) {
 				t.Error("Server did not respond with", expected, ", it said:"+response)
 			}
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -523,8 +563,7 @@ func TestRFC2821LimitDomain(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
+
 }
 
 // Test several different inputs to MAIL FROM command
@@ -533,6 +572,7 @@ func TestMailFromCmd(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -767,7 +807,7 @@ func TestMailFromCmd(t *testing.T) {
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -785,6 +825,7 @@ func TestHeloEhlo(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		hostname := config.Servers[0].Hostname
@@ -841,7 +882,7 @@ func TestHeloEhlo(t *testing.T) {
 				t.Error("Server did not respond with", expected, ", it said:"+response)
 			}
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -859,6 +900,7 @@ func TestNestedMailCmd(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -915,7 +957,7 @@ func TestNestedMailCmd(t *testing.T) {
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -924,8 +966,6 @@ func TestNestedMailCmd(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }
 
 // It should error on a very long command line, exceeding CommandLineMaxLength 1024
@@ -934,7 +974,7 @@ func TestCommandLineMaxLength(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
-
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -958,7 +998,7 @@ func TestCommandLineMaxLength(t *testing.T) {
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -967,8 +1007,7 @@ func TestCommandLineMaxLength(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
+
 }
 
 // It should error on a very long message, exceeding servers config value
@@ -977,7 +1016,7 @@ func TestDataMaxLength(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
-
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -1012,13 +1051,13 @@ func TestDataMaxLength(t *testing.T) {
 					strings.Repeat("n", int(config.Servers[0].MaxSize-20))))
 
 			//expected := "500 Line too long"
-			expected := "451 4.3.0 Error: Maximum DATA size exceeded"
+			expected := "451 4.3.0 Error: maximum DATA size exceeded"
 			if strings.Index(response, expected) != 0 {
 				t.Error("Server did not respond with", expected, ", it said:"+response)
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -1027,8 +1066,7 @@ func TestDataMaxLength(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
+
 }
 
 func TestDataCommand(t *testing.T) {
@@ -1036,7 +1074,7 @@ func TestDataCommand(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
-
+	defer cleanTestArtifacts(t)
 	testHeader :=
 		"Subject: =?Shift_JIS?B?W4NYg06DRYNGg0GBRYNHg2qDYoNOg1ggg0GDSoNFg5ODZ12DQYNKg0WDk4Nn?=\r\n" +
 			"\t=?Shift_JIS?B?k2+YXoqul7mCzIKokm2C54K5?=\r\n"
@@ -1114,7 +1152,7 @@ func TestDataCommand(t *testing.T) {
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -1123,8 +1161,6 @@ func TestDataCommand(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }
 
 // Fuzzer crashed the server by submitting "DATA\r\n" as the first command
@@ -1133,7 +1169,7 @@ func TestFuzz86f25b86b09897aed8f6c2aa5b5ee1557358a6de(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
-
+	defer cleanTestArtifacts(t)
 	if startErrors := app.Start(); startErrors == nil {
 		conn, bufin, err := Connect(config.Servers[0], 20)
 		if err != nil {
@@ -1152,7 +1188,7 @@ func TestFuzz86f25b86b09897aed8f6c2aa5b5ee1557358a6de(t *testing.T) {
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -1161,8 +1197,6 @@ func TestFuzz86f25b86b09897aed8f6c2aa5b5ee1557358a6de(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }
 
 // Appears to hang the fuzz test, but not server.
@@ -1171,6 +1205,7 @@ func TestFuzz21c56f89989d19c3bbbd81b288b2dae9e6dd2150(t *testing.T) {
 		t.Error(initErr)
 		t.FailNow()
 	}
+	defer cleanTestArtifacts(t)
 	str := "X_\r\nMAIL FROM:<u\xfd\xfdrU" +
 		"\x10c22695140\xfd727235530" +
 		" Walter Sobchak\x1a\tDon" +
@@ -1246,7 +1281,7 @@ func TestFuzz21c56f89989d19c3bbbd81b288b2dae9e6dd2150(t *testing.T) {
 			}
 
 		}
-		conn.Close()
+		_ = conn.Close()
 		app.Shutdown()
 	} else {
 		if startErrors := app.Start(); startErrors != nil {
@@ -1255,6 +1290,4 @@ func TestFuzz21c56f89989d19c3bbbd81b288b2dae9e6dd2150(t *testing.T) {
 			t.FailNow()
 		}
 	}
-	// don't forget to reset
-	os.Truncate("./testlog", 0)
 }

+ 35 - 17
tests/testcert/generate_cert.go

@@ -13,6 +13,7 @@ import (
 	"crypto/x509/pkix"
 	"encoding/pem"
 
+	"errors"
 	"fmt"
 	"log"
 	"math/big"
@@ -44,32 +45,30 @@ func publicKey(priv interface{}) interface{} {
 	}
 }
 
-func pemBlockForKey(priv interface{}) *pem.Block {
+func pemBlockForKey(priv interface{}) (*pem.Block, error) {
 	switch k := priv.(type) {
 	case *rsa.PrivateKey:
-		return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}
+		return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}, nil
 	case *ecdsa.PrivateKey:
 		b, err := x509.MarshalECPrivateKey(k)
 		if err != nil {
-			fmt.Fprintf(os.Stderr, "Unable to marshal ECDSA private key: %v", err)
-			os.Exit(2)
+			err = errors.New(fmt.Sprintf("Unable to marshal ECDSA private key: %v", err))
 		}
-		return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}
+		return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}, err
 	default:
-		return nil
+		return nil, errors.New("not a private key")
 	}
 }
 
 // validFrom - Creation date formatted as Jan 1 15:04:05 2011 or ""
 
-func GenerateCert(host string, validFrom string, validFor time.Duration, isCA bool, rsaBits int, ecdsaCurve string, dirPrefix string) {
+func GenerateCert(host string, validFrom string, validFor time.Duration, isCA bool, rsaBits int, ecdsaCurve string, dirPrefix string) (err error) {
 
 	if len(host) == 0 {
 		log.Fatalf("Missing required --host parameter")
 	}
 
 	var priv interface{}
-	var err error
 	switch ecdsaCurve {
 	case "":
 		priv, err = rsa.GenerateKey(rand.Reader, rsaBits)
@@ -82,11 +81,11 @@ func GenerateCert(host string, validFrom string, validFor time.Duration, isCA bo
 	case "P521":
 		priv, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
 	default:
-		fmt.Fprintf(os.Stderr, "Unrecognized elliptic curve: %q", ecdsaCurve)
-		os.Exit(1)
+		err = errors.New(fmt.Sprintf("Unrecognized elliptic curve: %q", ecdsaCurve))
 	}
 	if err != nil {
 		log.Fatalf("failed to generate private key: %s", err)
+		return
 	}
 
 	var notBefore time.Time
@@ -95,8 +94,8 @@ func GenerateCert(host string, validFrom string, validFor time.Duration, isCA bo
 	} else {
 		notBefore, err = time.Parse("Jan 2 15:04:05 2006", validFrom)
 		if err != nil {
-			fmt.Fprintf(os.Stderr, "Failed to parse creation date: %s\n", err)
-			os.Exit(1)
+			err = errors.New(fmt.Sprintf("Failed to parse creation date: %s\n", err))
+			return
 		}
 	}
 
@@ -144,16 +143,35 @@ func GenerateCert(host string, validFrom string, validFor time.Duration, isCA bo
 	if err != nil {
 		log.Fatalf("failed to open cert.pem for writing: %s", err)
 	}
-	pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
-	certOut.Close()
+	err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+	if err != nil {
+		return
+	}
+	if err = certOut.Sync(); err != nil {
+		return
+	}
+	if err = certOut.Close(); err != nil {
+		return
+	}
 
 	keyOut, err := os.OpenFile(dirPrefix+host+".key.pem", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
 	if err != nil {
 		log.Print("failed to open key.pem for writing:", err)
 		return
 	}
-	pem.Encode(keyOut, pemBlockForKey(priv))
-	keyOut.Sync()
-	keyOut.Close()
+	var block *pem.Block
+	if block, err = pemBlockForKey(priv); err != nil {
+		return err
+	}
+	if err = pem.Encode(keyOut, block); err != nil {
+		return err
+	}
+	if err = keyOut.Sync(); err != nil {
+		return err
+	}
+	if err = keyOut.Close(); err != nil {
+		return err
+	}
+	return
 
 }

この差分においてかなりの量のファイルが変更されているため、一部のファイルを表示していません