server.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. package guerrilla
  2. import (
  3. "crypto/rand"
  4. "crypto/tls"
  5. "fmt"
  6. "io"
  7. "net"
  8. "strings"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. "github.com/flashmob/go-guerrilla/backends"
  13. "github.com/flashmob/go-guerrilla/log"
  14. "github.com/flashmob/go-guerrilla/mail"
  15. "github.com/flashmob/go-guerrilla/response"
  16. )
  17. const (
  18. CommandVerbMaxLength = 16
  19. CommandLineMaxLength = 1024
  20. // Number of allowed unrecognized commands before we terminate the connection
  21. MaxUnrecognizedCommands = 5
  22. // The maximum total length of a reverse-path or forward-path is 256
  23. RFC2821LimitPath = 256
  24. // The maximum total length of a user name or other local-part is 64
  25. RFC2832LimitLocalPart = 64
  26. //The maximum total length of a domain name or number is 255
  27. RFC2821LimitDomain = 255
  28. // The minimum total number of recipients that must be buffered is 100
  29. RFC2821LimitRecipients = 100
  30. )
  31. const (
  32. // server has just been created
  33. ServerStateNew = iota
  34. // Server has just been stopped
  35. ServerStateStopped
  36. // Server has been started and is running
  37. ServerStateRunning
  38. // Server could not start due to an error
  39. ServerStateStartError
  40. )
  41. // Server listens for SMTP clients on the port specified in its config
  42. type server struct {
  43. configStore atomic.Value // stores guerrilla.ServerConfig
  44. tlsConfigStore atomic.Value
  45. timeout atomic.Value // stores time.Duration
  46. listenInterface string
  47. clientPool *Pool
  48. wg sync.WaitGroup // for waiting to shutdown
  49. listener net.Listener
  50. closedListener chan (bool)
  51. hosts allowedHosts // stores map[string]bool for faster lookup
  52. state int
  53. // If log changed after a config reload, newLogStore stores the value here until it's safe to change it
  54. logStore atomic.Value
  55. mainlogStore atomic.Value
  56. backendStore atomic.Value
  57. envelopePool *mail.Pool
  58. }
  59. type allowedHosts struct {
  60. table map[string]bool // host lookup table
  61. sync.Mutex // guard access to the map
  62. }
  63. // Creates and returns a new ready-to-run Server from a configuration
  64. func newServer(sc *ServerConfig, b backends.Backend, l log.Logger) (*server, error) {
  65. server := &server{
  66. clientPool: NewPool(sc.MaxClients),
  67. closedListener: make(chan (bool), 1),
  68. listenInterface: sc.ListenInterface,
  69. state: ServerStateNew,
  70. envelopePool: mail.NewPool(sc.MaxClients),
  71. }
  72. server.logStore.Store(l)
  73. server.backendStore.Store(b)
  74. logFile := sc.LogFile
  75. if logFile == "" {
  76. // none set, use the same log file as mainlog
  77. logFile = server.mainlog().GetLogDest()
  78. }
  79. // set level to same level as mainlog level
  80. mainlog, logOpenError := log.GetLogger(logFile, server.mainlog().GetLevel())
  81. server.mainlogStore.Store(mainlog)
  82. if logOpenError != nil {
  83. server.log().WithError(logOpenError).Errorf("Failed creating a logger for server [%s]", sc.ListenInterface)
  84. }
  85. server.setConfig(sc)
  86. server.setTimeout(sc.Timeout)
  87. if err := server.configureSSL(); err != nil {
  88. return server, err
  89. }
  90. return server, nil
  91. }
  92. func (s *server) configureSSL() error {
  93. sConfig := s.configStore.Load().(ServerConfig)
  94. if sConfig.TLSAlwaysOn || sConfig.StartTLSOn {
  95. cert, err := tls.LoadX509KeyPair(sConfig.PublicKeyFile, sConfig.PrivateKeyFile)
  96. if err != nil {
  97. return fmt.Errorf("error while loading the certificate: %s", err)
  98. }
  99. tlsConfig := &tls.Config{
  100. Certificates: []tls.Certificate{cert},
  101. ClientAuth: tls.VerifyClientCertIfGiven,
  102. ServerName: sConfig.Hostname,
  103. }
  104. tlsConfig.Rand = rand.Reader
  105. s.tlsConfigStore.Store(tlsConfig)
  106. }
  107. return nil
  108. }
  109. // setBackend sets the backend to use for processing email envelopes
  110. func (s *server) setBackend(b backends.Backend) {
  111. s.backendStore.Store(b)
  112. }
  113. // backend gets the backend used to process email envelopes
  114. func (s *server) backend() backends.Backend {
  115. if b, ok := s.backendStore.Load().(backends.Backend); ok {
  116. return b
  117. }
  118. return nil
  119. }
  120. // Set the timeout for the server and all clients
  121. func (server *server) setTimeout(seconds int) {
  122. duration := time.Duration(int64(seconds))
  123. server.clientPool.SetTimeout(duration)
  124. server.timeout.Store(duration)
  125. }
  126. // goroutine safe config store
  127. func (server *server) setConfig(sc *ServerConfig) {
  128. server.configStore.Store(*sc)
  129. }
  130. // goroutine safe
  131. func (server *server) isEnabled() bool {
  132. sc := server.configStore.Load().(ServerConfig)
  133. return sc.IsEnabled
  134. }
  135. // Set the allowed hosts for the server
  136. func (server *server) setAllowedHosts(allowedHosts []string) {
  137. server.hosts.Lock()
  138. defer server.hosts.Unlock()
  139. server.hosts.table = make(map[string]bool, len(allowedHosts))
  140. for _, h := range allowedHosts {
  141. server.hosts.table[strings.ToLower(h)] = true
  142. }
  143. }
  144. // Begin accepting SMTP clients. Will block unless there is an error or server.Shutdown() is called
  145. func (server *server) Start(startWG *sync.WaitGroup) error {
  146. var clientID uint64
  147. clientID = 0
  148. listener, err := net.Listen("tcp", server.listenInterface)
  149. server.listener = listener
  150. if err != nil {
  151. startWG.Done() // don't wait for me
  152. server.state = ServerStateStartError
  153. return fmt.Errorf("[%s] Cannot listen on port: %s ", server.listenInterface, err.Error())
  154. }
  155. server.log().Infof("Listening on TCP %s", server.listenInterface)
  156. server.state = ServerStateRunning
  157. startWG.Done() // start successful, don't wait for me
  158. for {
  159. server.log().Debugf("[%s] Waiting for a new client. Next Client ID: %d", server.listenInterface, clientID+1)
  160. conn, err := listener.Accept()
  161. clientID++
  162. if err != nil {
  163. if e, ok := err.(net.Error); ok && !e.Temporary() {
  164. server.log().Infof("Server [%s] has stopped accepting new clients", server.listenInterface)
  165. // the listener has been closed, wait for clients to exit
  166. server.log().Infof("shutting down pool [%s]", server.listenInterface)
  167. server.clientPool.ShutdownState()
  168. server.clientPool.ShutdownWait()
  169. server.state = ServerStateStopped
  170. server.closedListener <- true
  171. return nil
  172. }
  173. server.mainlog().WithError(err).Info("Temporary error accepting client")
  174. continue
  175. }
  176. go func(p Poolable, borrow_err error) {
  177. c := p.(*client)
  178. if borrow_err == nil {
  179. server.handleClient(c)
  180. server.envelopePool.Return(c.Envelope)
  181. server.clientPool.Return(c)
  182. } else {
  183. server.log().WithError(borrow_err).Info("couldn't borrow a new client")
  184. // we could not get a client, so close the connection.
  185. conn.Close()
  186. }
  187. // intentionally placed Borrow in args so that it's called in the
  188. // same main goroutine.
  189. }(server.clientPool.Borrow(conn, clientID, server.log(), server.envelopePool))
  190. }
  191. }
  192. func (server *server) Shutdown() {
  193. if server.listener != nil {
  194. // This will cause Start function to return, by causing an error on listener.Accept
  195. server.listener.Close()
  196. // wait for the listener to listener.Accept
  197. <-server.closedListener
  198. // At this point Start will exit and close down the pool
  199. } else {
  200. server.clientPool.ShutdownState()
  201. // listener already closed, wait for clients to exit
  202. server.clientPool.ShutdownWait()
  203. server.state = ServerStateStopped
  204. }
  205. }
  206. func (server *server) GetActiveClientsCount() int {
  207. return server.clientPool.GetActiveClientsCount()
  208. }
  209. // Verifies that the host is a valid recipient.
  210. func (server *server) allowsHost(host string) bool {
  211. server.hosts.Lock()
  212. defer server.hosts.Unlock()
  213. if _, ok := server.hosts.table[strings.ToLower(host)]; ok {
  214. return true
  215. }
  216. return false
  217. }
  218. // Reads from the client until a terminating sequence is encountered,
  219. // or until a timeout occurs.
  220. func (server *server) readCommand(client *client, maxSize int64) (string, error) {
  221. var input, reply string
  222. var err error
  223. // In command state, stop reading at line breaks
  224. suffix := "\r\n"
  225. for {
  226. client.setTimeout(server.timeout.Load().(time.Duration))
  227. reply, err = client.bufin.ReadString('\n')
  228. input = input + reply
  229. if err != nil {
  230. break
  231. }
  232. if strings.HasSuffix(input, suffix) {
  233. // discard the suffix and stop reading
  234. input = input[0 : len(input)-len(suffix)]
  235. break
  236. }
  237. }
  238. return input, err
  239. }
  240. // flushResponse a response to the client. Flushes the client.bufout buffer to the connection
  241. func (server *server) flushResponse(client *client) error {
  242. client.setTimeout(server.timeout.Load().(time.Duration))
  243. return client.bufout.Flush()
  244. }
  245. func (server *server) isShuttingDown() bool {
  246. return server.clientPool.IsShuttingDown()
  247. }
  248. // Handles an entire client SMTP exchange
  249. func (server *server) handleClient(client *client) {
  250. defer func() {
  251. server.log().WithFields(map[string]interface{}{
  252. "event": "disconnect",
  253. "id": client.ID,
  254. }).Info("Disconnect client")
  255. client.closeConn()
  256. }()
  257. sc := server.configStore.Load().(ServerConfig)
  258. server.log().WithFields(map[string]interface{}{
  259. "event": "connect",
  260. "id": client.ID,
  261. }).Infof("Handle client [%s]", client.RemoteIP)
  262. // Initial greeting
  263. greeting := fmt.Sprintf("220 %s SMTP Guerrilla(%s) #%d (%d) %s",
  264. sc.Hostname, Version, client.ID,
  265. server.clientPool.GetActiveClientsCount(), time.Now().Format(time.RFC3339))
  266. helo := fmt.Sprintf("250 %s Hello", sc.Hostname)
  267. // ehlo is a multi-line reply and need additional \r\n at the end
  268. ehlo := fmt.Sprintf("250-%s Hello\r\n", sc.Hostname)
  269. // Extended feature advertisements
  270. messageSize := fmt.Sprintf("250-SIZE %d\r\n", sc.MaxSize)
  271. pipelining := "250-PIPELINING\r\n"
  272. advertiseTLS := "250-STARTTLS\r\n"
  273. advertiseEnhancedStatusCodes := "250-ENHANCEDSTATUSCODES\r\n"
  274. // The last line doesn't need \r\n since string will be printed as a new line.
  275. // Also, Last line has no dash -
  276. help := "250 HELP"
  277. if sc.TLSAlwaysOn {
  278. tlsConfig, ok := server.tlsConfigStore.Load().(*tls.Config)
  279. if !ok {
  280. server.mainlog().Error("Failed to load *tls.Config")
  281. } else if err := client.upgradeToTLS(tlsConfig); err == nil {
  282. advertiseTLS = ""
  283. } else {
  284. server.log().WithError(err).Warnf("[%s] Failed TLS handshake", client.RemoteIP)
  285. // server requires TLS, but can't handshake
  286. client.kill()
  287. }
  288. }
  289. if !sc.StartTLSOn {
  290. // STARTTLS turned off, don't advertise it
  291. advertiseTLS = ""
  292. }
  293. for client.isAlive() {
  294. switch client.state {
  295. case ClientGreeting:
  296. client.sendResponse(greeting)
  297. client.state = ClientCmd
  298. case ClientCmd:
  299. client.bufin.setLimit(CommandLineMaxLength)
  300. input, err := server.readCommand(client, sc.MaxSize)
  301. server.log().Debugf("Client sent: %s", input)
  302. if err == io.EOF {
  303. server.log().WithError(err).Warnf("Client closed the connection: %s", client.RemoteIP)
  304. return
  305. } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
  306. server.log().WithError(err).Warnf("Timeout: %s", client.RemoteIP)
  307. return
  308. } else if err == LineLimitExceeded {
  309. client.sendResponse(response.Canned.FailLineTooLong)
  310. client.kill()
  311. break
  312. } else if err != nil {
  313. server.log().WithError(err).Warnf("Read error: %s", client.RemoteIP)
  314. client.kill()
  315. break
  316. }
  317. if server.isShuttingDown() {
  318. client.state = ClientShutdown
  319. continue
  320. }
  321. input = strings.Trim(input, " \r\n")
  322. cmdLen := len(input)
  323. if cmdLen > CommandVerbMaxLength {
  324. cmdLen = CommandVerbMaxLength
  325. }
  326. cmd := strings.ToUpper(input[:cmdLen])
  327. switch {
  328. case strings.Index(cmd, "HELO") == 0:
  329. client.Helo = strings.Trim(input[4:], " ")
  330. client.resetTransaction()
  331. client.sendResponse(helo)
  332. case strings.Index(cmd, "EHLO") == 0:
  333. client.Helo = strings.Trim(input[4:], " ")
  334. client.resetTransaction()
  335. client.sendResponse(ehlo,
  336. messageSize,
  337. pipelining,
  338. advertiseTLS,
  339. advertiseEnhancedStatusCodes,
  340. help)
  341. case strings.Index(cmd, "HELP") == 0:
  342. quote := response.GetQuote()
  343. client.sendResponse("214-OK\r\n" + quote)
  344. case sc.XClientOn && strings.Index(cmd, "XCLIENT ") == 0:
  345. if toks := strings.Split(input[8:], " "); len(toks) > 0 {
  346. for i := range toks {
  347. if vals := strings.Split(toks[i], "="); len(vals) == 2 {
  348. if vals[1] == "[UNAVAILABLE]" {
  349. // skip
  350. continue
  351. }
  352. if vals[0] == "ADDR" {
  353. client.RemoteIP = vals[1]
  354. }
  355. if vals[0] == "HELO" {
  356. client.Helo = vals[1]
  357. }
  358. }
  359. }
  360. }
  361. client.sendResponse(response.Canned.SuccessMailCmd)
  362. case strings.Index(cmd, "MAIL FROM:") == 0:
  363. if client.isInTransaction() {
  364. client.sendResponse(response.Canned.FailNestedMailCmd)
  365. break
  366. }
  367. addr := input[10:]
  368. if !(strings.Index(addr, "<>") == 0) &&
  369. !(strings.Index(addr, " <>") == 0) {
  370. // Not Bounce, extract mail.
  371. if from, err := extractEmail(addr); err != nil {
  372. client.sendResponse(err)
  373. break
  374. } else {
  375. client.MailFrom = from
  376. server.log().WithFields(map[string]interface{}{
  377. "event": "mailfrom",
  378. "helo": client.Helo,
  379. "domain": from.Host,
  380. "address": getRemoteAddr(client.conn),
  381. "id": client.ID,
  382. }).Info("Mail from")
  383. }
  384. } else {
  385. // bounce has empty from address
  386. client.MailFrom = mail.Address{}
  387. }
  388. client.sendResponse(response.Canned.SuccessMailCmd)
  389. case strings.Index(cmd, "RCPT TO:") == 0:
  390. if len(client.RcptTo) > RFC2821LimitRecipients {
  391. client.sendResponse(response.Canned.ErrorTooManyRecipients)
  392. break
  393. }
  394. to, err := extractEmail(input[8:])
  395. if err != nil {
  396. client.sendResponse(err.Error())
  397. } else {
  398. if !server.allowsHost(to.Host) {
  399. client.sendResponse(response.Canned.ErrorRelayDenied, to.Host)
  400. } else {
  401. client.PushRcpt(to)
  402. rcptError := server.backend().ValidateRcpt(client.Envelope)
  403. if rcptError != nil {
  404. client.PopRcpt()
  405. client.sendResponse(response.Canned.FailRcptCmd + " " + rcptError.Error())
  406. } else {
  407. client.sendResponse(response.Canned.SuccessRcptCmd)
  408. }
  409. }
  410. }
  411. case strings.Index(cmd, "RSET") == 0:
  412. client.resetTransaction()
  413. client.sendResponse(response.Canned.SuccessResetCmd)
  414. case strings.Index(cmd, "VRFY") == 0:
  415. client.sendResponse(response.Canned.SuccessVerifyCmd)
  416. case strings.Index(cmd, "NOOP") == 0:
  417. client.sendResponse(response.Canned.SuccessNoopCmd)
  418. case strings.Index(cmd, "QUIT") == 0:
  419. client.sendResponse(response.Canned.SuccessQuitCmd)
  420. client.kill()
  421. case strings.Index(cmd, "DATA") == 0:
  422. if client.MailFrom.IsEmpty() {
  423. client.sendResponse(response.Canned.FailNoSenderDataCmd)
  424. break
  425. }
  426. if len(client.RcptTo) == 0 {
  427. client.sendResponse(response.Canned.FailNoRecipientsDataCmd)
  428. break
  429. }
  430. client.sendResponse(response.Canned.SuccessDataCmd)
  431. client.state = ClientData
  432. case sc.StartTLSOn && strings.Index(cmd, "STARTTLS") == 0:
  433. client.sendResponse(response.Canned.SuccessStartTLSCmd)
  434. client.state = ClientStartTLS
  435. default:
  436. client.errors++
  437. if client.errors >= MaxUnrecognizedCommands {
  438. client.sendResponse(response.Canned.FailMaxUnrecognizedCmd)
  439. client.kill()
  440. } else {
  441. client.sendResponse(response.Canned.FailUnrecognizedCmd)
  442. }
  443. }
  444. case ClientData:
  445. // intentionally placed the limit 1MB above so that reading does not return with an error
  446. // if the client goes a little over. Anything above will err
  447. client.bufin.setLimit(int64(sc.MaxSize) + 1024000) // This a hard limit.
  448. n, err := client.Data.ReadFrom(client.smtpReader.DotReader())
  449. if n > sc.MaxSize {
  450. err = fmt.Errorf("Maximum DATA size exceeded (%d)", sc.MaxSize)
  451. }
  452. if err != nil {
  453. if err == LineLimitExceeded {
  454. client.sendResponse(response.Canned.FailReadLimitExceededDataCmd, LineLimitExceeded.Error())
  455. client.kill()
  456. } else if err == MessageSizeExceeded {
  457. client.sendResponse(response.Canned.FailMessageSizeExceeded, MessageSizeExceeded.Error())
  458. client.kill()
  459. } else {
  460. client.sendResponse(response.Canned.FailReadErrorDataCmd, err.Error())
  461. client.kill()
  462. }
  463. server.log().WithError(err).Warn("Error reading data")
  464. client.resetTransaction()
  465. break
  466. }
  467. res := server.backend().Process(client.Envelope)
  468. if res.Code() < 300 {
  469. client.messagesSent++
  470. server.log().WithFields(map[string]interface{}{
  471. "helo": client.Helo,
  472. "remoteAddress": getRemoteAddr(client.conn),
  473. "success": true,
  474. }).Info("Received message")
  475. }
  476. client.sendResponse(res.String())
  477. client.state = ClientCmd
  478. if server.isShuttingDown() {
  479. client.state = ClientShutdown
  480. }
  481. client.resetTransaction()
  482. case ClientStartTLS:
  483. if !client.TLS && sc.StartTLSOn {
  484. tlsConfig, ok := server.tlsConfigStore.Load().(*tls.Config)
  485. if !ok {
  486. server.mainlog().Error("Failed to load *tls.Config")
  487. } else if err := client.upgradeToTLS(tlsConfig); err == nil {
  488. advertiseTLS = ""
  489. client.resetTransaction()
  490. } else {
  491. server.log().WithError(err).Warnf("[%s] Failed TLS handshake", client.RemoteIP)
  492. // Don't disconnect, let the client decide if it wants to continue
  493. }
  494. }
  495. // change to command state
  496. client.state = ClientCmd
  497. case ClientShutdown:
  498. // shutdown state
  499. client.sendResponse(response.Canned.ErrorShutdown)
  500. client.kill()
  501. }
  502. if client.bufout.Buffered() > 0 {
  503. if server.log().IsDebug() {
  504. server.log().Debugf("Writing response to client: \n%s", client.response.String())
  505. }
  506. err := server.flushResponse(client)
  507. if err != nil {
  508. server.log().WithError(err).Debug("Error writing response")
  509. return
  510. }
  511. }
  512. }
  513. }
  514. func (s *server) log() log.Logger {
  515. if l, ok := s.logStore.Load().(log.Logger); ok {
  516. return l
  517. }
  518. l, err := log.GetLogger(log.OutputStderr.String(), log.InfoLevel.String())
  519. if err == nil {
  520. s.logStore.Store(l)
  521. }
  522. return l
  523. }
  524. func (s *server) mainlog() log.Logger {
  525. if l, ok := s.mainlogStore.Load().(log.Logger); ok {
  526. return l
  527. }
  528. l, err := log.GetLogger(log.OutputStderr.String(), log.InfoLevel.String())
  529. if err == nil {
  530. s.mainlogStore.Store(l)
  531. }
  532. return l
  533. }