main.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. package main
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "net"
  6. "net/textproto"
  7. "os"
  8. "os/exec"
  9. "os/signal"
  10. "strings"
  11. "syscall"
  12. "github.com/chrj/smtpd"
  13. "github.com/google/uuid"
  14. "github.com/sirupsen/logrus"
  15. )
  16. func connectionChecker(peer smtpd.Peer) error {
  17. // This can't panic because we only have TCP listeners
  18. peerIP := peer.Addr.(*net.TCPAddr).IP
  19. if len(allowedNets) == 0 {
  20. // Special case: empty string means allow everything
  21. return nil
  22. }
  23. for _, allowedNet := range allowedNets {
  24. if allowedNet.Contains(peerIP) {
  25. return nil
  26. }
  27. }
  28. log.WithFields(logrus.Fields{
  29. "ip": peerIP,
  30. }).Warn("Connection refused from address outside of allowed_nets")
  31. return smtpd.Error{Code: 421, Message: "Denied"}
  32. }
  33. func addrAllowed(addr string, allowedAddrs []string) bool {
  34. if allowedAddrs == nil {
  35. // If absent, all addresses are allowed
  36. return true
  37. }
  38. addr = strings.ToLower(addr)
  39. // Extract optional domain part
  40. domain := ""
  41. if idx := strings.LastIndex(addr, "@"); idx != -1 {
  42. domain = strings.ToLower(addr[idx+1:])
  43. }
  44. // Test each address from allowedUsers file
  45. for _, allowedAddr := range allowedAddrs {
  46. allowedAddr = strings.ToLower(allowedAddr)
  47. // Three cases for allowedAddr format:
  48. if idx := strings.Index(allowedAddr, "@"); idx == -1 {
  49. // 1. local address (no @) -- must match exactly
  50. if allowedAddr == addr {
  51. return true
  52. }
  53. } else {
  54. if idx != 0 {
  55. // 2. email address ([email protected]) -- must match exactly
  56. if allowedAddr == addr {
  57. return true
  58. }
  59. } else {
  60. // 3. domain (@domain.com) -- must match addr domain
  61. allowedDomain := allowedAddr[idx+1:]
  62. if allowedDomain == domain {
  63. return true
  64. }
  65. }
  66. }
  67. }
  68. return false
  69. }
  70. func senderChecker(peer smtpd.Peer, addr string) error {
  71. // check sender address from auth file if user is authenticated
  72. if localAuthRequired() && peer.Username != "" {
  73. user, err := AuthFetch(peer.Username)
  74. if err != nil {
  75. // Shouldn't happen: authChecker already validated username+password
  76. log.WithFields(logrus.Fields{
  77. "peer": peer.Addr,
  78. "username": peer.Username,
  79. }).WithError(err).Warn("could not fetch auth user")
  80. return smtpd.Error{Code: 451, Message: "Bad sender address"}
  81. }
  82. if !addrAllowed(addr, user.allowedAddresses) {
  83. log.WithFields(logrus.Fields{
  84. "peer": peer.Addr,
  85. "username": peer.Username,
  86. "sender_address": addr,
  87. }).Warn("sender address not allowed for authenticated user")
  88. return smtpd.Error{Code: 451, Message: "Bad sender address"}
  89. }
  90. }
  91. if allowedSender == nil {
  92. // Any sender is permitted
  93. return nil
  94. }
  95. if allowedSender.MatchString(addr) {
  96. // Permitted by regex
  97. return nil
  98. }
  99. log.WithFields(logrus.Fields{
  100. "sender_address": addr,
  101. "peer": peer.Addr,
  102. }).Warn("sender address not allowed by allowed_sender pattern")
  103. return smtpd.Error{Code: 451, Message: "Bad sender address"}
  104. }
  105. func recipientChecker(peer smtpd.Peer, addr string) error {
  106. if allowedRecipients == nil {
  107. // Any recipient is permitted
  108. return nil
  109. }
  110. if allowedRecipients.MatchString(addr) {
  111. // Permitted by regex
  112. return nil
  113. }
  114. log.WithFields(logrus.Fields{
  115. "peer": peer.Addr,
  116. "recipient_address": addr,
  117. }).Warn("recipient address not allowed by allowed_recipients pattern")
  118. return smtpd.Error{Code: 451, Message: "Bad recipient address"}
  119. }
  120. func authChecker(peer smtpd.Peer, username string, password string) error {
  121. err := AuthCheckPassword(username, password)
  122. if err != nil {
  123. log.WithFields(logrus.Fields{
  124. "peer": peer.Addr,
  125. "username": username,
  126. }).WithError(err).Warn("auth error")
  127. return smtpd.Error{Code: 535, Message: "Authentication credentials invalid"}
  128. }
  129. return nil
  130. }
  131. func mailHandler(peer smtpd.Peer, env smtpd.Envelope) error {
  132. peerIP := ""
  133. if addr, ok := peer.Addr.(*net.TCPAddr); ok {
  134. peerIP = addr.IP.String()
  135. }
  136. logger := log.WithFields(logrus.Fields{
  137. "from": env.Sender,
  138. "to": env.Recipients,
  139. "peer": peerIP,
  140. "uuid": generateUUID(),
  141. })
  142. if *remotesStr == "" && *command == "" {
  143. logger.Warning("no remote_host or command set; discarding mail")
  144. return nil
  145. }
  146. env.AddReceivedLine(peer)
  147. if *command != "" {
  148. cmdLogger := logger.WithField("command", *command)
  149. var stdout bytes.Buffer
  150. var stderr bytes.Buffer
  151. cmd := exec.Command(*command)
  152. cmd.Stdin = bytes.NewReader(env.Data)
  153. cmd.Stdout = &stdout
  154. cmd.Stderr = &stderr
  155. err := cmd.Run()
  156. if err != nil {
  157. cmdLogger.WithError(err).Error(stderr.String())
  158. return smtpd.Error{Code: 554, Message: "External command failed"}
  159. }
  160. cmdLogger.Info("pipe command successful: " + stdout.String())
  161. }
  162. for _, remote := range remotes {
  163. logger = logger.WithField("host", remote.Addr)
  164. logger.Info("delivering mail from peer using smarthost")
  165. err := SendMail(
  166. remote,
  167. env.Sender,
  168. env.Recipients,
  169. env.Data,
  170. )
  171. if err != nil {
  172. var smtpError smtpd.Error
  173. switch err := err.(type) {
  174. case *textproto.Error:
  175. smtpError = smtpd.Error{Code: err.Code, Message: err.Msg}
  176. logger.WithFields(logrus.Fields{
  177. "err_code": err.Code,
  178. "err_msg": err.Msg,
  179. }).Error("delivery failed")
  180. default:
  181. smtpError = smtpd.Error{Code: 554, Message: "Forwarding failed"}
  182. logger.WithError(err).
  183. Error("delivery failed")
  184. }
  185. return smtpError
  186. }
  187. logger.Debug("delivery successful")
  188. }
  189. return nil
  190. }
  191. func generateUUID() string {
  192. uniqueID, err := uuid.NewRandom()
  193. if err != nil {
  194. log.WithError(err).
  195. Error("could not generate UUIDv4")
  196. return ""
  197. }
  198. return uniqueID.String()
  199. }
  200. func getTLSConfig() *tls.Config {
  201. // Ciphersuites as defined in stock Go but without 3DES and RC4
  202. // https://golang.org/src/crypto/tls/cipher_suites.go
  203. var tlsCipherSuites = []uint16{
  204. tls.TLS_AES_128_GCM_SHA256,
  205. tls.TLS_AES_256_GCM_SHA384,
  206. tls.TLS_CHACHA20_POLY1305_SHA256,
  207. tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  208. tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
  209. tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  210. tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
  211. tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
  212. tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
  213. tls.TLS_RSA_WITH_AES_128_GCM_SHA256, // does not provide PFS
  214. tls.TLS_RSA_WITH_AES_256_GCM_SHA384, // does not provide PFS
  215. }
  216. if *localCert == "" || *localKey == "" {
  217. log.WithFields(logrus.Fields{
  218. "cert_file": *localCert,
  219. "key_file": *localKey,
  220. }).Fatal("TLS certificate/key file not defined in config")
  221. }
  222. cert, err := tls.LoadX509KeyPair(*localCert, *localKey)
  223. if err != nil {
  224. log.WithField("error", err).
  225. Fatal("cannot load X509 keypair")
  226. }
  227. return &tls.Config{
  228. PreferServerCipherSuites: true,
  229. MinVersion: tls.VersionTLS12,
  230. CipherSuites: tlsCipherSuites,
  231. Certificates: []tls.Certificate{cert},
  232. }
  233. }
  234. func main() {
  235. ConfigLoad()
  236. log.WithField("version", appVersion).
  237. Debug("starting smtprelay")
  238. // Load allowed users file
  239. if localAuthRequired() {
  240. err := AuthLoadFile(*allowedUsers)
  241. if err != nil {
  242. log.WithField("file", *allowedUsers).
  243. WithError(err).
  244. Fatal("cannot load allowed users file")
  245. }
  246. }
  247. var servers []*smtpd.Server
  248. // Create a server for each desired listen address
  249. for _, listen := range listenAddrs {
  250. logger := log.WithField("address", listen.address)
  251. server := &smtpd.Server{
  252. Hostname: *hostName,
  253. WelcomeMessage: *welcomeMsg,
  254. ReadTimeout: readTimeout,
  255. WriteTimeout: writeTimeout,
  256. DataTimeout: dataTimeout,
  257. MaxConnections: *maxConnections,
  258. MaxMessageSize: *maxMessageSize,
  259. MaxRecipients: *maxRecipients,
  260. ConnectionChecker: connectionChecker,
  261. SenderChecker: senderChecker,
  262. RecipientChecker: recipientChecker,
  263. Handler: mailHandler,
  264. }
  265. if localAuthRequired() {
  266. server.Authenticator = authChecker
  267. }
  268. var lsnr net.Listener
  269. var err error
  270. switch listen.protocol {
  271. case "":
  272. logger.Info("listening on address")
  273. lsnr, err = net.Listen("tcp", listen.address)
  274. case "starttls":
  275. server.TLSConfig = getTLSConfig()
  276. server.ForceTLS = *localForceTLS
  277. logger.Info("listening on address (STARTTLS)")
  278. lsnr, err = net.Listen("tcp", listen.address)
  279. case "tls":
  280. server.TLSConfig = getTLSConfig()
  281. logger.Info("listening on address (TLS)")
  282. lsnr, err = tls.Listen("tcp", listen.address, server.TLSConfig)
  283. default:
  284. logger.WithField("protocol", listen.protocol).
  285. Fatal("unknown protocol in listen address")
  286. }
  287. if err != nil {
  288. logger.WithError(err).Fatal("error starting listener")
  289. }
  290. servers = append(servers, server)
  291. go func() {
  292. server.Serve(lsnr)
  293. }()
  294. }
  295. handleSignals()
  296. // First close the listeners
  297. for _, server := range servers {
  298. logger := log.WithField("address", server.Address())
  299. logger.Debug("Shutting down server")
  300. err := server.Shutdown(false)
  301. if err != nil {
  302. logger.WithError(err).
  303. Warning("Shutdown failed")
  304. }
  305. }
  306. // Then wait for the clients to exit
  307. for _, server := range servers {
  308. logger := log.WithField("address", server.Address())
  309. logger.Debug("Waiting for server")
  310. err := server.Wait()
  311. if err != nil {
  312. logger.WithError(err).
  313. Warning("Wait failed")
  314. }
  315. }
  316. log.Debug("done")
  317. }
  318. func handleSignals() {
  319. // Wait for SIGINT, SIGQUIT, or SIGTERM
  320. sigs := make(chan os.Signal, 1)
  321. signal.Notify(sigs, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM)
  322. sig := <-sigs
  323. log.WithField("signal", sig).
  324. Info("shutting down in response to received signal")
  325. }