session.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package sshd
  2. import (
  3. "fmt"
  4. "github.com/anmitsu/go-shlex"
  5. "github.com/armon/go-radix"
  6. "github.com/sirupsen/logrus"
  7. "golang.org/x/crypto/ssh"
  8. "golang.org/x/crypto/ssh/terminal"
  9. "sort"
  10. "strings"
  11. )
  12. type session struct {
  13. l *logrus.Entry
  14. c *ssh.ServerConn
  15. term *terminal.Terminal
  16. commands *radix.Tree
  17. exitChan chan bool
  18. }
  19. func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session {
  20. s := &session{
  21. commands: radix.NewFromMap(commands.ToMap()),
  22. l: l,
  23. c: conn,
  24. exitChan: make(chan bool),
  25. }
  26. s.commands.Insert("logout", &Command{
  27. Name: "logout",
  28. ShortDescription: "Ends the current session",
  29. Callback: func(a interface{}, args []string, w StringWriter) error {
  30. s.Close()
  31. return nil
  32. },
  33. })
  34. go s.handleChannels(chans)
  35. return s
  36. }
  37. func (s *session) handleChannels(chans <-chan ssh.NewChannel) {
  38. for newChannel := range chans {
  39. if newChannel.ChannelType() != "session" {
  40. s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type")
  41. newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
  42. continue
  43. }
  44. channel, requests, err := newChannel.Accept()
  45. if err != nil {
  46. s.l.WithError(err).Warn("could not accept channel")
  47. continue
  48. }
  49. go s.handleRequests(requests, channel)
  50. }
  51. }
  52. func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
  53. for req := range in {
  54. var err error
  55. //TODO: maybe support window sizing?
  56. switch req.Type {
  57. case "shell":
  58. if s.term == nil {
  59. s.term = s.createTerm(channel)
  60. err = req.Reply(true, nil)
  61. } else {
  62. err = req.Reply(false, nil)
  63. }
  64. case "pty-req":
  65. err = req.Reply(true, nil)
  66. case "window-change":
  67. err = req.Reply(true, nil)
  68. case "exec":
  69. var payload = struct{ Value string }{}
  70. cErr := ssh.Unmarshal(req.Payload, &payload)
  71. if cErr == nil {
  72. s.dispatchCommand(payload.Value, &stringWriter{channel})
  73. } else {
  74. //TODO: log it
  75. }
  76. channel.Close()
  77. return
  78. default:
  79. s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request")
  80. err = req.Reply(false, nil)
  81. }
  82. if err != nil {
  83. s.l.WithError(err).Info("Error handling ssh session requests")
  84. s.Close()
  85. return
  86. }
  87. }
  88. }
  89. func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal {
  90. //TODO: PS1 with nebula cert name
  91. term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ")
  92. term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
  93. // key 9 is tab
  94. if key == 9 {
  95. cmds := matchCommand(s.commands, line)
  96. if len(cmds) == 1 {
  97. return cmds[0] + " ", len(cmds[0]) + 1, true
  98. }
  99. sort.Strings(cmds)
  100. term.Write([]byte(strings.Join(cmds, "\n") + "\n\n"))
  101. }
  102. return "", 0, false
  103. }
  104. go s.handleInput(channel)
  105. return term
  106. }
  107. func (s *session) handleInput(channel ssh.Channel) {
  108. defer s.Close()
  109. w := &stringWriter{w: s.term}
  110. for {
  111. line, err := s.term.ReadLine()
  112. if err != nil {
  113. //TODO: log
  114. break
  115. }
  116. s.dispatchCommand(line, w)
  117. }
  118. }
  119. func (s *session) dispatchCommand(line string, w StringWriter) {
  120. args, err := shlex.Split(line, true)
  121. if err != nil {
  122. //todo: LOG IT
  123. return
  124. }
  125. if len(args) == 0 {
  126. dumpCommands(s.commands, w)
  127. return
  128. }
  129. c, err := lookupCommand(s.commands, args[0])
  130. if err != nil {
  131. //TODO: handle the error
  132. return
  133. }
  134. if c == nil {
  135. err := w.WriteLine(fmt.Sprintf("did not understand: %s", line))
  136. //TODO: log error
  137. _ = err
  138. dumpCommands(s.commands, w)
  139. return
  140. }
  141. if checkHelpArgs(args) {
  142. s.dispatchCommand(fmt.Sprintf("%s %s", "help", c.Name), w)
  143. return
  144. }
  145. err = execCommand(c, args[1:], w)
  146. if err != nil {
  147. //TODO: log the error
  148. }
  149. return
  150. }
  151. func (s *session) Close() {
  152. s.c.Close()
  153. s.exitChan <- true
  154. }