server.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. package sshd
  2. import (
  3. "errors"
  4. "fmt"
  5. "net"
  6. "sync"
  7. "github.com/armon/go-radix"
  8. "github.com/sirupsen/logrus"
  9. "golang.org/x/crypto/ssh"
  10. )
  11. type SSHServer struct {
  12. config *ssh.ServerConfig
  13. l *logrus.Entry
  14. // Map of user -> authorized keys
  15. trustedKeys map[string]map[string]bool
  16. // List of available commands
  17. helpCommand *Command
  18. commands *radix.Tree
  19. listener net.Listener
  20. // Locks the conns/counter to avoid concurrent map access
  21. connsLock sync.Mutex
  22. conns map[int]*session
  23. counter int
  24. }
  25. // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
  26. func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
  27. s := &SSHServer{
  28. trustedKeys: make(map[string]map[string]bool),
  29. l: l,
  30. commands: radix.New(),
  31. conns: make(map[int]*session),
  32. }
  33. s.config = &ssh.ServerConfig{
  34. PublicKeyCallback: s.matchPubKey,
  35. //TODO: AuthLogCallback: s.authAttempt,
  36. //TODO: version string
  37. ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"),
  38. }
  39. s.RegisterCommand(&Command{
  40. Name: "help",
  41. ShortDescription: "prints available commands or help <command> for specific usage info",
  42. Callback: func(a interface{}, args []string, w StringWriter) error {
  43. return helpCallback(s.commands, args, w)
  44. },
  45. })
  46. return s, nil
  47. }
  48. func (s *SSHServer) SetHostKey(hostPrivateKey []byte) error {
  49. private, err := ssh.ParsePrivateKey(hostPrivateKey)
  50. if err != nil {
  51. return fmt.Errorf("failed to parse private key: %s", err)
  52. }
  53. s.config.AddHostKey(private)
  54. return nil
  55. }
  56. func (s *SSHServer) ClearAuthorizedKeys() {
  57. s.trustedKeys = make(map[string]map[string]bool)
  58. }
  59. // AddAuthorizedKey adds an ssh public key for a user
  60. func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error {
  61. pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey))
  62. if err != nil {
  63. return err
  64. }
  65. tk, ok := s.trustedKeys[user]
  66. if !ok {
  67. tk = make(map[string]bool)
  68. s.trustedKeys[user] = tk
  69. }
  70. tk[string(pk.Marshal())] = true
  71. s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key")
  72. return nil
  73. }
  74. // RegisterCommand adds a command that can be run by a user, by default only `help` is available
  75. func (s *SSHServer) RegisterCommand(c *Command) {
  76. s.commands.Insert(c.Name, c)
  77. }
  78. // Run begins listening and accepting connections
  79. func (s *SSHServer) Run(addr string) error {
  80. var err error
  81. s.listener, err = net.Listen("tcp", addr)
  82. if err != nil {
  83. return err
  84. }
  85. s.l.WithField("sshListener", addr).Info("SSH server is listening")
  86. // Run loops until there is an error
  87. s.run()
  88. s.closeSessions()
  89. s.l.Info("SSH server stopped listening")
  90. // We don't return an error because run logs for us
  91. return nil
  92. }
  93. func (s *SSHServer) run() {
  94. for {
  95. c, err := s.listener.Accept()
  96. if err != nil {
  97. if !errors.Is(err, net.ErrClosed) {
  98. s.l.WithError(err).Warn("Error in listener, shutting down")
  99. }
  100. return
  101. }
  102. conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
  103. fp := ""
  104. if conn != nil {
  105. fp = conn.Permissions.Extensions["fp"]
  106. }
  107. if err != nil {
  108. l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr())
  109. if conn != nil {
  110. l = l.WithField("sshUser", conn.User())
  111. conn.Close()
  112. }
  113. if fp != "" {
  114. l = l.WithField("sshFingerprint", fp)
  115. }
  116. l.Warn("failed to handshake")
  117. continue
  118. }
  119. l := s.l.WithField("sshUser", conn.User())
  120. l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
  121. session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session"))
  122. s.connsLock.Lock()
  123. s.counter++
  124. counter := s.counter
  125. s.conns[counter] = session
  126. s.connsLock.Unlock()
  127. go ssh.DiscardRequests(reqs)
  128. go func() {
  129. <-session.exitChan
  130. s.l.WithField("id", counter).Debug("closing conn")
  131. s.connsLock.Lock()
  132. delete(s.conns, counter)
  133. s.connsLock.Unlock()
  134. }()
  135. }
  136. }
  137. func (s *SSHServer) Stop() {
  138. // Close the listener, this will cause all session to terminate as well, see SSHServer.Run
  139. if s.listener != nil {
  140. if err := s.listener.Close(); err != nil {
  141. s.l.WithError(err).Warn("Failed to close the sshd listener")
  142. }
  143. }
  144. }
  145. func (s *SSHServer) closeSessions() {
  146. s.connsLock.Lock()
  147. for _, c := range s.conns {
  148. c.Close()
  149. }
  150. s.connsLock.Unlock()
  151. }
  152. func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
  153. pk := string(pubKey.Marshal())
  154. fp := ssh.FingerprintSHA256(pubKey)
  155. tk, ok := s.trustedKeys[c.User()]
  156. if !ok {
  157. return nil, fmt.Errorf("unknown user %s", c.User())
  158. }
  159. _, ok = tk[pk]
  160. if !ok {
  161. return nil, fmt.Errorf("unknown public key for %s (%s)", c.User(), fp)
  162. }
  163. return &ssh.Permissions{
  164. // Record the public key used for authentication.
  165. Extensions: map[string]string{
  166. "fp": fp,
  167. "user": c.User(),
  168. },
  169. }, nil
  170. }