Browse Source

Move AllowedHosts stuff to config

Reza Mohammadi 9 years ago
parent
commit
fc1a98d29e
5 changed files with 58 additions and 52 deletions
  1. 1 1
      cmd/guerrillad/serve.go
  2. 15 0
      config.go
  3. 4 4
      server/goguerrilla.go
  4. 25 19
      server/smtpd.go
  5. 13 28
      util/util.go

+ 1 - 1
cmd/guerrillad/serve.go

@@ -94,7 +94,7 @@ func serve(cmd *cobra.Command, args []string) {
 		if serverConfig.IsEnabled {
 		if serverConfig.IsEnabled {
 			log.Infof("Starting server on %s", serverConfig.ListenInterface)
 			log.Infof("Starting server on %s", serverConfig.ListenInterface)
 			go func(sConfig guerrilla.ServerConfig) {
 			go func(sConfig guerrilla.ServerConfig) {
-				err := server.RunServer(sConfig, backend, mainConfig.AllowedHosts)
+				err := server.RunServer(mainConfig, sConfig, backend)
 				if err != nil {
 				if err != nil {
 					log.WithError(err).Fatalf("Error while starting server on %s", serverConfig.ListenInterface)
 					log.WithError(err).Fatalf("Error while starting server on %s", serverConfig.ListenInterface)
 				}
 				}

+ 15 - 0
config.go

@@ -1,5 +1,7 @@
 package guerrilla
 package guerrilla
 
 
+import "strings"
+
 type BackendConfig map[string]interface{}
 type BackendConfig map[string]interface{}
 
 
 // Config is the holder of the configuration of the app
 // Config is the holder of the configuration of the app
@@ -8,6 +10,19 @@ type Config struct {
 	BackendConfig BackendConfig  `json:"backend_config,omitempty"`
 	BackendConfig BackendConfig  `json:"backend_config,omitempty"`
 	Servers       []ServerConfig `json:"servers"`
 	Servers       []ServerConfig `json:"servers"`
 	AllowedHosts  string         `json:"allowed_hosts"`
 	AllowedHosts  string         `json:"allowed_hosts"`
+
+	_allowedHosts map[string]bool
+}
+
+func (c *Config) IsAllowed(host string) bool {
+	if c._allowedHosts == nil {
+		arr := strings.Split(c.AllowedHosts, ",")
+		c._allowedHosts = make(map[string]bool, len(arr))
+		for _, h := range arr {
+			c._allowedHosts[strings.ToLower(h)] = true
+		}
+	}
+	return c._allowedHosts[strings.ToLower(host)]
 }
 }
 
 
 // ServerConfig is the holder of the configuration of a server
 // ServerConfig is the holder of the configuration of a server

+ 4 - 4
server/goguerrilla.go

@@ -27,11 +27,11 @@ import (
 	guerrilla "github.com/flashmob/go-guerrilla"
 	guerrilla "github.com/flashmob/go-guerrilla"
 )
 )
 
 
-func RunServer(sConfig guerrilla.ServerConfig, backend guerrilla.Backend, allowedHostsStr string) (err error) {
+func RunServer(mainConfig guerrilla.Config, sConfig guerrilla.ServerConfig, backend guerrilla.Backend) (err error) {
 	server := SmtpdServer{
 	server := SmtpdServer{
-		Config:          sConfig,
-		sem:             make(chan int, sConfig.MaxClients),
-		allowedHostsStr: allowedHostsStr,
+		mainConfig: mainConfig,
+		config:     sConfig,
+		sem:        make(chan int, sConfig.MaxClients),
 	}
 	}
 
 
 	// configure ssl
 	// configure ssl

+ 25 - 19
server/smtpd.go

@@ -16,12 +16,12 @@ import (
 )
 )
 
 
 type SmtpdServer struct {
 type SmtpdServer struct {
-	tlsConfig       *tls.Config
-	maxSize         int // max email DATA size
-	timeout         time.Duration
-	sem             chan int // currently active client list
-	Config          guerrilla.ServerConfig
-	allowedHostsStr string
+	mainConfig guerrilla.Config
+	config     guerrilla.ServerConfig
+	tlsConfig  *tls.Config
+	maxSize    int // max email DATA size
+	timeout    time.Duration
+	sem        chan int // currently active client list
 }
 }
 
 
 // Upgrades the connection to TLS
 // Upgrades the connection to TLS
@@ -46,16 +46,16 @@ func (server *SmtpdServer) upgradeToTls(client *guerrilla.Client) bool {
 func (server *SmtpdServer) handleClient(client *guerrilla.Client, backend guerrilla.Backend) {
 func (server *SmtpdServer) handleClient(client *guerrilla.Client, backend guerrilla.Backend) {
 	defer server.closeClient(client)
 	defer server.closeClient(client)
 	advertiseTLS := "250-STARTTLS\r\n"
 	advertiseTLS := "250-STARTTLS\r\n"
-	if server.Config.TLSAlwaysOn {
+	if server.config.TLSAlwaysOn {
 		if server.upgradeToTls(client) {
 		if server.upgradeToTls(client) {
 			advertiseTLS = ""
 			advertiseTLS = ""
 		}
 		}
 	}
 	}
 	greeting := fmt.Sprintf("220 %s SMTP guerrillad(%s) #%d (%d) %s",
 	greeting := fmt.Sprintf("220 %s SMTP guerrillad(%s) #%d (%d) %s",
-		server.Config.Hostname, guerrilla.Version, client.ClientID,
+		server.config.Hostname, guerrilla.Version, client.ClientID,
 		len(server.sem), time.Now().Format(time.RFC1123Z))
 		len(server.sem), time.Now().Format(time.RFC1123Z))
 
 
-	if !server.Config.StartTLS {
+	if !server.config.StartTLS {
 		// STARTTLS turned off
 		// STARTTLS turned off
 		advertiseTLS = ""
 		advertiseTLS = ""
 	}
 	}
@@ -93,7 +93,7 @@ func (server *SmtpdServer) handleClient(client *guerrilla.Client, backend guerri
 				if len(input) > 5 {
 				if len(input) > 5 {
 					client.Helo = input[5:]
 					client.Helo = input[5:]
 				}
 				}
-				responseAdd(client, "250 "+server.Config.Hostname+" Hello ")
+				responseAdd(client, "250 "+server.config.Hostname+" Hello ")
 			case strings.Index(cmd, "EHLO") == 0:
 			case strings.Index(cmd, "EHLO") == 0:
 				if len(input) > 5 {
 				if len(input) > 5 {
 					client.Helo = input[5:]
 					client.Helo = input[5:]
@@ -103,8 +103,8 @@ func (server *SmtpdServer) handleClient(client *guerrilla.Client, backend guerri
 250-SIZE %d\r
 250-SIZE %d\r
 250-PIPELINING \r
 250-PIPELINING \r
 %s250 HELP`,
 %s250 HELP`,
-					server.Config.Hostname, client.Helo, client.Address,
-					server.Config.MaxSize, advertiseTLS))
+					server.config.Hostname, client.Helo, client.Address,
+					server.config.MaxSize, advertiseTLS))
 			case strings.Index(cmd, "HELP") == 0:
 			case strings.Index(cmd, "HELP") == 0:
 				responseAdd(client, "250 Help! I need somebody...")
 				responseAdd(client, "250 Help! I need somebody...")
 			case strings.Index(cmd, "MAIL FROM:") == 0:
 			case strings.Index(cmd, "MAIL FROM:") == 0:
@@ -135,7 +135,7 @@ func (server *SmtpdServer) handleClient(client *guerrilla.Client, backend guerri
 				client.State = 2
 				client.State = 2
 			case (strings.Index(cmd, "STARTTLS") == 0) &&
 			case (strings.Index(cmd, "STARTTLS") == 0) &&
 				!client.TLS &&
 				!client.TLS &&
-				server.Config.StartTLS:
+				server.config.StartTLS:
 				responseAdd(client, "220 Ready to start TLS")
 				responseAdd(client, "220 Ready to start TLS")
 				// go to start TLS state
 				// go to start TLS state
 				client.State = 3
 				client.State = 3
@@ -152,12 +152,18 @@ func (server *SmtpdServer) handleClient(client *guerrilla.Client, backend guerri
 			}
 			}
 		case 2:
 		case 2:
 			var err error
 			var err error
-			client.Bufin.SetLimit(int64(server.Config.MaxSize) + 1024000) // This is a hard limit.
+			client.Bufin.SetLimit(int64(server.config.MaxSize) + 1024000) // This is a hard limit.
 			client.Data, err = server.readSmtp(client)
 			client.Data, err = server.readSmtp(client)
 			if err == nil {
 			if err == nil {
-				if user, host, mailErr := util.ValidateEmailData(client, server.allowedHostsStr); mailErr == nil {
-					resp := backend.Process(client, user, host)
-					responseAdd(client, resp)
+				if from, to, mailErr := util.ValidateEmailData(client.MailFrom, client.RcptTo); mailErr == nil {
+					client.MailFrom = fmt.Sprintf("%s@%s", from.User, from.Host)
+					client.RcptTo = fmt.Sprintf("%s@%s", to.User, to.Host)
+					if !server.mainConfig.IsAllowed(to.Host) {
+						responseAdd(client, "550 Error: not allowed")
+					} else {
+						resp := backend.Process(client, to.User, to.Host)
+						responseAdd(client, resp)
+					}
 				} else {
 				} else {
 					responseAdd(client, "550 Error: "+mailErr.Error())
 					responseAdd(client, "550 Error: "+mailErr.Error())
 				}
 				}
@@ -226,8 +232,8 @@ func (server *SmtpdServer) readSmtp(client *guerrilla.Client) (input string, err
 		reply, err = client.Bufin.ReadString('\n')
 		reply, err = client.Bufin.ReadString('\n')
 		if reply != "" {
 		if reply != "" {
 			input = input + reply
 			input = input + reply
-			if len(input) > server.Config.MaxSize {
-				err = fmt.Errorf("Maximum DATA size exceeded (%d)", server.Config.MaxSize)
+			if len(input) > server.config.MaxSize {
+				err = fmt.Errorf("Maximum DATA size exceeded (%d)", server.config.MaxSize)
 				return input, err
 				return input, err
 			}
 			}
 			if client.State == 2 {
 			if client.State == 2 {

+ 13 - 28
util/util.go

@@ -15,41 +15,26 @@ import (
 	"gopkg.in/iconv.v1"
 	"gopkg.in/iconv.v1"
 
 
 	"github.com/sloonz/go-qprintable"
 	"github.com/sloonz/go-qprintable"
-
-	guerrilla "github.com/flashmob/go-guerrilla"
 )
 )
 
 
-var allowedHosts map[string]bool
-
-// map the allow hosts for easy lookup
-func prepareAllowedHosts(allowedHostsStr string) {
-	allowedHosts = make(map[string]bool, 15)
-	if arr := strings.Split(allowedHostsStr, ","); len(arr) > 0 {
-		for i := 0; i < len(arr); i++ {
-			allowedHosts[arr[i]] = true
-		}
-	}
+type EmailParts struct {
+	User string
+	Host string
 }
 }
 
 
-// TODO: cleanup
-func ValidateEmailData(client *guerrilla.Client, allowedHostsStr string) (user string, host string, addr_err error) {
-	if allowedHosts == nil {
-		prepareAllowedHosts(allowedHostsStr)
-	}
+func ValidateEmailData(mailFrom, rcptTo string) (*EmailParts, *EmailParts, error) {
+	var user, host string
+	var addrErr error
 
 
-	if user, host, addr_err = extractEmail(client.MailFrom); addr_err != nil {
-		return user, host, addr_err
-	}
-	client.MailFrom = user + "@" + host
-	if user, host, addr_err = extractEmail(client.RcptTo); addr_err != nil {
-		return user, host, addr_err
+	if user, host, addrErr = extractEmail(mailFrom); addrErr != nil {
+		return nil, nil, addrErr
 	}
 	}
-	client.RcptTo = user + "@" + host
-	// check if on allowed hosts
-	if allowed := allowedHosts[strings.ToLower(host)]; !allowed {
-		return user, host, errors.New("invalid host:" + host)
+	from := &EmailParts{User: user, Host: host}
+	if user, host, addrErr = extractEmail(rcptTo); addrErr != nil {
+		return nil, nil, addrErr
 	}
 	}
-	return user, host, addr_err
+	to := &EmailParts{User: user, Host: host}
+	return from, to, nil
 }
 }
 
 
 var extractEmailRegex, _ = regexp.Compile(`<(.+?)@(.+?)>`) // go home regex, you're drunk!
 var extractEmailRegex, _ = regexp.Compile(`<(.+?)@(.+?)>`) // go home regex, you're drunk!