Browse Source

Revert "adjusted main to use one single context"

This reverts commit 92d0d12e8fb87084424faff10cf1d50ee9a834ec.
0xdcarns 2 years ago
parent
commit
2749e7311b
5 changed files with 105 additions and 56 deletions
  1. 11 4
      controllers/controller.go
  2. 30 0
      logic/nodes.go
  3. 35 17
      main.go
  4. 0 5
      mq/mq.go
  5. 29 30
      stun-server/stun-server.go

+ 11 - 4
controllers/controller.go

@@ -4,8 +4,11 @@ import (
 	"context"
 	"context"
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
+	"os"
+	"os/signal"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
+	"syscall"
 	"time"
 	"time"
 
 
 	"github.com/gorilla/handlers"
 	"github.com/gorilla/handlers"
@@ -30,7 +33,7 @@ var HttpHandlers = []interface{}{
 }
 }
 
 
 // HandleRESTRequests - handles the rest requests
 // HandleRESTRequests - handles the rest requests
-func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) {
+func HandleRESTRequests(wg *sync.WaitGroup) {
 	defer wg.Done()
 	defer wg.Done()
 
 
 	r := mux.NewRouter()
 	r := mux.NewRouter()
@@ -56,14 +59,18 @@ func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) {
 	}()
 	}()
 	logger.Log(0, "REST Server successfully started on port ", port, " (REST)")
 	logger.Log(0, "REST Server successfully started on port ", port, " (REST)")
 
 
+	// Relay os.Interrupt to our channel (os.Interrupt = CTRL+C)
+	// Ignore other incoming signals
+	ctx, stop := signal.NotifyContext(context.TODO(), syscall.SIGTERM, os.Interrupt)
+	defer stop()
+
 	// Block main routine until a signal is received
 	// Block main routine until a signal is received
 	// As long as user doesn't press CTRL+C a message is not passed and our main routine keeps running
 	// As long as user doesn't press CTRL+C a message is not passed and our main routine keeps running
 	<-ctx.Done()
 	<-ctx.Done()
+
 	// After receiving CTRL+C Properly stop the server
 	// After receiving CTRL+C Properly stop the server
 	logger.Log(0, "Stopping the REST server...")
 	logger.Log(0, "Stopping the REST server...")
-	if err := srv.Shutdown(context.TODO()); err != nil {
-		logger.Log(0, "REST shutdown error occurred -", err.Error())
-	}
 	logger.Log(0, "REST Server closed.")
 	logger.Log(0, "REST Server closed.")
 	logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay)))
 	logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay)))
+	srv.Shutdown(context.TODO())
 }
 }

+ 30 - 0
logic/nodes.go

@@ -1,6 +1,7 @@
 package logic
 package logic
 
 
 import (
 import (
+	"context"
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
@@ -420,6 +421,35 @@ func updateProNodeACLS(node *models.Node) error {
 	return nil
 	return nil
 }
 }
 
 
+func PurgePendingNodes(ctx context.Context) {
+	ticker := time.NewTicker(NodePurgeCheckTime)
+	defer ticker.Stop()
+	for {
+		select {
+		case <-ctx.Done():
+			return
+		case <-ticker.C:
+			nodes, err := GetAllNodes()
+			if err != nil {
+				logger.Log(0, "PurgePendingNodes failed to retrieve nodes", err.Error())
+				continue
+			}
+			for _, node := range nodes {
+				if node.PendingDelete {
+					modified := node.LastModified
+					if time.Since(modified) > NodePurgeTime {
+						if err := DeleteNode(&node, true); err != nil {
+							logger.Log(0, "failed to purge node", node.ID.String(), err.Error())
+						} else {
+							logger.Log(0, "purged node ", node.ID.String())
+						}
+					}
+				}
+			}
+		}
+	}
+}
+
 // createNode - creates a node in database
 // createNode - creates a node in database
 func createNode(node *models.Node) error {
 func createNode(node *models.Node) error {
 	host, err := GetHost(node.HostID.String())
 	host, err := GetHost(node.HostID.String())

+ 35 - 17
main.go

@@ -36,16 +36,12 @@ func main() {
 	setupConfig(*absoluteConfigPath)
 	setupConfig(*absoluteConfigPath)
 	servercfg.SetVersion(version)
 	servercfg.SetVersion(version)
 	fmt.Println(models.RetrieveLogo()) // print the logo
 	fmt.Println(models.RetrieveLogo()) // print the logo
-	initialize()                       // initial db and acls
+	// fmt.Println(models.ProLogo())
+	initialize() // initial db and acls; gen cert if required
 	setGarbageCollection()
 	setGarbageCollection()
 	setVerbosity()
 	setVerbosity()
 	defer database.CloseDB()
 	defer database.CloseDB()
-	ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, os.Interrupt)
-	defer stop()
-	var waitGroup sync.WaitGroup
-	startControllers(&waitGroup, ctx) // start the api endpoint and mq and stun
-	<-ctx.Done()
-	waitGroup.Wait()
+	startControllers() // start the api endpoint and mq
 }
 }
 
 
 func setupConfig(absoluteConfigPath string) {
 func setupConfig(absoluteConfigPath string) {
@@ -114,7 +110,8 @@ func initialize() { // Client Mode Prereq Check
 	}
 	}
 }
 }
 
 
