main.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. package main
  2. import (
  3. "crypto/tls"
  4. "fmt"
  5. "io"
  6. "log"
  7. "net"
  8. "net/smtp"
  9. "os"
  10. "regexp"
  11. "strings"
  12. "time"
  13. "github.com/chrj/smtpd"
  14. )
  15. func connectionChecker(peer smtpd.Peer) error {
  16. var peerIP net.IP
  17. if addr, ok := peer.Addr.(*net.TCPAddr); ok {
  18. peerIP = net.ParseIP(addr.IP.String())
  19. } else {
  20. return smtpd.Error{Code: 421, Message: "Denied"}
  21. }
  22. nets := strings.Split(*allowedNets, " ")
  23. for i := range nets {
  24. _, allowedNet, _ := net.ParseCIDR(nets[i])
  25. if allowedNet.Contains(peerIP) {
  26. return nil
  27. }
  28. }
  29. log.Printf("Connection from peer=[%s] denied: Not in allowed_nets\n", peerIP)
  30. return smtpd.Error{Code: 421, Message: "Denied"}
  31. }
  32. func addrAllowed(addr string, allowedAddrs []string) bool {
  33. if allowedAddrs == nil {
  34. // If absent, all addresses are allowed
  35. return true
  36. }
  37. addr = strings.ToLower(addr)
  38. // Extract optional domain part
  39. domain := ""
  40. if idx := strings.LastIndex(addr, "@"); idx != -1 {
  41. domain = strings.ToLower(addr[idx+1:])
  42. }
  43. // Test each address from allowedUsers file
  44. for _, allowedAddr := range allowedAddrs {
  45. allowedAddr = strings.ToLower(allowedAddr)
  46. // Three cases for allowedAddr format:
  47. if idx := strings.Index(allowedAddr, "@"); idx == -1 {
  48. // 1. local address (no @) -- must match exactly
  49. if allowedAddr == addr {
  50. return true
  51. }
  52. } else {
  53. if idx != 0 {
  54. // 2. email address ([email protected]) -- must match exactly
  55. if allowedAddr == addr {
  56. return true
  57. }
  58. } else {
  59. // 3. domain (@domain.com) -- must match addr domain
  60. allowedDomain := allowedAddr[idx+1:]
  61. if allowedDomain == domain {
  62. return true
  63. }
  64. }
  65. }
  66. }
  67. return false
  68. }
  69. func senderChecker(peer smtpd.Peer, addr string) error {
  70. // check sender address from auth file if user is authenticated
  71. if *allowedUsers != "" && peer.Username != "" {
  72. user, err := AuthFetch(peer.Username)
  73. if err != nil {
  74. // Shouldn't happen: authChecker already validated username+password
  75. return smtpd.Error{Code: 451, Message: "Bad sender address"}
  76. }
  77. if !addrAllowed(addr, user.allowedAddresses) {
  78. log.Printf("Mail from=<%s> not allowed for authenticated user %s (%v)\n",
  79. addr, peer.Username, peer.Addr)
  80. return smtpd.Error{Code: 451, Message: "Bad sender address"}
  81. }
  82. }
  83. if *allowedSender == "" {
  84. return nil
  85. }
  86. re, err := regexp.Compile(*allowedSender)
  87. if err != nil {
  88. log.Printf("allowed_sender invalid: %v\n", err)
  89. return smtpd.Error{Code: 451, Message: "Bad sender address"}
  90. }
  91. if re.MatchString(addr) {
  92. return nil
  93. }
  94. log.Printf("Mail from=<%s> not allowed by allowed_sender pattern for peer %v\n",
  95. addr, peer.Addr)
  96. return smtpd.Error{Code: 451, Message: "Bad sender address"}
  97. }
  98. func recipientChecker(peer smtpd.Peer, addr string) error {
  99. if *allowedRecipients == "" {
  100. return nil
  101. }
  102. re, err := regexp.Compile(*allowedRecipients)
  103. if err != nil {
  104. log.Printf("allowed_recipients invalid: %v\n", err)
  105. return smtpd.Error{Code: 451, Message: "Bad recipient address"}
  106. }
  107. if re.MatchString(addr) {
  108. return nil
  109. }
  110. log.Printf("Mail to=<%s> not allowed by allowed_recipients pattern for peer %v\n",
  111. addr, peer.Addr)
  112. return smtpd.Error{Code: 451, Message: "Bad recipient address"}
  113. }
  114. func authChecker(peer smtpd.Peer, username string, password string) error {
  115. err := AuthCheckPassword(username, password)
  116. if err != nil {
  117. log.Printf("Auth error for peer %v: %v\n", peer.Addr, err)
  118. return smtpd.Error{Code: 535, Message: "Authentication credentials invalid"}
  119. }
  120. return nil
  121. }
  122. func mailHandler(peer smtpd.Peer, env smtpd.Envelope) error {
  123. peerIP := ""
  124. if addr, ok := peer.Addr.(*net.TCPAddr); ok {
  125. peerIP = addr.IP.String()
  126. }
  127. log.Printf("new mail from=<%s> to=%s peer=[%s]\n", env.Sender,
  128. env.Recipients, peerIP)
  129. var auth smtp.Auth
  130. host, _, _ := net.SplitHostPort(*remoteHost)
  131. if *remoteUser != "" && *remotePass != "" {
  132. switch *remoteAuth {
  133. case "plain":
  134. auth = smtp.PlainAuth("", *remoteUser, *remotePass, host)
  135. case "login":
  136. auth = LoginAuth(*remoteUser, *remotePass)
  137. default:
  138. return smtpd.Error{Code: 530, Message: "Authentication method not supported"}
  139. }
  140. }
  141. env.AddReceivedLine(peer)
  142. log.Printf("delivering using smarthost %s\n", *remoteHost)
  143. var sender string
  144. if *remoteSender == "" {
  145. sender = env.Sender
  146. } else {
  147. sender = *remoteSender
  148. }
  149. err := SendMail(
  150. *remoteHost,
  151. auth,
  152. sender,
  153. env.Recipients,
  154. env.Data,
  155. )
  156. if err != nil {
  157. log.Printf("delivery failed: %v\n", err)
  158. return smtpd.Error{Code: 554, Message: "Forwarding failed"}
  159. }
  160. log.Printf("%s delivery successful\n", env.Recipients)
  161. return nil
  162. }
  163. func getTLSConfig() *tls.Config {
  164. // Ciphersuites as defined in stock Go but without 3DES and RC4
  165. // https://golang.org/src/crypto/tls/cipher_suites.go
  166. var tlsCipherSuites = []uint16{
  167. tls.TLS_AES_128_GCM_SHA256,
  168. tls.TLS_AES_256_GCM_SHA384,
  169. tls.TLS_CHACHA20_POLY1305_SHA256,
  170. tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  171. tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
  172. tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  173. tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
  174. tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
  175. tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
  176. tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
  177. tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
  178. tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
  179. tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
  180. tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
  181. tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
  182. tls.TLS_RSA_WITH_AES_128_GCM_SHA256, // does not provide PFS
  183. tls.TLS_RSA_WITH_AES_256_GCM_SHA384, // does not provide PFS
  184. tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
  185. tls.TLS_RSA_WITH_AES_128_CBC_SHA,
  186. tls.TLS_RSA_WITH_AES_256_CBC_SHA,
  187. }
  188. if *localCert == "" || *localKey == "" {
  189. log.Fatal("TLS certificate/key not defined in config")
  190. }
  191. cert, err := tls.LoadX509KeyPair(*localCert, *localKey)
  192. if err != nil {
  193. log.Fatal(err)
  194. }
  195. return &tls.Config{
  196. PreferServerCipherSuites: true,
  197. MinVersion: tls.VersionTLS11,
  198. CipherSuites: tlsCipherSuites,
  199. Certificates: []tls.Certificate{cert},
  200. }
  201. }
  202. func main() {
  203. ConfigLoad()
  204. if *versionInfo {
  205. fmt.Printf("smtprelay/%s\n", VERSION)
  206. os.Exit(0)
  207. }
  208. if *logFile != "" {
  209. f, err := os.OpenFile(*logFile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0600)
  210. if err != nil {
  211. log.Fatalf("Error opening logfile: %v", err)
  212. }
  213. defer f.Close()
  214. log.SetOutput(io.MultiWriter(os.Stdout, f))
  215. }
  216. // Create a server for each desired listen address
  217. for _, listenAddr := range strings.Split(*listen, " ") {
  218. server := &smtpd.Server{
  219. Hostname: *hostName,
  220. WelcomeMessage: *welcomeMsg,
  221. ConnectionChecker: connectionChecker,
  222. SenderChecker: senderChecker,
  223. RecipientChecker: recipientChecker,
  224. Handler: mailHandler,
  225. }
  226. if *allowedUsers != "" {
  227. err := AuthLoadFile(*allowedUsers)
  228. if err != nil {
  229. log.Fatalf("Authentication file: %s\n", err)
  230. }
  231. server.Authenticator = authChecker
  232. }
  233. var lsnr net.Listener
  234. var err error
  235. if strings.Index(listenAddr, "://") == -1 {
  236. log.Printf("Listen on %s ...\n", listenAddr)
  237. lsnr, err = net.Listen("tcp", listenAddr)
  238. } else if strings.HasPrefix(listenAddr, "starttls://") {
  239. listenAddr = strings.TrimPrefix(listenAddr, "starttls://")
  240. server.TLSConfig = getTLSConfig()
  241. server.ForceTLS = *localForceTLS
  242. log.Printf("Listen on %s (STARTTLS) ...\n", listenAddr)
  243. lsnr, err = net.Listen("tcp", listenAddr)
  244. } else if strings.HasPrefix(listenAddr, "tls://") {
  245. listenAddr = strings.TrimPrefix(listenAddr, "tls://")
  246. server.TLSConfig = getTLSConfig()
  247. log.Printf("Listen on %s (TLS) ...\n", listenAddr)
  248. lsnr, err = tls.Listen("tcp", listenAddr, server.TLSConfig)
  249. } else {
  250. log.Fatal("Unknown protocol in listen address ", listenAddr)
  251. }
  252. if err != nil {
  253. log.Fatal(err)
  254. }
  255. defer lsnr.Close()
  256. go server.Serve(lsnr)
  257. }
  258. for true {
  259. time.Sleep(time.Minute)
  260. }
  261. }