main.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. package main
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "fmt"
  6. "net"
  7. "net/textproto"
  8. "os"
  9. "os/exec"
  10. "os/signal"
  11. "strings"
  12. "syscall"
  13. "github.com/chrj/smtpd"
  14. "github.com/google/uuid"
  15. "github.com/sirupsen/logrus"
  16. )
  17. func connectionChecker(peer smtpd.Peer) error {
  18. // This can't panic because we only have TCP listeners
  19. peerIP := peer.Addr.(*net.TCPAddr).IP
  20. if len(allowedNets) == 0 {
  21. // Special case: empty string means allow everything
  22. return nil
  23. }
  24. for _, allowedNet := range allowedNets {
  25. if allowedNet.Contains(peerIP) {
  26. return nil
  27. }
  28. }
  29. log.WithFields(logrus.Fields{
  30. "ip": peerIP,
  31. }).Warn("Connection refused from address outside of allowed_nets")
  32. return smtpd.Error{Code: 421, Message: "Denied"}
  33. }
  34. func addrAllowed(addr string, allowedAddrs []string) bool {
  35. if allowedAddrs == nil {
  36. // If absent, all addresses are allowed
  37. return true
  38. }
  39. addr = strings.ToLower(addr)
  40. // Extract optional domain part
  41. domain := ""
  42. if idx := strings.LastIndex(addr, "@"); idx != -1 {
  43. domain = strings.ToLower(addr[idx+1:])
  44. }
  45. // Test each address from allowedUsers file
  46. for _, allowedAddr := range allowedAddrs {
  47. allowedAddr = strings.ToLower(allowedAddr)
  48. // Three cases for allowedAddr format:
  49. if idx := strings.Index(allowedAddr, "@"); idx == -1 {
  50. // 1. local address (no @) -- must match exactly
  51. if allowedAddr == addr {
  52. return true
  53. }
  54. } else {
  55. if idx != 0 {
  56. // 2. email address ([email protected]) -- must match exactly
  57. if allowedAddr == addr {
  58. return true
  59. }
  60. } else {
  61. // 3. domain (@domain.com) -- must match addr domain
  62. allowedDomain := allowedAddr[idx+1:]
  63. if allowedDomain == domain {
  64. return true
  65. }
  66. }
  67. }
  68. }
  69. return false
  70. }
  71. func senderChecker(peer smtpd.Peer, addr string) error {
  72. // check sender address from auth file if user is authenticated
  73. if localAuthRequired() && peer.Username != "" {
  74. user, err := AuthFetch(peer.Username)
  75. if err != nil {
  76. // Shouldn't happen: authChecker already validated username+password
  77. log.WithFields(logrus.Fields{
  78. "peer": peer.Addr,
  79. "username": peer.Username,
  80. }).WithError(err).Warn("could not fetch auth user")
  81. return smtpd.Error{Code: 451, Message: "Bad sender address"}
  82. }
  83. if !addrAllowed(addr, user.allowedAddresses) {
  84. log.WithFields(logrus.Fields{
  85. "peer": peer.Addr,
  86. "username": peer.Username,
  87. "sender_address": addr,
  88. }).Warn("sender address not allowed for authenticated user")
  89. return smtpd.Error{Code: 451, Message: "Bad sender address"}
  90. }
  91. }
  92. if allowedSender == nil {
  93. // Any sender is permitted
  94. return nil
  95. }
  96. if allowedSender.MatchString(addr) {
  97. // Permitted by regex
  98. return nil
  99. }
  100. log.WithFields(logrus.Fields{
  101. "sender_address": addr,
  102. "peer": peer.Addr,
  103. }).Warn("sender address not allowed by allowed_sender pattern")
  104. return smtpd.Error{Code: 451, Message: "Bad sender address"}
  105. }
  106. func recipientChecker(peer smtpd.Peer, addr string) error {
  107. if allowedRecipients == nil {
  108. // Any recipient is permitted
  109. return nil
  110. }
  111. if allowedRecipients.MatchString(addr) {
  112. // Permitted by regex
  113. return nil
  114. }
  115. log.WithFields(logrus.Fields{
  116. "peer": peer.Addr,
  117. "recipient_address": addr,
  118. }).Warn("recipient address not allowed by allowed_recipients pattern")
  119. return smtpd.Error{Code: 451, Message: "Bad recipient address"}
  120. }
  121. func authChecker(peer smtpd.Peer, username string, password string) error {
  122. err := AuthCheckPassword(username, password)
  123. if err != nil {
  124. log.WithFields(logrus.Fields{
  125. "peer": peer.Addr,
  126. "username": username,
  127. }).WithError(err).Warn("auth error")
  128. return smtpd.Error{Code: 535, Message: "Authentication credentials invalid"}
  129. }
  130. return nil
  131. }
  132. func mailHandler(peer smtpd.Peer, env smtpd.Envelope) error {
  133. peerIP := ""
  134. if addr, ok := peer.Addr.(*net.TCPAddr); ok {
  135. peerIP = addr.IP.String()
  136. }
  137. logger := log.WithFields(logrus.Fields{
  138. "from": env.Sender,
  139. "to": env.Recipients,
  140. "peer": peerIP,
  141. "uuid": generateUUID(),
  142. })
  143. var envRemotes []*Remote
  144. if *strictSender {
  145. for _, remote := range remotes {
  146. if remote.Sender == env.Sender {
  147. envRemotes = append(envRemotes, remote)
  148. }
  149. }
  150. } else {
  151. envRemotes = remotes
  152. }
  153. if len(envRemotes) == 0 && *command == "" {
  154. logger.Warning("no remote_host or command set; discarding mail")
  155. return smtpd.Error{Code: 554, Message: "There are no appropriate remote_host or command"}
  156. }
  157. env.AddReceivedLine(peer)
  158. if *command != "" {
  159. cmdLogger := logger.WithField("command", *command)
  160. var stdout bytes.Buffer
  161. var stderr bytes.Buffer
  162. environ := os.Environ()
  163. environ = append(environ, fmt.Sprintf("%s=%s", "SMTPRELAY_FROM", env.Sender))
  164. environ = append(environ, fmt.Sprintf("%s=%s", "SMTPRELAY_TO", env.Recipients))
  165. environ = append(environ, fmt.Sprintf("%s=%s", "SMTPRELAY_PEER", peerIP))
  166. cmd := exec.Cmd{
  167. Env: environ,
  168. Path: *command,
  169. }
  170. cmd.Stdin = bytes.NewReader(env.Data)
  171. cmd.Stdout = &stdout
  172. cmd.Stderr = &stderr
  173. err := cmd.Run()
  174. if err != nil {
  175. cmdLogger.WithError(err).Error(stderr.String())
  176. return smtpd.Error{Code: 554, Message: "External command failed"}
  177. }
  178. cmdLogger.Info("pipe command successful: " + stdout.String())
  179. }
  180. for _, remote := range envRemotes {
  181. logger = logger.WithField("host", remote.Addr)
  182. logger.Info("delivering mail from peer using smarthost")
  183. err := SendMail(
  184. remote,
  185. env.Sender,
  186. env.Recipients,
  187. env.Data,
  188. )
  189. if err != nil {
  190. var smtpError smtpd.Error
  191. switch err := err.(type) {
  192. case *textproto.Error:
  193. smtpError = smtpd.Error{Code: err.Code, Message: err.Msg}
  194. logger.WithFields(logrus.Fields{
  195. "err_code": err.Code,
  196. "err_msg": err.Msg,
  197. }).Error("delivery failed")
  198. default:
  199. smtpError = smtpd.Error{Code: 554, Message: "Forwarding failed"}
  200. logger.WithError(err).
  201. Error("delivery failed")
  202. }
  203. return smtpError
  204. }
  205. logger.Debug("delivery successful")
  206. }
  207. return nil
  208. }
  209. func generateUUID() string {
  210. uniqueID, err := uuid.NewRandom()
  211. if err != nil {
  212. log.WithError(err).
  213. Error("could not generate UUIDv4")
  214. return ""
  215. }
  216. return uniqueID.String()
  217. }
  218. func getTLSConfig() *tls.Config {
  219. // Ciphersuites as defined in stock Go but without 3DES and RC4
  220. // https://golang.org/src/crypto/tls/cipher_suites.go
  221. var tlsCipherSuites = []uint16{
  222. tls.TLS_AES_128_GCM_SHA256,
  223. tls.TLS_AES_256_GCM_SHA384,
  224. tls.TLS_CHACHA20_POLY1305_SHA256,
  225. tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  226. tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
  227. tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  228. tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
  229. tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
  230. tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
  231. tls.TLS_RSA_WITH_AES_128_GCM_SHA256, // does not provide PFS
  232. tls.TLS_RSA_WITH_AES_256_GCM_SHA384, // does not provide PFS
  233. }
  234. if *localCert == "" || *localKey == "" {
  235. log.WithFields(logrus.Fields{
  236. "cert_file": *localCert,
  237. "key_file": *localKey,
  238. }).Fatal("TLS certificate/key file not defined in config")
  239. }
  240. cert, err := tls.LoadX509KeyPair(*localCert, *localKey)
  241. if err != nil {
  242. log.WithField("error", err).
  243. Fatal("cannot load X509 keypair")
  244. }
  245. return &tls.Config{
  246. PreferServerCipherSuites: true,
  247. MinVersion: tls.VersionTLS12,
  248. CipherSuites: tlsCipherSuites,
  249. Certificates: []tls.Certificate{cert},
  250. }
  251. }
  252. func main() {
  253. ConfigLoad()
  254. log.WithField("version", appVersion).
  255. Debug("starting smtprelay")
  256. // Load allowed users file
  257. if localAuthRequired() {
  258. err := AuthLoadFile(*allowedUsers)
  259. if err != nil {
  260. log.WithField("file", *allowedUsers).
  261. WithError(err).
  262. Fatal("cannot load allowed users file")
  263. }
  264. }
  265. var servers []*smtpd.Server
  266. // Create a server for each desired listen address
  267. for _, listen := range listenAddrs {
  268. logger := log.WithField("address", listen.address)
  269. server := &smtpd.Server{
  270. Hostname: *hostName,
  271. WelcomeMessage: *welcomeMsg,
  272. ReadTimeout: readTimeout,
  273. WriteTimeout: writeTimeout,
  274. DataTimeout: dataTimeout,
  275. MaxConnections: *maxConnections,
  276. MaxMessageSize: *maxMessageSize,
  277. MaxRecipients: *maxRecipients,
  278. ConnectionChecker: connectionChecker,
  279. SenderChecker: senderChecker,
  280. RecipientChecker: recipientChecker,
  281. Handler: mailHandler,
  282. }
  283. if localAuthRequired() {
  284. server.Authenticator = authChecker
  285. }
  286. var lsnr net.Listener
  287. var err error
  288. switch listen.protocol {
  289. case "":
  290. logger.Info("listening on address")
  291. lsnr, err = net.Listen("tcp", listen.address)
  292. case "starttls":
  293. server.TLSConfig = getTLSConfig()
  294. server.ForceTLS = *localForceTLS
  295. logger.Info("listening on address (STARTTLS)")
  296. lsnr, err = net.Listen("tcp", listen.address)
  297. case "tls":
  298. server.TLSConfig = getTLSConfig()
  299. logger.Info("listening on address (TLS)")
  300. lsnr, err = tls.Listen("tcp", listen.address, server.TLSConfig)
  301. default:
  302. logger.WithField("protocol", listen.protocol).
  303. Fatal("unknown protocol in listen address")
  304. }
  305. if err != nil {
  306. logger.WithError(err).Fatal("error starting listener")
  307. }
  308. servers = append(servers, server)
  309. go func() {
  310. server.Serve(lsnr)
  311. }()
  312. }
  313. handleSignals()
  314. // First close the listeners
  315. for _, server := range servers {
  316. logger := log.WithField("address", server.Address())
  317. logger.Debug("Shutting down server")
  318. err := server.Shutdown(false)
  319. if err != nil {
  320. logger.WithError(err).
  321. Warning("Shutdown failed")
  322. }
  323. }
  324. // Then wait for the clients to exit
  325. for _, server := range servers {
  326. logger := log.WithField("address", server.Address())
  327. logger.Debug("Waiting for server")
  328. err := server.Wait()
  329. if err != nil {
  330. logger.WithError(err).
  331. Warning("Wait failed")
  332. }
  333. }
  334. log.Debug("done")
  335. }
  336. func handleSignals() {
  337. // Wait for SIGINT, SIGQUIT, or SIGTERM
  338. sigs := make(chan os.Signal, 1)
  339. signal.Notify(sigs, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM)
  340. sig := <-sigs
  341. log.WithField("signal", sig).
  342. Info("shutting down in response to received signal")
  343. }