ソースを参照

Move parsing of "allowed_nets" out to ConfigLoad()

This has several benefits:
- Configuration errors are caught at startup rather than upon a connection
- connectionChecker() has less work to do for each connection
Jonathon Reinhart 4 年 前
コミット
ef3f9c8ea0
2 ファイル変更20 行追加5 行削除
  1. 19 1
      config.go
  2. 1 4
      main.go

+ 19 - 1
config.go

@@ -2,6 +2,7 @@ package main
 
 import (
 	"flag"
+	"net"
 
 	"github.com/vharitonsky/iniflags"
 )
@@ -21,7 +22,8 @@ var (
 	localCert         = flag.String("local_cert", "", "SSL certificate for STARTTLS/TLS")
 	localKey          = flag.String("local_key", "", "SSL private key for STARTTLS/TLS")
 	localForceTLS     = flag.Bool("local_forcetls", false, "Force STARTTLS (needs local_cert and local_key)")
-	allowedNets       = flag.String("allowed_nets", "127.0.0.1/8 ::1/128", "Networks allowed to send mails")
+	allowedNetsStr    = flag.String("allowed_nets", "127.0.0.1/8 ::1/128", "Networks allowed to send mails")
+	allowedNets       = []*net.IPNet{}
 	allowedSender     = flag.String("allowed_sender", "", "Regular expression for valid FROM EMail addresses")
 	allowedRecipients = flag.String("allowed_recipients", "", "Regular expression for valid TO EMail addresses")
 	allowedUsers      = flag.String("allowed_users", "", "Path to file with valid users/passwords")
@@ -33,6 +35,20 @@ var (
 	versionInfo       = flag.Bool("version", false, "Show version information")
 )
 
+
+func setupAllowedNetworks() {
+	for _, netstr := range splitstr(*allowedNetsStr, ' ') {
+		_, allowedNet, err := net.ParseCIDR(netstr)
+		if err != nil {
+			log.WithField("netstr", netstr).
+				WithError(err).
+				Fatal("Invalid CIDR notation in allowed_nets")
+		}
+
+		allowedNets = append(allowedNets, allowedNet)
+	}
+}
+
 func ConfigLoad() {
 	iniflags.Parse()
 
@@ -42,4 +58,6 @@ func ConfigLoad() {
 	if (*remoteHost == "") {
 		log.Warn("remote_host not set; mail will not be forwarded!")
 	}
+
+	setupAllowedNetworks()
 }

+ 1 - 4
main.go

@@ -20,11 +20,8 @@ func connectionChecker(peer smtpd.Peer) error {
 	// This can't panic because we only have TCP listeners
 	peerIP := peer.Addr.(*net.TCPAddr).IP
 
-	nets := strings.Split(*allowedNets, " ")
-
-	for i := range nets {
-		_, allowedNet, _ := net.ParseCIDR(nets[i])
 
+	for _, allowedNet := range allowedNets {
 		if allowedNet.Contains(peerIP) {
 			return nil
 		}