stun-server.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package stunserver
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "net"
  7. "os"
  8. "os/signal"
  9. "strings"
  10. "sync"
  11. "syscall"
  12. "github.com/gravitl/netmaker/logger"
  13. "github.com/gravitl/netmaker/servercfg"
  14. "github.com/pkg/errors"
  15. "github.com/sirupsen/logrus"
  16. "gortc.io/stun"
  17. )
  18. // Server is RFC 5389 basic server implementation.
  19. //
  20. // Current implementation is UDP only and not utilizes FINGERPRINT mechanism,
  21. // nor ALTERNATE-SERVER, nor credentials mechanisms. It does not support
  22. // backwards compatibility with RFC 3489.
  23. type Server struct {
  24. Addr string
  25. Ctx context.Context
  26. }
  27. // Logger is used for logging formatted messages.
  28. type Logger interface {
  29. // Printf must have the same semantics as log.Printf.
  30. Printf(format string, args ...interface{})
  31. }
  32. var (
  33. defaultLogger = logrus.New()
  34. software = stun.NewSoftware("netmaker-stun")
  35. errNotSTUNMessage = errors.New("not stun message")
  36. )
  37. func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error {
  38. if !stun.IsMessage(b) {
  39. return errNotSTUNMessage
  40. }
  41. if _, err := req.Write(b); err != nil {
  42. return errors.Wrap(err, "failed to read message")
  43. }
  44. var (
  45. ip net.IP
  46. port int
  47. )
  48. switch a := addr.(type) {
  49. case *net.UDPAddr:
  50. ip = a.IP
  51. port = a.Port
  52. default:
  53. panic(fmt.Sprintf("unknown addr: %v", addr))
  54. }
  55. return res.Build(req,
  56. stun.BindingSuccess,
  57. software,
  58. &stun.XORMappedAddress{
  59. IP: ip,
  60. Port: port,
  61. },
  62. stun.Fingerprint,
  63. )
  64. }
  65. func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error {
  66. if c == nil {
  67. return nil
  68. }
  69. buf := make([]byte, 1024)
  70. n, addr, err := c.ReadFrom(buf)
  71. if err != nil {
  72. logger.Log(1, "ReadFrom: %v", err.Error())
  73. return nil
  74. }
  75. log.Printf("read %d bytes from %s\n", n, addr)
  76. if _, err = req.Write(buf[:n]); err != nil {
  77. logger.Log(1, "Write: %v", err.Error())
  78. return err
  79. }
  80. if err = basicProcess(addr, buf[:n], req, res); err != nil {
  81. if err == errNotSTUNMessage {
  82. return nil
  83. }
  84. logger.Log(1, "basicProcess: %v", err.Error())
  85. return nil
  86. }
  87. _, err = c.WriteTo(res.Raw, addr)
  88. if err != nil {
  89. logger.Log(1, "WriteTo: %v", err.Error())
  90. }
  91. return err
  92. }
  93. // Serve reads packets from connections and responds to BINDING requests.
  94. func (s *Server) serve(c net.PacketConn) error {
  95. var (
  96. res = new(stun.Message)
  97. req = new(stun.Message)
  98. )
  99. for {
  100. select {
  101. case <-s.Ctx.Done():
  102. logger.Log(0, "Shutting down stun server...")
  103. c.Close()
  104. return nil
  105. default:
  106. if err := s.serveConn(c, res, req); err != nil {
  107. logger.Log(1, "serve: %v", err.Error())
  108. continue
  109. }
  110. res.Reset()
  111. req.Reset()
  112. }
  113. }
  114. }
  115. // listenUDPAndServe listens on laddr and process incoming packets.
  116. func listenUDPAndServe(ctx context.Context, serverNet, laddr string) error {
  117. c, err := net.ListenPacket(serverNet, laddr)
  118. if err != nil {
  119. return err
  120. }
  121. s := &Server{
  122. Addr: laddr,
  123. Ctx: ctx,
  124. }
  125. return s.serve(c)
  126. }
  127. func normalize(address string) string {
  128. if len(address) == 0 {
  129. address = "0.0.0.0"
  130. }
  131. if !strings.Contains(address, ":") {
  132. address = fmt.Sprintf("%s:%d", address, stun.DefaultPort)
  133. }
  134. return address
  135. }
  136. // Start - starts the stun server
  137. func Start(wg *sync.WaitGroup) {
  138. ctx, cancel := context.WithCancel(context.Background())
  139. go func(wg *sync.WaitGroup) {
  140. defer wg.Done()
  141. quit := make(chan os.Signal, 1)
  142. signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
  143. <-quit
  144. cancel()
  145. }(wg)
  146. normalized := normalize(fmt.Sprintf("0.0.0.0:%d", servercfg.GetStunPort()))
  147. logger.Log(0, "netmaker-stun listening on", normalized, "via udp")
  148. err := listenUDPAndServe(ctx, "udp", normalized)
  149. if err != nil {
  150. logger.Log(0, "failed to start stun server: ", err.Error())
  151. }
  152. }