Browse Source

add context to proxy and stun

Abhishek Kondur 2 years ago
parent
commit
4d7691b71b
6 changed files with 140 additions and 83 deletions
  1. 12 2
      main.go
  2. 13 9
      netclient/functions/daemon.go
  3. 5 3
      nm-proxy/nm-proxy.go
  4. 59 41
      nm-proxy/server/server.go
  5. 12 8
      nm-proxy/stun/stun.go
  6. 39 20
      stun-server/stun-server.go

+ 12 - 2
main.go

@@ -173,14 +173,24 @@ func startControllers() {
 		logger.Log(0, "No Server Mode selected, so nothing is being served! Set Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) or MessageQueue (MESSAGEQUEUE_BACKEND) to 'true'.")
 	}
 	// starts the stun server
-	go stunserver.Start()
-	go nmproxy.Start(logic.ProxyMgmChan)
+	waitnetwork.Add(1)
+	go stunserver.Start(&waitnetwork)
+	waitnetwork.Add(1)
 	go func() {
+		defer waitnetwork.Done()
+		ctx, cancel := context.WithCancel(context.Background())
+		waitnetwork.Add(1)
+		go nmproxy.Start(ctx, logic.ProxyMgmChan, servercfg.GetAPIHost())
 		err := serverctl.SyncServerNetworkWithProxy()
 		if err != nil {
 			logger.Log(0, "failed to sync proxy with server interfaces: ", err.Error())
 		}
+		quit := make(chan os.Signal, 1)
+		signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
+		<-quit
+		cancel()
 	}()
+
 	waitnetwork.Wait()
 }
 

+ 13 - 9
netclient/functions/daemon.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"io"
 	"log"
+	"net"
 	"net/http"
 	"os"
 	"os/signal"
