123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403 |
- package main
- import (
- "bytes"
- "crypto/tls"
- "fmt"
- "net"
- "net/textproto"
- "os"
- "os/exec"
- "os/signal"
- "strings"
- "syscall"
- "github.com/chrj/smtpd"
- "github.com/google/uuid"
- "github.com/sirupsen/logrus"
- )
- func connectionChecker(peer smtpd.Peer) error {
- // This can't panic because we only have TCP listeners
- peerIP := peer.Addr.(*net.TCPAddr).IP
- if len(allowedNets) == 0 {
- // Special case: empty string means allow everything
- return nil
- }
- for _, allowedNet := range allowedNets {
- if allowedNet.Contains(peerIP) {
- return nil
- }
- }
- log.WithFields(logrus.Fields{
- "ip": peerIP,
- }).Warn("Connection refused from address outside of allowed_nets")
- return smtpd.Error{Code: 421, Message: "Denied"}
- }
- func addrAllowed(addr string, allowedAddrs []string) bool {
- if allowedAddrs == nil {
- // If absent, all addresses are allowed
- return true
- }
- addr = strings.ToLower(addr)
- // Extract optional domain part
- domain := ""
- if idx := strings.LastIndex(addr, "@"); idx != -1 {
- domain = strings.ToLower(addr[idx+1:])
- }
- // Test each address from allowedUsers file
- for _, allowedAddr := range allowedAddrs {
- allowedAddr = strings.ToLower(allowedAddr)
- // Three cases for allowedAddr format:
- if idx := strings.Index(allowedAddr, "@"); idx == -1 {
- // 1. local address (no @) -- must match exactly
- if allowedAddr == addr {
- return true
- }
- } else {
- if idx != 0 {
- // 2. email address ([email protected]) -- must match exactly
- if allowedAddr == addr {
- return true
- }
- } else {
- // 3. domain (@domain.com) -- must match addr domain
- allowedDomain := allowedAddr[idx+1:]
- if allowedDomain == domain {
- return true
- }
- }
- }
- }
- return false
- }
- func senderChecker(peer smtpd.Peer, addr string) error {
- // check sender address from auth file if user is authenticated
- if localAuthRequired() && peer.Username != "" {
- user, err := AuthFetch(peer.Username)
- if err != nil {
- // Shouldn't happen: authChecker already validated username+password
- log.WithFields(logrus.Fields{
- "peer": peer.Addr,
- "username": peer.Username,
- }).WithError(err).Warn("could not fetch auth user")
- return smtpd.Error{Code: 451, Message: "Bad sender address"}
- }
- if !addrAllowed(addr, user.allowedAddresses) {
- log.WithFields(logrus.Fields{
- "peer": peer.Addr,
- "username": peer.Username,
- "sender_address": addr,
- }).Warn("sender address not allowed for authenticated user")
- return smtpd.Error{Code: 451, Message: "Bad sender address"}
- }
- }
- if allowedSender == nil {
- // Any sender is permitted
- return nil
- }
- if allowedSender.MatchString(addr) {
- // Permitted by regex
- return nil
- }
- log.WithFields(logrus.Fields{
- "sender_address": addr,
- "peer": peer.Addr,
- }).Warn("sender address not allowed by allowed_sender pattern")
- return smtpd.Error{Code: 451, Message: "Bad sender address"}
- }
- func recipientChecker(peer smtpd.Peer, addr string) error {
- if allowedRecipients == nil {
- // Any recipient is permitted
- return nil
- }
- if allowedRecipients.MatchString(addr) {
- // Permitted by regex
- return nil
- }
- log.WithFields(logrus.Fields{
- "peer": peer.Addr,
- "recipient_address": addr,
- }).Warn("recipient address not allowed by allowed_recipients pattern")
- return smtpd.Error{Code: 451, Message: "Bad recipient address"}
- }
- func authChecker(peer smtpd.Peer, username string, password string) error {
- err := AuthCheckPassword(username, password)
- if err != nil {
- log.WithFields(logrus.Fields{
- "peer": peer.Addr,
- "username": username,
- }).WithError(err).Warn("auth error")
- return smtpd.Error{Code: 535, Message: "Authentication credentials invalid"}
- }
- return nil
- }
- func mailHandler(peer smtpd.Peer, env smtpd.Envelope) error {
- peerIP := ""
- if addr, ok := peer.Addr.(*net.TCPAddr); ok {
- peerIP = addr.IP.String()
- }
- logger := log.WithFields(logrus.Fields{
- "from": env.Sender,
- "to": env.Recipients,
- "peer": peerIP,
- "uuid": generateUUID(),
- })
- if *remotesStr == "" && *command == "" {
- logger.Warning("no remote_host or command set; discarding mail")
- return nil
- }
- env.AddReceivedLine(peer)
- if *command != "" {
- cmdLogger := logger.WithField("command", *command)
- var stdout bytes.Buffer
- var stderr bytes.Buffer
- environ := os.Environ()
- environ = append(environ, fmt.Sprintf("%s=%s", "SMTPRELAY_FROM", env.Sender))
- environ = append(environ, fmt.Sprintf("%s=%s", "SMTPRELAY_TO", env.Recipients))
- environ = append(environ, fmt.Sprintf("%s=%s", "SMTPRELAY_PEER", peerIP))
- cmd := exec.Cmd{
- Env: environ,
- Path: *command,
- }
- cmd.Stdin = bytes.NewReader(env.Data)
- cmd.Stdout = &stdout
- cmd.Stderr = &stderr
- err := cmd.Run()
- if err != nil {
- cmdLogger.WithError(err).Error(stderr.String())
- return smtpd.Error{Code: 554, Message: "External command failed"}
- }
- cmdLogger.Info("pipe command successful: " + stdout.String())
- }
- for _, remote := range remotes {
- logger = logger.WithField("host", remote.Addr)
- logger.Info("delivering mail from peer using smarthost")
- err := SendMail(
- remote,
- env.Sender,
- env.Recipients,
- env.Data,
- )
- if err != nil {
- var smtpError smtpd.Error
- switch err := err.(type) {
- case *textproto.Error:
- smtpError = smtpd.Error{Code: err.Code, Message: err.Msg}
- logger.WithFields(logrus.Fields{
- "err_code": err.Code,
- "err_msg": err.Msg,
- }).Error("delivery failed")
- default:
- smtpError = smtpd.Error{Code: 554, Message: "Forwarding failed"}
- logger.WithError(err).
- Error("delivery failed")
- }
- return smtpError
- }
- logger.Debug("delivery successful")
- }
- return nil
- }
- func generateUUID() string {
- uniqueID, err := uuid.NewRandom()
- if err != nil {
- log.WithError(err).
- Error("could not generate UUIDv4")
- return ""
- }
- return uniqueID.String()
- }
- func getTLSConfig() *tls.Config {
- // Ciphersuites as defined in stock Go but without 3DES and RC4
- // https://golang.org/src/crypto/tls/cipher_suites.go
- var tlsCipherSuites = []uint16{
- tls.TLS_AES_128_GCM_SHA256,
- tls.TLS_AES_256_GCM_SHA384,
- tls.TLS_CHACHA20_POLY1305_SHA256,
- tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
- tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
- tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
- tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
- tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
- tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
- tls.TLS_RSA_WITH_AES_128_GCM_SHA256, // does not provide PFS
- tls.TLS_RSA_WITH_AES_256_GCM_SHA384, // does not provide PFS
- }
- if *localCert == "" || *localKey == "" {
- log.WithFields(logrus.Fields{
- "cert_file": *localCert,
- "key_file": *localKey,
- }).Fatal("TLS certificate/key file not defined in config")
- }
- cert, err := tls.LoadX509KeyPair(*localCert, *localKey)
- if err != nil {
- log.WithField("error", err).
- Fatal("cannot load X509 keypair")
- }
- return &tls.Config{
- PreferServerCipherSuites: true,
- MinVersion: tls.VersionTLS12,
- CipherSuites: tlsCipherSuites,
- Certificates: []tls.Certificate{cert},
- }
- }
- func main() {
- ConfigLoad()
- log.WithField("version", appVersion).
- Debug("starting smtprelay")
- // Load allowed users file
- if localAuthRequired() {
- err := AuthLoadFile(*allowedUsers)
- if err != nil {
- log.WithField("file", *allowedUsers).
- WithError(err).
- Fatal("cannot load allowed users file")
- }
- }
- var servers []*smtpd.Server
- // Create a server for each desired listen address
- for _, listen := range listenAddrs {
- logger := log.WithField("address", listen.address)
- server := &smtpd.Server{
- Hostname: *hostName,
- WelcomeMessage: *welcomeMsg,
- ReadTimeout: readTimeout,
- WriteTimeout: writeTimeout,
- DataTimeout: dataTimeout,
- MaxConnections: *maxConnections,
- MaxMessageSize: *maxMessageSize,
- MaxRecipients: *maxRecipients,
- ConnectionChecker: connectionChecker,
- SenderChecker: senderChecker,
- RecipientChecker: recipientChecker,
- Handler: mailHandler,
- }
- if localAuthRequired() {
- server.Authenticator = authChecker
- }
- var lsnr net.Listener
- var err error
- switch listen.protocol {
- case "":
- logger.Info("listening on address")
- lsnr, err = net.Listen("tcp", listen.address)
- case "starttls":
- server.TLSConfig = getTLSConfig()
- server.ForceTLS = *localForceTLS
- logger.Info("listening on address (STARTTLS)")
- lsnr, err = net.Listen("tcp", listen.address)
- case "tls":
- server.TLSConfig = getTLSConfig()
- logger.Info("listening on address (TLS)")
- lsnr, err = tls.Listen("tcp", listen.address, server.TLSConfig)
- default:
- logger.WithField("protocol", listen.protocol).
- Fatal("unknown protocol in listen address")
- }
- if err != nil {
- logger.WithError(err).Fatal("error starting listener")
- }
- servers = append(servers, server)
- go func() {
- server.Serve(lsnr)
- }()
- }
- handleSignals()
- // First close the listeners
- for _, server := range servers {
- logger := log.WithField("address", server.Address())
- logger.Debug("Shutting down server")
- err := server.Shutdown(false)
- if err != nil {
- logger.WithError(err).
- Warning("Shutdown failed")
- }
- }
- // Then wait for the clients to exit
- for _, server := range servers {
- logger := log.WithField("address", server.Address())
- logger.Debug("Waiting for server")
- err := server.Wait()
- if err != nil {
- logger.WithError(err).
- Warning("Wait failed")
- }
- }
- log.Debug("done")
- }
- func handleSignals() {
- // Wait for SIGINT, SIGQUIT, or SIGTERM
- sigs := make(chan os.Signal, 1)
- signal.Notify(sigs, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM)
- sig := <-sigs
- log.WithField("signal", sig).
- Info("shutting down in response to received signal")
- }
|