server.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package sshd
  2. import (
  3. "fmt"
  4. "github.com/armon/go-radix"
  5. "github.com/sirupsen/logrus"
  6. "golang.org/x/crypto/ssh"
  7. "net"
  8. )
  9. type SSHServer struct {
  10. config *ssh.ServerConfig
  11. l *logrus.Entry
  12. // Map of user -> authorized keys
  13. trustedKeys map[string]map[string]bool
  14. // List of available commands
  15. helpCommand *Command
  16. commands *radix.Tree
  17. listener net.Listener
  18. conns map[int]*session
  19. counter int
  20. }
  21. // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
  22. func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
  23. s := &SSHServer{
  24. trustedKeys: make(map[string]map[string]bool),
  25. l: l,
  26. commands: radix.New(),
  27. conns: make(map[int]*session),
  28. }
  29. s.config = &ssh.ServerConfig{
  30. PublicKeyCallback: s.matchPubKey,
  31. //TODO: AuthLogCallback: s.authAttempt,
  32. //TODO: version string
  33. ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"),
  34. }
  35. s.RegisterCommand(&Command{
  36. Name: "help",
  37. ShortDescription: "prints available commands or help <command> for specific usage info",
  38. Callback: func(a interface{}, args []string, w StringWriter) error {
  39. return helpCallback(s.commands, args, w)
  40. },
  41. })
  42. return s, nil
  43. }
  44. func (s *SSHServer) SetHostKey(hostPrivateKey []byte) error {
  45. private, err := ssh.ParsePrivateKey(hostPrivateKey)
  46. if err != nil {
  47. return fmt.Errorf("failed to parse private key: %s", err)
  48. }
  49. s.config.AddHostKey(private)
  50. return nil
  51. }
  52. func (s *SSHServer) ClearAuthorizedKeys() {
  53. s.trustedKeys = make(map[string]map[string]bool)
  54. }
  55. // AddAuthorizedKey adds an ssh public key for a user
  56. func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error {
  57. pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey))
  58. if err != nil {
  59. return err
  60. }
  61. tk, ok := s.trustedKeys[user]
  62. if !ok {
  63. tk = make(map[string]bool)
  64. s.trustedKeys[user] = tk
  65. }
  66. tk[string(pk.Marshal())] = true
  67. s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key")
  68. return nil
  69. }
  70. // RegisterCommand adds a command that can be run by a user, by default only `help` is available
  71. func (s *SSHServer) RegisterCommand(c *Command) {
  72. s.commands.Insert(c.Name, c)
  73. }
  74. // Run begins listening and accepting connections
  75. func (s *SSHServer) Run(addr string) error {
  76. var err error
  77. s.listener, err = net.Listen("tcp", addr)
  78. if err != nil {
  79. return err
  80. }
  81. s.l.WithField("sshListener", addr).Info("SSH server is listening")
  82. for {
  83. c, err := s.listener.Accept()
  84. if err != nil {
  85. s.l.WithError(err).Warn("Error in listener, shutting down")
  86. return nil
  87. }
  88. conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
  89. fp := ""
  90. if conn != nil {
  91. fp = conn.Permissions.Extensions["fp"]
  92. }
  93. if err != nil {
  94. l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr())
  95. if conn != nil {
  96. l = l.WithField("sshUser", conn.User())
  97. conn.Close()
  98. }
  99. if fp != "" {
  100. l = l.WithField("sshFingerprint", fp)
  101. }
  102. l.Warn("failed to handshake")
  103. continue
  104. }
  105. l := s.l.WithField("sshUser", conn.User())
  106. l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
  107. session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session"))
  108. s.counter++
  109. counter := s.counter
  110. s.conns[counter] = session
  111. go ssh.DiscardRequests(reqs)
  112. go func() {
  113. <-session.exitChan
  114. s.l.WithField("id", counter).Debug("closing conn")
  115. delete(s.conns, counter)
  116. }()
  117. }
  118. }
  119. func (s *SSHServer) Stop() {
  120. for _, c := range s.conns {
  121. c.Close()
  122. }
  123. if s.listener == nil {
  124. return
  125. }
  126. err := s.listener.Close()
  127. if err != nil {
  128. s.l.WithError(err).Warn("Failed to close the sshd listener")
  129. return
  130. }
  131. s.l.Info("SSH server stopped listening")
  132. return
  133. }
  134. func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
  135. pk := string(pubKey.Marshal())
  136. fp := ssh.FingerprintSHA256(pubKey)
  137. tk, ok := s.trustedKeys[c.User()]
  138. if !ok {
  139. return nil, fmt.Errorf("unknown user %s", c.User())
  140. }
  141. _, ok = tk[pk]
  142. if !ok {
  143. return nil, fmt.Errorf("unknown public key for %s (%s)", c.User(), fp)
  144. }
  145. return &ssh.Permissions{
  146. // Record the public key used for authentication.
  147. Extensions: map[string]string{
  148. "fp": fp,
  149. "user": c.User(),
  150. },
  151. }, nil
  152. }