Преглед изворни кода

Huge refactor: backends, package layout, ...

* Packaging of the code is now based on this approach:
  https://medium.com/@benbjohnson/standard-package-layout-7cdbc8391fc1
* Now we can have different backends to store/deliver/process messages
* dummy backend
* logrus for logging (I hadn't a quick way in my mind to keep the server
  log file feature, so I've temporary dropped that feature)
* Makefile, to inject git metadata into the binary (version, build time, ...)
* cobra as the CLI parser
  * serve subcommand
	* version subcommand
Reza Mohammadi пре 9 година
родитељ
комит
fc89404ace
20 измењених фајлова са 1052 додато и 821 уклоњено
  1. 1 1
      .gitignore
  2. 28 0
      Makefile
  3. 87 0
      backend.go
  4. 23 0
      backends/backend.go
  5. 37 0
      backends/dummy.go
  6. 208 0
      backends/guerrilla_db_redis.go
  7. 13 0
      cmd/guerrillad/main.go
  8. 31 0
      cmd/guerrillad/root.go
  9. 106 0
      cmd/guerrillad/serve.go
  10. 27 0
      cmd/guerrillad/version.go
  11. 19 89
      config.go
  12. 37 0
      config/config.go
  13. 6 14
      goguerrilla.conf.sample
  14. 0 140
      goguerrilla.go
  15. 0 159
      save_mail.go
  16. 83 0
      server/goguerrilla.go
  17. 273 0
      server/smtpd.go
  18. 0 396
      smtpd.go
  19. 46 22
      util/util.go
  20. 27 0
      version.go

+ 1 - 1
.gitignore

@@ -1,3 +1,3 @@
 .idea
 goguerrilla.conf
-go-guerrilla
+/guerrillad

+ 28 - 0
Makefile

@@ -0,0 +1,28 @@
+GIT ?= git
+GO_VARS ?=
+GO ?= go
+COMMIT := $(shell $(GIT) rev-parse HEAD)
+VERSION ?= $(shell $(GIT) describe --tags ${COMMIT} 2> /dev/null || echo "$(COMMIT)")
+BUILD_TIME := $(shell LANG=en_US date +"%F_%T_%z")
+ROOT := github.com/flashmob/go-guerrilla
+LD_FLAGS := -X $(ROOT).Version=$(VERSION) -X $(ROOT).Commit=$(COMMIT) -X $(ROOT).BuildTime=$(BUILD_TIME)
+
+.PHONY: help clean dependencies test
+help:
+	@echo "Please use \`make <ROOT>' where <ROOT> is one of"
+	@echo "  dependencies to go install the dependencies"
+	@echo "  guerrillad   to build the main binary for current platform"
+	@echo "  test         to run unittests"
+
+clean:
+	rm -f guerrillad
+
+dependencies:
+	$(GO_VARS) $(GO) list -f='{{ join .Deps "\n" }}' $(ROOT)/cmd/guerrillad | grep -v $(ROOT) | tr '\n' ' ' | $(GO_VARS) xargs $(GO) get -u -v
+	$(GO_VARS) $(GO) list -f='{{ join .Deps "\n" }}' $(ROOT)/cmd/guerrillad | grep -v $(ROOT) | tr '\n' ' ' | $(GO_VARS) xargs $(GO) install -v
+
+guerrillad: *.go */*.go */*/*.go
+	$(GO_VARS) $(GO) build -o="guerrillad" -ldflags="$(LD_FLAGS)" $(ROOT)/cmd/guerrillad
+
+test: *.go */*.go */*/*.go
+	$(GO_VARS) $(GO) test -v ./...

+ 87 - 0
backend.go

@@ -0,0 +1,87 @@
+package guerrilla
+
+import (
+	"bufio"
+	"errors"
+	"io"
+	"net"
+)
+
+// Backend accepts the recieved messages, and store/deliver/process them
+type Backend interface {
+	Initialize(BackendConfig) error
+	Process(client *Client, user, host string) string
+}
+
+const CommandMaxLength = 1024
+
+// TODO: cleanup
+type Client struct {
+	State       int
+	Helo        string
+	MailFrom    string
+	RcptTo      string
+	Response    string
+	Address     string
+	Data        string
+	Subject     string
+	Hash        string
+	Time        int64
+	TLS         bool
+	Conn        net.Conn
+	Bufin       *SMTPBufferedReader
+	Bufout      *bufio.Writer
+	KillTime    int64
+	Errors      int
+	ClientID    int64
+	SavedNotify chan int
+}
+
+var InputLimitExceeded = errors.New("Line too long") // 500 Line too long.
+
+// we need to adjust the limit, so we embed io.LimitedReader
+type adjustableLimitedReader struct {
+	R *io.LimitedReader
+}
+
+// bolt this on so we can adjust the limit
+func (alr *adjustableLimitedReader) setLimit(n int64) {
+	alr.R.N = n
+}
+
+// this just delegates to the underlying reader in order to satisfy the Reader interface
+// Since the vanilla limited reader returns io.EOF when the limit is reached, we need a more specific
+// error so that we can distinguish when a limit is reached
+func (alr *adjustableLimitedReader) Read(p []byte) (n int, err error) {
+	n, err = alr.R.Read(p)
+	if err == io.EOF && alr.R.N <= 0 {
+		// return our custom error since std lib returns EOF
+		err = InputLimitExceeded
+	}
+	return
+}
+
+// allocate a new adjustableLimitedReader
+func newAdjustableLimitedReader(r io.Reader, n int64) *adjustableLimitedReader {
+	lr := &io.LimitedReader{R: r, N: n}
+	return &adjustableLimitedReader{lr}
+}
+
+// This is a bufio.Reader what will use our adjustable limit reader
+// We 'extend' buffio to have the limited reader feature
+type SMTPBufferedReader struct {
+	*bufio.Reader
+	alr *adjustableLimitedReader
+}
+
+// delegate to the adjustable limited reader
+func (sbr *SMTPBufferedReader) SetLimit(n int64) {
+	sbr.alr.setLimit(n)
+}
+
+// allocate a new smtpBufferedReader
+func NewSMTPBufferedReader(rd io.Reader) *SMTPBufferedReader {
+	alr := newAdjustableLimitedReader(rd, CommandMaxLength)
+	s := &SMTPBufferedReader{bufio.NewReader(alr), alr}
+	return s
+}

+ 23 - 0
backends/backend.go

@@ -0,0 +1,23 @@
+package backends
+
+import (
+	"fmt"
+
+	guerrilla "github.com/flashmob/go-guerrilla"
+)
+
+var backends = map[string]guerrilla.Backend{}
+
+// New retrive a backend specified by the backendName, and initialize it using
+// backendConfig
+func New(backendName string, backendConfig guerrilla.BackendConfig) (guerrilla.Backend, error) {
+	backend, found := backends[backendName]
+	if !found {
+		return nil, fmt.Errorf("backend %q not found", backendName)
+	}
+	err := backend.Initialize(backendConfig)
+	if err != nil {
+		return nil, fmt.Errorf("error while initializing the backend: %s", err)
+	}
+	return backend, nil
+}

+ 37 - 0
backends/dummy.go

@@ -0,0 +1,37 @@
+package backends
+
+import (
+	"fmt"
+
+	log "github.com/Sirupsen/logrus"
+
+	guerrilla "github.com/flashmob/go-guerrilla"
+)
+
+func init() {
+	backends["dummy"] = &DummyBackend{}
+}
+
+type DummyBackend struct {
+	config dummyConfig
+}
+
+type dummyConfig struct {
+	LogReceivedMails bool `json:"log_received_mails"`
+}
+
+func (b *DummyBackend) Initialize(backendConfig guerrilla.BackendConfig) error {
+	var converted bool
+	b.config.LogReceivedMails, converted = backendConfig["log_received_mails"].(bool)
+	if !converted {
+		return fmt.Errorf("failed to load backend config (%v)", backendConfig)
+	}
+	return nil
+}
+
+func (b *DummyBackend) Process(client *guerrilla.Client, user, host string) string {
+	if b.config.LogReceivedMails {
+		log.Infof("Mail from: %s@%s", user, host)
+	}
+	return fmt.Sprintf("250 OK : queued as %s", client.Hash)
+}

+ 208 - 0
backends/guerrilla_db_redis.go

@@ -0,0 +1,208 @@
+package backends
+
+import (
+	"fmt"
+	"time"
+
+	log "github.com/Sirupsen/logrus"
+
+	"github.com/garyburd/redigo/redis"
+	"github.com/ziutek/mymysql/autorc"
+	_ "github.com/ziutek/mymysql/godrv"
+
+	guerrilla "github.com/flashmob/go-guerrilla"
+	"github.com/flashmob/go-guerrilla/util"
+)
+
+func init() {
+	backends["guerrilla-db-redis"] = &GuerrillaDBAndRedisBackend{}
+}
+
+type GuerrillaDBAndRedisBackend struct {
+	config guerrillaDBAndRedisConfig
+}
+
+type guerrillaDBAndRedisConfig struct {
+	NumberOfWorkers    int    `json:"save_workers_size"`
+	MysqlTable         string `json:"mail_table"`
+	MysqlDB            string `json:"mysql_db"`
+	MysqlHost          string `json:"mysql_host"`
+	MysqlPass          string `json:"mysql_pass"`
+	MysqlUser          string `json:"mysql_user"`
+	RedisExpireSeconds int    `json:"redis_expire_seconds"`
+	RedisInterface     string `json:"redis_interface"`
+	PrimaryHost        string `json:"primary_mail_host"`
+}
+
+func (g *GuerrillaDBAndRedisBackend) Initialize(backendConfig guerrilla.BackendConfig) error {
+	// TODO: load config
+
+	if err := g.testDbConnections(); err != nil {
+		return err
+	}
+
+	SaveMailChan = make(chan *savePayload, g.config.NumberOfWorkers)
+
+	// start some savemail workers
+	for i := 0; i < g.config.NumberOfWorkers; i++ {
+		go g.saveMail()
+	}
+
+	return nil
+}
+
+func (g *GuerrillaDBAndRedisBackend) Process(client *guerrilla.Client, user, host string) string {
+	// to do: timeout when adding to SaveMailChan
+	// place on the channel so that one of the save mail workers can pick it up
+	SaveMailChan <- &savePayload{client: client, user: user, host: host}
+	// wait for the save to complete
+	// or timeout
+	select {
+	case status := <-client.SavedNotify:
+		if status == 1 {
+			return fmt.Sprintf("250 OK : queued as %s", client.Hash)
+		}
+		return "554 Error: transaction failed, blame it on the weather"
+	case <-time.After(time.Second * 30):
+		log.Debug("timeout")
+		return "554 Error: transaction timeout"
+	}
+}
+
+type savePayload struct {
+	client *guerrilla.Client
+	user   string
+	host   string
+}
+
+var SaveMailChan chan *savePayload
+
+type redisClient struct {
+	count int
+	conn  redis.Conn
+	time  int
+}
+
+func (g *GuerrillaDBAndRedisBackend) saveMail() {
+	var to, recipient, body string
+	var err error
+
+	var redisErr error
+	var length int
+	redisClient := &redisClient{}
+	db := autorc.New(
+		"tcp",
+		"",
+		g.config.MysqlHost,
+		g.config.MysqlUser,
+		g.config.MysqlPass,
+		g.config.MysqlDB)
+	db.Register("set names utf8")
+	sql := "INSERT INTO " + g.config.MysqlTable + " "
+	sql += "(`date`, `to`, `from`, `subject`, `body`, `charset`, `mail`, `spam_score`, `hash`, `content_type`, `recipient`, `has_attach`, `ip_addr`, `return_path`, `is_tls`)"
+	sql += " values (NOW(), ?, ?, ?, ? , 'UTF-8' , ?, 0, ?, '', ?, 0, ?, ?, ?)"
+	ins, sqlErr := db.Prepare(sql)
+	if sqlErr != nil {
+		log.WithError(sqlErr).Fatalf("failed while db.Prepare(INSERT...)")
+	}
+	sql = "UPDATE gm2_setting SET `setting_value` = `setting_value`+1 WHERE `setting_name`='received_emails' LIMIT 1"
+	incr, sqlErr := db.Prepare(sql)
+	if sqlErr != nil {
+		log.WithError(sqlErr).Fatalf("failed while db.Prepare(UPDATE...)")
+	}
+
+	//  receives values from the channel repeatedly until it is closed.
+	for {
+		payload := <-SaveMailChan
+		recipient = payload.user + "@" + payload.host
+		to = payload.user + "@" + g.config.PrimaryHost
+		length = len(payload.client.Data)
+		ts := fmt.Sprintf("%d", time.Now().UnixNano())
+		payload.client.Subject = util.MimeHeaderDecode(payload.client.Subject)
+		payload.client.Hash = util.MD5Hex(
+			&to,
+			&payload.client.MailFrom,
+			&payload.client.Subject,
+			&ts)
+		// Add extra headers
+		var addHead string
+		addHead += "Delivered-To: " + to + "\r\n"
+		addHead += "Received: from " + payload.client.Helo + " (" + payload.client.Helo + "  [" + payload.client.Address + "])\r\n"
+		addHead += "	by " + payload.host + " with SMTP id " + payload.client.Hash + "@" + payload.host + ";\r\n"
+		addHead += "	" + time.Now().Format(time.RFC1123Z) + "\r\n"
+		// compress to save space
+		payload.client.Data = util.Compress(&addHead, &payload.client.Data)
+		body = "gzencode"
+		redisErr = redisClient.redisConnection(g.config.RedisInterface)
+		if redisErr == nil {
+			_, doErr := redisClient.conn.Do("SETEX", payload.client.Hash, g.config.RedisExpireSeconds, payload.client.Data)
+			if doErr == nil {
+				payload.client.Data = ""
+				body = "redis"
+			}
+		} else {
+			log.WithError(redisErr).Warn("Error while SETEX on redis")
+		}
+		// bind data to cursor
+		ins.Bind(
+			to,
+			payload.client.MailFrom,
+			payload.client.Subject,
+			body,
+			payload.client.Data,
+			payload.client.Hash,
+			recipient,
+			payload.client.Address,
+			payload.client.MailFrom,
+			payload.client.TLS,
+		)
+		// save, discard result
+		_, _, err = ins.Exec()
+		if err != nil {
+			log.WithError(err).Warn("Database error while inster")
+			payload.client.SavedNotify <- -1
+		} else {
+			log.Debugf("Email saved %s (len=%d)", payload.client.Hash, length)
+			_, _, err = incr.Exec()
+			if err != nil {
+				log.WithError(err).Warn("Database error while incr count")
+			}
+			payload.client.SavedNotify <- 1
+		}
+	}
+}
+
+func (c *redisClient) redisConnection(redisInterface string) (err error) {
+	if c.count == 0 {
+		c.conn, err = redis.Dial("tcp", redisInterface)
+		if err != nil {
+			// handle error
+			return err
+		}
+	}
+	return nil
+}
+
+// test database connection settings
+func (g *GuerrillaDBAndRedisBackend) testDbConnections() (err error) {
+	db := autorc.New(
+		"tcp",
+		"",
+		g.config.MysqlHost,
+		g.config.MysqlUser,
+		g.config.MysqlPass,
+		g.config.MysqlDB)
+
+	if mysqlErr := db.Raw.Connect(); mysqlErr != nil {
+		err = fmt.Errorf("MySql cannot connect, check your settings: %s", mysqlErr)
+	} else {
+		db.Raw.Close()
+	}
+
+	redisClient := &redisClient{}
+	if redisErr := redisClient.redisConnection(g.config.RedisInterface); redisErr != nil {
+		err = fmt.Errorf("Redis cannot connect, check your settings: %s", redisErr)
+	}
+
+	return
+}

+ 13 - 0
cmd/guerrillad/main.go

@@ -0,0 +1,13 @@
+package main
+
+import (
+	"fmt"
+	"os"
+)
+
+func main() {
+	if err := rootCmd.Execute(); err != nil {
+		fmt.Println(err)
+		os.Exit(-1)
+	}
+}

+ 31 - 0
cmd/guerrillad/root.go

@@ -0,0 +1,31 @@
+package main
+
+import (
+	log "github.com/Sirupsen/logrus"
+	"github.com/spf13/cobra"
+)
+
+var rootCmd = &cobra.Command{
+	Use:   "guerrillad",
+	Short: "small SMTP server",
+	Long: `It's a small SMTP server written in Go, for the purpose of receiving large volume of email.
+Written for GuerrillaMail.com which processes tens of thousands of emails every hour.`,
+	Run: nil,
+}
+
+var (
+	verbose bool
+)
+
+func init() {
+	cobra.OnInitialize()
+	rootCmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false,
+		"print out more debug information")
+	rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) {
+		if verbose {
+			log.SetLevel(log.DebugLevel)
+		} else {
+			log.SetLevel(log.InfoLevel)
+		}
+	}
+}

+ 106 - 0
cmd/guerrillad/serve.go

@@ -0,0 +1,106 @@
+package main
+
+import (
+	"os"
+	"os/signal"
+	"syscall"
+
+	log "github.com/Sirupsen/logrus"
+	"github.com/spf13/cobra"
+
+	"fmt"
+
+	guerrilla "github.com/flashmob/go-guerrilla"
+	"github.com/flashmob/go-guerrilla/backends"
+	"github.com/flashmob/go-guerrilla/config"
+	"github.com/flashmob/go-guerrilla/server"
+)
+
+var (
+	iface      string
+	configFile string
+	pidFile    string
+
+	serveCmd = &cobra.Command{
+		Use:   "serve",
+		Short: "start the small SMTP server",
+		Run:   serve,
+	}
+
+	mainConfig    = guerrilla.Config{}
+	signalChannel = make(chan os.Signal, 1) // for trapping SIG_HUB
+)
+
+func init() {
+	serveCmd.PersistentFlags().StringVarP(&iface, "if", "", "",
+		"Interface and port to listen on, eg. 127.0.0.1:2525 ")
+	serveCmd.PersistentFlags().StringVarP(&configFile, "config", "c",
+		"goguerrilla.conf", "Path to the configuration file")
+	serveCmd.PersistentFlags().StringVarP(&pidFile, "pidFile", "p",
+		"/var/run/go-guerrilla.pid", "Path to the pid file")
+
+	rootCmd.AddCommand(serveCmd)
+}
+
+func sigHandler() {
+	// handle SIGHUP for reloading the configuration while running
+	signal.Notify(signalChannel, syscall.SIGHUP)
+
+	for sig := range signalChannel {
+		if sig == syscall.SIGHUP {
+			err := config.ReadConfig(configFile, iface, verbose, &mainConfig)
+			if err != nil {
+				log.WithError(err).Error("Error while ReadConfig (reload)")
+			} else {
+				log.Infof("Configuration is reloaded at %s", guerrilla.ConfigLoadTime)
+			}
+			// TODO: reinitialize
+		} else {
+			os.Exit(0)
+		}
+	}
+}
+
+func serve(cmd *cobra.Command, args []string) {
+	logVersion()
+
+	err := config.ReadConfig(configFile, iface, verbose, &mainConfig)
+	if err != nil {
+		log.WithError(err).Fatal("Error while ReadConfig")
+	}
+
+	// write out our PID
+	if len(pidFile) > 0 {
+		if f, err := os.Create(pidFile); err == nil {
+			defer f.Close()
+			if _, err := f.WriteString(fmt.Sprintf("%d", os.Getpid())); err == nil {
+				f.Sync()
+			} else {
+				log.WithError(err).Fatalf("Error while writing pidFile (%s)", pidFile)
+			}
+		} else {
+			log.WithError(err).Fatalf("Error while creating pidFile (%s)", pidFile)
+		}
+	}
+
+	backend, err := backends.New(mainConfig.BackendName, mainConfig.BackendConfig)
+	if err != nil {
+		log.WithError(err).Fatalf("Error while loading the backend %q",
+			mainConfig.BackendName)
+	}
+
+	// run our servers
+	for _, serverConfig := range mainConfig.Servers {
+		if serverConfig.IsEnabled {
+			log.Infof("Starting server on %s", serverConfig.ListenInterface)
+			go func(sConfig guerrilla.ServerConfig) {
+				err := server.RunServer(sConfig, backend, mainConfig.AllowedHosts)
+				if err != nil {
+					log.WithError(err).Fatalf("Error while starting server on %s", serverConfig.ListenInterface)
+				}
+			}(serverConfig)
+		}
+	}
+
+	sigHandler()
+}

+ 27 - 0
cmd/guerrillad/version.go

@@ -0,0 +1,27 @@
+package main
+
+import (
+	log "github.com/Sirupsen/logrus"
+	"github.com/spf13/cobra"
+
+	guerrilla "github.com/flashmob/go-guerrilla"
+)
+
+var versionCmd = &cobra.Command{
+	Use:   "version",
+	Short: "Print the version info",
+	Long:  `Every software has a version. This is Guerrilla's`,
+	Run: func(cmd *cobra.Command, args []string) {
+		logVersion()
+	},
+}
+
+func init() {
+	rootCmd.AddCommand(versionCmd)
+}
+
+func logVersion() {
+	log.Infof("guerrillad %s", guerrilla.Version)
+	log.Debugf("Build Time: %s", guerrilla.BuildTime)
+	log.Debugf("Commit:     %s", guerrilla.Commit)
+}

+ 19 - 89
config.go

@@ -1,95 +1,25 @@
-package main
+package guerrilla
 
-import (
-	"encoding/json"
-	"flag"
-	"fmt"
-	"io/ioutil"
-	"log"
-	"os"
-	"strings"
-)
+type BackendConfig map[string]interface{}
 
-type GlobalConfig struct {
-	Allowed_hosts        string         `json:"allowed_hosts"`
-	Primary_host         string         `json:"primary_mail_host"`
-	Verbose              bool           `json:"verbose"`
-	Mysql_table          string         `json:"mail_table"`
-	Mysql_db             string         `json:"mysql_db"`
-	Mysql_host           string         `json:"mysql_host"`
-	Mysql_pass           string         `json:"mysql_pass"`
-	Mysql_user           string         `json:"mysql_user"`
-	Servers              []ServerConfig `json:"servers"`
-	Pid_file             string         `json:"pid_file,omitempty"`
-	Save_workers_size    int            `json:"save_workers_size"`
-	Redis_expire_seconds int            `json:"redis_expire_seconds"`
-	Redis_interface      string         `json:"redis_interface"`
+// Config is the holder of the configuration of the app
+type Config struct {
+	BackendName   string         `json:"backend_name"`
+	BackendConfig BackendConfig  `json:"backend_config,omitempty"`
+	Servers       []ServerConfig `json:"servers"`
+	AllowedHosts  string         `json:"allowed_hosts"`
 }
 
+// ServerConfig is the holder of the configuration of a server
 type ServerConfig struct {
-	Is_enabled       bool   `json:"is_enabled"`
-	Host_name        string `json:"host_name"`
-	Max_size         int    `json:"max_size"`
-	Private_key_file string `json:"private_key_file"`
-	Public_key_file  string `json:"public_key_file"`
-	Timeout          int    `json:"timeout"`
-	Listen_interface string `json:"listen_interface"`
-	Start_tls_on     bool   `json:"start_tls_on,omitempty"`
-	Tls_always_on    bool   `json:"tls_always_on,omitempty"`
-	Max_clients      int    `json:"max_clients"`
-	Log_file         string `json:"log_file"`
-}
-
-var mainConfig GlobalConfig
-var flagVerbose, flagIface, flagConfigFile string
-
-// config is read at startup, or when a SIG_HUP is caught
-func readConfig() {
-	log.SetOutput(os.Stdout)
-	// parse command line arguments
-	if !flag.Parsed() {
-		flag.StringVar(&flagConfigFile, "config", "goguerrilla.conf", "Path to the configuration file")
-		flag.StringVar(&flagVerbose, "v", "n", "Verbose, [y | n] ")
-		flag.StringVar(&flagIface, "if", "", "Interface and port to listen on, eg. 127.0.0.1:2525 ")
-		flag.Parse()
-	}
-	// load in the config.
-	b, err := ioutil.ReadFile(flagConfigFile)
-	if err != nil {
-		log.Fatalln("Could not read config file", err)
-	}
-
-	mainConfig = GlobalConfig{}
-	err = json.Unmarshal(b, &mainConfig)
-	//fmt.Println(theConfig)
-	//fmt.Println(fmt.Sprintf("allowed hosts: %s", theConfig.Allowed_hosts))
-	//log.Fatalln("Could not parse config file:", theConfig)
-	if err != nil {
-		fmt.Println("Could not parse config file:", err)
-		log.Fatalln("Could not parse config file:", err)
-	}
-
-	// copy command line flag over so it takes precedence
-	if len(flagVerbose) > 0 && strings.ToUpper(flagVerbose) == "Y" {
-		mainConfig.Verbose = true
-	}
-
-	if len(flagIface) > 0 {
-		mainConfig.Servers[0].Listen_interface = flagIface
-	}
-	// map the allow hosts for easy lookup
-	if len(mainConfig.Allowed_hosts) > 0 {
-		if arr := strings.Split(mainConfig.Allowed_hosts, ","); len(arr) > 0 {
-			for i := 0; i < len(arr); i++ {
-				allowedHosts[arr[i]] = true
-			}
-		}
-	} else {
-		log.Fatalln("Config error, GM_ALLOWED_HOSTS must be s string.")
-	}
-	if mainConfig.Pid_file == "" {
-		mainConfig.Pid_file = "/var/run/go-guerrilla.pid"
-	}
-
-	return
+	IsEnabled       bool   `json:"is_enabled"`
+	Hostname        string `json:"host_name"`
+	MaxSize         int    `json:"max_size"`
+	PrivateKeyFile  string `json:"private_key_file"`
+	PublicKeyFile   string `json:"public_key_file"`
+	Timeout         int    `json:"timeout"`
+	ListenInterface string `json:"listen_interface"`
+	StartTLS        bool   `json:"start_tls_on,omitempty"`
+	TLSAlwaysOn     bool   `json:"tls_always_on,omitempty"`
+	MaxClients      int    `json:"max_clients"`
 }

+ 37 - 0
config/config.go

@@ -0,0 +1,37 @@
+package config
+
+import (
+	"encoding/json"
+	"errors"
+	"fmt"
+	"io/ioutil"
+	"time"
+
+	guerrilla "github.com/flashmob/go-guerrilla"
+)
+
+// ReadConfig which should be called at startup, or when a SIG_HUP is caught
+func ReadConfig(configFile, iface string, verbose bool, mainConfig *guerrilla.Config) error {
+	// load in the config.
+	b, err := ioutil.ReadFile(configFile)
+	if err != nil {
+		return fmt.Errorf("could not read config file: %s", err)
+	}
+
+	err = json.Unmarshal(b, &mainConfig)
+	if err != nil {
+		return fmt.Errorf("could not parse config file: %s", err)
+	}
+
+	if len(mainConfig.AllowedHosts) == 0 {
+		return errors.New("empty AllowedHosts is not allowed")
+	}
+
+	// TODO: deprecate
+	if len(iface) > 0 && len(mainConfig.Servers) > 0 {
+		mainConfig.Servers[0].ListenInterface = iface
+	}
+
+	guerrilla.ConfigLoadTime = time.Now()
+	return nil
+}

+ 6 - 14
goguerrilla.conf.sample

@@ -1,16 +1,10 @@
 {
     "allowed_hosts": "guerrillamail.com,guerrillamailblock.com,sharklasers.com,guerrillamail.net,guerrillamail.org",
     "primary_mail_host":"sharklasers.com",
-    "verbose":false,
-    "mysql_db":"gmail_mail",
-    "mysql_host":"127.0.0.1:3306",
-    "mysql_pass":"ok",
-    "mysql_user":"gmail_mail",
-    "mail_table":"new_mail",
-    "redis_interface" : "127.0.0.1:6379",
-	"redis_expire_seconds" : 3600,
-	"save_workers_size" : 3,
-	"pid_file" : "/var/run/go-guerrilla.pid",
+    "backend_name": "dummy",
+    "backend_config": {
+        "log_received_mails": true
+    },
     "servers" : [
         {
             "is_enabled" : true,
@@ -22,8 +16,7 @@
             "listen_interface":"127.0.0.1:25",
             "start_tls_on":true,
             "tls_always_on":false,
-            "max_clients": 1000,
-            "log_file":"/dev/stdout"
+            "max_clients": 1000
         },
         {
             "is_enabled" : true,
@@ -35,8 +28,7 @@
             "listen_interface":"127.0.0.1:465",
             "start_tls_on":false,
             "tls_always_on":true,
-            "max_clients":500,
-            "log_file":"/dev/stdout"
+            "max_clients":500
         }
     ]
 }

+ 0 - 140
goguerrilla.go

@@ -1,140 +0,0 @@
-/**
-Go-Guerrilla SMTPd
-
-Version: 1.5
-Author: Flashmob, GuerrillaMail.com
-Contact: [email protected]
-License: MIT
-Repository: https://github.com/flashmob/Go-Guerrilla-SMTPd
-Site: http://www.guerrillamail.com/
-
-See README for more details
-
-
-*/
-
-package main
-
-import (
-	"bufio"
-	"crypto/rand"
-	"crypto/tls"
-	"fmt"
-	"net"
-	"os"
-	"os/signal"
-	"runtime"
-	"strconv"
-	"syscall"
-	"time"
-)
-
-var allowedHosts = make(map[string]bool, 15)
-
-
-var signalChannel = make(chan os.Signal, 1) // for trapping SIG_HUB
-
-func sigHandler() {
-	for sig := range signalChannel {
-		if sig == syscall.SIGHUP {
-			readConfig()
-			fmt.Print("Reloading Configuration!\n")
-		} else {
-			os.Exit(0)
-		}
-
-	}
-}
-
-func initialise() {
-
-	// database writing workers
-	SaveMailChan = make(chan *savePayload, mainConfig.Save_workers_size)
-
-	// write out our PID
-	if f, err := os.Create(mainConfig.Pid_file); err == nil {
-		defer f.Close()
-		if _, err := f.WriteString(strconv.Itoa(os.Getpid())); err == nil {
-			f.Sync()
-		}
-	}
-	// handle SIGHUP for reloading the configuration while running
-	signal.Notify(signalChannel, syscall.SIGHUP)
-
-	return
-}
-
-func runServer(sConfig ServerConfig) (err error) {
-	server := SmtpdServer{Config: sConfig, sem: make(chan int, sConfig.Max_clients)}
-
-	// setup logging
-	server.openLog()
-
-	// configure ssl
-	if (sConfig.Tls_always_on || sConfig.Start_tls_on) {
-		cert, err := tls.LoadX509KeyPair(sConfig.Public_key_file, sConfig.Private_key_file)
-		if err != nil {
-			server.logln(2, fmt.Sprintf("There was a problem with loading the certificate: %s", err))
-		}
-		server.tlsConfig = &tls.Config{
-			Certificates: []tls.Certificate{cert},
-			ClientAuth:   tls.VerifyClientCertIfGiven,
-			ServerName:   sConfig.Host_name,
-		}
-		server.tlsConfig.Rand = rand.Reader
-	}
-
-
-	// configure timeout
-	server.timeout = time.Duration(sConfig.Timeout)
-
-	// Start listening for SMTP connections
-	listener, err := net.Listen("tcp", sConfig.Listen_interface)
-	if err != nil {
-		server.logln(2, fmt.Sprintf("Cannot listen on port, %v", err))
-		return err
-	} else {
-		server.logln(1, fmt.Sprintf("Listening on tcp %s", sConfig.Listen_interface))
-	}
-	var clientId int64
-	clientId = 1
-	for {
-		conn, err := listener.Accept()
-		if err != nil {
-			server.logln(1, fmt.Sprintf("Accept error: %s", err))
-			continue
-		}
-		server.logln(0, fmt.Sprintf(" There are now "+strconv.Itoa(runtime.NumGoroutine())+" serving goroutines"))
-		server.sem <- 1 // Wait for active queue to drain.
-		go server.handleClient(&Client{
-			conn:        conn,
-			address:     conn.RemoteAddr().String(),
-			time:        time.Now().Unix(),
-			bufin:       newSmtpBufferedReader(conn),
-			bufout:      bufio.NewWriter(conn),
-			clientId:    clientId,
-			savedNotify: make(chan int),
-		})
-		clientId++
-	}
-}
-
-func main() {
-	readConfig()
-	initialise()
-	if err := testDbConnections(); err != nil {
-		fmt.Println(err)
-		os.Exit(1);
-	}
-	// start some savemail workers
-	for i := 0; i < mainConfig.Save_workers_size; i++ {
-		go saveMail()
-	}
-	// run our servers
-	for serverId := 0; serverId < len(mainConfig.Servers); serverId++ {
-		if mainConfig.Servers[serverId].Is_enabled {
-			go runServer(mainConfig.Servers[serverId])
-		}
-	}
-	sigHandler()
-}

+ 0 - 159
save_mail.go

@@ -1,159 +0,0 @@
-package main
-
-import (
-	"fmt"
-	"github.com/garyburd/redigo/redis"
-	"github.com/ziutek/mymysql/autorc"
-	_ "github.com/ziutek/mymysql/godrv"
-	"log"
-	"strconv"
-	"time"
-	"errors"
-)
-
-type savePayload struct {
-	client *Client
-	server *SmtpdServer
-}
-
-var SaveMailChan chan *savePayload // workers for saving mail
-
-type redisClient struct {
-	count int
-	conn  redis.Conn
-	time  int
-}
-
-func saveMail() {
-	var to, recipient, body string
-	var err error
-
-	var redis_err error
-	var length int
-	redisClient := &redisClient{}
-	db := autorc.New(
-		"tcp",
-		"",
-		mainConfig.Mysql_host,
-		mainConfig.Mysql_user,
-		mainConfig.Mysql_pass,
-		mainConfig.Mysql_db)
-	db.Register("set names utf8")
-	sql := "INSERT INTO " + mainConfig.Mysql_table + " "
-	sql += "(`date`, `to`, `from`, `subject`, `body`, `charset`, `mail`, `spam_score`, `hash`, `content_type`, `recipient`, `has_attach`, `ip_addr`, `return_path`, `is_tls`)"
-	sql += " values (NOW(), ?, ?, ?, ? , 'UTF-8' , ?, 0, ?, '', ?, 0, ?, ?, ?)"
-	ins, sql_err := db.Prepare(sql)
-	if sql_err != nil {
-		log.Fatalf(fmt.Sprintf("Sql statement incorrect: %s\n", sql_err))
-	}
-	sql = "UPDATE gm2_setting SET `setting_value` = `setting_value`+1 WHERE `setting_name`='received_emails' LIMIT 1"
-	incr, sql_err := db.Prepare(sql)
-	if sql_err != nil {
-		log.Fatalf(fmt.Sprintf("Sql statement incorrect: %s\n", sql_err))
-	}
-
-	//  receives values from the channel repeatedly until it is closed.
-	for {
-		payload := <-SaveMailChan
-		if user, host, addr_err := validateEmailData(payload.client); addr_err != nil {
-			payload.server.logln(1, fmt.Sprintf("mail_from didnt validate: %v", addr_err)+" client.mail_from:"+payload.client.mail_from)
-			// notify client that a save completed, -1 = error
-			payload.client.savedNotify <- -1
-			continue
-		} else {
-			recipient = user + "@" + host
-			to = user + "@" + mainConfig.Primary_host
-		}
-		length = len(payload.client.data)
-		ts := strconv.FormatInt(time.Now().UnixNano(), 10);
-		payload.client.subject = mimeHeaderDecode(payload.client.subject)
-		payload.client.hash = md5hex(
-			&to,
-			&payload.client.mail_from,
-			&payload.client.subject,
-			&ts)
-		// Add extra headers
-		add_head := ""
-		add_head += "Delivered-To: " + to + "\r\n"
-		add_head += "Received: from " + payload.client.helo + " (" + payload.client.helo + "  [" + payload.client.address + "])\r\n"
-		add_head += "	by " + payload.server.Config.Host_name + " with SMTP id " + payload.client.hash + "@" +
-			payload.server.Config.Host_name + ";\r\n"
-		add_head += "	" + time.Now().Format(time.RFC1123Z) + "\r\n"
-		// compress to save space
-		payload.client.data = compress(&add_head, &payload.client.data)
-		body = "gzencode"
-		redis_err = redisClient.redisConnection()
-		if redis_err == nil {
-			_, do_err := redisClient.conn.Do("SETEX", payload.client.hash, mainConfig.Redis_expire_seconds, payload.client.data)
-			if do_err == nil {
-				payload.client.data = ""
-				body = "redis"
-			}
-		} else {
-			payload.server.logln(1, fmt.Sprintf("redis: %v", redis_err))
-		}
-		// bind data to cursor
-		ins.Bind(
-			to,
-			payload.client.mail_from,
-			payload.client.subject,
-			body,
-			payload.client.data,
-			payload.client.hash,
-			recipient,
-			payload.client.address,
-			payload.client.mail_from,
-			payload.client.tls_on,
-		)
-		// save, discard result
-		_, _, err = ins.Exec()
-		if err != nil {
-			payload.server.logln(1, fmt.Sprintf("Database error, %v ", err))
-			payload.client.savedNotify <- -1
-		} else {
-			payload.server.logln(0, "Email saved "+payload.client.hash+" len:"+strconv.Itoa(length))
-			_, _, err = incr.Exec()
-			if err != nil {
-				payload.server.logln(1, fmt.Sprintf("Failed to incr count: %v", err))
-			}
-			payload.client.savedNotify <- 1
-		}
-	}
-}
-
-func (c *redisClient) redisConnection() (err error) {
-
-	if c.count == 0 {
-		c.conn, err = redis.Dial("tcp", mainConfig.Redis_interface)
-		if err != nil {
-			// handle error
-			return err
-		}
-	}
-	return nil
-}
-
-// test database connection settings
-func testDbConnections() (err error) {
-
-	db := autorc.New(
-		"tcp",
-		"",
-		mainConfig.Mysql_host,
-		mainConfig.Mysql_user,
-		mainConfig.Mysql_pass,
-		mainConfig.Mysql_db)
-
-	if mysql_err := db.Raw.Connect(); mysql_err != nil {
-		err = errors.New("MySql cannot connect, check your settings. " + mysql_err.Error() )
-	} else {
-		db.Raw.Close();
-	}
-
-	redisClient := &redisClient{}
-	if redis_err := redisClient.redisConnection(); redis_err != nil {
-		err = errors.New("Redis cannot connect, check your settings. " + redis_err.Error())
-	}
-
-	return
-}

+ 83 - 0
server/goguerrilla.go

@@ -0,0 +1,83 @@
+/**
+Go-Guerrilla SMTPd
+
+Version: 1.5
+Author: Flashmob, GuerrillaMail.com
+Contact: [email protected]
+License: MIT
+Repository: https://github.com/flashmob/Go-Guerrilla-SMTPd
+Site: http://www.guerrillamail.com/
+
+See README for more details
+*/
+
+package server
+
+import (
+	"bufio"
+	"crypto/rand"
+	"crypto/tls"
+	"fmt"
+	"net"
+	"runtime"
+	"time"
+
+	log "github.com/Sirupsen/logrus"
+
+	guerrilla "github.com/flashmob/go-guerrilla"
+)
+
+func RunServer(sConfig guerrilla.ServerConfig, backend guerrilla.Backend, allowedHostsStr string) (err error) {
+	server := SmtpdServer{
+		Config:          sConfig,
+		sem:             make(chan int, sConfig.MaxClients),
+		allowedHostsStr: allowedHostsStr,
+	}
+
+	// configure ssl
+	if sConfig.TLSAlwaysOn || sConfig.StartTLS {
+		cert, err := tls.LoadX509KeyPair(sConfig.PublicKeyFile, sConfig.PrivateKeyFile)
+		if err != nil {
+			return fmt.Errorf("error while loading the certificate: %s", err)
+		}
+		server.tlsConfig = &tls.Config{
+			Certificates: []tls.Certificate{cert},
+			ClientAuth:   tls.VerifyClientCertIfGiven,
+			ServerName:   sConfig.Hostname,
+		}
+		server.tlsConfig.Rand = rand.Reader
+	}
+
+	// configure timeout
+	server.timeout = time.Duration(sConfig.Timeout)
+
+	// Start listening for SMTP connections
+	listener, err := net.Listen("tcp", sConfig.ListenInterface)
+	if err != nil {
+		return fmt.Errorf("cannot listen on port, %v", err)
+	}
+
+	log.Infof("Listening on tcp %s", sConfig.ListenInterface)
+
+	var clientID int64
+	clientID = 1
+	for {
+		conn, err := listener.Accept()
+		if err != nil {
+			log.WithError(err).Infof("Accept error")
+			continue
+		}
+		log.Debugf("Number of serving goroutines: %d", runtime.NumGoroutine())
+		server.sem <- 1 // Wait for active queue to drain.
+		go server.handleClient(&guerrilla.Client{
+			Conn:        conn,
+			Address:     conn.RemoteAddr().String(),
+			Time:        time.Now().Unix(),
+			Bufin:       guerrilla.NewSMTPBufferedReader(conn),
+			Bufout:      bufio.NewWriter(conn),
+			ClientID:    clientID,
+			SavedNotify: make(chan int),
+		}, backend)
+		clientID++
+	}
+}

+ 273 - 0
server/smtpd.go

@@ -0,0 +1,273 @@
+package server
+
+import (
+	"bufio"
+	"crypto/tls"
+	"fmt"
+	"io"
+	"net"
+	"strings"
+	"time"
+
+	log "github.com/Sirupsen/logrus"
+
+	guerrilla "github.com/flashmob/go-guerrilla"
+	"github.com/flashmob/go-guerrilla/util"
+)
+
+type SmtpdServer struct {
+	tlsConfig       *tls.Config
+	maxSize         int // max email DATA size
+	timeout         time.Duration
+	sem             chan int // currently active client list
+	Config          guerrilla.ServerConfig
+	allowedHostsStr string
+}
+
+// Upgrades the connection to TLS
+// Sets up buffers with the upgraded connection
+func (server *SmtpdServer) upgradeToTls(client *guerrilla.Client) bool {
+	var tlsConn *tls.Conn
+	tlsConn = tls.Server(client.Conn, server.tlsConfig)
+	err := tlsConn.Handshake()
+	if err == nil {
+		client.Conn = net.Conn(tlsConn)
+		client.Bufin = guerrilla.NewSMTPBufferedReader(client.Conn)
+		client.Bufout = bufio.NewWriter(client.Conn)
+		client.TLS = true
+
+		return true
+	}
+
+	log.WithError(err).Warn("Failed to TLS handshake")
+	return false
+}
+
+func (server *SmtpdServer) handleClient(client *guerrilla.Client, backend guerrilla.Backend) {
+	defer server.closeClient(client)
+	advertiseTLS := "250-STARTTLS\r\n"
+	if server.Config.TLSAlwaysOn {
+		if server.upgradeToTls(client) {
+			advertiseTLS = ""
+		}
+	}
+	greeting := fmt.Sprintf("220 %s SMTP guerrillad(%s) #%d (%d) %s",
+		server.Config.Hostname, guerrilla.Version, client.ClientID,
+		len(server.sem), time.Now().Format(time.RFC1123Z))
+
+	if !server.Config.StartTLS {
+		// STARTTLS turned off
+		advertiseTLS = ""
+	}
+	for i := 0; i < 100; i++ {
+		switch client.State {
+		case 0:
+			responseAdd(client, greeting)
+			client.State = 1
+		case 1:
+			client.Bufin.SetLimit(guerrilla.CommandMaxLength)
+			input, err := server.readSmtp(client)
+			if err != nil {
+				if err == io.EOF {
+					log.WithError(err).Debugf("Client closed the connection already: %s", client.Address)
+					return
+				} else if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
+					log.WithError(err).Debugf("Timeout: %s", client.Address)
+					return
+				} else if err == guerrilla.InputLimitExceeded {
+					responseAdd(client, "500 Line too long.")
+					// kill it so that another one can connect
+					killClient(client)
+				}
+				log.WithError(err).Warnf("Read error: %s", client.Address)
+				break
+			}
+			input = strings.Trim(input, " \n\r")
+			bound := len(input)
+			if bound > 16 {
+				bound = 16
+			}
+			cmd := strings.ToUpper(input[0:bound])
+			switch {
+			case strings.Index(cmd, "HELO") == 0:
+				if len(input) > 5 {
+					client.Helo = input[5:]
+				}
+				responseAdd(client, "250 "+server.Config.Hostname+" Hello ")
+			case strings.Index(cmd, "EHLO") == 0:
+				if len(input) > 5 {
+					client.Helo = input[5:]
+				}
+				responseAdd(client, fmt.Sprintf(
+					`250-%s Hello %s[%s]\r
+250-SIZE %d\r
+250-PIPELINING \r
+%s250 HELP`,
+					server.Config.Hostname, client.Helo, client.Address,
+					server.Config.MaxSize, advertiseTLS))
+			case strings.Index(cmd, "HELP") == 0:
+				responseAdd(client, "250 Help! I need somebody...")
+			case strings.Index(cmd, "MAIL FROM:") == 0:
+				if len(input) > 10 {
+					client.MailFrom = input[10:]
+				}
+				responseAdd(client, "250 Ok")
+			case strings.Index(cmd, "XCLIENT") == 0:
+				// Nginx sends this
+				// XCLIENT ADDR=212.96.64.216 NAME=[UNAVAILABLE]
+				client.Address = input[13:]
+				client.Address = client.Address[0:strings.Index(client.Address, " ")]
+				fmt.Println("client address:[" + client.Address + "]")
+				responseAdd(client, "250 OK")
+			case strings.Index(cmd, "RCPT TO:") == 0:
+				if len(input) > 8 {
+					client.RcptTo = input[8:]
+				}
+				responseAdd(client, "250 Accepted")
+			case strings.Index(cmd, "NOOP") == 0:
+				responseAdd(client, "250 OK")
+			case strings.Index(cmd, "RSET") == 0:
+				client.MailFrom = ""
+				client.RcptTo = ""
+				responseAdd(client, "250 OK")
+			case strings.Index(cmd, "DATA") == 0:
+				responseAdd(client, "354 Enter message, ending with \".\" on a line by itself")
+				client.State = 2
+			case (strings.Index(cmd, "STARTTLS") == 0) &&
+				!client.TLS &&
+				server.Config.StartTLS:
+				responseAdd(client, "220 Ready to start TLS")
+				// go to start TLS state
+				client.State = 3
+			case strings.Index(cmd, "QUIT") == 0:
+				responseAdd(client, "221 Bye")
+				killClient(client)
+			default:
+				responseAdd(client, "500 unrecognized command: "+cmd)
+				client.Errors++
+				if client.Errors > 3 {
+					responseAdd(client, "500 Too many unrecognized commands")
+					killClient(client)
+				}
+			}
+		case 2:
+			var err error
+			client.Bufin.SetLimit(int64(server.Config.MaxSize) + 1024000) // This is a hard limit.
+			client.Data, err = server.readSmtp(client)
+			if err == nil {
+				if user, host, mailErr := util.ValidateEmailData(client, server.allowedHostsStr); mailErr == nil {
+					resp := backend.Process(client, user, host)
+					responseAdd(client, resp)
+				} else {
+					responseAdd(client, "550 Error: "+mailErr.Error())
+				}
+
+			} else {
+				if err == guerrilla.InputLimitExceeded {
+					// hard limit reached, end to make room for other clients
+					responseAdd(client, "550 Error: DATA limit exceeded by more than a megabyte!")
+					killClient(client)
+				} else {
+					responseAdd(client, "550 Error: "+err.Error())
+				}
+
+				log.WithError(err).Warn("DATA read error")
+			}
+			client.State = 1
+		case 3:
+			// upgrade to TLS
+			if server.upgradeToTls(client) {
+				advertiseTLS = ""
+				client.State = 1
+			}
+		}
+		// Send a response back to the client
+		err := server.responseWrite(client)
+		if err != nil {
+			if err == io.EOF {
+				// client closed the connection already
+				return
+			}
+			if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
+				// too slow, timeout
+				return
+			}
+		}
+		if client.KillTime > 1 {
+			return
+		}
+	}
+
+}
+
+// add a response on the response buffer
+func responseAdd(client *guerrilla.Client, line string) {
+	client.Response = line + "\r\n"
+}
+func (server *SmtpdServer) closeClient(client *guerrilla.Client) {
+	client.Conn.Close()
+	<-server.sem // Done; enable next client to run.
+}
+func killClient(client *guerrilla.Client) {
+	client.KillTime = time.Now().Unix()
+}
+
+// Reads from the smtpBufferedReader, can be in command state or data state.
+func (server *SmtpdServer) readSmtp(client *guerrilla.Client) (input string, err error) {
+	var reply string
+	// Command state terminator by default
+	suffix := "\r\n"
+	if client.State == 2 {
+		// DATA state ends with a dot on a line by itself
+		suffix = "\r\n.\r\n"
+	}
+	for err == nil {
+		client.Conn.SetDeadline(time.Now().Add(server.timeout * time.Second))
+		reply, err = client.Bufin.ReadString('\n')
+		if reply != "" {
+			input = input + reply
+			if len(input) > server.Config.MaxSize {
+				err = fmt.Errorf("Maximum DATA size exceeded (%d)", server.Config.MaxSize)
+				return input, err
+			}
+			if client.State == 2 {
+				// Extract the subject while we are at it.
+				scanSubject(client, reply)
+			}
+		}
+		if err != nil {
+			break
+		}
+		if strings.HasSuffix(input, suffix) {
+			break
+		}
+	}
+	return input, err
+}
+
+// Scan the data part for a Subject line. Can be a multi-line
+func scanSubject(client *guerrilla.Client, reply string) {
+	if client.Subject == "" && (len(reply) > 8) {
+		test := strings.ToUpper(reply[0:9])
+		if i := strings.Index(test, "SUBJECT: "); i == 0 {
+			// first line with \r\n
+			client.Subject = reply[9:]
+		}
+	} else if strings.HasSuffix(client.Subject, "\r\n") {
+		// chop off the \r\n
+		client.Subject = client.Subject[0 : len(client.Subject)-2]
+		if (strings.HasPrefix(reply, " ")) || (strings.HasPrefix(reply, "\t")) {
+			// subject is multi-line
+			client.Subject = client.Subject + reply[1:]
+		}
+	}
+}
+
+func (server *SmtpdServer) responseWrite(client *guerrilla.Client) (err error) {
+	var size int
+	client.Conn.SetDeadline(time.Now().Add(server.timeout * time.Second))
+	size, err = client.Bufout.WriteString(client.Response)
+	client.Bufout.Flush()
+	client.Response = client.Response[size:]
+	return err
+}

