ソースを参照

Refactor parsing of -listen string out into separate config function

This makes the "for each listen address" loop in main() look even cleaner.
Jonathon Reinhart 4 年 前
コミット
ca1ccd85e3
3 ファイル変更91 行追加23 行削除
  1. 30 2
      config.go
  2. 44 0
      config_test.go
  3. 17 21
      main.go

+ 30 - 2
config.go

@@ -3,8 +3,9 @@ package main
 import (
 	"flag"
 	"net"
-	"regexp"
 	"net/smtp"
+	"regexp"
+	"strings"
 
 	"github.com/vharitonsky/iniflags"
 	"github.com/sirupsen/logrus"
@@ -21,7 +22,8 @@ var (
 	logLevel          = flag.String("log_level", "info", "Minimum log level to output")
 	hostName          = flag.String("hostname", "localhost.localdomain", "Server hostname")
 	welcomeMsg        = flag.String("welcome_msg", "", "Welcome message for SMTP session")
-	listen            = flag.String("listen", "127.0.0.1:25 [::1]:25", "Address and port to listen for incoming SMTP")
+	listenStr         = flag.String("listen", "127.0.0.1:25 [::1]:25", "Address and port to listen for incoming SMTP")
+	listenAddrs       = []protoAddr{}
 	localCert         = flag.String("local_cert", "", "SSL certificate for STARTTLS/TLS")
 	localKey          = flag.String("local_key", "", "SSL private key for STARTTLS/TLS")
 	localForceTLS     = flag.Bool("local_forcetls", false, "Force STARTTLS (needs local_cert and local_key)")
@@ -130,6 +132,31 @@ func setupRemoteAuth() {
 	}
 }
 
+type protoAddr struct {
+	protocol string
+	address  string
+}
+
+func splitProto(s string) protoAddr {
+	idx := strings.Index(s, "://")
+	if idx == -1 {
+		return protoAddr {
+			address:  s,
+		}
+	}
+	return protoAddr {
+		protocol: s[0 : idx],
+		address:  s[idx+3 : len(s)],
+	}
+}
+
+func setupListeners() {
+	for _, listenAddr := range strings.Split(*listenStr, " ") {
+		pa := splitProto(listenAddr)
+		listenAddrs = append(listenAddrs, pa)
+	}
+}
+
 func ConfigLoad() {
 	iniflags.Parse()
 
@@ -143,4 +170,5 @@ func ConfigLoad() {
 	setupAllowedNetworks()
 	setupAllowedPatterns()
 	setupRemoteAuth()
+	setupListeners()
 }

+ 44 - 0
config_test.go

@@ -0,0 +1,44 @@
+package main
+
+import (
+	"testing"
+)
+
+func TestSplitProto(t *testing.T) {
+	var tests = []struct {
+		input      string
+		proto      string
+		addr       string
+	}{
+		{
+			input:      "localhost",
+			proto:      "",
+			addr:       "localhost",
+		},
+		{
+			input:      "tls://my.local.domain",
+			proto:      "tls",
+			addr:       "my.local.domain",
+		},
+		{
+			input:      "starttls://my.local.domain",
+			proto:      "starttls",
+			addr:       "my.local.domain",
+		},
+	}
+
+	for i, test := range tests {
+		testName := test.input
+		t.Run(testName, func(t *testing.T) {
+			pa := splitProto(test.input)
+			if pa.protocol != test.proto {
+				t.Errorf("Testcase %d: Incorrect proto: expected %v, got %v",
+					i, test.proto, pa.protocol)
+			}
+			if pa.address != test.addr {
+				t.Errorf("Testcase %d: Incorrect addr: expected %v, got %v",
+					i, test.addr, pa.address)
+			}
+		})
+	}
+}

+ 17 - 21
main.go

@@ -288,7 +288,9 @@ func main() {
 	var servers []*smtpd.Server
 
 	// Create a server for each desired listen address
-	for _, listenAddr := range strings.Split(*listen, " ") {
+	for _, listen := range listenAddrs {
+		logger := log.WithField("address", listen.address)
+
 		server := &smtpd.Server{
 			Hostname:          *hostName,
 			WelcomeMessage:    *welcomeMsg,
@@ -305,37 +307,31 @@ func main() {
 		var lsnr net.Listener
 		var err error
 
-		if strings.Index(listenAddr, "://") == -1 {
-			log.WithField("address", listenAddr).
-				Info("listening on address")
-
-			lsnr, err = net.Listen("tcp", listenAddr)
-		} else if strings.HasPrefix(listenAddr, "starttls://") {
-			listenAddr = strings.TrimPrefix(listenAddr, "starttls://")
+		switch listen.protocol {
+		case "":
+			logger.Info("listening on address")
+			lsnr, err = net.Listen("tcp", listen.address)
 
+		case "starttls":
 			server.TLSConfig = getTLSConfig()
 			server.ForceTLS = *localForceTLS
 
-			log.WithField("address", listenAddr).
-				Info("listening on address (STARTTLS)")
-			lsnr, err = net.Listen("tcp", listenAddr)
-		} else if strings.HasPrefix(listenAddr, "tls://") {
-			listenAddr = strings.TrimPrefix(listenAddr, "tls://")
+			logger.Info("listening on address (STARTTLS)")
+			lsnr, err = net.Listen("tcp", listen.address)
 
+		case "tls":
 			server.TLSConfig = getTLSConfig()
 
-			log.WithField("address", listenAddr).
-				Info("listening on address (TLS)")
-			lsnr, err = tls.Listen("tcp", listenAddr, server.TLSConfig)
-		} else {
-			log.WithField("address", listenAddr).
+			logger.Info("listening on address (TLS)")
+			lsnr, err = tls.Listen("tcp", listen.address, server.TLSConfig)
+
+		default:
+			logger.WithField("protocol", listen.protocol).
 				Fatal("unknown protocol in listen address")
 		}
 
 		if err != nil {
-			log.WithFields(logrus.Fields{
-				"address": listenAddr,
-			}).WithError(err).Fatal("error starting listener")
+			logger.WithError(err).Fatal("error starting listener")
 		}
 		servers = append(servers, server)