main.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. package main
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "fmt"
  6. "net"
  7. "net/textproto"
  8. "os"
  9. "os/exec"
  10. "os/signal"
  11. "strings"
  12. "syscall"
  13. "github.com/chrj/smtpd"
  14. "github.com/google/uuid"
  15. "github.com/sirupsen/logrus"
  16. )
  17. func connectionChecker(peer smtpd.Peer) error {
  18. // This can't panic because we only have TCP listeners
  19. peerIP := peer.Addr.(*net.TCPAddr).IP
  20. if len(allowedNets) == 0 {
  21. // Special case: empty string means allow everything
  22. return nil
  23. }
  24. for _, allowedNet := range allowedNets {
  25. if allowedNet.Contains(peerIP) {
  26. return nil
  27. }
  28. }
  29. log.WithFields(logrus.Fields{
  30. "ip": peerIP,
  31. }).Warn("Connection refused from address outside of allowed_nets")
  32. return smtpd.Error{Code: 421, Message: "Denied"}
  33. }
  34. func addrAllowed(addr string, allowedAddrs []string) bool {
  35. if allowedAddrs == nil {
  36. // If absent, all addresses are allowed
  37. return true
  38. }
  39. addr = strings.ToLower(addr)
  40. // Extract optional domain part
  41. domain := ""
  42. if idx := strings.LastIndex(addr, "@"); idx != -1 {
  43. domain = strings.ToLower(addr[idx+1:])
  44. }
  45. // Test each address from allowedUsers file
  46. for _, allowedAddr := range allowedAddrs {
  47. allowedAddr = strings.ToLower(allowedAddr)
  48. // Three cases for allowedAddr format:
  49. if idx := strings.Index(allowedAddr, "@"); idx == -1 {
  50. // 1. local address (no @) -- must match exactly
  51. if allowedAddr == addr {
  52. return true
  53. }
  54. } else {
  55. if idx != 0 {
  56. // 2. email address ([email protected]) -- must match exactly
  57. if allowedAddr == addr {
  58. return true
  59. }
  60. } else {
  61. // 3. domain (@domain.com) -- must match addr domain
  62. allowedDomain := allowedAddr[idx+1:]
  63. if allowedDomain == domain {
  64. return true
  65. }
  66. }
  67. }
  68. }
  69. return false
  70. }
  71. func senderChecker(peer smtpd.Peer, addr string) error {
  72. // check sender address from auth file if user is authenticated
  73. if localAuthRequired() && peer.Username != "" {
  74. user, err := AuthFetch(peer.Username)
  75. if err != nil {
  76. // Shouldn't happen: authChecker already validated username+password
  77. log.WithFields(logrus.Fields{
  78. "peer": peer.Addr,
  79. "username": peer.Username,
  80. }).WithError(err).Warn("could not fetch auth user")
  81. return smtpd.Error{Code: 451, Message: "Bad sender address"}
  82. }
  83. if !addrAllowed(addr, user.allowedAddresses) {
  84. log.WithFields(logrus.Fields{
  85. "peer": peer.Addr,
  86. "username": peer.Username,
  87. "sender_address": addr,
  88. }).Warn("sender address not allowed for authenticated user")
  89. return smtpd.Error{Code: 451, Message: "Bad sender address"}
  90. }
  91. }
  92. if allowedSender == nil {
  93. // Any sender is permitted
  94. return nil
  95. }
  96. if allowedSender.MatchString(addr) {
  97. // Permitted by regex
  98. return nil
  99. }
  100. log.WithFields(logrus.Fields{
  101. "sender_address": addr,
  102. "peer": peer.Addr,
  103. }).Warn("sender address not allowed by allowed_sender pattern")
  104. return smtpd.Error{Code: 451, Message: "Bad sender address"}
  105. }
  106. func recipientChecker(peer smtpd.Peer, addr string) error {
  107. if allowedRecipients == nil {
  108. // Any recipient is permitted
  109. return nil
  110. }
  111. if allowedRecipients.MatchString(addr) {
  112. // Permitted by regex
  113. return nil
  114. }
  115. log.WithFields(logrus.Fields{
  116. "peer": peer.Addr,
  117. "recipient_address": addr,
  118. }).Warn("recipient address not allowed by allowed_recipients pattern")
  119. return smtpd.Error{Code: 451, Message: "Bad recipient address"}
  120. }
  121. func authChecker(peer smtpd.Peer, username string, password string) error {
  122. err := AuthCheckPassword(username, password)
  123. if err != nil {
  124. log.WithFields(logrus.Fields{
  125. "peer": peer.Addr,
  126. "username": username,
  127. }).WithError(err).Warn("auth error")
  128. return smtpd.Error{Code: 535, Message: "Authentication credentials invalid"}
  129. }
  130. return nil
  131. }
  132. func mailHandler(peer smtpd.Peer, env smtpd.Envelope) error {
  133. peerIP := ""
  134. if addr, ok := peer.Addr.(*net.TCPAddr); ok {
  135. peerIP = addr.IP.String()
  136. }
  137. logger := log.WithFields(logrus.Fields{
  138. "from": env.Sender,
  139. "to": env.Recipients,
  140. "peer": peerIP,
  141. "uuid": generateUUID(),
  142. })
  143. if *remotesStr == "" && *command == "" {
  144. logger.Warning("no remote_host or command set; discarding mail")
  145. return nil
  146. }
  147. env.AddReceivedLine(peer)
  148. if *command != "" {
  149. cmdLogger := logger.WithField("command", *command)
  150. var stdout bytes.Buffer
  151. var stderr bytes.Buffer
  152. environ := os.Environ()
  153. environ = append(environ, fmt.Sprintf("%s=%s", "SMTPRELAY_FROM", env.Sender))
  154. environ = append(environ, fmt.Sprintf("%s=%s", "SMTPRELAY_TO", env.Recipients))
  155. environ = append(environ, fmt.Sprintf("%s=%s", "SMTPRELAY_PEER", peerIP))
  156. cmd := exec.Cmd{
  157. Env: environ,
  158. Path: *command,
  159. }
  160. cmd.Stdin = bytes.NewReader(env.Data)
  161. cmd.Stdout = &stdout
  162. cmd.Stderr = &stderr
  163. err := cmd.Run()
  164. if err != nil {
  165. cmdLogger.WithError(err).Error(stderr.String())
  166. return smtpd.Error{Code: 554, Message: "External command failed"}
  167. }
  168. cmdLogger.Info("pipe command successful: " + stdout.String())
  169. }
  170. for _, remote := range remotes {
  171. logger = logger.WithField("host", remote.Addr)
  172. logger.Info("delivering mail from peer using smarthost")
  173. err := SendMail(
  174. remote,
  175. env.Sender,
  176. env.Recipients,
  177. env.Data,
  178. )
  179. if err != nil {
  180. var smtpError smtpd.Error
  181. switch err := err.(type) {
  182. case *textproto.Error:
  183. smtpError = smtpd.Error{Code: err.Code, Message: err.Msg}
  184. logger.WithFields(logrus.Fields{
  185. "err_code": err.Code,
  186. "err_msg": err.Msg,
  187. }).Error("delivery failed")
  188. default:
  189. smtpError = smtpd.Error{Code: 554, Message: "Forwarding failed"}
  190. logger.WithError(err).
  191. Error("delivery failed")
  192. }
  193. return smtpError
  194. }
  195. logger.Debug("delivery successful")
  196. }
  197. return nil
  198. }
  199. func generateUUID() string {
  200. uniqueID, err := uuid.NewRandom()
  201. if err != nil {
  202. log.WithError(err).
  203. Error("could not generate UUIDv4")
  204. return ""
  205. }
  206. return uniqueID.String()
  207. }
  208. func getTLSConfig() *tls.Config {
  209. // Ciphersuites as defined in stock Go but without 3DES and RC4
  210. // https://golang.org/src/crypto/tls/cipher_suites.go
  211. var tlsCipherSuites = []uint16{
  212. tls.TLS_AES_128_GCM_SHA256,
  213. tls.TLS_AES_256_GCM_SHA384,
  214. tls.TLS_CHACHA20_POLY1305_SHA256,
  215. tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  216. tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
  217. tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  218. tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
  219. tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
  220. tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
  221. tls.TLS_RSA_WITH_AES_128_GCM_SHA256, // does not provide PFS
  222. tls.TLS_RSA_WITH_AES_256_GCM_SHA384, // does not provide PFS
  223. }
  224. if *localCert == "" || *localKey == "" {
  225. log.WithFields(logrus.Fields{
  226. "cert_file": *localCert,
  227. "key_file": *localKey,
  228. }).Fatal("TLS certificate/key file not defined in config")
  229. }
  230. cert, err := tls.LoadX509KeyPair(*localCert, *localKey)
  231. if err != nil {
  232. log.WithField("error", err).
  233. Fatal("cannot load X509 keypair")
  234. }
  235. return &tls.Config{
  236. PreferServerCipherSuites: true,
  237. MinVersion: tls.VersionTLS12,
  238. CipherSuites: tlsCipherSuites,
  239. Certificates: []tls.Certificate{cert},
  240. }
  241. }
  242. func main() {
  243. ConfigLoad()
  244. log.WithField("version", appVersion).
  245. Debug("starting smtprelay")
  246. // Load allowed users file
  247. if localAuthRequired() {
  248. err := AuthLoadFile(*allowedUsers)
  249. if err != nil {
  250. log.WithField("file", *allowedUsers).
  251. WithError(err).
  252. Fatal("cannot load allowed users file")
  253. }
  254. }
  255. var servers []*smtpd.Server
  256. // Create a server for each desired listen address
  257. for _, listen := range listenAddrs {
  258. logger := log.WithField("address", listen.address)
  259. server := &smtpd.Server{
  260. Hostname: *hostName,
  261. WelcomeMessage: *welcomeMsg,
  262. ReadTimeout: readTimeout,
  263. WriteTimeout: writeTimeout,
  264. DataTimeout: dataTimeout,
  265. MaxConnections: *maxConnections,
  266. MaxMessageSize: *maxMessageSize,
  267. MaxRecipients: *maxRecipients,
  268. ConnectionChecker: connectionChecker,
  269. SenderChecker: senderChecker,
  270. RecipientChecker: recipientChecker,
  271. Handler: mailHandler,
  272. }
  273. if localAuthRequired() {
  274. server.Authenticator = authChecker
  275. }
  276. var lsnr net.Listener
  277. var err error
  278. switch listen.protocol {
  279. case "":
  280. logger.Info("listening on address")
  281. lsnr, err = net.Listen("tcp", listen.address)
  282. case "starttls":
  283. server.TLSConfig = getTLSConfig()
  284. server.ForceTLS = *localForceTLS
  285. logger.Info("listening on address (STARTTLS)")
  286. lsnr, err = net.Listen("tcp", listen.address)
  287. case "tls":
  288. server.TLSConfig = getTLSConfig()
  289. logger.Info("listening on address (TLS)")
  290. lsnr, err = tls.Listen("tcp", listen.address, server.TLSConfig)
  291. default:
  292. logger.WithField("protocol", listen.protocol).
  293. Fatal("unknown protocol in listen address")
  294. }
  295. if err != nil {
  296. logger.WithError(err).Fatal("error starting listener")
  297. }
  298. servers = append(servers, server)
  299. go func() {
  300. server.Serve(lsnr)
  301. }()
  302. }
  303. handleSignals()
  304. // First close the listeners
  305. for _, server := range servers {
  306. logger := log.WithField("address", server.Address())
  307. logger.Debug("Shutting down server")
  308. err := server.Shutdown(false)
  309. if err != nil {
  310. logger.WithError(err).
  311. Warning("Shutdown failed")
  312. }
  313. }
  314. // Then wait for the clients to exit
  315. for _, server := range servers {
  316. logger := log.WithField("address", server.Address())
  317. logger.Debug("Waiting for server")
  318. err := server.Wait()
  319. if err != nil {
  320. logger.WithError(err).
  321. Warning("Wait failed")
  322. }
  323. }
  324. log.Debug("done")
  325. }
  326. func handleSignals() {
  327. // Wait for SIGINT, SIGQUIT, or SIGTERM
  328. sigs := make(chan os.Signal, 1)
  329. signal.Notify(sigs, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM)
  330. sig := <-sigs
  331. log.WithField("signal", sig).
  332. Info("shutting down in response to received signal")
  333. }