+ 0 - 396
smtpd.go

@@ -1,396 +0,0 @@
-package main
-
-import (
-	"bufio"
-	"bytes"
-	"crypto/tls"
-	"errors"
-	"fmt"
-	"io"
-	"log"
-	"net"
-	"os"
-	"strconv"
-	"strings"
-	"time"
-)
-
-const commandMaxLength = 1024
-
-type Client struct {
-	state       int
-	helo        string
-	mail_from   string
-	rcpt_to     string
-	response    string
-	address     string
-	data        string
-	subject     string
-	hash        string
-	time        int64
-	tls_on      bool
-	conn        net.Conn
-	bufin       *smtpBufferedReader
-	bufout      *bufio.Writer
-	kill_time   int64
-	errors      int
-	clientId    int64
-	savedNotify chan int
-}
-
-type SmtpdServer struct {
-	tlsConfig    *tls.Config
-	max_size     int // max email DATA size
-	timeout      time.Duration
-	allowedHosts map[string]bool
-	sem          chan int // currently active client list
-	Config       ServerConfig
-	logger       *log.Logger
-}
-
-func (server *SmtpdServer) logln(level int, s string) {
-
-	if mainConfig.Verbose {
-		fmt.Println(s)
-	}
-	// fatal errors
-	if level == 2 {
-		server.logger.Fatalf(s)
-	}
-	// warnings
-	if level == 1 && len(server.Config.Log_file) > 0 {
-		server.logger.Println(s)
-	}
-
-}
-
-func (server *SmtpdServer) openLog() {
-
-	server.logger = log.New(&bytes.Buffer{}, "", log.Lshortfile)
-	// custom log file
-	if len(server.Config.Log_file) > 0 {
-		logfile, err := os.OpenFile(
-			server.Config.Log_file,
-			os.O_WRONLY|os.O_APPEND|os.O_CREATE|os.O_SYNC, 0600)
-		if err != nil {
-			server.logln(1, fmt.Sprintf("Unable to open log file [%s]: %s ", server.Config.Log_file, err))
-		}
-		server.logger.SetOutput(logfile)
-	}
-}
-
-// Upgrades the connection to TLS
-// Sets up buffers with the upgraded connection
-func (server *SmtpdServer) upgradeToTls(client *Client) bool {
-	var tlsConn *tls.Conn
-	tlsConn = tls.Server(client.conn, server.tlsConfig)
-	err := tlsConn.Handshake()
-	if err == nil {
-		client.conn = net.Conn(tlsConn)
-		client.bufin = newSmtpBufferedReader(client.conn)
-		client.bufout = bufio.NewWriter(client.conn)
-		client.tls_on = true
-
-		return true
-	} else {
-		server.logln(1, fmt.Sprintf("Could not TLS handshake:%v", err))
-		return false
-	}
-
-}
-
-func (server *SmtpdServer) handleClient(client *Client) {
-	defer server.closeClient(client)
-	advertiseTls := "250-STARTTLS\r\n"
-	if server.Config.Tls_always_on {
-		if server.upgradeToTls(client) {
-			advertiseTls = ""
-		}
-	}
-	greeting := "220 " + server.Config.Host_name +
-		" SMTP Guerrilla-SMTPd #" +
-		strconv.FormatInt(client.clientId, 10) +
-		" (" + strconv.Itoa(len(server.sem)) + ") " + time.Now().Format(time.RFC1123Z)
-
-	if !server.Config.Start_tls_on {
-		// STARTTLS turned off
-		advertiseTls = ""
-	}
-	for i := 0; i < 100; i++ {
-		switch client.state {
-		case 0:
-			responseAdd(client, greeting)
-			client.state = 1
-		case 1:
-			client.bufin.setLimit(commandMaxLength)
-			input, err := server.readSmtp(client)
-			if err != nil {
-				if err == io.EOF {
-					// client closed the connection already
-					server.logln(0, fmt.Sprintf("%s: %v", client.address, err))
-					return
-				} else if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
-					// too slow, timeout
-					server.logln(0, fmt.Sprintf("%s: %v", client.address, err))
-					return
-				} else if err == INPUT_LIMIT_EXCEEDED {
-					responseAdd(client, "500 Line too long.")
-					// kill it so that another one can connect
-					killClient(client)
-				}
-				server.logln(1, fmt.Sprintf("Read error: %v", err))
-				break
-			}
-			input = strings.Trim(input, " \n\r")
-			bound := len(input)
-			if bound > 16 {
-				bound = 16
-			}
-			cmd := strings.ToUpper(input[0:bound])
-			switch {
-			case strings.Index(cmd, "HELO") == 0:
-				if len(input) > 5 {
-					client.helo = input[5:]
-				}
-				responseAdd(client, "250 "+server.Config.Host_name+" Hello ")
-			case strings.Index(cmd, "EHLO") == 0:
-				if len(input) > 5 {
-					client.helo = input[5:]
-				}
-				responseAdd(client, "250-"+server.Config.Host_name+
-					" Hello "+client.helo+"["+client.address+"]"+"\r\n"+
-					"250-SIZE "+strconv.Itoa(server.Config.Max_size)+"\r\n"+
-					"250-PIPELINING \r\n"+
-					advertiseTls+"250 HELP")
-			case strings.Index(cmd, "HELP") == 0:
-				responseAdd(client, "250 Help! I need somebody...")
-			case strings.Index(cmd, "MAIL FROM:") == 0:
-				if len(input) > 10 {
-					client.mail_from = input[10:]
-				}
-				responseAdd(client, "250 Ok")
-			case strings.Index(cmd, "XCLIENT") == 0:
-				// Nginx sends this
-				// XCLIENT ADDR=212.96.64.216 NAME=[UNAVAILABLE]
-				client.address = input[13:]
-				client.address = client.address[0:strings.Index(client.address, " ")]
-				fmt.Println("client address:[" + client.address + "]")
-				responseAdd(client, "250 OK")
-			case strings.Index(cmd, "RCPT TO:") == 0:
-				if len(input) > 8 {
-					client.rcpt_to = input[8:]
-				}
-				responseAdd(client, "250 Accepted")
-			case strings.Index(cmd, "NOOP") == 0:
-				responseAdd(client, "250 OK")
-			case strings.Index(cmd, "RSET") == 0:
-				client.mail_from = ""
-				client.rcpt_to = ""
-				responseAdd(client, "250 OK")
-			case strings.Index(cmd, "DATA") == 0:
-				responseAdd(client, "354 Enter message, ending with \".\" on a line by itself")
-				client.state = 2
-			case (strings.Index(cmd, "STARTTLS") == 0) &&
-				!client.tls_on &&
-				server.Config.Start_tls_on:
-				responseAdd(client, "220 Ready to start TLS")
-				// go to start TLS state
-				client.state = 3
-			case strings.Index(cmd, "QUIT") == 0:
-				responseAdd(client, "221 Bye")
-				killClient(client)
-			default:
-				responseAdd(client, "500 unrecognized command: "+cmd)
-				client.errors++
-				if client.errors > 3 {
-					responseAdd(client, "500 Too many unrecognized commands")
-					killClient(client)
-				}
-			}
-		case 2:
-			var err error
-			client.bufin.setLimit(int64(server.Config.Max_size) + 1024000) // This is a hard limit.
-			client.data, err = server.readSmtp(client)
-			if err == nil {
-				if _, _, mailErr := validateEmailData(client); mailErr == nil {
-					// to do: timeout when adding to SaveMailChan
-					// place on the channel so that one of the save mail workers can pick it up
-					SaveMailChan <- &savePayload{client: client, server: server}
-					// wait for the save to complete
-					// or timeout
-					select {
-					case status := <-client.savedNotify:
-						if status == 1 {
-							responseAdd(client, "250 OK : queued as "+client.hash)
-						} else {
-							responseAdd(client, "554 Error: transaction failed, blame it on the weather")
-						}
-					case <-time.After(time.Second * 30):
-						fmt.Println("timeout 1")
-						responseAdd(client, "554 Error: transaction timeout")
-					}
-
-				} else {
-					responseAdd(client, "550 Error: "+mailErr.Error())
-				}
-
-			} else {
-				if (err == INPUT_LIMIT_EXCEEDED) {
-					// hard limit reached, end to make room for other clients
-					responseAdd(client, "550 Error: DATA limit exceeded by more than a megabyte!")
-					killClient(client)
-				} else {
-					responseAdd(client, "550 Error: "+err.Error())
-				}
-
-				server.logln(1, fmt.Sprintf("DATA read error: %v", err))
-			}
-			client.state = 1
-		case 3:
-			// upgrade to TLS
-			if server.upgradeToTls(client) {
-				advertiseTls = ""
-				client.state = 1
-			}
-		}
-		// Send a response back to the client
-		err := server.responseWrite(client)
-		if err != nil {
-			if err == io.EOF {
-				// client closed the connection already
-				return
-			}
-			if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
-				// too slow, timeout
-				return
-			}
-		}
-		if client.kill_time > 1 {
-			return
-		}
-	}
-
-}
-
-// add a response on the response buffer
-func responseAdd(client *Client, line string) {
-	client.response = line + "\r\n"
-}
-func (server *SmtpdServer) closeClient(client *Client) {
-	client.conn.Close()
-	<-server.sem // Done; enable next client to run.
-}
-func killClient(client *Client) {
-	client.kill_time = time.Now().Unix()
-}
-
-var INPUT_LIMIT_EXCEEDED = errors.New("Line too long") // 500 Line too long.
-
-// we need to adjust the limit, so we embed io.LimitedReader
-type adjustableLimitedReader struct {
-	R *io.LimitedReader
-}
-
-// bolt this on so we can adjust the limit
-func (alr *adjustableLimitedReader) setLimit(n int64) {
-	alr.R.N = n
-}
-
-// this just delegates to the underlying reader in order to satisfy the Reader interface
-// Since the vanilla limited reader returns io.EOF when the limit is reached, we need a more specific
-// error so that we can distinguish when a limit is reached
-func (alr *adjustableLimitedReader) Read(p []byte) (n int, err error) {
-	n, err = alr.R.Read(p)
-	if err == io.EOF && alr.R.N <= 0 {
-		// return our custom error since std lib returns EOF
-		err = INPUT_LIMIT_EXCEEDED
-	}
-	return
-}
-
-// allocate a new adjustableLimitedReader
-func newAdjustableLimitedReader(r io.Reader, n int64) *adjustableLimitedReader {
-	lr := &io.LimitedReader{R: r, N: n}
-	return &adjustableLimitedReader{lr}
-}
-
-// This is a bufio.Reader what will use our adjustable limit reader
-// We 'extend' buffio to have the limited reader feature
-type smtpBufferedReader struct {
-	*bufio.Reader
-	alr *adjustableLimitedReader
-}
-
-
-// delegate to the adjustable limited reader
-func (sbr *smtpBufferedReader) setLimit(n int64) {
-	sbr.alr.setLimit(n)
-}
-
-
-// allocate a new smtpBufferedReader
-func newSmtpBufferedReader(rd io.Reader) *smtpBufferedReader {
-	alr := newAdjustableLimitedReader(rd, commandMaxLength)
-	s := &smtpBufferedReader{bufio.NewReader(alr), alr}
-	return s
-}
-
-// Reads from the smtpBufferedReader, can be in command state or data state.
-func (server *SmtpdServer) readSmtp(client *Client) (input string, err error) {
-	var reply string
-	// Command state terminator by default
-	suffix := "\r\n"
-	if client.state == 2 {
-		// DATA state ends with a dot on a line by itself
-		suffix = "\r\n.\r\n"
-	}
-	for err == nil {
-		client.conn.SetDeadline(time.Now().Add(server.timeout * time.Second))
-		reply, err = client.bufin.ReadString('\n')
-		if reply != "" {
-			input = input + reply
-			if len(input) > server.Config.Max_size {
-				err = errors.New("Maximum DATA size exceeded (" + strconv.Itoa(server.Config.Max_size) + ")")
-				return input, err
-			}
-			if client.state == 2 {
-				// Extract the subject while we are at it.
-				scanSubject(client, reply)
-			}
-		}
-		if err != nil {
-			break
-		}
-		if strings.HasSuffix(input, suffix) {
-			break
-		}
-	}
-	return input, err
-}
-
-// Scan the data part for a Subject line. Can be a multi-line
-func scanSubject(client *Client, reply string) {
-	if client.subject == "" && (len(reply) > 8) {
-		test := strings.ToUpper(reply[0:9])
-		if i := strings.Index(test, "SUBJECT: "); i == 0 {
-			// first line with \r\n
-			client.subject = reply[9:]
-		}
-	} else if strings.HasSuffix(client.subject, "\r\n") {
-		// chop off the \r\n
-		client.subject = client.subject[0 : len(client.subject)-2]
-		if (strings.HasPrefix(reply, " ")) || (strings.HasPrefix(reply, "\t")) {
-			// subject is multi-line
-			client.subject = client.subject + reply[1:]
-		}
-	}
-}
-
-func (server *SmtpdServer) responseWrite(client *Client) (err error) {
-	var size int
-	client.conn.SetDeadline(time.Now().Add(server.timeout * time.Second))
-	size, err = client.bufout.WriteString(client.response)
-	client.bufout.Flush()
-	client.response = client.response[size:]
-	return err
-}

