Forráskód Böngészése

Expand allowedUsers email field to support comma-separated and domains (#9)

* Expand allowedUsers email field to support comma-separated and domains

Closes #8

* Refactor AuthFetch() to return AuthUser struct

Also, this breaks out a parseLine() function which can be easily tested.

* Ignore empty addrs after splitting commas

This ignores a trailing comma

* Add tests for auth parseLine()

* Update documentation in smtprelay.ini

* Fix bug where addrAllowed() was incorrectly case-sensitive

* Update allowedUsers allowed domain format to require leading @

This disambiguates a local user ('john.smith') from a domain ('example.com')
Jonathon Reinhart 4 éve
szülő
commit
0e8986ca79
5 módosított fájl, 278 hozzáadás és 23 törlés
  1. 42 20
      auth.go
  2. 89 0
      auth_test.go
  3. 45 2
      main.go
  4. 94 0
      main_test.go
  5. 8 1
      smtprelay.ini

+ 42 - 20
auth.go

@@ -13,6 +13,12 @@ var (
 	filename string
 )
 
+type AuthUser struct {
+	username string
+	passwordHash string
+	allowedAddresses []string
+}
+
 func AuthLoadFile(file string) error {
 	f, err := os.Open(file)
 	if err != nil {
@@ -28,50 +34,66 @@ func AuthReady() bool {
 	return (filename != "")
 }
 
-// Returns bcrypt-hash, email
-// email can be empty in which case it is not checked
-func AuthFetch(username string) (string, string, error) {
+// Split a string and ignore empty results
+// https://stackoverflow.com/a/46798310/119527
+func splitstr(s string, sep rune) []string {
+	return strings.FieldsFunc(s, func(c rune) bool { return c == sep })
+}
+
+func parseLine(line string) *AuthUser {
+	parts := strings.Fields(line)
+
+	if len(parts) < 2 || len(parts) > 3 {
+		return nil
+	}
+
+	user := AuthUser{
+		username: parts[0],
+		passwordHash: parts[1],
+		allowedAddresses: nil,
+	}
+
+	if len(parts) >= 3 {
+		user.allowedAddresses = splitstr(parts[2], ',')
+	}
+
+	return &user
+}
+
+func AuthFetch(username string) (*AuthUser, error) {
 	if !AuthReady() {
-		return "", "", errors.New("Authentication file not specified. Call LoadFile() first")
+		return nil, errors.New("Authentication file not specified. Call LoadFile() first")
 	}
 
 	file, err := os.Open(filename)
 	if err != nil {
-		return "", "", err
+		return nil, err
 	}
 	defer file.Close()
 
 	scanner := bufio.NewScanner(file)
 	for scanner.Scan() {
-		parts := strings.Fields(scanner.Text())
-
-		if len(parts) < 2 || len(parts) > 3 {
+		user := parseLine(scanner.Text())
+		if user == nil {
 			continue
 		}
 
-		if strings.ToLower(username) != strings.ToLower(parts[0]) {
+		if strings.ToLower(username) != strings.ToLower(user.username) {
 			continue
 		}
 
-		hash := parts[1]
-		email := ""
-
-		if len(parts) >= 3 {
-			email = parts[2]
-		}
-
-		return hash, email, nil
+		return user, nil
 	}
 
-	return "", "", errors.New("User not found")
+	return nil, errors.New("User not found")
 }
 
 func AuthCheckPassword(username string, secret string) error {
-	hash, _, err := AuthFetch(username)
+	user, err := AuthFetch(username)
 	if err != nil {
 		return err
 	}
-	if bcrypt.CompareHashAndPassword([]byte(hash), []byte(secret)) == nil {
+	if bcrypt.CompareHashAndPassword([]byte(user.passwordHash), []byte(secret)) == nil {
 		return nil
 	}
 	return errors.New("Password invalid")

+ 89 - 0
auth_test.go

@@ -0,0 +1,89 @@
+package main
+
+import (
+	"testing"
+)
+
+func stringsEqual(a, b []string) bool {
+	if len(a) != len(b) {
+		return false
+	}
+	for i, _ := range a {
+		if a[i] != b[i] {
+			return false
+		}
+	}
+	return true
+}
+
+func TestParseLine(t *testing.T) {
+	var tests = []struct {
+		name string
+		expectFail bool
+		line string
+		username string
+		addrs []string
+	}{
+		{
+			name: "Empty line",
+			expectFail: true,
+			line: "",
+		},
+		{
+			name: "Too few fields",
+			expectFail: true,
+			line: "joe",
+		},
+		{
+			name: "Too many fields",
+			expectFail: true,
+			line: "joe xxx [email protected] whatsthis",
+		},
+		{
+			name: "Normal case",
+			line: "joe xxx [email protected]",
+			username: "joe",
+			addrs: []string{"[email protected]"},
+		},
+		{
+			name: "No allowed addrs given",
+			line: "joe xxx",
+			username: "joe",
+			addrs: []string{},
+		},
+		{
+			name: "Trailing comma",
+			line: "joe xxx [email protected],",
+			username: "joe",
+			addrs: []string{"[email protected]"},
+		},
+		{
+			name: "Multiple allowed addrs",
+			line: "joe xxx [email protected],@foo.example.com",
+			username: "joe",
+			addrs: []string{"[email protected]", "@foo.example.com"},
+		},
+	}
+
+	for i, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			user := parseLine(test.line)
+			if user == nil {
+				if !test.expectFail {
+					t.Errorf("parseLine() returned nil unexpectedly")
+				}
+				return
+			}
+
+			if user.username != test.username {
+				t.Errorf("Testcase %d: Incorrect username: expected %v, got %v",
+						 i, test.username, user.username)
+			}
+
+			if !stringsEqual(user.allowedAddresses, test.addrs) {
+				t.Errorf("Testcase %d: Incorrect addresses: expected %v, got %v",
+						 i, test.addrs, user.allowedAddresses)
+			}
+		})
+	}
+}

+ 45 - 2
main.go

@@ -36,15 +36,58 @@ func connectionChecker(peer smtpd.Peer) error {
 	return smtpd.Error{Code: 421, Message: "Denied"}
 }
 
+func addrAllowed(addr string, allowedAddrs []string) bool {
+	if allowedAddrs == nil {
+		// If absent, all addresses are allowed
+		return true
+	}
+
+	addr = strings.ToLower(addr)
+
+	// Extract optional domain part
+	domain := ""
+	if idx := strings.LastIndex(addr, "@"); idx != -1 {
+		domain = strings.ToLower(addr[idx+1:])
+	}
+
+	// Test each address from allowedUsers file
+	for _, allowedAddr := range allowedAddrs {
+		allowedAddr = strings.ToLower(allowedAddr)
+
+		// Three cases for allowedAddr format:
+		if idx := strings.Index(allowedAddr, "@"); idx == -1 {
+			// 1. local address (no @) -- must match exactly
+			if allowedAddr == addr {
+				return true
+			}
+		} else {
+			if idx != 0 {
+				// 2. email address ([email protected]) -- must match exactly
+				if allowedAddr == addr {
+					return true
+				}
+			} else {
+				// 3. domain (@domain.com) -- must match addr domain
+				allowedDomain := allowedAddr[idx+1:]
+				if allowedDomain == domain {
+					return true
+				}
+			}
+		}
+	}
+
+	return false
+}
+
 func senderChecker(peer smtpd.Peer, addr string) error {
 	// check sender address from auth file if user is authenticated
 	if *allowedUsers != "" && peer.Username != "" {
-		_, email, err := AuthFetch(peer.Username)
+		user, err := AuthFetch(peer.Username)
 		if err != nil {
 			return smtpd.Error{Code: 451, Message: "Bad sender address"}
 		}
 
-		if email != "" && strings.ToLower(addr) != strings.ToLower(email) {
+		if !addrAllowed(addr, user.allowedAddresses) {
 			return smtpd.Error{Code: 451, Message: "Bad sender address"}
 		}
 	}

+ 94 - 0
main_test.go

@@ -0,0 +1,94 @@
+package main
+
+import (
+	"testing"
+)
+
+func TestAddrAllowedNoDomain(t *testing.T) {
+	allowedAddrs := []string{"[email protected]"}
+	if addrAllowed("bob.com", allowedAddrs) {
+		t.FailNow()
+	}
+}
+
+func TestAddrAllowedSingle(t *testing.T) {
+	allowedAddrs := []string{"[email protected]"}
+
+	if !addrAllowed("[email protected]", allowedAddrs) {
+		t.FailNow()
+	}
+	if addrAllowed("[email protected]", allowedAddrs) {
+		t.FailNow()
+	}
+}
+
+func TestAddrAllowedDifferentCase(t *testing.T) {
+	allowedAddrs := []string{"[email protected]"}
+    testAddrs := []string{
+        "[email protected]",
+        "[email protected]",
+        "[email protected]",
+        "[email protected]",
+    }
+    for _, addr := range testAddrs {
+        if !addrAllowed(addr, allowedAddrs) {
+            t.Errorf("Address %v not allowed, but should be", addr)
+        }
+    }
+}
+
+func TestAddrAllowedLocal(t *testing.T) {
+	allowedAddrs := []string{"joe"}
+
+	if !addrAllowed("joe", allowedAddrs) {
+		t.FailNow()
+	}
+	if addrAllowed("bob", allowedAddrs) {
+		t.FailNow()
+	}
+}
+
+func TestAddrAllowedMulti(t *testing.T) {
+	allowedAddrs := []string{"[email protected]", "[email protected]"}
+	if !addrAllowed("[email protected]", allowedAddrs) {
+		t.FailNow()
+	}
+	if !addrAllowed("[email protected]", allowedAddrs) {
+		t.FailNow()
+	}
+	if addrAllowed("[email protected]", allowedAddrs) {
+		t.FailNow()
+	}
+}
+
+func TestAddrAllowedSingleDomain(t *testing.T) {
+	allowedAddrs := []string{"@abc.com"}
+	if !addrAllowed("[email protected]", allowedAddrs) {
+		t.FailNow()
+	}
+	if addrAllowed("[email protected]", allowedAddrs) {
+		t.FailNow()
+	}
+}
+
+func TestAddrAllowedMixed(t *testing.T) {
+	allowedAddrs := []string{"app", "[email protected]", "@appsrv.example.com"}
+	if !addrAllowed("app", allowedAddrs) {
+		t.FailNow()
+	}
+	if !addrAllowed("[email protected]", allowedAddrs) {
+		t.FailNow()
+	}
+	if addrAllowed("[email protected]", allowedAddrs) {
+		t.FailNow()
+	}
+	if !addrAllowed("[email protected]", allowedAddrs) {
+		t.FailNow()
+	}
+	if !addrAllowed("[email protected]", allowedAddrs) {
+		t.FailNow()
+	}
+	if addrAllowed("[email protected]", allowedAddrs) {
+		t.FailNow()
+	}
+}

+ 8 - 1
smtprelay.ini

@@ -37,7 +37,14 @@
 
 ; File which contains username and password used for
 ; authentication before they can send mail.
-; File format: username bcrypt-hash [email]
+; File format: username bcrypt-hash [email[,email[,...]]]
+;   username: The SMTP auth username
+;   bcrypt-hash: The bcrypt hash of the pasword (generate with "./hasher password")
+;   email: Comma-separated list of allowed "from" addresses:
+;          - If omitted, user can send from any address
+;          - If @domain.com is given, user can send from any address @domain.com
+;          - Otherwise, email address must match exactly (case-insensitive)
+;          E.g. "[email protected],@appsrv.example.com"
 ;allowed_users =
 
 ; Relay all mails to this SMTP server