-func startControllers(wg *sync.WaitGroup, ctx context.Context) {
+func startControllers() {
+	var waitnetwork sync.WaitGroup
 	if servercfg.IsDNSMode() {
 	if servercfg.IsDNSMode() {
 		err := logic.SetDNS()
 		err := logic.SetDNS()
 		if err != nil {
 		if err != nil {
@@ -130,13 +127,13 @@ func startControllers(wg *sync.WaitGroup, ctx context.Context) {
 				logger.FatalLog("Unable to Set host. Exiting...", err.Error())
 				logger.FatalLog("Unable to Set host. Exiting...", err.Error())
 			}
 			}
 		}
 		}
-		wg.Add(1)
-		go controller.HandleRESTRequests(wg, ctx)
+		waitnetwork.Add(1)
+		go controller.HandleRESTRequests(&waitnetwork)
 	}
 	}
 	//Run MessageQueue
 	//Run MessageQueue
 	if servercfg.IsMessageQueueBackend() {
 	if servercfg.IsMessageQueueBackend() {
-		wg.Add(1)
-		go runMessageQueue(wg, ctx)
+		waitnetwork.Add(1)
+		go runMessageQueue(&waitnetwork)
 	}
 	}
 
 
 	if !servercfg.IsRestBackend() && !servercfg.IsMessageQueueBackend() {
 	if !servercfg.IsRestBackend() && !servercfg.IsMessageQueueBackend() {
@@ -144,17 +141,34 @@ func startControllers(wg *sync.WaitGroup, ctx context.Context) {
 	}
 	}
 
 
 	// starts the stun server
 	// starts the stun server
-	wg.Add(1)
-	go stunserver.Start(wg, ctx)
+	waitnetwork.Add(1)
+	go stunserver.Start(&waitnetwork)
+	if servercfg.IsProxyEnabled() {
+
+		waitnetwork.Add(1)
+		go func() {
+			defer waitnetwork.Done()
+			_, cancel := context.WithCancel(context.Background())
+			waitnetwork.Add(1)
+
+			//go nmproxy.Start(ctx, logic.ProxyMgmChan, servercfg.GetAPIHost())
+			quit := make(chan os.Signal, 1)
+			signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
+			<-quit
+			cancel()
+		}()
+	}
+
+	waitnetwork.Wait()
 }
 }
 
 
 // Should we be using a context vice a waitgroup????????????
 // Should we be using a context vice a waitgroup????????????
-func runMessageQueue(wg *sync.WaitGroup, ctx context.Context) {
+func runMessageQueue(wg *sync.WaitGroup) {
 	defer wg.Done()
 	defer wg.Done()
 	brokerHost, secure := servercfg.GetMessageQueueEndpoint()
 	brokerHost, secure := servercfg.GetMessageQueueEndpoint()
 	logger.Log(0, "connecting to mq broker at", brokerHost, "with TLS?", fmt.Sprintf("%v", secure))
 	logger.Log(0, "connecting to mq broker at", brokerHost, "with TLS?", fmt.Sprintf("%v", secure))
 	mq.SetupMQTT()
 	mq.SetupMQTT()
-	defer mq.CloseClient()
+	ctx, cancel := context.WithCancel(context.Background())
 	go mq.Keepalive(ctx)
 	go mq.Keepalive(ctx)
 	go func() {
 	go func() {
 		peerUpdate := make(chan *models.Node)
 		peerUpdate := make(chan *models.Node)
@@ -165,7 +179,11 @@ func runMessageQueue(wg *sync.WaitGroup, ctx context.Context) {
 			}
 			}
 		}
 		}
 	}()
 	}()
-	<-ctx.Done()
+	go logic.PurgePendingNodes(ctx)
+	quit := make(chan os.Signal, 1)
+	signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
+	<-quit
+	cancel()
 	logger.Log(0, "Message Queue shutting down")
 	logger.Log(0, "Message Queue shutting down")
 }
 }
 
 

+ 0 - 5
mq/mq.go