+ 46 - 22
util.go → util/util.go

@@ -1,4 +1,4 @@
-package main
+package util
 
 import (
 	"bytes"
@@ -6,24 +6,45 @@ import (
 	"crypto/md5"
 	"encoding/base64"
 	"errors"
-	"github.com/sloonz/go-qprintable"
-	"gopkg.in/iconv.v1"
+	"fmt"
+	"io"
 	"io/ioutil"
 	"regexp"
 	"strings"
-	"io"
-	"fmt"
+
+	"gopkg.in/iconv.v1"
+
+	"github.com/sloonz/go-qprintable"
+
+	guerrilla "github.com/flashmob/go-guerrilla"
 )
 
-func validateEmailData(client *Client) (user string, host string, addr_err error) {
-	if user, host, addr_err = extractEmail(client.mail_from); addr_err != nil {
+var allowedHosts map[string]bool
+
+// map the allow hosts for easy lookup
+func prepareAllowedHosts(allowedHostsStr string) {
+	allowedHosts = make(map[string]bool, 15)
+	if arr := strings.Split(allowedHostsStr, ","); len(arr) > 0 {
+		for i := 0; i < len(arr); i++ {
+			allowedHosts[arr[i]] = true
+		}
+	}
+}
+
+// TODO: cleanup
+func ValidateEmailData(client *guerrilla.Client, allowedHostsStr string) (user string, host string, addr_err error) {
+	if allowedHosts == nil {
+		prepareAllowedHosts(allowedHostsStr)
+	}
+
+	if user, host, addr_err = extractEmail(client.MailFrom); addr_err != nil {
 		return user, host, addr_err
 	}
-	client.mail_from = user + "@" + host
-	if user, host, addr_err = extractEmail(client.rcpt_to); addr_err != nil {
+	client.MailFrom = user + "@" + host
+	if user, host, addr_err = extractEmail(client.RcptTo); addr_err != nil {
 		return user, host, addr_err
 	}
-	client.rcpt_to = user + "@" + host
+	client.RcptTo = user + "@" + host
 	// check if on allowed hosts
 	if allowed := allowedHosts[strings.ToLower(host)]; !allowed {
 		return user, host, errors.New("invalid host:" + host)
@@ -48,10 +69,12 @@ func extractEmail(str string) (name string, host string, err error) {
 	}
 	return name, host, err
 }
+
 var mimeRegex, _ = regexp.Compile(`=\?(.+?)\?([QBqp])\?(.+?)\?=`)
+
 // Decode strings in Mime header format
 // eg. =?ISO-2022-JP?B?GyRCIVo9dztSOWJAOCVBJWMbKEI=?=
-func mimeHeaderDecode(str string) string {
+func MimeHeaderDecode(str string) string {
 
 	matched := mimeRegex.FindAllStringSubmatch(str, -1)
 	var charset, encoding, payload string
@@ -66,13 +89,13 @@ func mimeHeaderDecode(str string) string {
 					str = strings.Replace(
 						str,
 						matched[i][0],
-						mailTransportDecode(payload, "base64", charset),
+						MailTransportDecode(payload, "base64", charset),
 						1)
 				case "Q":
 					str = strings.Replace(
 						str,
 						matched[i][0],
-						mailTransportDecode(payload, "quoted-printable", charset),
+						MailTransportDecode(payload, "quoted-printable", charset),
 						1)
 				}
 			}
@@ -82,6 +105,7 @@ func mimeHeaderDecode(str string) string {
 }
 
 var valihostRegex, _ = regexp.Compile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
+
 func validHost(host string) string {
 	host = strings.Trim(host, " ")
 	if valihostRegex.MatchString(host) {
@@ -91,16 +115,16 @@ func validHost(host string) string {
 }
 
 // decode from 7bit to 8bit UTF-8
-// encoding_type can be "base64" or "quoted-printable"
-func mailTransportDecode(str string, encoding_type string, charset string) string {
+// encodingType can be "base64" or "quoted-printable"
+func MailTransportDecode(str string, encodingType string, charset string) string {
 	if charset == "" {
 		charset = "UTF-8"
 	} else {
 		charset = strings.ToUpper(charset)
 	}
-	if encoding_type == "base64" {
+	if encodingType == "base64" {
 		str = fromBase64(str)
-	} else if encoding_type == "quoted-printable" {
+	} else if encodingType == "quoted-printable" {
 		str = fromQuotedP(str)
 	}
 
@@ -135,8 +159,8 @@ func fromQuotedP(data string) string {
 	return string(res)
 }
 
-
 var charsetRegex, _ = regexp.Compile(`[_:.\/\\]`)
+
 func fixCharset(charset string) string {
 	fixed_charset := charsetRegex.ReplaceAllString(charset, "-")
 	// Fix charset
@@ -160,10 +184,10 @@ func fixCharset(charset string) string {
 }
 
 // returns an md5 hash as string of hex characters
-func md5hex(stringArguments ...*string) string {
+func MD5Hex(stringArguments ...*string) string {
 	h := md5.New()
 	var r *strings.Reader
-	for i:=0; i < len(stringArguments); i++ {
+	for i := 0; i < len(stringArguments); i++ {
 		r = strings.NewReader(*stringArguments[i])
 		io.Copy(h, r)
 	}
@@ -172,11 +196,11 @@ func md5hex(stringArguments ...*string) string {
 }
 
 // concatenate & compress all strings  passed in
-func compress(stringArguments ...*string) string {
+func Compress(stringArguments ...*string) string {
 	var b bytes.Buffer
 	var r *strings.Reader
 	w, _ := zlib.NewWriterLevel(&b, zlib.BestSpeed)
-	for i:=0; i < len(stringArguments); i++ {
+	for i := 0; i < len(stringArguments); i++ {
 		r = strings.NewReader(*stringArguments[i])
 		io.Copy(w, r)
 	}

+ 27 - 0
version.go

@@ -0,0 +1,27 @@
+package guerrilla
+
+import "time"
+
+var (
+	Version   string
+	Commit    string
+	BuildTime string
+
+	StartTime      time.Time
+	ConfigLoadTime time.Time
+)
+
+func init() {
+	// If version, commit, or build time are not set, make that clear.
+	if Version == "" {
+		Version = "unknown"
+	}
+	if Commit == "" {
+		Commit = "unknown"
+	}
+	if BuildTime == "" {
+		BuildTime = "unknown"
+	}
+
+	StartTime = time.Now()
+}