Browse Source

Client pool (#31)

* added a basic pool + shutting down function for the pool & new API function for shutting down a server.

* - added shutdown method to backends interface
- redis/mysql backend will use wait group to shutdown
- moved backend struct out of Config, now it is passed into guerrilla.New constructor (this may change later)
- deleted the old server.sem channel is it has been replaced by the pool

* - close mysql & redis when exiting a saveMail worker

* - add Shutdown() to readme

* expand on the travis ci config

* fix lazy connection of the redis client

* add the shutdown state & change to shutdown state when shutting down in command state.

* travis gofmt - ignore /vendor/*

* * add clientshutdown state to enum
* handle shutdown in data state

* change clientID to uint64

* - API 'sStart() function now returns a list of errors
- API's Start function will block until all servers have started (so that it can return errors)
- API's Shutdown will also shutdown the backend
- []*ServerConfig changed to []ServerConfig
- Automated testing: Started working on automated testing framework. Added two simple tests that test startup, shutdown and greeting
- waiting for clients to exit is now managed by the pool
- pool tweaks to it's compatible with API Start() and Shutdown()
-  improved server's method for closing the socket listener

* - add tests command to travis

* add build status to readme

* change pool to work on a "Poolable" interface, making it reusable for other things in the future

* sighandler - added SIGINIT (control + c), catch SIGKILL too

* tweak the shutdown seq

* fix automated  tests & add more tests
Guerrilla Mail 8 years ago
parent
commit
8f8f7ff648
15 changed files with 937 additions and 86 deletions
  1. 7 0
      .travis.gofmt.sh
  2. 8 1
      .travis.yml
  3. 13 3
      README.md
  4. 4 0
      backends/dummy.go
  5. 26 8
      backends/guerrilla_db_redis.go
  6. 54 11
      client.go
  7. 16 7
      cmd/guerrillad/serve.go
  8. 2 3
      config.go
  9. 48 11
      guerrilla.go
  10. 7 0
      models.go
  11. 163 0
      pool.go
  12. 90 42
      server.go
  13. 1 0
      tests/.gitignore
  14. 159 0
      tests/generate_cert.go
  15. 339 0
      tests/guerrilla_test.go

+ 7 - 0
.travis.gofmt.sh

@@ -0,0 +1,7 @@
+#!/bin/bash
+
+if [[ -n $(find . -path '*/vendor/*' -prune -o -name '*.go' -type f -exec gofmt -l {} \;) ]]; then
+    echo "Go code is not formatted:"
+    gofmt -d .
+    exit 1
+fi

+ 8 - 1
.travis.yml

@@ -1,11 +1,18 @@
 language: go
 language: go
-
+sudo: false
 go:
 go:
   - 1.5
   - 1.5
+  - 1.6
+  - 1.7
+  - master
 
 
 install:
 install:
   - export GO15VENDOREXPERIMENT=1
   - export GO15VENDOREXPERIMENT=1
   - go get github.com/Masterminds/glide
   - go get github.com/Masterminds/glide
   - go install github.com/Masterminds/glide
   - go install github.com/Masterminds/glide
   - glide up
   - glide up
+
+script:
+  - ./.travis.gofmt.sh
   - make guerrillad
   - make guerrillad
+  - go test ./tests

+ 13 - 3
README.md

@@ -1,4 +1,6 @@
 
 
+[![Build Status](https://travis-ci.org/flashmob/go-guerrilla.svg?branch=master)](https://travis-ci.org/flashmob/go-guerrilla)
+
 Go-Guerrilla SMTPd
 Go-Guerrilla SMTPd
 ====================
 ====================
 
 
@@ -172,19 +174,27 @@ func (cb *CustomBackend) Process(c *guerrilla.Envelope) guerrilla.BackendResult
 See Configuration section below for setting configuration options.
 See Configuration section below for setting configuration options.
 ```go
 ```go
 config := &guerrilla.AppConfig{
 config := &guerrilla.AppConfig{
-  Backend: &CustomBackend{...},
   Servers: []*guerrilla.ServerConfig{...},
   Servers: []*guerrilla.ServerConfig{...},
   AllowedHosts: []string{...}
   AllowedHosts: []string{...}
 }
 }
-app := guerrilla.New(config)
+backend := &CustomBackend{...}
+app := guerrilla.New(config, backend)
 ```
 ```
 
 
 ## Start the app.
 ## Start the app.
 `Start` is non-blocking, so make sure the main goroutine is kept busy
 `Start` is non-blocking, so make sure the main goroutine is kept busy
 ```go
 ```go
-app.Start()
+app.Start() (startErrors []error)
 ```
 ```
 
 
+## Shutting down.
+`Shutdown` will do a graceful shutdown, close all the connections, close
+ the ports, and gracefully shutdown the backend. It will block until all
+  operations are complete.
+ 
+```go
+app.Shutdown()
+```
 
 
 Configuration
 Configuration
 ============================================
 ============================================

+ 4 - 0
backends/dummy.go

@@ -27,6 +27,10 @@ func (b *DummyBackend) Initialize(config map[string]interface{}) {
 	b.loadConfig(config)
 	b.loadConfig(config)
 }
 }
 
 
+func (b *DummyBackend) Shutdown() error {
+	return nil
+}
+
 func (b *DummyBackend) Process(mail *guerrilla.Envelope) guerrilla.BackendResult {
 func (b *DummyBackend) Process(mail *guerrilla.Envelope) guerrilla.BackendResult {
 	if b.config.LogReceivedMails {
 	if b.config.LogReceivedMails {
 		log.Infof("Mail from: %s / to: %v", mail.MailFrom.String(), mail.RcptTo)
 		log.Infof("Mail from: %s / to: %v", mail.MailFrom.String(), mail.RcptTo)

+ 26 - 8
backends/guerrilla_db_redis.go

@@ -75,8 +75,8 @@ func (g *GuerrillaDBAndRedisBackend) Initialize(backendConfig map[string]interfa
 	return nil
 	return nil
 }
 }
 
 
-func (g *GuerrillaDBAndRedisBackend) Finalize() error {
-	close(g.saveMailChan)
+func (g *GuerrillaDBAndRedisBackend) Shutdown() error {
+	close(g.saveMailChan) // workers will stop
 	g.wg.Wait()
 	g.wg.Wait()
 	return nil
 	return nil
 }
 }
@@ -120,9 +120,9 @@ type saveStatus struct {
 }
 }
 
 
 type redisClient struct {
 type redisClient struct {
-	count int
-	conn  redis.Conn
-	time  int
+	isConnected bool
+	conn        redis.Conn
+	time        int
 }
 }
 
 
 func (g *GuerrillaDBAndRedisBackend) saveMail() {
 func (g *GuerrillaDBAndRedisBackend) saveMail() {
@@ -131,6 +131,7 @@ func (g *GuerrillaDBAndRedisBackend) saveMail() {
 
 
 	var redisErr error
 	var redisErr error
 	var length int
 	var length int
+
 	redisClient := &redisClient{}
 	redisClient := &redisClient{}
 	db := autorc.New(
 	db := autorc.New(
 		"tcp",
 		"tcp",
@@ -152,13 +153,27 @@ func (g *GuerrillaDBAndRedisBackend) saveMail() {
 	if sqlErr != nil {
 	if sqlErr != nil {
 		log.WithError(sqlErr).Fatalf("failed while db.Prepare(UPDATE...)")
 		log.WithError(sqlErr).Fatalf("failed while db.Prepare(UPDATE...)")
 	}
 	}
+	defer func() {
+		if r := recover(); r != nil {
+			// recover form closed channel
+			fmt.Println("Recovered in f", r)
+		}
+		if db.Raw != nil {
+			db.Raw.Close()
+		}
+		if redisClient.conn != nil {
+			log.Infof("closed redis")
+			redisClient.conn.Close()
+		}
+
+		g.wg.Done()
+	}()
 
 
 	//  receives values from the channel repeatedly until it is closed.
 	//  receives values from the channel repeatedly until it is closed.
 	for {
 	for {
 		payload := <-g.saveMailChan
 		payload := <-g.saveMailChan
 		if payload == nil {
 		if payload == nil {
-			log.Debug("No more payload")
-			g.wg.Done()
+			log.Debug("No more saveMailChan payload")
 			return
 			return
 		}
 		}
 		to = payload.recipient.User + "@" + g.config.PrimaryHost
 		to = payload.recipient.User + "@" + g.config.PrimaryHost
@@ -220,13 +235,16 @@ func (g *GuerrillaDBAndRedisBackend) saveMail() {
 }
 }
 
 
 func (c *redisClient) redisConnection(redisInterface string) (err error) {
 func (c *redisClient) redisConnection(redisInterface string) (err error) {
-	if c.count == 0 {
+
+	if c.isConnected == false {
 		c.conn, err = redis.Dial("tcp", redisInterface)
 		c.conn, err = redis.Dial("tcp", redisInterface)
 		if err != nil {
 		if err != nil {
 			// handle error
 			// handle error
 			return err
 			return err
 		}
 		}
+		c.isConnected = true
 	}
 	}
+
 	return nil
 	return nil
 }
 }
 
 

+ 54 - 11
client.go

@@ -4,6 +4,7 @@ import (
 	"bufio"
 	"bufio"
 	"net"
 	"net"
 	"strings"
 	"strings"
+	"sync"
 	"time"
 	"time"
 )
 )
 
 
@@ -19,11 +20,13 @@ const (
 	ClientData
 	ClientData
 	// We have agreed with the client to secure the connection over TLS
 	// We have agreed with the client to secure the connection over TLS
 	ClientStartTLS
 	ClientStartTLS
+	// Server will shutdown, client to shutdown on next command turn
+	ClientShutdown
 )
 )
 
 
 type client struct {
 type client struct {
 	*Envelope
 	*Envelope
-	ID          int64
+	ID          uint64
 	ConnectedAt time.Time
 	ConnectedAt time.Time
 	KilledAt    time.Time
 	KilledAt    time.Time
 	// Number of errors encountered during session with this client
 	// Number of errors encountered during session with this client
@@ -31,10 +34,11 @@ type client struct {
 	state        ClientState
 	state        ClientState
 	messagesSent int
 	messagesSent int
 	// Response to be written to the client
 	// Response to be written to the client
-	response string
-	conn     net.Conn
-	bufin    *smtpBufferedReader
-	bufout   *bufio.Writer
+	response  string
+	conn      net.Conn
+	bufin     *smtpBufferedReader
+	bufout    *bufio.Writer
+	timeoutMu sync.Mutex
 }
 }
 
 
 // Email represents a single SMTP message.
 // Email represents a single SMTP message.
@@ -42,14 +46,27 @@ type Envelope struct {
 	// Remote IP address
 	// Remote IP address
 	RemoteAddress string
 	RemoteAddress string
 	// Message sent in EHLO command
 	// Message sent in EHLO command
-	Helo          string
+	Helo string
 	// Sender
 	// Sender
-	MailFrom      *EmailAddress
+	MailFrom *EmailAddress
 	// Recipients
 	// Recipients
-	RcptTo        []*EmailAddress
-	Data          string
-	Subject       string
-	TLS           bool
+	RcptTo  []*EmailAddress
+	Data    string
+	Subject string
+	TLS     bool
+}
+
+func NewClient(conn net.Conn, clientID uint64) *client {
+	return &client{
+		conn: conn,
+		Envelope: &Envelope{
+			RemoteAddress: conn.RemoteAddr().String(),
+		},
+		ConnectedAt: time.Now(),
+		bufin:       newSMTPBufferedReader(conn),
+		bufout:      bufio.NewWriter(conn),
+		ID:          clientID,
+	}
 }
 }
 
 
 func (c *client) responseAdd(r string) {
 func (c *client) responseAdd(r string) {
@@ -85,3 +102,29 @@ func (c *client) scanSubject(reply string) {
 		}
 		}
 	}
 	}
 }
 }
+
+func (c *client) setTimeout(t time.Duration) {
+	defer c.timeoutMu.Unlock()
+	c.timeoutMu.Lock()
+	c.conn.SetDeadline(time.Now().Add(t * time.Second))
+}
+
+func (c *client) init(conn net.Conn, clientID uint64) {
+	c.conn = conn
+	// reset our reader & writer
+	c.bufout.Reset(conn)
+	c.bufin.Reset(conn)
+	// reset session data
+	c.state = 0
+	c.KilledAt = time.Time{}
+	c.ConnectedAt = time.Now()
+	c.ID = clientID
+	c.TLS = false
+	c.errors = 0
+	c.response = ""
+	c.Helo = ""
+}
+
+func (c *client) getID() uint64 {
+	return c.ID
+}

+ 16 - 7
cmd/guerrillad/serve.go

@@ -44,11 +44,12 @@ func init() {
 	rootCmd.AddCommand(serveCmd)
 	rootCmd.AddCommand(serveCmd)
 }
 }
 
 
-func sigHandler() {
+func sigHandler(app guerrilla.Guerrilla) {
 	// handle SIGHUP for reloading the configuration while running
 	// handle SIGHUP for reloading the configuration while running
-	signal.Notify(signalChannel, syscall.SIGHUP)
+	signal.Notify(signalChannel, syscall.SIGHUP, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGINT, syscall.SIGKILL)
 
 
 	for sig := range signalChannel {
 	for sig := range signalChannel {
+
 		if sig == syscall.SIGHUP {
 		if sig == syscall.SIGHUP {
 			err := readConfig(configPath, verbose, &cmdConfig)
 			err := readConfig(configPath, verbose, &cmdConfig)
 			if err != nil {
 			if err != nil {
@@ -57,6 +58,11 @@ func sigHandler() {
 				log.Infof("Configuration is reloaded at %s", guerrilla.ConfigLoadTime)
 				log.Infof("Configuration is reloaded at %s", guerrilla.ConfigLoadTime)
 			}
 			}
 			// TODO: reinitialize
 			// TODO: reinitialize
+		} else if sig == syscall.SIGTERM || sig == syscall.SIGQUIT || sig == syscall.SIGINT {
+			log.Infof("Shutdown signal caught")
+			app.Shutdown()
+			log.Infof("Shutdown completd, exiting.")
+			os.Exit(0)
 		} else {
 		} else {
 			os.Exit(0)
 			os.Exit(0)
 		}
 		}
@@ -98,26 +104,29 @@ func serve(cmd *cobra.Command, args []string) {
 			log.WithError(err).Fatalf("Error while creating pidFile (%s)", pidFile)
 			log.WithError(err).Fatalf("Error while creating pidFile (%s)", pidFile)
 		}
 		}
 	}
 	}
-
+	var backend guerrilla.Backend
 	switch cmdConfig.BackendName {
 	switch cmdConfig.BackendName {
 	case "dummy":
 	case "dummy":
 		b := &backends.DummyBackend{}
 		b := &backends.DummyBackend{}
 		b.Initialize(cmdConfig.BackendConfig)
 		b.Initialize(cmdConfig.BackendConfig)
-		cmdConfig.Backend = b
+		backend = guerrilla.Backend(b)
 	case "guerrilla-db-redis":
 	case "guerrilla-db-redis":
 		b := &backends.GuerrillaDBAndRedisBackend{}
 		b := &backends.GuerrillaDBAndRedisBackend{}
 		err = b.Initialize(cmdConfig.BackendConfig)
 		err = b.Initialize(cmdConfig.BackendConfig)
 		if err != nil {
 		if err != nil {
 			log.WithError(err).Errorf("Initalization of %s backend failed", cmdConfig.BackendName)
 			log.WithError(err).Errorf("Initalization of %s backend failed", cmdConfig.BackendName)
 		}
 		}
-		cmdConfig.Backend = b
+
+		backend = guerrilla.Backend(b)
 	default:
 	default:
 		log.Fatalf("Unknown backend: %s", cmdConfig.BackendName)
 		log.Fatalf("Unknown backend: %s", cmdConfig.BackendName)
 	}
 	}
+	b := &backends.GuerrillaDBAndRedisBackend{}
+	err = b.Initialize(cmdConfig.BackendConfig)
 
 
-	app := guerrilla.New(&cmdConfig.AppConfig)
+	app := guerrilla.New(&cmdConfig.AppConfig, &backend)
 	go app.Start()
 	go app.Start()
-	sigHandler()
+	sigHandler(app)
 }
 }
 
 
 // Superset of `guerrilla.AppConfig` containing options specific
 // Superset of `guerrilla.AppConfig` containing options specific

+ 2 - 3
config.go

@@ -2,9 +2,8 @@ package guerrilla
 
 
 // AppConfig is the holder of the configuration of the app
 // AppConfig is the holder of the configuration of the app
 type AppConfig struct {
 type AppConfig struct {
-	Backend      Backend
-	Servers      []*ServerConfig `json:"servers"`
-	AllowedHosts []string        `json:"allowed_hosts"`
+	Servers      []ServerConfig `json:"servers"`
+	AllowedHosts []string       `json:"allowed_hosts"`
 }
 }
 
 
 // ServerConfig specifies config options for a single server
 // ServerConfig specifies config options for a single server

+ 48 - 11
guerrilla.go

@@ -1,40 +1,77 @@
 package guerrilla
 package guerrilla
 
 
-import log "github.com/Sirupsen/logrus"
+import (
+	"errors"
+	log "github.com/Sirupsen/logrus"
+	"sync"
+)
 
 
 type Guerrilla interface {
 type Guerrilla interface {
-	Start()
+	Start() (startErrors []error)
+	Shutdown()
 }
 }
 
 
 type guerrilla struct {
 type guerrilla struct {
 	Config  *AppConfig
 	Config  *AppConfig
-	servers []*server
+	servers []server
+	backend *Backend
 }
 }
 
 
 // Returns a new instance of Guerrilla with the given config, not yet running.
 // Returns a new instance of Guerrilla with the given config, not yet running.
-func New(ac *AppConfig) Guerrilla {
-	g := &guerrilla{ac, []*server{}}
-
+func New(ac *AppConfig, b *Backend) Guerrilla {
+	g := &guerrilla{ac, []server{}, b}
 	// Instantiate servers
 	// Instantiate servers
 	for _, sc := range ac.Servers {
 	for _, sc := range ac.Servers {
 		if !sc.IsEnabled {
 		if !sc.IsEnabled {
 			continue
 			continue
 		}
 		}
-
 		// Add relevant app-wide config options to each server
 		// Add relevant app-wide config options to each server
 		sc.AllowedHosts = ac.AllowedHosts
 		sc.AllowedHosts = ac.AllowedHosts
-		server, err := newServer(sc, ac.Backend)
+		server, err := newServer(sc, b)
 		if err != nil {
 		if err != nil {
 			log.WithError(err).Error("Failed to create server")
 			log.WithError(err).Error("Failed to create server")
+		} else {
+			g.servers = append(g.servers, *server)
 		}
 		}
-		g.servers = append(g.servers, server)
 	}
 	}
 	return g
 	return g
 }
 }
 
 
 // Entry point for the application. Starts all servers.
 // Entry point for the application. Starts all servers.
-func (g *guerrilla) Start() {
+func (g *guerrilla) Start() (startErrors []error) {
+	if len(g.servers) == 0 {
+		return append(startErrors, errors.New("No servers to start, please check the config"))
+	}
+	// channel for reading errors
+	errs := make(chan error, len(g.servers))
+	var startWG sync.WaitGroup
+	startWG.Add(len(g.servers))
+	// start servers, send any errors back to errs channel
+	for i := 0; i < len(g.servers); i++ {
+		go func(s *server) {
+			if err := s.Start(&startWG); err != nil {
+				errs <- err
+				startWG.Done()
+			}
+		}(&g.servers[i])
+	}
+	// wait for all servers to start
+	startWG.Wait()
+
+	// close, then read any errors
+	close(errs)
+	for err := range errs {
+		if err != nil {
+			startErrors = append(startErrors, err)
+		}
+	}
+	return startErrors
+}
+
+func (g *guerrilla) Shutdown() {
 	for _, s := range g.servers {
 	for _, s := range g.servers {
-		go s.Start()
+		s.Shutdown()
+		log.Infof("shutdown completed for [%s]", s.config.ListenInterface)
 	}
 	}
+	log.Infof("Backend shutdown completed")
 }
 }

+ 7 - 0
models.go

@@ -20,6 +20,7 @@ var (
 // whether the message was processed successfully.
 // whether the message was processed successfully.
 type Backend interface {
 type Backend interface {
 	Process(*Envelope) BackendResult
 	Process(*Envelope) BackendResult
+	Shutdown() error
 }
 }
 
 
 // BackendResult represents a response to an SMTP client after receiving DATA.
 // BackendResult represents a response to an SMTP client after receiving DATA.
@@ -109,6 +110,12 @@ func (sbr *smtpBufferedReader) setLimit(n int64) {
 	sbr.alr.setLimit(n)
 	sbr.alr.setLimit(n)
 }
 }
 
 
+// Set a new reader & use it to reset the underlying reader
+func (sbr *smtpBufferedReader) Reset(r io.Reader) {
+	sbr.alr = newAdjustableLimitedReader(r, CommandLineMaxLength)
+	sbr.Reader.Reset(sbr.alr)
+}
+
 // Allocate a new SMTPBufferedReader
 // Allocate a new SMTPBufferedReader
 func newSMTPBufferedReader(rd io.Reader) *smtpBufferedReader {
 func newSMTPBufferedReader(rd io.Reader) *smtpBufferedReader {
 	alr := newAdjustableLimitedReader(rd, CommandLineMaxLength)
 	alr := newAdjustableLimitedReader(rd, CommandLineMaxLength)

+ 163 - 0
pool.go

@@ -0,0 +1,163 @@
+package guerrilla
+
+import (
+	"errors"
+	"net"
+	"sync"
+	"sync/atomic"
+	"time"
+)
+
+var (
+	ErrPoolShuttingDown = errors.New("server pool: shutting down")
+)
+
+// a struct can be pooled if it has the following interface
+type Poolable interface {
+	// ability to set read/write timeout
+	setTimeout(t time.Duration)
+	// set a new connection and client id
+	init(c net.Conn, clientID uint64)
+	// reset any internal state
+	reset()
+	// get a unique id
+	getID() uint64
+}
+
+// Pool holds Clients.
+type Pool struct {
+	// clients that are ready to be borrowed
+	pool chan Poolable
+	// semaphore to control number of maximum borrowed clients
+	sem chan bool
+	// book-keeping of clients that have been lent
+	activeClients     lentClients
+	isShuttingDownFlg atomic.Value
+	poolGuard         sync.Mutex
+	ShutdownChan      chan int
+}
+
+type lentClients struct {
+	m  map[uint64]Poolable
+	mu sync.Mutex // guards access to this struct
+	wg sync.WaitGroup
+}
+
+// NewPool creates a new pool of Clients.
+func NewPool(poolSize int) *Pool {
+	return &Pool{
+		pool:          make(chan Poolable, poolSize),
+		sem:           make(chan bool, poolSize),
+		activeClients: lentClients{m: make(map[uint64]Poolable, poolSize)},
+		ShutdownChan:  make(chan int, 1),
+	}
+}
+func (p *Pool) Start() {
+	p.isShuttingDownFlg.Store(true)
+}
+
+// Lock the pool from borrowing then remove all active clients
+// each active client's timeout is lowered to 1 sec and notified
+// to stop accepting commands
+func (p *Pool) ShutdownState() {
+	const aVeryLowTimeout = 1
+	p.poolGuard.Lock() // ensure no other thread is in the borrowing now
+	defer p.poolGuard.Unlock()
+	p.isShuttingDownFlg.Store(true) // no more borrowing
+	p.ShutdownChan <- 1             // release any waiting p.sem
+
+	// set a low timeout
+	var c Poolable
+	for _, c = range p.activeClients.m {
+		c.setTimeout(time.Duration(int64(aVeryLowTimeout)))
+	}
+
+}
+
+func (p *Pool) ShutdownWait() {
+	p.poolGuard.Lock() // ensure no other thread is in the borrowing now
+	defer p.poolGuard.Unlock()
+	p.activeClients.wg.Wait() // wait for clients to finish
+	if len(p.ShutdownChan) > 0 {
+		// drain
+		<-p.ShutdownChan
+	}
+	p.isShuttingDownFlg.Store(false)
+}
+
+// returns true if the pool is shutting down
+func (p *Pool) IsShuttingDown() bool {
+	if value, ok := p.isShuttingDownFlg.Load().(bool); ok {
+		return value
+	}
+	return false
+}
+
+// set a timeout for all lent clients
+func (p *Pool) SetTimeout(duration time.Duration) {
+	var client Poolable
+	p.activeClients.mu.Lock()
+	defer p.activeClients.mu.Unlock()
+	for _, client = range p.activeClients.m {
+		client.setTimeout(duration)
+	}
+}
+
+// Gets the number of active clients that are currently
+// out of the pool and busy serving
+func (p *Pool) GetActiveClientsCount() int {
+	return len(p.sem)
+}
+
+// Borrow a Client from the pool. Will block if len(activeClients) > maxClients
+func (p *Pool) Borrow(conn net.Conn, clientID uint64) (Poolable, error) {
+	p.poolGuard.Lock()
+	defer p.poolGuard.Unlock()
+
+	var c Poolable
+	if yes, really := p.isShuttingDownFlg.Load().(bool); yes && really {
+		// pool is shutting down.
+		return c, ErrPoolShuttingDown
+	}
+	select {
+	case p.sem <- true: // block the client from serving until there is room
+		select {
+		case c = <-p.pool:
+			c.init(conn, clientID)
+		default:
+			c = NewClient(conn, clientID)
+		}
+		p.activeClientsAdd(c)
+
+	case <-p.ShutdownChan: // unblock p.sem when shutting down
+		// pool is shutting down.
+		return c, ErrPoolShuttingDown
+	}
+	return c, nil
+}
+
+// Return returns a Client back to the pool.
+func (p *Pool) Return(c Poolable) {
+	select {
+	case p.pool <- c:
+		c.reset()
+	default:
+		// hasta la vista, baby...
+	}
+	p.activeClientsRemove(c)
+	<-p.sem // make room for the next serving client
+}
+
+func (p *Pool) activeClientsAdd(c Poolable) {
+	p.activeClients.mu.Lock()
+	p.activeClients.wg.Add(1)
+	p.activeClients.m[c.getID()] = c
+	p.activeClients.mu.Unlock()
+}
+
+func (p *Pool) activeClientsRemove(c Poolable) {
+	p.activeClients.mu.Lock()
+	p.activeClients.wg.Done()
+	delete(p.activeClients.m, c.getID())
+	p.activeClients.mu.Unlock()
+}

+ 90 - 42
server.go

@@ -1,7 +1,6 @@
 package guerrilla
 package guerrilla
 
 
 import (
 import (
-	"bufio"
 	"crypto/rand"
 	"crypto/rand"
 	"crypto/tls"
 	"crypto/tls"
 	"fmt"
 	"fmt"
@@ -15,6 +14,8 @@ import (
 	"runtime"
 	"runtime"
 
 
 	log "github.com/Sirupsen/logrus"
 	log "github.com/Sirupsen/logrus"
+
+	"sync"
 )
 )
 
 
 const (
 const (
@@ -26,29 +27,32 @@ const (
 
 
 // Server listens for SMTP clients on the port specified in its config
 // Server listens for SMTP clients on the port specified in its config
 type server struct {
 type server struct {
-	config    *ServerConfig
-	backend   Backend
-	tlsConfig *tls.Config
-	maxSize   int64
-	timeout   time.Duration
-	sem       chan int
+	config         *ServerConfig
+	backend        Backend
+	tlsConfig      *tls.Config
+	maxSize        int64
+	timeout        time.Duration
+	clientPool     *Pool
+	wg             sync.WaitGroup // for waiting to shutdown
+	listener       net.Listener
+	closedListener chan (bool)
 }
 }
 
 
 // Creates and returns a new ready-to-run Server from a configuration
 // Creates and returns a new ready-to-run Server from a configuration
-func newServer(sc *ServerConfig, b Backend) (*server, error) {
+func newServer(sc ServerConfig, b *Backend) (*server, error) {
 	server := &server{
 	server := &server{
-		config:  sc,
-		backend: b,
-		maxSize: sc.MaxSize,
-		sem:     make(chan int, sc.MaxClients),
+		config:         &sc,
+		backend:        *b,
+		maxSize:        sc.MaxSize,
+		timeout:        time.Duration(sc.Timeout),
+		clientPool:     NewPool(sc.MaxClients),
+		closedListener: make(chan (bool), 1),
 	}
 	}
-
 	if server.config.TLSAlwaysOn || server.config.StartTLSOn {
 	if server.config.TLSAlwaysOn || server.config.StartTLSOn {
 		cert, err := tls.LoadX509KeyPair(server.config.PublicKeyFile, server.config.PrivateKeyFile)
 		cert, err := tls.LoadX509KeyPair(server.config.PublicKeyFile, server.config.PrivateKeyFile)
 		if err != nil {
 		if err != nil {
 			return nil, fmt.Errorf("Error loading TLS certificate: %s", err.Error())
 			return nil, fmt.Errorf("Error loading TLS certificate: %s", err.Error())
 		}
 		}
-
 		server.tlsConfig = &tls.Config{
 		server.tlsConfig = &tls.Config{
 			Certificates: []tls.Certificate{cert},
 			Certificates: []tls.Certificate{cert},
 			ClientAuth:   tls.VerifyClientCertIfGiven,
 			ClientAuth:   tls.VerifyClientCertIfGiven,
@@ -56,44 +60,74 @@ func newServer(sc *ServerConfig, b Backend) (*server, error) {
 			Rand:         rand.Reader,
 			Rand:         rand.Reader,
 		}
 		}
 	}
 	}
-
-	server.timeout = time.Duration(server.config.Timeout) * time.Second
-
 	return server, nil
 	return server, nil
 }
 }
 
 
 // Begin accepting SMTP clients
 // Begin accepting SMTP clients
-func (server *server) Start() error {
+func (server *server) Start(startWG *sync.WaitGroup) error {
+	var clientID uint64
+	clientID = 0
+
 	listener, err := net.Listen("tcp", server.config.ListenInterface)
 	listener, err := net.Listen("tcp", server.config.ListenInterface)
+	server.listener = listener
 	if err != nil {
 	if err != nil {
-		return fmt.Errorf("Cannot listen on port: %s", err.Error())
+		return fmt.Errorf("[%s] Cannot listen on port: %s ", server.config.ListenInterface, err.Error())
 	}
 	}
 
 
 	log.Infof("Listening on TCP %s", server.config.ListenInterface)
 	log.Infof("Listening on TCP %s", server.config.ListenInterface)
-	var clientID int64
-	clientID = 1
+	startWG.Done() // start successful
+
 	for {
 	for {
-		log.Debugf("Waiting for a new client. Client ID: %d", clientID)
+		log.Debugf("[%s] Waiting for a new client. Next Client ID: %d", server.config.ListenInterface, clientID+1)
 		conn, err := listener.Accept()
 		conn, err := listener.Accept()
+		clientID++
 		if err != nil {
 		if err != nil {
-			log.WithError(err).Info("Error accepting client")
+			if e, ok := err.(net.Error); ok && !e.Temporary() {
+				log.Infof("Server [%s] has stopped accepting new clients", server.config.ListenInterface)
+				// the listener has been closed, wait for clients to exit
+				log.Infof("shutting down pool [%s]", server.config.ListenInterface)
+				server.clientPool.ShutdownWait()
+				server.closedListener <- true
+				return nil
+			}
+			log.WithError(err).Info("Temporary error accepting client")
 			continue
 			continue
 		}
 		}
-		server.sem <- 1
-		go server.handleClient(&client{
-			Envelope: &Envelope{
-				RemoteAddress: conn.RemoteAddr().String(),
-			},
-			conn:        conn,
-			ConnectedAt: time.Now(),
-			bufin:       newSMTPBufferedReader(conn),
-			bufout:      bufio.NewWriter(conn),
-			ID:          clientID,
-		})
-		clientID++
+		go func(p Poolable, borrow_err error) {
+			c := p.(*client)
+			if borrow_err == nil {
+				server.handleClient(c)
+				server.clientPool.Return(c)
+			} else {
+				log.WithError(borrow_err).Info("couldn't borrow a new client")
+				// we could not get a client, so close the connection.
+				conn.Close()
+
+			}
+			// intentionally placed Borrow in args so that it's called in the
+			// same main goroutine.
+		}(server.clientPool.Borrow(conn, clientID))
+
 	}
 	}
 }
 }
 
 
+func (server *server) Shutdown() {
+	server.clientPool.ShutdownState()
+	if server.listener != nil {
+		server.listener.Close()
+		// wait for the listener to close.
+		<-server.closedListener
+		// At this point Start will exit and close down the pool
+	} else {
+		// listener already closed, wait for clients to exit
+		server.clientPool.ShutdownWait()
+	}
+}
+
+func (server *server) GetActiveClientsCount() int {
+	return server.clientPool.GetActiveClientsCount()
+}
+
 // Verifies that the host is a valid recipient.
 // Verifies that the host is a valid recipient.
 func (server *server) allowsHost(host string) bool {
 func (server *server) allowsHost(host string) bool {
 	for _, allowed := range server.config.AllowedHosts {
 	for _, allowed := range server.config.AllowedHosts {
@@ -113,8 +147,8 @@ func (server *server) upgradeToTLS(client *client) bool {
 		return false
 		return false
 	}
 	}
 	client.conn = net.Conn(tlsConn)
 	client.conn = net.Conn(tlsConn)
-	client.bufin = newSMTPBufferedReader(client.conn)
-	client.bufout = bufio.NewWriter(client.conn)
+	client.bufout.Reset(client.conn)
+	client.bufin.Reset(client.conn)
 	client.TLS = true
 	client.TLS = true
 
 
 	return true
 	return true
@@ -124,7 +158,6 @@ func (server *server) upgradeToTLS(client *client) bool {
 func (server *server) closeConn(client *client) {
 func (server *server) closeConn(client *client) {
 	client.conn.Close()
 	client.conn.Close()
 	client.conn = nil
 	client.conn = nil
-	<-server.sem
 }
 }
 
 
 // Reads from the client until a terminating sequence is encountered,
 // Reads from the client until a terminating sequence is encountered,
@@ -141,7 +174,7 @@ func (server *server) read(client *client) (string, error) {
 	}
 	}
 
 
 	for {
 	for {
-		client.conn.SetDeadline(time.Now().Add(server.timeout))
+		client.setTimeout(server.timeout)
 		reply, err = client.bufin.ReadString('\n')
 		reply, err = client.bufin.ReadString('\n')
 		input = input + reply
 		input = input + reply
 		if client.state == ClientData && reply != "" {
 		if client.state == ClientData && reply != "" {
@@ -163,7 +196,7 @@ func (server *server) read(client *client) (string, error) {
 
 
 // Writes a response to the client.
 // Writes a response to the client.
 func (server *server) writeResponse(client *client) error {
 func (server *server) writeResponse(client *client) error {
-	client.conn.SetDeadline(time.Now().Add(server.timeout))
+	client.setTimeout(server.timeout)
 	size, err := client.bufout.WriteString(client.response)
 	size, err := client.bufout.WriteString(client.response)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -176,15 +209,20 @@ func (server *server) writeResponse(client *client) error {
 	return nil
 	return nil
 }
 }
 
 
+func (server *server) isShuttingDown() bool {
+	return server.clientPool.IsShuttingDown()
+}
+
 // Handles an entire client SMTP exchange
 // Handles an entire client SMTP exchange
 func (server *server) handleClient(client *client) {
 func (server *server) handleClient(client *client) {
 	defer server.closeConn(client)
 	defer server.closeConn(client)
+
 	log.Infof("Handle client [%s], id: %d", client.RemoteAddress, client.ID)
 	log.Infof("Handle client [%s], id: %d", client.RemoteAddress, client.ID)
 
 
 	// Initial greeting
 	// Initial greeting
 	greeting := fmt.Sprintf("220 %s SMTP Guerrilla(%s) #%d (%d) %s gr:%d",
 	greeting := fmt.Sprintf("220 %s SMTP Guerrilla(%s) #%d (%d) %s gr:%d",
 		server.config.Hostname, Version, client.ID,
 		server.config.Hostname, Version, client.ID,
-		len(server.sem), time.Now().Format(time.RFC3339), runtime.NumGoroutine())
+		server.clientPool.GetActiveClientsCount(), time.Now().Format(time.RFC3339), runtime.NumGoroutine())
 
 
 	helo := fmt.Sprintf("250 %s Hello", server.config.Hostname)
 	helo := fmt.Sprintf("250 %s Hello", server.config.Hostname)
 	ehlo := fmt.Sprintf("250-%s Hello", server.config.Hostname)
 	ehlo := fmt.Sprintf("250-%s Hello", server.config.Hostname)
@@ -212,7 +250,6 @@ func (server *server) handleClient(client *client) {
 		case ClientGreeting:
 		case ClientGreeting:
 			client.responseAdd(greeting)
 			client.responseAdd(greeting)
 			client.state = ClientCmd
 			client.state = ClientCmd
-
 		case ClientCmd:
 		case ClientCmd:
 			client.bufin.setLimit(CommandLineMaxLength)
 			client.bufin.setLimit(CommandLineMaxLength)
 			input, err := server.read(client)
 			input, err := server.read(client)
@@ -232,6 +269,10 @@ func (server *server) handleClient(client *client) {
 				client.kill()
 				client.kill()
 				break
 				break
 			}
 			}
+			if server.isShuttingDown() {
+				client.state = ClientShutdown
+				continue
+			}
 
 
 			input = strings.Trim(input, " \r\n")
 			input = strings.Trim(input, " \r\n")
 			cmdLen := len(input)
 			cmdLen := len(input)
@@ -322,6 +363,9 @@ func (server *server) handleClient(client *client) {
 				continue
 				continue
 			}
 			}
 			client.state = ClientCmd
 			client.state = ClientCmd
+			if server.isShuttingDown() {
+				client.state = ClientShutdown
+			}
 
 
 			if client.MailFrom.isEmpty() {
 			if client.MailFrom.isEmpty() {
 				client.responseAdd("550 Error: No sender")
 				client.responseAdd("550 Error: No sender")
@@ -351,6 +395,10 @@ func (server *server) handleClient(client *client) {
 			}
 			}
 			// change to command state
 			// change to command state
 			client.state = ClientCmd
 			client.state = ClientCmd
+		case ClientShutdown:
+			// shutdown state
+			client.responseAdd("421 Server is shutting down. Please try again later. Sayonara!")
+			client.kill()
 		}
 		}
 
 
 		if len(client.response) > 0 {
 		if len(client.response) > 0 {

+ 1 - 0
tests/.gitignore

@@ -0,0 +1 @@
+*.pem

+ 159 - 0
tests/generate_cert.go

@@ -0,0 +1,159 @@
+// adopted from https://golang.org/src/crypto/tls/generate_cert.go?m=text
+
+// Generate a self-signed X.509 certificate for a TLS server. Outputs to
+// 'cert.pem' and 'key.pem' and will overwrite existing files.
+
+package test
+
+import (
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
+	"crypto/rsa"
+	"crypto/x509"
+	"crypto/x509/pkix"
+	"encoding/pem"
+
+	"fmt"
+	"log"
+	"math/big"
+	"net"
+	"os"
+	"strings"
+	"time"
+)
+
+/*
+var (
+	host       = flag.String("host", "", "Comma-separated hostnames and IPs to generate a certificate for")
+	validFrom  = flag.String("start-date", "", "Creation date formatted as Jan 1 15:04:05 2011")
+	validFor   = flag.Duration("duration", 365*24*time.Hour, "Duration that certificate is valid for")
+	isCA       = flag.Bool("ca", false, "whether this cert should be its own Certificate Authority")
+	rsaBits    = flag.Int("rsa-bits", 2048, "Size of RSA key to generate. Ignored if --ecdsa-curve is set")
+	ecdsaCurve = flag.String("ecdsa-curve", "", "ECDSA curve to use to generate a key. Valid values are P224, P256, P384, P521")
+)
+*/
+
+func publicKey(priv interface{}) interface{} {
+	switch k := priv.(type) {
+	case *rsa.PrivateKey:
+		return &k.PublicKey
+	case *ecdsa.PrivateKey:
+		return &k.PublicKey
+	default:
+		return nil
+	}
+}
+
+func pemBlockForKey(priv interface{}) *pem.Block {
+	switch k := priv.(type) {
+	case *rsa.PrivateKey:
+		return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}
+	case *ecdsa.PrivateKey:
+		b, err := x509.MarshalECPrivateKey(k)
+		if err != nil {
+			fmt.Fprintf(os.Stderr, "Unable to marshal ECDSA private key: %v", err)
+			os.Exit(2)
+		}
+		return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}
+	default:
+		return nil
+	}
+}
+
+// validFrom - Creation date formatted as Jan 1 15:04:05 2011 or ""
+
+func generateCert(host string, validFrom string, validFor time.Duration, isCA bool, rsaBits int, ecdsaCurve string) {
+
+	if len(host) == 0 {
+		log.Fatalf("Missing required --host parameter")
+	}
+
+	var priv interface{}
+	var err error
+	switch ecdsaCurve {
+	case "":
+		priv, err = rsa.GenerateKey(rand.Reader, rsaBits)
+	case "P224":
+		priv, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
+	case "P256":
+		priv, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+	case "P384":
+		priv, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
+	case "P521":
+		priv, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
+	default:
+		fmt.Fprintf(os.Stderr, "Unrecognized elliptic curve: %q", ecdsaCurve)
+		os.Exit(1)
+	}
+	if err != nil {
+		log.Fatalf("failed to generate private key: %s", err)
+	}
+
+	var notBefore time.Time
+	if len(validFrom) == 0 {
+		notBefore = time.Now()
+	} else {
+		notBefore, err = time.Parse("Jan 2 15:04:05 2006", validFrom)
+		if err != nil {
+			fmt.Fprintf(os.Stderr, "Failed to parse creation date: %s\n", err)
+			os.Exit(1)
+		}
+	}
+
+	notAfter := notBefore.Add(validFor)
+
+	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
+	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+	if err != nil {
+		log.Fatalf("failed to generate serial number: %s", err)
+	}
+
+	template := x509.Certificate{
+		SerialNumber: serialNumber,
+		Subject: pkix.Name{
+			Organization: []string{"Acme Co"},
+		},
+		NotBefore: notBefore,
+		NotAfter:  notAfter,
+
+		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
+		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+		BasicConstraintsValid: true,
+	}
+
+	hosts := strings.Split(host, ",")
+	for _, h := range hosts {
+		if ip := net.ParseIP(h); ip != nil {
+			template.IPAddresses = append(template.IPAddresses, ip)
+		} else {
+			template.DNSNames = append(template.DNSNames, h)
+		}
+	}
+
+	if isCA {
+		template.IsCA = true
+		template.KeyUsage |= x509.KeyUsageCertSign
+	}
+
+	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
+	if err != nil {
+		log.Fatalf("Failed to create certificate: %s", err)
+	}
+
+	certOut, err := os.Create("./" + host + ".cert.pem")
+	if err != nil {
+		log.Fatalf("failed to open cert.pem for writing: %s", err)
+	}
+	pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+	certOut.Close()
+
+	keyOut, err := os.OpenFile("./"+host+".key.pem", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
+	if err != nil {
+		log.Print("failed to open key.pem for writing:", err)
+		return
+	}
+	pem.Encode(keyOut, pemBlockForKey(priv))
+	keyOut.Close()
+
+}

+ 339 - 0
tests/guerrilla_test.go

@@ -0,0 +1,339 @@
+// integration / smokeless
+// =======================
+// Tests are in a different package so we can test as a consumer of the guerrilla package
+// The following are integration / smokeless, that test the overall server.
+// (Please put unit tests to go in a different file)
+// How it works:
+// Server's log output is redirected to the logBuffer which is then used by the tests to look for expected behaviour
+// the package sets up the logBuffer & redirection by the use of package init()
+// (self signed certs are also generated on each run)
+// server's responses from a connection are also used to check for expected behaviour
+// to run:
+// $ go test
+
+package test
+
+import (
+	"encoding/json"
+	log "github.com/Sirupsen/logrus"
+	"testing"
+
+	"github.com/flashmob/go-guerrilla"
+	"github.com/flashmob/go-guerrilla/backends"
+	"time"
+
+	"bufio"
+
+	"bytes"
+	"crypto/tls"
+	//	"crypto/x509"
+	"errors"
+	"fmt"
+	"io/ioutil"
+	"net"
+	"strings"
+)
+
+type TestConfig struct {
+	guerrilla.AppConfig
+	BackendName   string                 `json:"backend_name"`
+	BackendConfig map[string]interface{} `json:"backend_config"`
+}
+
+var (
+	// hold the output of logs
+	logBuffer bytes.Buffer
+	// logs redirected to this writer
+	logOut *bufio.Writer
+	// read the logs
+	logIn *bufio.Reader
+	// app config loaded here
+	config *TestConfig
+
+	app guerrilla.Guerrilla
+
+	initErr error
+)
+
+func init() {
+	logOut = bufio.NewWriter(&logBuffer)
+	logIn = bufio.NewReader(&logBuffer)
+	log.SetLevel(log.DebugLevel)
+	log.SetOutput(logOut)
+	config = &TestConfig{}
+	if err := json.Unmarshal([]byte(configJson), config); err != nil {
+		initErr = errors.New("Could not unmartial config," + err.Error())
+	} else {
+		setupCerts(config)
+		backend := getDummyBackend(config.BackendConfig)
+		app = guerrilla.New(&config.AppConfig, &backend)
+	}
+
+}
+
+// a configuration file with a dummy backend
+var configJson = `
+{
+    "pid_file" : "/var/run/go-guerrilla.pid",
+    "allowed_hosts": ["spam4.me","grr.la"],
+    "backend_name" : "dummy",
+    "backend_config" :
+        {
+            "log_received_mails" : true
+        },
+    "servers" : [
+        {
+            "is_enabled" : true,
+            "host_name":"mail.guerrillamail.com",
+            "max_size": 100017,
+            "private_key_file":"/vagrant/projects/htdocs/guerrilla/config/ssl/guerrillamail.com.key",
+            "public_key_file":"/vagrant/projects/htdocs/guerrilla/config/ssl/guerrillamail.com.crt",
+            "timeout":160,
+            "listen_interface":"127.0.0.1:2526",
+            "start_tls_on":true,
+            "tls_always_on":false,
+            "max_clients": 2,
+            "log_file":"/dev/stdout"
+        },
+
+        {
+            "is_enabled" : true,
+            "host_name":"mail.guerrillamail.com",
+            "max_size":1000001,
+            "private_key_file":"/vagrant/projects/htdocs/guerrilla/config/ssl/guerrillamail.com.key",
+            "public_key_file":"/vagrant/projects/htdocs/guerrilla/config/ssl/guerrillamail.com.crt",
+            "timeout":180,
+            "listen_interface":"127.0.0.1:4654",
+            "start_tls_on":false,
+            "tls_always_on":true,
+            "max_clients":1,
+            "log_file":"/dev/stdout"
+        }
+    ]
+}
+`
+
+func getDummyBackend(backendConfig map[string]interface{}) guerrilla.Backend {
+	var backend guerrilla.Backend
+	b := &backends.DummyBackend{}
+	b.Initialize(backendConfig)
+	backend = guerrilla.Backend(b)
+	return backend
+}
+
+func setupCerts(c *TestConfig) {
+	for i := range c.Servers {
+		generateCert(c.Servers[i].Hostname, "", 365*24*time.Hour, false, 2048, "P256")
+		c.Servers[i].PrivateKeyFile = c.Servers[i].Hostname + ".key.pem"
+		c.Servers[i].PublicKeyFile = c.Servers[i].Hostname + ".cert.pem"
+	}
+}
+
+// Testing start and stop of server
+func TestStart(t *testing.T) {
+	if initErr != nil {
+		t.Error(initErr)
+		t.FailNow()
+	}
+	if startErrors := app.Start(); startErrors != nil {
+		for _, err := range startErrors {
+			t.Error(err)
+		}
+		t.FailNow()
+	}
+	time.Sleep(time.Second)
+	app.Shutdown()
+	logOut.Flush()
+	if read, err := ioutil.ReadAll(logIn); err == nil {
+		logOutput := string(read)
+		if i := strings.Index(logOutput, "Listening on TCP 127.0.0.1:4654"); i < 0 {
+			t.Error("Server did not listen on 127.0.0.1:4654")
+		}
+		if i := strings.Index(logOutput, "Listening on TCP 127.0.0.1:2526"); i < 0 {
+			t.Error("Server did not listen on 127.0.0.1:2526")
+		}
+		if i := strings.Index(logOutput, "[127.0.0.1:4654] Waiting for a new client"); i < 0 {
+			t.Error("Server did not wait on 127.0.0.1:4654")
+		}
+		if i := strings.Index(logOutput, "[127.0.0.1:2526] Waiting for a new client"); i < 0 {
+			t.Error("Server did not wait on 127.0.0.1:2526")
+		}
+		if i := strings.Index(logOutput, "Server [127.0.0.1:4654] has stopped accepting new clients"); i < 0 {
+			t.Error("Server did not stop on 127.0.0.1:4654")
+		}
+		if i := strings.Index(logOutput, "Server [127.0.0.1:2526] has stopped accepting new clients"); i < 0 {
+			t.Error("Server did not stop on 127.0.0.1:2526")
+		}
+		if i := strings.Index(logOutput, "shutdown completed for [127.0.0.1:4654]"); i < 0 {
+			t.Error("Server did not complete shutdown on 127.0.0.1:4654")
+		}
+		if i := strings.Index(logOutput, "shutdown completed for [127.0.0.1:2526]"); i < 0 {
+			t.Error("Server did not complete shutdown on 127.0.0.1:2526")
+		}
+		if i := strings.Index(logOutput, "shutting down pool [127.0.0.1:4654]"); i < 0 {
+			t.Error("Server did not shutdown pool on 127.0.0.1:4654")
+		}
+		if i := strings.Index(logOutput, "shutting down pool [127.0.0.1:2526]"); i < 0 {
+			t.Error("Server did not shutdown pool on 127.0.0.1:2526")
+		}
+		if i := strings.Index(logOutput, "Backend shutdown completed"); i < 0 {
+			t.Error("Backend didn't shut down")
+		}
+
+	}
+	logBuffer.Reset()
+	logIn.Reset(&logBuffer)
+
+}
+
+// Simple smoke-test to see if the server can listen & issues a greeting on connect
+func TestGreeting(t *testing.T) {
+	//log.SetOutput(os.Stdout)
+	if initErr != nil {
+		t.Error(initErr)
+		t.FailNow()
+	}
+	if startErrors := app.Start(); startErrors == nil {
+
+		// 1. plaintext connection
+		conn, err := net.Dial("tcp", config.Servers[0].ListenInterface)
+		if err != nil {
+			// handle error
+			t.Error("Cannot dial server", config.Servers[0].ListenInterface)
+		}
+		conn.SetReadDeadline(time.Now().Add(time.Duration(time.Millisecond * 500)))
+		greeting, err := bufio.NewReader(conn).ReadString('\n')
+		//fmt.Println(greeting)
+		if err != nil {
+			t.Error(err)
+			t.FailNow()
+		} else {
+			expected := "220 mail.guerrillamail.com SMTP Guerrilla"
+			if strings.Index(greeting, expected) != 0 {
+				t.Error("Server[1] did not have the expected greeting prefix", expected)
+			}
+		}
+		conn.Close()
+
+		// 2. tls connection
+		//	roots, err := x509.SystemCertPool()
+		conn, err = tls.Dial("tcp", config.Servers[1].ListenInterface, &tls.Config{
+
+			InsecureSkipVerify: true,
+			ServerName:         "127.0.0.1",
+		})
+		if err != nil {
+			// handle error
+			t.Error(err, "Cannot dial server (TLS)", config.Servers[1].ListenInterface)
+			t.FailNow()
+		}
+		conn.SetReadDeadline(time.Now().Add(time.Duration(time.Millisecond * 500)))
+		greeting, err = bufio.NewReader(conn).ReadString('\n')
+		//fmt.Println(greeting)
+		if err != nil {
+			t.Error(err)
+			t.FailNow()
+		} else {
+			expected := "220 mail.guerrillamail.com SMTP Guerrilla"
+			if strings.Index(greeting, expected) != 0 {
+				t.Error("Server[2] (TLS) did not have the expected greeting prefix", expected)
+			}
+		}
+		conn.Close()
+
+	} else {
+		if startErrors := app.Start(); startErrors != nil {
+			for _, err := range startErrors {
+				t.Error(err)
+			}
+			t.FailNow()
+		}
+	}
+	app.Shutdown()
+	logOut.Flush()
+	if read, err := ioutil.ReadAll(logIn); err == nil {
+		logOutput := string(read)
+		//fmt.Println(logOutput)
+		if i := strings.Index(logOutput, "Handle client [127.0.0.1:"); i < 0 {
+			t.Error("Server did not handle any clients")
+		}
+	}
+	logBuffer.Reset()
+	logIn.Reset(&logBuffer)
+
+}
+
+// start up a server, connect a client, greet, then shutdown, then client sends a command
+// expecting: 421 Server is shutting down. Please try again later. Sayonara!
+// server should close connection after that
+func TestShutDown(t *testing.T) {
+
+	if initErr != nil {
+		t.Error(initErr)
+		t.FailNow()
+	}
+	if startErrors := app.Start(); startErrors == nil {
+		conn, err := net.Dial("tcp", config.Servers[0].ListenInterface)
+		if err != nil {
+			// handle error
+			t.Error("Cannot dial server", config.Servers[0].ListenInterface)
+		}
+		bufin := bufio.NewReader(conn)
+
+		// should be ample time to complete the test
+		conn.SetDeadline(time.Now().Add(time.Duration(time.Second * 20)))
+		// read greeting, ignore it
+		_, err = bufin.ReadString('\n')
+
+		if err != nil {
+			t.Error(err)
+			t.FailNow()
+		} else {
+			// client goes into command state
+			n, err := fmt.Fprintln(conn, "HELO localtester\r")
+			if err != nil {
+				log.WithError(err).Info("n was %d", n)
+			}
+			_, err = bufin.ReadString('\n')
+
+			// do a shutdown while the client is connected & in client state
+			go app.Shutdown()
+
+			// issue a command while shutting down
+			n, err = fmt.Fprintln(conn, "HELP\r")
+			if err != nil {
+				log.WithError(err).Info("n was %d", n)
+			}
+			response, err := bufin.ReadString('\n')
+			//fmt.Println(response)
+			expected := "421 Server is shutting down. Please try again later. Sayonara!"
+			if strings.Index(response, expected) != 0 {
+				t.Error("Server did not shut down with", expected)
+			}
+			time.Sleep(time.Millisecond * 250) // let server to close
+
+		}
+
+		conn.Close()
+
+	} else {
+		if startErrors := app.Start(); startErrors != nil {
+			for _, err := range startErrors {
+				t.Error(err)
+			}
+			app.Shutdown()
+			t.FailNow()
+		}
+	}
+	logOut.Flush()
+	if read, err := ioutil.ReadAll(logIn); err == nil {
+		logOutput := string(read)
+		//	fmt.Println(logOutput)
+		if i := strings.Index(logOutput, "Handle client [127.0.0.1:"); i < 0 {
+			t.Error("Server did not handle any clients")
+		}
+	}
+	logBuffer.Reset()
+	logIn.Reset(&logBuffer)
+
+}