@@ -100,8 +100,3 @@ func Keepalive(ctx context.Context) {
 func IsConnected() bool {
 func IsConnected() bool {
 	return mqclient != nil && mqclient.IsConnected()
 	return mqclient != nil && mqclient.IsConnected()
 }
 }
-
-// CloseClient - function to close the mq connection from server
-func CloseClient() {
-	mqclient.Disconnect(250)
-}

+ 29 - 30
stun-server/stun-server.go

@@ -4,8 +4,11 @@ import (
 	"context"
 	"context"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"os"
+	"os/signal"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
+	"syscall"
 
 
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/servercfg"
 	"github.com/gravitl/netmaker/servercfg"
@@ -20,6 +23,7 @@ import (
 // backwards compatibility with RFC 3489.
 // backwards compatibility with RFC 3489.
 type Server struct {
 type Server struct {
 	Addr string
 	Addr string
+	Ctx  context.Context
 }
 }
 
 
 var (
 var (
@@ -56,58 +60,48 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error {
 	)
 	)
 }
 }
 
 
-func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message, ctx context.Context) error {
+func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error {
 	if c == nil {
 	if c == nil {
 		return nil
 		return nil
 	}
 	}
-	go func(ctx context.Context) {
-		<-ctx.Done()
-		if c != nil {
-			// kill connection on server shutdown
-			c.Close()
-		}
-	}(ctx)
-
 	buf := make([]byte, 1024)
 	buf := make([]byte, 1024)
-	n, addr, err := c.ReadFrom(buf) // this be blocky af
+	n, addr, err := c.ReadFrom(buf)
 	if err != nil {
 	if err != nil {
-		if !strings.Contains(err.Error(), "use of closed network connection") {
-			logger.Log(1, "STUN read error:", err.Error())
-		}
+		logger.Log(1, "ReadFrom: %v", err.Error())
 		return nil
 		return nil
 	}
 	}
-
 	if _, err = req.Write(buf[:n]); err != nil {
 	if _, err = req.Write(buf[:n]); err != nil {
-		logger.Log(1, "STUN write error:", err.Error())
+		logger.Log(1, "Write: %v", err.Error())
 		return err
 		return err
 	}
 	}
 	if err = basicProcess(addr, buf[:n], req, res); err != nil {
 	if err = basicProcess(addr, buf[:n], req, res); err != nil {
 		if err == errNotSTUNMessage {
 		if err == errNotSTUNMessage {
 			return nil
 			return nil
 		}
 		}
-		logger.Log(1, "STUN process error:", err.Error())
+		logger.Log(1, "basicProcess: %v", err.Error())
 		return nil
 		return nil
 	}
 	}
 	_, err = c.WriteTo(res.Raw, addr)
 	_, err = c.WriteTo(res.Raw, addr)
 	if err != nil {
 	if err != nil {
-		logger.Log(1, "STUN response write error", err.Error())
+		logger.Log(1, "WriteTo: %v", err.Error())
 	}
 	}
 	return err
 	return err
 }
 }
 
 
 // Serve reads packets from connections and responds to BINDING requests.
 // Serve reads packets from connections and responds to BINDING requests.
-func (s *Server) serve(c net.PacketConn, ctx context.Context) error {
+func (s *Server) serve(c net.PacketConn) error {
 	var (
 	var (
 		res = new(stun.Message)
 		res = new(stun.Message)
 		req = new(stun.Message)
 		req = new(stun.Message)
 	)
 	)
 	for {
 	for {
 		select {
 		select {
-		case <-ctx.Done():
-			logger.Log(0, "shut down STUN server")
+		case <-s.Ctx.Done():
+			logger.Log(0, "Shutting down stun server...")
+			c.Close()
 			return nil
 			return nil
 		default:
 		default:
-			if err := s.serveConn(c, res, req, ctx); err != nil {
+			if err := s.serveConn(c, res, req); err != nil {
 				logger.Log(1, "serve: %v", err.Error())
 				logger.Log(1, "serve: %v", err.Error())
 				continue
 				continue
 			}
 			}
@@ -125,8 +119,9 @@ func listenUDPAndServe(ctx context.Context, serverNet, laddr string) error {
 	}
 	}
 	s := &Server{
 	s := &Server{
 		Addr: laddr,
 		Addr: laddr,
+		Ctx:  ctx,
 	}
 	}
-	return s.serve(c, ctx)
+	return s.serve(c)
 }
 }
 
 
 func normalize(address string) string {
 func normalize(address string) string {
@@ -140,15 +135,19 @@ func normalize(address string) string {
 }
 }
 
 
 // Start - starts the stun server
 // Start - starts the stun server
-func Start(wg *sync.WaitGroup, ctx context.Context) {
-	defer wg.Done()
+func Start(wg *sync.WaitGroup) {
+	ctx, cancel := context.WithCancel(context.Background())
+	go func(wg *sync.WaitGroup) {
+		defer wg.Done()
+		quit := make(chan os.Signal, 1)
+		signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
+		<-quit
+		cancel()
+	}(wg)
 	normalized := normalize(fmt.Sprintf("0.0.0.0:%d", servercfg.GetStunPort()))
 	normalized := normalize(fmt.Sprintf("0.0.0.0:%d", servercfg.GetStunPort()))
 	logger.Log(0, "netmaker-stun listening on", normalized, "via udp")
 	logger.Log(0, "netmaker-stun listening on", normalized, "via udp")
-	if err := listenUDPAndServe(ctx, "udp", normalized); err != nil {
-		if strings.Contains(err.Error(), "closed network connection") {
-			logger.Log(0, "shutdown STUN server")
-		} else {
-			logger.Log(0, "server: ", err.Error())
-		}
+	err := listenUDPAndServe(ctx, "udp", normalized)
+	if err != nil {
+		logger.Log(0, "failed to start stun server: ", err.Error())
 	}
 	}
 }
 }