Parcourir la source

Merge branch 'repackaging' of https://github.com/remohammadi/go-guerrilla into remohammadi-repackaging

flashmob il y a 9 ans
Parent
commit
38d37571d6
8 fichiers modifiés avec 161 ajouts et 87 suppressions
  1. 11 3
      backends/dummy.go
  2. 72 14
      backends/guerrilla_db_redis.go
  3. 1 1
      cmd/guerrillad/serve.go
  4. 15 0
      config.go
  5. 13 2
      models.go
  6. 4 4
      server/goguerrilla.go
  7. 34 20
      server/smtpd.go
  8. 11 43
      util/util.go

+ 11 - 3
backends/dummy.go

@@ -20,7 +20,7 @@ type dummyConfig struct {
 	LogReceivedMails bool `json:"log_received_mails"`
 }
 
-func (b *DummyBackend) Initialize(backendConfig guerrilla.BackendConfig) error {
+func (b *DummyBackend) loadConfig(backendConfig guerrilla.BackendConfig) error {
 	var converted bool
 	b.config.LogReceivedMails, converted = backendConfig["log_received_mails"].(bool)
 	if !converted {
@@ -29,9 +29,17 @@ func (b *DummyBackend) Initialize(backendConfig guerrilla.BackendConfig) error {
 	return nil
 }
 
-func (b *DummyBackend) Process(client *guerrilla.Client, user, host string) string {
+func (b *DummyBackend) Initialize(backendConfig guerrilla.BackendConfig) error {
+	return b.loadConfig(backendConfig)
+}
+
+func (b *DummyBackend) Finalize() error {
+	return nil
+}
+
+func (b *DummyBackend) Process(client *guerrilla.Client, from *guerrilla.EmailParts, to []*guerrilla.EmailParts) string {
 	if b.config.LogReceivedMails {
-		log.Infof("Mail from: %s@%s", user, host)
+		log.Infof("Mail from: %s / to: %v", from, to)
 	}
 	return fmt.Sprintf("250 OK : queued as %s", client.Hash)
 }

+ 72 - 14
backends/guerrilla_db_redis.go

@@ -2,6 +2,7 @@ package backends
 
 import (
 	"fmt"
+	"sync"
 	"time"
 
 	log "github.com/Sirupsen/logrus"
@@ -19,7 +20,9 @@ func init() {
 }
 
 type GuerrillaDBAndRedisBackend struct {
-	config guerrillaDBAndRedisConfig
+	config       guerrillaDBAndRedisConfig
+	saveMailChan chan *savePayload
+	wg           *sync.WaitGroup
 }
 
 type guerrillaDBAndRedisConfig struct {
@@ -34,16 +37,58 @@ type guerrillaDBAndRedisConfig struct {
 	PrimaryHost        string `json:"primary_mail_host"`
 }
 
+func convertError(name string) error {
+	return fmt.Errorf("failed to load backend config field (%s)", name)
+}
+
+func (g *GuerrillaDBAndRedisBackend) loadConfig(backendConfig guerrilla.BackendConfig) error {
+	var converted bool
+
+	if g.config.NumberOfWorkers, converted = backendConfig["save_workers_size"].(int); !converted {
+		return convertError("save_workers_size")
+	}
+	if g.config.MysqlTable, converted = backendConfig["mail_table"].(string); !converted {
+		return convertError("mail_table")
+	}
+	if g.config.MysqlDB, converted = backendConfig["mysql_db"].(string); !converted {
+		return convertError("mysql_db")
+	}
+	if g.config.MysqlHost, converted = backendConfig["mysql_host"].(string); !converted {
+		return convertError("mysql_host")
+	}
+	if g.config.MysqlPass, converted = backendConfig["mysql_pass"].(string); !converted {
+		return convertError("mysql_pass")
+	}
+	if g.config.MysqlUser, converted = backendConfig["mysql_user"].(string); !converted {
+		return convertError("mysql_user")
+	}
+	if g.config.RedisExpireSeconds, converted = backendConfig["redis_expire_seconds"].(int); !converted {
+		return convertError("redis_expire_seconds")
+	}
+	if g.config.RedisInterface, converted = backendConfig["redis_interface"].(string); !converted {
+		return convertError("redis_interface")
+	}
+	if g.config.PrimaryHost, converted = backendConfig["primary_mail_host"].(string); !converted {
+		return convertError("primary_mail_host")
+	}
+
+	return nil
+}
+
 func (g *GuerrillaDBAndRedisBackend) Initialize(backendConfig guerrilla.BackendConfig) error {
-	// TODO: load config
+	err := g.loadConfig(backendConfig)
+	if err != nil {
+		return err
+	}
 
 	if err := g.testDbConnections(); err != nil {
 		return err
 	}
 
-	SaveMailChan = make(chan *savePayload, g.config.NumberOfWorkers)
+	g.saveMailChan = make(chan *savePayload, g.config.NumberOfWorkers)
 
 	// start some savemail workers
+	g.wg.Add(g.config.NumberOfWorkers)
 	for i := 0; i < g.config.NumberOfWorkers; i++ {
 		go g.saveMail()
 	}
@@ -51,10 +96,21 @@ func (g *GuerrillaDBAndRedisBackend) Initialize(backendConfig guerrilla.BackendC
 	return nil
 }
 
-func (g *GuerrillaDBAndRedisBackend) Process(client *guerrilla.Client, user, host string) string {
+func (g *GuerrillaDBAndRedisBackend) Finalize() error {
+	close(g.saveMailChan)
+	g.wg.Wait()
+	return nil
+}
+
+func (g *GuerrillaDBAndRedisBackend) Process(client *guerrilla.Client, from *guerrilla.EmailParts, to []*guerrilla.EmailParts) string {
+	if len(to) == 0 {
+		return "554 Error: no recipient"
+	}
+
 	// to do: timeout when adding to SaveMailChan
 	// place on the channel so that one of the save mail workers can pick it up
-	SaveMailChan <- &savePayload{client: client, user: user, host: host}
+	// TODO: support multiple recipients
+	g.saveMailChan <- &savePayload{client: client, from: from, recipient: to[0]}
 	// wait for the save to complete
 	// or timeout
 	select {
@@ -70,13 +126,11 @@ func (g *GuerrillaDBAndRedisBackend) Process(client *guerrilla.Client, user, hos
 }
 
 type savePayload struct {
-	client *guerrilla.Client
-	user   string
-	host   string
+	client    *guerrilla.Client
+	from      *guerrilla.EmailParts
+	recipient *guerrilla.EmailParts
 }
 
-var SaveMailChan chan *savePayload
-
 type redisClient struct {
 	count int
 	conn  redis.Conn
@@ -113,9 +167,13 @@ func (g *GuerrillaDBAndRedisBackend) saveMail() {
 
 	//  receives values from the channel repeatedly until it is closed.
 	for {
-		payload := <-SaveMailChan
-		recipient = payload.user + "@" + payload.host
-		to = payload.user + "@" + g.config.PrimaryHost
+		payload := <-g.saveMailChan
+		if payload == nil {
+			log.Debug("No more payload")
+			g.wg.Done()
+			return
+		}
+		to = payload.recipient.User + "@" + g.config.PrimaryHost
 		length = len(payload.client.Data)
 		ts := fmt.Sprintf("%d", time.Now().UnixNano())
 		payload.client.Subject = util.MimeHeaderDecode(payload.client.Subject)
@@ -128,7 +186,7 @@ func (g *GuerrillaDBAndRedisBackend) saveMail() {
 		var addHead string
 		addHead += "Delivered-To: " + to + "\r\n"
 		addHead += "Received: from " + payload.client.Helo + " (" + payload.client.Helo + "  [" + payload.client.Address + "])\r\n"
-		addHead += "	by " + payload.host + " with SMTP id " + payload.client.Hash + "@" + payload.host + ";\r\n"
+		addHead += "	by " + payload.recipient.Host + " with SMTP id " + payload.client.Hash + "@" + payload.recipient.Host + ";\r\n"
 		addHead += "	" + time.Now().Format(time.RFC1123Z) + "\r\n"
 		// compress to save space
 		payload.client.Data = util.Compress(&addHead, &payload.client.Data)

+ 1 - 1
cmd/guerrillad/serve.go

@@ -94,7 +94,7 @@ func serve(cmd *cobra.Command, args []string) {
 		if serverConfig.IsEnabled {
 			log.Infof("Starting server on %s", serverConfig.ListenInterface)
 			go func(sConfig guerrilla.ServerConfig) {
-				err := server.RunServer(sConfig, backend, mainConfig.AllowedHosts)
+				err := server.RunServer(mainConfig, sConfig, backend)
 				if err != nil {
 					log.WithError(err).Fatalf("Error while starting server on %s", serverConfig.ListenInterface)
 				}

+ 15 - 0
config.go

@@ -1,5 +1,7 @@
 package guerrilla
 
+import "strings"
+
 type BackendConfig map[string]interface{}
 
 // Config is the holder of the configuration of the app
@@ -8,6 +10,19 @@ type Config struct {
 	BackendConfig BackendConfig  `json:"backend_config,omitempty"`
 	Servers       []ServerConfig `json:"servers"`
 	AllowedHosts  string         `json:"allowed_hosts"`
+
+	_allowedHosts map[string]bool
+}
+
+func (c *Config) IsAllowed(host string) bool {
+	if c._allowedHosts == nil {
+		arr := strings.Split(c.AllowedHosts, ",")
+		c._allowedHosts = make(map[string]bool, len(arr))
+		for _, h := range arr {
+			c._allowedHosts[strings.ToLower(h)] = true
+		}
+	}
+	return c._allowedHosts[strings.ToLower(host)]
 }
 
 // ServerConfig is the holder of the configuration of a server

+ 13 - 2
backend.go → models.go

@@ -3,14 +3,25 @@ package guerrilla
 import (
 	"bufio"
 	"errors"
+	"fmt"
 	"io"
 	"net"
 )
 
-// Backend accepts the relieved messages, and store/deliver/process them
+type EmailParts struct {
+	User string
+	Host string
+}
+
+func (ep *EmailParts) String() string {
+	return fmt.Sprintf("%s@%s", ep.User, ep.Host)
+}
+
+// Backend accepts the recieved messages, and store/deliver/process them
 type Backend interface {
 	Initialize(BackendConfig) error
-	Process(client *Client, user, host string) string
+	Process(client *Client, from *EmailParts, to []*EmailParts) string
+	Finalize() error
 }
 
 const CommandMaxLength = 1024

+ 4 - 4
server/goguerrilla.go

@@ -27,11 +27,11 @@ import (
 	guerrilla "github.com/flashmob/go-guerrilla"
 )
 
-func RunServer(sConfig guerrilla.ServerConfig, backend guerrilla.Backend, allowedHostsStr string) (err error) {
+func RunServer(mainConfig guerrilla.Config, sConfig guerrilla.ServerConfig, backend guerrilla.Backend) (err error) {
 	server := SmtpdServer{
-		Config:          sConfig,
-		sem:             make(chan int, sConfig.MaxClients),
-		allowedHostsStr: allowedHostsStr,
+		mainConfig: mainConfig,
+		config:     sConfig,
+		sem:        make(chan int, sConfig.MaxClients),
 	}
 
 	// configure ssl

+ 34 - 20
server/smtpd.go

@@ -16,12 +16,12 @@ import (
 )
 
 type SmtpdServer struct {
-	tlsConfig       *tls.Config
-	maxSize         int // max email DATA size
-	timeout         time.Duration
-	sem             chan int // currently active client list
-	Config          guerrilla.ServerConfig
-	allowedHostsStr string
+	mainConfig guerrilla.Config
+	config     guerrilla.ServerConfig
+	tlsConfig  *tls.Config
+	maxSize    int // max email DATA size
+	timeout    time.Duration
+	sem        chan int // currently active client list
 }
 
 // Upgrades the connection to TLS
@@ -46,16 +46,16 @@ func (server *SmtpdServer) upgradeToTls(client *guerrilla.Client) bool {
 func (server *SmtpdServer) handleClient(client *guerrilla.Client, backend guerrilla.Backend) {
 	defer server.closeClient(client)
 	advertiseTLS := "250-STARTTLS\r\n"
-	if server.Config.TLSAlwaysOn {
+	if server.config.TLSAlwaysOn {
 		if server.upgradeToTls(client) {
 			advertiseTLS = ""
 		}
 	}
 	greeting := fmt.Sprintf("220 %s SMTP guerrillad(%s) #%d (%d) %s",
-		server.Config.Hostname, guerrilla.Version, client.ClientID,
+		server.config.Hostname, guerrilla.Version, client.ClientID,
 		len(server.sem), time.Now().Format(time.RFC1123Z))
 
-	if !server.Config.StartTLS {
+	if !server.config.StartTLS {
 		// STARTTLS turned off
 		advertiseTLS = ""
 	}
@@ -93,7 +93,7 @@ func (server *SmtpdServer) handleClient(client *guerrilla.Client, backend guerri
 				if len(input) > 5 {
 					client.Helo = input[5:]
 				}
-				responseAdd(client, "250 "+server.Config.Hostname+" Hello ")
+				responseAdd(client, "250 "+server.config.Hostname+" Hello ")
 			case strings.Index(cmd, "EHLO") == 0:
 				if len(input) > 5 {
 					client.Helo = input[5:]
@@ -103,8 +103,8 @@ func (server *SmtpdServer) handleClient(client *guerrilla.Client, backend guerri
 250-SIZE %d\r
 250-PIPELINING \r
 %s250 HELP`,
-					server.Config.Hostname, client.Helo, client.Address,
-					server.Config.MaxSize, advertiseTLS))
+					server.config.Hostname, client.Helo, client.Address,
+					server.config.MaxSize, advertiseTLS))
 			case strings.Index(cmd, "HELP") == 0:
 				responseAdd(client, "250 Help! I need somebody...")
 			case strings.Index(cmd, "MAIL FROM:") == 0:
@@ -135,7 +135,7 @@ func (server *SmtpdServer) handleClient(client *guerrilla.Client, backend guerri
 				client.State = 2
 			case (strings.Index(cmd, "STARTTLS") == 0) &&
 				!client.TLS &&
-				server.Config.StartTLS:
+				server.config.StartTLS:
 				responseAdd(client, "220 Ready to start TLS")
 				// go to start TLS state
 				client.State = 3
@@ -152,14 +152,28 @@ func (server *SmtpdServer) handleClient(client *guerrilla.Client, backend guerri
 			}
 		case 2:
 			var err error
-			client.Bufin.SetLimit(int64(server.Config.MaxSize) + 1024000) // This is a hard limit.
+			client.Bufin.SetLimit(int64(server.config.MaxSize) + 1024000) // This is a hard limit.
 			client.Data, err = server.readSmtp(client)
 			if err == nil {
-				if user, host, mailErr := util.ValidateEmailData(client, server.allowedHostsStr); mailErr == nil {
-					resp := backend.Process(client, user, host)
-					responseAdd(client, resp)
+				from, mailErr := util.ExtractEmail(client.MailFrom)
+				if mailErr != nil {
+					responseAdd(client, fmt.Sprintf("550 Error: invalid from: ", mailErr.Error()))
 				} else {
-					responseAdd(client, "550 Error: "+mailErr.Error())
+					// TODO: support multiple RcptTo
+					to, mailErr := util.ExtractEmail(client.RcptTo)
+					if mailErr != nil {
+						responseAdd(client, fmt.Sprintf("550 Error: invalid from: ", mailErr.Error()))
+					} else {
+						client.MailFrom = from.String()
+						client.RcptTo = to.String()
+						if !server.mainConfig.IsAllowed(to.Host) {
+							responseAdd(client, "550 Error: not allowed")
+						} else {
+							toArray := []*guerrilla.EmailParts{to}
+							resp := backend.Process(client, from, toArray)
+							responseAdd(client, resp)
+						}
+					}
 				}
 
 			} else {
@@ -226,8 +240,8 @@ func (server *SmtpdServer) readSmtp(client *guerrilla.Client) (input string, err
 		reply, err = client.Bufin.ReadString('\n')
 		if reply != "" {
 			input = input + reply
-			if len(input) > server.Config.MaxSize {
-				err = fmt.Errorf("Maximum DATA size exceeded (%d)", server.Config.MaxSize)
+			if len(input) > server.config.MaxSize {
+				err = fmt.Errorf("Maximum DATA size exceeded (%d)", server.config.MaxSize)
 				return input, err
 			}
 			if client.State == 2 {

+ 11 - 43
util/util.go

@@ -12,62 +12,29 @@ import (
 	"regexp"
 	"strings"
 
-	"gopkg.in/iconv.v1"
-
 	"github.com/sloonz/go-qprintable"
+	"gopkg.in/iconv.v1"
 
 	guerrilla "github.com/flashmob/go-guerrilla"
 )
 
-var allowedHosts map[string]bool
-
-// map the allow hosts for easy lookup
-func prepareAllowedHosts(allowedHostsStr string) {
-	allowedHosts = make(map[string]bool, 15)
-	if arr := strings.Split(allowedHostsStr, ","); len(arr) > 0 {
-		for i := 0; i < len(arr); i++ {
-			allowedHosts[arr[i]] = true
-		}
-	}
-}
-
-// TODO: cleanup
-func ValidateEmailData(client *guerrilla.Client, allowedHostsStr string) (user string, host string, addr_err error) {
-	if allowedHosts == nil {
-		prepareAllowedHosts(allowedHostsStr)
-	}
-
-	if user, host, addr_err = extractEmail(client.MailFrom); addr_err != nil {
-		return user, host, addr_err
-	}
-	client.MailFrom = user + "@" + host
-	if user, host, addr_err = extractEmail(client.RcptTo); addr_err != nil {
-		return user, host, addr_err
-	}
-	client.RcptTo = user + "@" + host
-	// check if on allowed hosts
-	if allowed := allowedHosts[strings.ToLower(host)]; !allowed {
-		return user, host, errors.New("invalid host:" + host)
-	}
-	return user, host, addr_err
-}
-
 var extractEmailRegex, _ = regexp.Compile(`<(.+?)@(.+?)>`) // go home regex, you're drunk!
 
-func extractEmail(str string) (name string, host string, err error) {
+func ExtractEmail(str string) (email *guerrilla.EmailParts, err error) {
+	email = &guerrilla.EmailParts{}
 	if matched := extractEmailRegex.FindStringSubmatch(str); len(matched) > 2 {
-		host = validHost(matched[2])
-		name = matched[1]
+		email.User = matched[1]
+		email.Host = validHost(matched[2])
 	} else {
 		if res := strings.Split(str, "@"); len(res) > 1 {
-			name = res[0]
-			host = validHost(res[1])
+			email.User = res[0]
+			email.Host = validHost(res[1])
 		}
 	}
-	if host == "" || name == "" {
-		err = errors.New("Invalid address, [" + name + "@" + host + "] address:" + str)
+	if email.User == "" || email.Host == "" {
+		err = errors.New("Invalid address, [" + email.User + "@" + email.Host + "] address:" + str)
 	}
-	return name, host, err
+	return
 }
 
 var mimeRegex, _ = regexp.Compile(`=\?(.+?)\?([QBqp])\?(.+?)\?=`)
@@ -130,6 +97,7 @@ func MailTransportDecode(str string, encodingType string, charset string) string
 
 	if charset != "UTF-8" {
 		charset = fixCharset(charset)
+		// TODO: remove dependency to os-dependent iconv library
 		if cd, err := iconv.Open("UTF-8", charset); err == nil {
 			defer func() {
 				cd.Close()