session.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. package sshd
  2. import (
  3. "fmt"
  4. "sort"
  5. "strings"
  6. "github.com/anmitsu/go-shlex"
  7. "github.com/armon/go-radix"
  8. "github.com/sirupsen/logrus"
  9. "golang.org/x/crypto/ssh"
  10. "golang.org/x/term"
  11. )
  12. type session struct {
  13. l *logrus.Entry
  14. c *ssh.ServerConn
  15. term *term.Terminal
  16. commands *radix.Tree
  17. exitChan chan bool
  18. }
  19. func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session {
  20. s := &session{
  21. commands: radix.NewFromMap(commands.ToMap()),
  22. l: l,
  23. c: conn,
  24. exitChan: make(chan bool),
  25. }
  26. s.commands.Insert("logout", &Command{
  27. Name: "logout",
  28. ShortDescription: "Ends the current session",
  29. Callback: func(a any, args []string, w StringWriter) error {
  30. s.Close()
  31. return nil
  32. },
  33. })
  34. go s.handleChannels(chans)
  35. return s
  36. }
  37. func (s *session) handleChannels(chans <-chan ssh.NewChannel) {
  38. for newChannel := range chans {
  39. if newChannel.ChannelType() != "session" {
  40. s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type")
  41. newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
  42. continue
  43. }
  44. channel, requests, err := newChannel.Accept()
  45. if err != nil {
  46. s.l.WithError(err).Warn("could not accept channel")
  47. continue
  48. }
  49. go s.handleRequests(requests, channel)
  50. }
  51. }
  52. func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
  53. for req := range in {
  54. var err error
  55. switch req.Type {
  56. case "shell":
  57. if s.term == nil {
  58. s.term = s.createTerm(channel)
  59. err = req.Reply(true, nil)
  60. } else {
  61. err = req.Reply(false, nil)
  62. }
  63. case "pty-req":
  64. err = req.Reply(true, nil)
  65. case "window-change":
  66. err = req.Reply(true, nil)
  67. case "exec":
  68. var payload = struct{ Value string }{}
  69. cErr := ssh.Unmarshal(req.Payload, &payload)
  70. if cErr != nil {
  71. req.Reply(false, nil)
  72. return
  73. }
  74. req.Reply(true, nil)
  75. s.dispatchCommand(payload.Value, &stringWriter{channel})
  76. status := struct{ Status uint32 }{uint32(0)}
  77. channel.SendRequest("exit-status", false, ssh.Marshal(status))
  78. channel.Close()
  79. return
  80. default:
  81. s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request")
  82. err = req.Reply(false, nil)
  83. }
  84. if err != nil {
  85. s.l.WithError(err).Info("Error handling ssh session requests")
  86. s.Close()
  87. return
  88. }
  89. }
  90. }
  91. func (s *session) createTerm(channel ssh.Channel) *term.Terminal {
  92. term := term.NewTerminal(channel, s.c.User()+"@nebula > ")
  93. term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
  94. // key 9 is tab
  95. if key == 9 {
  96. cmds := matchCommand(s.commands, line)
  97. if len(cmds) == 1 {
  98. return cmds[0] + " ", len(cmds[0]) + 1, true
  99. }
  100. sort.Strings(cmds)
  101. term.Write([]byte(strings.Join(cmds, "\n") + "\n\n"))
  102. }
  103. return "", 0, false
  104. }
  105. go s.handleInput(channel)
  106. return term
  107. }
  108. func (s *session) handleInput(channel ssh.Channel) {
  109. defer s.Close()
  110. w := &stringWriter{w: s.term}
  111. for {
  112. line, err := s.term.ReadLine()
  113. if err != nil {
  114. break
  115. }
  116. s.dispatchCommand(line, w)
  117. }
  118. }
  119. func (s *session) dispatchCommand(line string, w StringWriter) {
  120. args, err := shlex.Split(line, true)
  121. if err != nil {
  122. return
  123. }
  124. if len(args) == 0 {
  125. dumpCommands(s.commands, w)
  126. return
  127. }
  128. c, err := lookupCommand(s.commands, args[0])
  129. if err != nil {
  130. return
  131. }
  132. if c == nil {
  133. err := w.WriteLine(fmt.Sprintf("did not understand: %s", line))
  134. _ = err
  135. dumpCommands(s.commands, w)
  136. return
  137. }
  138. if checkHelpArgs(args) {
  139. s.dispatchCommand(fmt.Sprintf("%s %s", "help", c.Name), w)
  140. return
  141. }
  142. _ = execCommand(c, args[1:], w)
  143. return
  144. }
  145. func (s *session) Close() {
  146. s.c.Close()
  147. s.exitChan <- true
  148. }