@@ -33,7 +34,7 @@ import (
 
 var ProxyMgmChan = make(chan *manager.ManagerAction, 100)
 var messageCache = new(sync.Map)
-var ProxyStatus = "OFF"
+
 var serverSet map[string]bool
 
 var mqclient mqtt.Client
@@ -123,16 +124,19 @@ func startGoRoutines(wg *sync.WaitGroup) context.CancelFunc {
 	}
 	wg.Add(1)
 	go Checkin(ctx, wg)
-	if ProxyStatus == "OFF" {
-		ProxyStatus = "ON"
-		go nmproxy.Start(ProxyMgmChan)
-	} else {
-		log.Println("Proxy already running...")
+
+	if len(networks) != 0 {
+		cfg := config.ClientConfig{}
+		cfg.Network = networks[0]
+		cfg.ReadConfig()
+		apiHost, _, err := net.SplitHostPort(cfg.Server.API)
+		if err == nil {
+			go nmproxy.Start(ctx, ProxyMgmChan, apiHost)
+		}
 	}
 
-	go func() {
+	go func(networks []string) {
 
-		networks, _ := ncutils.GetSystemNetworks()
 		for _, network := range networks {
 			logger.Log(0, "Collecting interface and peers info to configure proxy...")
 			cfg := config.ClientConfig{}
@@ -153,7 +157,7 @@ func startGoRoutines(wg *sync.WaitGroup) context.CancelFunc {
 
 		}
 
-	}()
+	}(networks)
 	return cancel
 }
 func GetNodeInfo(cfg *config.ClientConfig) (models.NodeGet, error) {

+ 5 - 3
nm-proxy/nm-proxy.go

@@ -1,6 +1,7 @@
 package nmproxy
 
 import (
+	"context"
 	"log"
 	"net"
 	"os"
@@ -17,11 +18,11 @@ import (
    2. Delete - remove close all conns for the interface,cleanup
 
 */
-func Start(mgmChan chan *manager.ManagerAction) {
+func Start(ctx context.Context, mgmChan chan *manager.ManagerAction, apiServerAddr string) {
 	log.Println("Starting Proxy...")
 	common.IsHostNetwork = (os.Getenv("HOST_NETWORK") == "" || os.Getenv("HOST_NETWORK") == "on")
 	go manager.StartProxyManager(mgmChan)
-	hInfo := stun.GetHostInfo()
+	hInfo := stun.GetHostInfo(apiServerAddr)
 	stun.Host = hInfo
 	log.Printf("HOSTINFO: %+v", hInfo)
 	if IsPublicIP(hInfo.PrivIp) {
@@ -32,7 +33,8 @@ func Start(mgmChan chan *manager.ManagerAction) {
 	if err != nil {
 		log.Fatal("failed to create proxy: ", err)
 	}
-	server.NmProxyServer.Listen()
+	server.NmProxyServer.Listen(ctx)
+
 }
 
 // IsPublicIP indicates whether IP is public or not.

+ 59 - 41
nm-proxy/server/server.go

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"context"
 	"fmt"
 	"log"
 	"net"
@@ -32,63 +33,80 @@ type ProxyServer struct {
 }
 
 // Proxy.Listen - begins listening for packets
-func (p *ProxyServer) Listen() {
+func (p *ProxyServer) Listen(ctx context.Context) {
 
 	// Buffer with indicated body size
 	buffer := make([]byte, 1532)
 	for {
-		// Read Packet
-		n, source, err := p.Server.ReadFromUDP(buffer)
-		if err != nil { // in future log errors?
-			log.Println("RECV ERROR: ", err)
-			continue
-		}
-		var srcPeerKeyHash, dstPeerKeyHash string
-		n, srcPeerKeyHash, dstPeerKeyHash = packet.ExtractInfo(buffer, n)
-		//log.Printf("--------> RECV PKT [DSTPORT: %d], [SRCKEYHASH: %s], SourceIP: [%s] \n", localWgPort, srcPeerKeyHash, source.IP.String())
-		if common.IsRelay && dstPeerKeyHash != "" && srcPeerKeyHash != "" {
-			if _, ok := common.WgIfaceKeyMap[dstPeerKeyHash]; !ok {
-
-				log.Println("----------> Relaying######")
-				// check for routing map and forward to right proxy
-				if remoteMap, ok := common.RelayPeerMap[srcPeerKeyHash]; ok {
-					if conf, ok := remoteMap[dstPeerKeyHash]; ok {
-						log.Printf("--------> Relaying PKT [ SourceIP: %s:%d ], [ SourceKeyHash: %s ], [ DstIP: %s:%d ], [ DstHashKey: %s ] \n",
-							source.IP.String(), source.Port, srcPeerKeyHash, conf.Endpoint.String(), conf.Endpoint.Port, dstPeerKeyHash)
-						_, err = NmProxyServer.Server.WriteToUDP(buffer[:n+32], conf.Endpoint)
-						if err != nil {
-							log.Println("Failed to send to remote: ", err)
-						}
-					}
-				} else {
-					if remoteMap, ok := common.RelayPeerMap[dstPeerKeyHash]; ok {
+
+		select {
+		case <-ctx.Done():
+			log.Println("--------->### Shutting down Proxy.....")
+			// clean up proxy connections
+			for iface, peers := range common.WgIFaceMap {
+				log.Println("########------------>  CLEANING UP: ", iface)
+				for _, peerI := range peers {
+					peerI.Proxy.Cancel()
+				}
+			}
+			// close server connection
+			NmProxyServer.Server.Close()
+			return
+		default:
+			// Read Packet
+			n, source, err := p.Server.ReadFromUDP(buffer)
+			if err != nil { // in future log errors?
+				log.Println("RECV ERROR: ", err)
+				continue
+			}
+			var srcPeerKeyHash, dstPeerKeyHash string
+			n, srcPeerKeyHash, dstPeerKeyHash = packet.ExtractInfo(buffer, n)
+			//log.Printf("--------> RECV PKT [DSTPORT: %d], [SRCKEYHASH: %s], SourceIP: [%s] \n", localWgPort, srcPeerKeyHash, source.IP.String())
+			if common.IsRelay && dstPeerKeyHash != "" && srcPeerKeyHash != "" {
+				if _, ok := common.WgIfaceKeyMap[dstPeerKeyHash]; !ok {
+
+					log.Println("----------> Relaying######")
+					// check for routing map and forward to right proxy
+					if remoteMap, ok := common.RelayPeerMap[srcPeerKeyHash]; ok {
 						if conf, ok := remoteMap[dstPeerKeyHash]; ok {
-							log.Printf("--------> Relaying BACK TO RELAYED NODE PKT [ SourceIP: %s ], [ SourceKeyHash: %s ], [ DstIP: %s ], [ DstHashKey: %s ] \n",
-								source.String(), srcPeerKeyHash, conf.Endpoint.String(), dstPeerKeyHash)
+							log.Printf("--------> Relaying PKT [ SourceIP: %s:%d ], [ SourceKeyHash: %s ], [ DstIP: %s:%d ], [ DstHashKey: %s ] \n",
+								source.IP.String(), source.Port, srcPeerKeyHash, conf.Endpoint.String(), conf.Endpoint.Port, dstPeerKeyHash)
 							_, err = NmProxyServer.Server.WriteToUDP(buffer[:n+32], conf.Endpoint)
 							if err != nil {
 								log.Println("Failed to send to remote: ", err)
 							}
 						}
+					} else {
+						if remoteMap, ok := common.RelayPeerMap[dstPeerKeyHash]; ok {
+							if conf, ok := remoteMap[dstPeerKeyHash]; ok {
+								log.Printf("--------> Relaying BACK TO RELAYED NODE PKT [ SourceIP: %s ], [ SourceKeyHash: %s ], [ DstIP: %s ], [ DstHashKey: %s ] \n",
+									source.String(), srcPeerKeyHash, conf.Endpoint.String(), dstPeerKeyHash)
+								_, err = NmProxyServer.Server.WriteToUDP(buffer[:n+32], conf.Endpoint)
+								if err != nil {
+									log.Println("Failed to send to remote: ", err)
+								}
+							}
+						}
 					}
-				}
 
+				}
 			}
-		}
 
-		if peerInfo, ok := common.PeerKeyHashMap[srcPeerKeyHash]; ok {
-			if peers, ok := common.WgIFaceMap[peerInfo.Interface]; ok {
-				if peerI, ok := peers[peerInfo.PeerKey]; ok {
-					log.Printf("PROXING TO LOCAL!!!---> %s <<<< %s <<<<<<<< %s   [[ RECV PKT [SRCKEYHASH: %s], [DSTKEYHASH: %s], SourceIP: [%s] ]]\n",
-						peerI.Proxy.LocalConn.RemoteAddr(), peerI.Proxy.LocalConn.LocalAddr(),
-						fmt.Sprintf("%s:%d", source.IP.String(), source.Port), srcPeerKeyHash, dstPeerKeyHash, source.IP.String())
-					_, err = peerI.Proxy.LocalConn.Write(buffer[:n])
-					if err != nil {
-						log.Println("Failed to proxy to Wg local interface: ", err)
-						continue
-					}
+			if peerInfo, ok := common.PeerKeyHashMap[srcPeerKeyHash]; ok {
+				if peers, ok := common.WgIFaceMap[peerInfo.Interface]; ok {
+					if peerI, ok := peers[peerInfo.PeerKey]; ok {
+						log.Printf("PROXING TO LOCAL!!!---> %s <<<< %s <<<<<<<< %s   [[ RECV PKT [SRCKEYHASH: %s], [DSTKEYHASH: %s], SourceIP: [%s] ]]\n",
+							peerI.Proxy.LocalConn.RemoteAddr(), peerI.Proxy.LocalConn.LocalAddr(),
+							fmt.Sprintf("%s:%d", source.IP.String(), source.Port), srcPeerKeyHash, dstPeerKeyHash, source.IP.String())
+						_, err = peerI.Proxy.LocalConn.Write(buffer[:n])
+						if err != nil {
+							log.Println("Failed to proxy to Wg local interface: ", err)
+							continue
+						}
 
+					}
 				}
+
 			}
 
 		}

+ 12 - 8
nm-proxy/stun/stun.go

@@ -20,11 +20,12 @@ type HostInfo struct {
 
 var Host HostInfo
 
-func GetHostInfo() (info HostInfo) {
+func GetHostInfo(stunHostAddr string) (info HostInfo) {
 
-	s, err := net.ResolveUDPAddr("udp", "stun.nm.134.209.115.146.nip.io:3478")
+	s, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:3478", stunHostAddr))
 	if err != nil {
 		log.Println("Resolve: ", err)
+		return
 	}
 	l := &net.UDPAddr{
 		IP:   net.ParseIP(""),
@@ -32,13 +33,14 @@ func GetHostInfo() (info HostInfo) {
 	}
 	conn, err := net.DialUDP("udp", l, s)
 	if err != nil {
-		log.Fatal(err)
+		log.Println(err)
+		return
 	}
 	defer conn.Close()
-	fmt.Printf("%+v\n", conn.LocalAddr())
 	c, err := stun.NewClient(conn)
 	if err != nil {
-		panic(err)
+		log.Println(err)
+		return
 	}
 	defer c.Close()
 	re := strings.Split(conn.LocalAddr().String(), ":")
@@ -49,17 +51,19 @@ func GetHostInfo() (info HostInfo) {
 	// Sending request to STUN server, waiting for response message.
 	if err := c.Do(message, func(res stun.Event) {
 		if res.Error != nil {
-			panic(res.Error)
+			log.Println("stun error: ", res.Error)
+			return
 		}
 		// Decoding XOR-MAPPED-ADDRESS attribute from message.
 		var xorAddr stun.XORMappedAddress
 		if err := xorAddr.GetFrom(res.Message); err != nil {
-			panic(err)
+			log.Println("stun error: ", res.Error)
+			return
 		}
 		info.PublicIp = xorAddr.IP
 		info.PubPort = xorAddr.Port
 	}); err != nil {
-		panic(err)
+		log.Println("stun error: ", err)
 	}
 	return
 }

+ 39 - 20
stun-server/stun-server.go

@@ -1,10 +1,15 @@
 package stunserver
 
 import (
+	"context"
 	"fmt"
 	"log"
 	"net"
+	"os"
+	"os/signal"
 	"strings"
+	"sync"
+	"syscall"
 
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/servercfg"
@@ -19,9 +24,8 @@ import (
 // nor ALTERNATE-SERVER, nor credentials mechanisms. It does not support
 // backwards compatibility with RFC 3489.
 type Server struct {
-	Addr         string
-	LogAllErrors bool
-	log          Logger
+	Addr string
+	Ctx  context.Context
 }
 
 // Logger is used for logging formatted messages.
@@ -72,54 +76,62 @@ func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error {
 	buf := make([]byte, 1024)
 	n, addr, err := c.ReadFrom(buf)
 	if err != nil {
-		s.log.Printf("ReadFrom: %v", err)
+		logger.Log(1, "ReadFrom: %v", err.Error())
 		return nil
 	}
 	log.Printf("read %d bytes from %s\n", n, addr)
 	if _, err = req.Write(buf[:n]); err != nil {
-		s.log.Printf("Write: %v", err)
+		logger.Log(1, "Write: %v", err.Error())
 		return err
 	}
 	if err = basicProcess(addr, buf[:n], req, res); err != nil {
 		if err == errNotSTUNMessage {
 			return nil
 		}
-		s.log.Printf("basicProcess: %v", err)
+		logger.Log(1, "basicProcess: %v", err.Error())
 		return nil
 	}
 	_, err = c.WriteTo(res.Raw, addr)
 	if err != nil {
-		s.log.Printf("WriteTo: %v", err)
+		logger.Log(1, "WriteTo: %v", err.Error())
 	}
 	return err
 }
 
 // Serve reads packets from connections and responds to BINDING requests.
-func (s *Server) Serve(c net.PacketConn) error {
+func (s *Server) serve(c net.PacketConn) error {
 	var (
 		res = new(stun.Message)
 		req = new(stun.Message)
 	)
 	for {
-		if err := s.serveConn(c, res, req); err != nil {
-			s.log.Printf("serve: %v", err)
-			return err
+		select {
+		case <-s.Ctx.Done():
+			logger.Log(0, "Shutting down stun server...")
+			c.Close()
+			return nil
+		default:
+			if err := s.serveConn(c, res, req); err != nil {
+				logger.Log(1, "serve: %v", err.Error())
+				continue
+			}
+			res.Reset()
+			req.Reset()
 		}
-		res.Reset()
-		req.Reset()
 	}
 }
 
-// ListenUDPAndServe listens on laddr and process incoming packets.
-func ListenUDPAndServe(serverNet, laddr string) error {
+// listenUDPAndServe listens on laddr and process incoming packets.
+func listenUDPAndServe(ctx context.Context, serverNet, laddr string) error {
 	c, err := net.ListenPacket(serverNet, laddr)
 	if err != nil {
 		return err
 	}
 	s := &Server{
-		log: defaultLogger,
+		Addr: laddr,
+		Ctx:  ctx,
 	}
-	return s.Serve(c)
+	return s.serve(c)
 }
 
 func normalize(address string) string {
@@ -132,11 +144,18 @@ func normalize(address string) string {
 	return address
 }
 
-func Start() {
-
+func Start(wg *sync.WaitGroup) {
+	defer wg.Done()
+	ctx, cancel := context.WithCancel(context.Background())
+	go func() {
+		quit := make(chan os.Signal, 1)
+		signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
+		<-quit
+		cancel()
+	}()
 	normalized := normalize(fmt.Sprintf("0.0.0.0:%s", servercfg.GetStunPort()))
 	logger.Log(0, "netmaker-stun listening on", normalized, "via udp")
-	err := ListenUDPAndServe("udp", normalized)
+	err := listenUDPAndServe(ctx, "udp", normalized)
 	if err != nil {
 		logger.Log(0, "failed to start stun server: ", err.Error())
 	}