stun-server.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. package stunserver
  2. import (
  3. "context"
  4. "fmt"
  5. "net"
  6. "os"
  7. "os/signal"
  8. "strings"
  9. "sync"
  10. "syscall"
  11. "github.com/gravitl/netmaker/logger"
  12. "github.com/gravitl/netmaker/servercfg"
  13. "github.com/pkg/errors"
  14. "gortc.io/stun"
  15. )
  16. // Server is RFC 5389 basic server implementation.
  17. //
  18. // Current implementation is UDP only and not utilizes FINGERPRINT mechanism,
  19. // nor ALTERNATE-SERVER, nor credentials mechanisms. It does not support
  20. // backwards compatibility with RFC 3489.
  21. type Server struct {
  22. Addr string
  23. Ctx context.Context
  24. }
  25. var (
  26. software = stun.NewSoftware("netmaker-stun")
  27. errNotSTUNMessage = errors.New("not stun message")
  28. )
  29. func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error {
  30. if !stun.IsMessage(b) {
  31. return errNotSTUNMessage
  32. }
  33. if _, err := req.Write(b); err != nil {
  34. return errors.Wrap(err, "failed to read message")
  35. }
  36. var (
  37. ip net.IP
  38. port int
  39. )
  40. switch a := addr.(type) {
  41. case *net.UDPAddr:
  42. ip = a.IP
  43. port = a.Port
  44. default:
  45. panic(fmt.Sprintf("unknown addr: %v", addr))
  46. }
  47. return res.Build(req,
  48. stun.BindingSuccess,
  49. software,
  50. &stun.XORMappedAddress{
  51. IP: ip,
  52. Port: port,
  53. },
  54. stun.Fingerprint,
  55. )
  56. }
  57. func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error {
  58. if c == nil {
  59. return nil
  60. }
  61. buf := make([]byte, 1024)
  62. n, addr, err := c.ReadFrom(buf)
  63. if err != nil {
  64. logger.Log(1, "ReadFrom: %v", err.Error())
  65. return nil
  66. }
  67. if _, err = req.Write(buf[:n]); err != nil {
  68. logger.Log(1, "Write: %v", err.Error())
  69. return err
  70. }
  71. if err = basicProcess(addr, buf[:n], req, res); err != nil {
  72. if err == errNotSTUNMessage {
  73. return nil
  74. }
  75. logger.Log(1, "basicProcess: %v", err.Error())
  76. return nil
  77. }
  78. _, err = c.WriteTo(res.Raw, addr)
  79. if err != nil {
  80. logger.Log(1, "WriteTo: %v", err.Error())
  81. }
  82. return err
  83. }
  84. // Serve reads packets from connections and responds to BINDING requests.
  85. func (s *Server) serve(c net.PacketConn) error {
  86. var (
  87. res = new(stun.Message)
  88. req = new(stun.Message)
  89. )
  90. for {
  91. select {
  92. case <-s.Ctx.Done():
  93. logger.Log(0, "Shutting down stun server...")
  94. c.Close()
  95. return nil
  96. default:
  97. if err := s.serveConn(c, res, req); err != nil {
  98. logger.Log(1, "serve: %v", err.Error())
  99. continue
  100. }
  101. res.Reset()
  102. req.Reset()
  103. }
  104. }
  105. }
  106. // listenUDPAndServe listens on laddr and process incoming packets.
  107. func listenUDPAndServe(ctx context.Context, serverNet, laddr string) error {
  108. c, err := net.ListenPacket(serverNet, laddr)
  109. if err != nil {
  110. return err
  111. }
  112. s := &Server{
  113. Addr: laddr,
  114. Ctx: ctx,
  115. }
  116. return s.serve(c)
  117. }
  118. func normalize(address string) string {
  119. if len(address) == 0 {
  120. address = "0.0.0.0"
  121. }
  122. if !strings.Contains(address, ":") {
  123. address = fmt.Sprintf("%s:%d", address, stun.DefaultPort)
  124. }
  125. return address
  126. }
  127. // Start - starts the stun server
  128. func Start(wg *sync.WaitGroup) {
  129. ctx, cancel := context.WithCancel(context.Background())
  130. go func(wg *sync.WaitGroup) {
  131. defer wg.Done()
  132. quit := make(chan os.Signal, 1)
  133. signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
  134. <-quit
  135. cancel()
  136. }(wg)
  137. normalized := normalize(fmt.Sprintf("0.0.0.0:%d", servercfg.GetStunPort()))
  138. logger.Log(0, "netmaker-stun listening on", normalized, "via udp")
  139. err := listenUDPAndServe(ctx, "udp", normalized)
  140. if err != nil {
  141. logger.Log(0, "failed to start stun server: ", err.Error())
  142. }
  143. }