瀏覽代碼

- refactoring to the TLS code
- add support for tls on by default

flashmob 9 年之前
父節點
當前提交
6951122ad9
共有 1 個文件被更改,包括 29 次插入14 次删除
  1. 29 14
      goguerrilla.go

+ 29 - 14
goguerrilla.go

@@ -224,14 +224,38 @@ func main() {
 
 }
 
+func (server *SmtpdServer) upgradeToTls(client *Client) bool {
+	var tlsConn *tls.Conn
+	tlsConn = tls.Server(client.conn, server.tlsConfig)
+	err := tlsConn.Handshake() // not necessary to call here, but might as well
+	if err == nil {
+		client.conn = net.Conn(tlsConn)
+		client.bufin = bufio.NewReader(client.conn)
+		client.bufout = bufio.NewWriter(client.conn)
+		client.tls_on = true
+		return true;
+	} else {
+		server.logln(1, fmt.Sprintf("Could not TLS handshake:%v", err))
+		return false;
+	}
+
+}
+
 func (server *SmtpdServer) handleClient(client *Client) {
 	defer server.closeClient(client)
+	advertiseTls := "250-STARTTLS\r\n"
+	if server.Config.Is_tls_on {
+		if server.upgradeToTls(client) {
+			advertiseTls = ""
+		}
+	}
 	greeting := "220 " + server.Config.Host_name +
 		" SMTP Guerrilla-SMTPd #" +
 		strconv.FormatInt(client.clientId, 10) +
 		" (" + strconv.Itoa(len(server.sem)) + ") " + time.Now().Format(time.RFC1123Z)
-	advertiseTls := "250-STARTTLS\r\n"
-	if server.Config.Start_tls_on {
+
+	if !server.Config.Start_tls_on {
+		// STARTTLS turned off
 		advertiseTls = ""
 	}
 	for i := 0; i < 100; i++ {
@@ -346,19 +370,10 @@ func (server *SmtpdServer) handleClient(client *Client) {
 			client.state = 1
 		case 3:
 			// upgrade to TLS
-			var tlsConn *tls.Conn
-			tlsConn = tls.Server(client.conn, server.tlsConfig)
-			err := tlsConn.Handshake() // not necessary to call here, but might as well
-			if err == nil {
-				client.conn = net.Conn(tlsConn)
-				client.bufin = bufio.NewReader(client.conn)
-				client.bufout = bufio.NewWriter(client.conn)
-				client.tls_on = true
-			} else {
-				server.logln(1, fmt.Sprintf("Could not TLS handshake:%v", err))
+			if server.upgradeToTls(client) {
+				advertiseTls = ""
+				client.state = 1
 			}
-			advertiseTls = ""
-			client.state = 1
 		}
 		// Send a response back to the client
 		err := server.responseWrite(client)