Forráskód Böngészése

Add locking around ssh conns to avoid concurrent map access on reload (#447)

Nathan Brown 4 éve
szülő
commit
a0735dd7d5
1 módosított fájl, 31 hozzáadás és 12 törlés
  1. 31 12
      sshd/server.go

+ 31 - 12
sshd/server.go

@@ -1,8 +1,10 @@
 package sshd
 
 import (
+	"errors"
 	"fmt"
 	"net"
+	"sync"
 
 	"github.com/armon/go-radix"
 	"github.com/sirupsen/logrus"
@@ -20,8 +22,11 @@ type SSHServer struct {
 	helpCommand *Command
 	commands    *radix.Tree
 	listener    net.Listener
-	conns       map[int]*session
-	counter     int
+
+	// Locks the conns/counter to avoid concurrent map access
+	connsLock sync.Mutex
+	conns     map[int]*session
+	counter   int
 }
 
 // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
@@ -97,11 +102,24 @@ func (s *SSHServer) Run(addr string) error {
 	}
 
 	s.l.WithField("sshListener", addr).Info("SSH server is listening")
+
+	// Run loops until there is an error
+	s.run()
+	s.closeSessions()
+
+	s.l.Info("SSH server stopped listening")
+	// We don't return an error because run logs for us
+	return nil
+}
+
+func (s *SSHServer) run() {
 	for {
 		c, err := s.listener.Accept()
 		if err != nil {
-			s.l.WithError(err).Warn("Error in listener, shutting down")
-			return nil
+			if !errors.Is(err, net.ErrClosed) {
+				s.l.WithError(err).Warn("Error in listener, shutting down")
+			}
+			return
 		}
 
 		conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
@@ -127,37 +145,38 @@ func (s *SSHServer) Run(addr string) error {
 		l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
 
 		session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session"))
+		s.connsLock.Lock()
 		s.counter++
 		counter := s.counter
 		s.conns[counter] = session
+		s.connsLock.Unlock()
 
 		go ssh.DiscardRequests(reqs)
 		go func() {
 			<-session.exitChan
 			s.l.WithField("id", counter).Debug("closing conn")
+			s.connsLock.Lock()
 			delete(s.conns, counter)
+			s.connsLock.Unlock()
 		}()
 	}
 }
 
 func (s *SSHServer) Stop() {
-	// Close the listener first, to prevent any new connections being accepted.
+	// Close the listener, this will cause all session to terminate as well, see SSHServer.Run
 	if s.listener != nil {
 		if err := s.listener.Close(); err != nil {
 			s.l.WithError(err).Warn("Failed to close the sshd listener")
-		} else {
-			s.l.Info("SSH server stopped listening")
 		}
 	}
+}
 
-	// Force close all existing connections.
-	// TODO I believe this has a slight race if the listener has just accepted
-	// a connection. Can fix by moving this to the goroutine that's accepting.
+func (s *SSHServer) closeSessions() {
+	s.connsLock.Lock()
 	for _, c := range s.conns {
 		c.Close()
 	}
-
-	return
+	s.connsLock.Unlock()
 }
 
 func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {