Browse Source

Merge pull request #2026 from gravitl/gra-1172-zombies

updates to zombie processing
dcarns 2 years ago
parent
commit
73fbdfea0d

+ 0 - 12
controllers/dns_test.go

@@ -6,7 +6,6 @@ import (
 	"testing"
 
 	"github.com/google/uuid"
-	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"github.com/stretchr/testify/assert"
@@ -16,7 +15,6 @@ import (
 var dnsHost models.Host
 
 func TestGetAllDNS(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllDNS(t)
 	deleteAllNetworks()
 	createNet()
@@ -47,7 +45,6 @@ func TestGetAllDNS(t *testing.T) {
 }
 
 func TestGetNodeDNS(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllDNS(t)
 	deleteAllNetworks()
 	createNet()
@@ -94,7 +91,6 @@ func TestGetNodeDNS(t *testing.T) {
 	})
 }
 func TestGetCustomDNS(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllDNS(t)
 	deleteAllNetworks()
 	t.Run("NoNetworks", func(t *testing.T) {
@@ -133,7 +129,6 @@ func TestGetCustomDNS(t *testing.T) {
 }
 
 func TestGetDNSEntryNum(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllDNS(t)
 	deleteAllNetworks()
 	createNet()
@@ -152,7 +147,6 @@ func TestGetDNSEntryNum(t *testing.T) {
 	})
 }
 func TestGetDNS(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllDNS(t)
 	deleteAllNetworks()
 	createNet()
@@ -196,7 +190,6 @@ func TestGetDNS(t *testing.T) {
 }
 
 func TestCreateDNS(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllDNS(t)
 	deleteAllNetworks()
 	createNet()
@@ -207,7 +200,6 @@ func TestCreateDNS(t *testing.T) {
 }
 
 func TestSetDNS(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllDNS(t)
 	deleteAllNetworks()
 	t.Run("NoNetworks", func(t *testing.T) {
@@ -255,7 +247,6 @@ func TestSetDNS(t *testing.T) {
 }
 
 func TestGetDNSEntry(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllDNS(t)
 	deleteAllNetworks()
 	createNet()
@@ -285,7 +276,6 @@ func TestGetDNSEntry(t *testing.T) {
 }
 
 func TestDeleteDNS(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllDNS(t)
 	deleteAllNetworks()
 	createNet()
@@ -307,7 +297,6 @@ func TestDeleteDNS(t *testing.T) {
 }
 
 func TestValidateDNSUpdate(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllDNS(t)
 	deleteAllNetworks()
 	createNet()
@@ -369,7 +358,6 @@ func TestValidateDNSUpdate(t *testing.T) {
 
 }
 func TestValidateDNSCreate(t *testing.T) {
-	database.InitializeDatabase()
 	_ = logic.DeleteDNS("mynode", "skynet")
 	t.Run("NoNetwork", func(t *testing.T) {
 		entry := models.DNSEntry{"10.0.0.2", "", "myhost", "badnet"}

+ 22 - 24
controllers/network_test.go

@@ -1,11 +1,13 @@
 package controller
 
 import (
+	"context"
 	"os"
 	"testing"
 
 	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"github.com/stretchr/testify/assert"
@@ -20,8 +22,27 @@ type NetworkValidationTestCase struct {
 
 var netHost models.Host
 
+func TestMain(m *testing.M) {
+	database.InitializeDatabase()
+	defer database.CloseDB()
+	logic.CreateAdmin(&models.User{
+		UserName: "admin",
+		Password: "password",
+		IsAdmin:  true,
+		Networks: []string{},
+		Groups:   []string{},
+	})
+	peerUpdate := make(chan *models.Node)
+	go logic.ManageZombies(context.Background(), peerUpdate)
+	go func() {
+		for update := range peerUpdate {
+			//do nothing
+			logger.Log(3, "received node update", update.Action)
+		}
+	}()
+}
+
 func TestCreateNetwork(t *testing.T) {
-	initialize()
 	deleteAllNetworks()
 
 	var network models.Network
@@ -34,7 +55,6 @@ func TestCreateNetwork(t *testing.T) {
 	assert.Nil(t, err)
 }
 func TestGetNetwork(t *testing.T) {
-	initialize()
 	createNet()
 
 	t.Run("GetExistingNetwork", func(t *testing.T) {
@@ -50,7 +70,6 @@ func TestGetNetwork(t *testing.T) {
 }
 
 func TestDeleteNetwork(t *testing.T) {
-	initialize()
 	createNet()
 	//create nodes
 	t.Run("NetworkwithNodes", func(t *testing.T) {
@@ -66,7 +85,6 @@ func TestDeleteNetwork(t *testing.T) {
 }
 
 func TestCreateKey(t *testing.T) {
-	initialize()
 	createNet()
 	keys, _ := logic.GetKeys("skynet")
 	for _, key := range keys {
@@ -138,7 +156,6 @@ func TestCreateKey(t *testing.T) {
 }
 
 func TestGetKeys(t *testing.T) {
-	initialize()
 	deleteAllNetworks()
 	createNet()
 	network, err := logic.GetNetwork("skynet")
@@ -161,7 +178,6 @@ func TestGetKeys(t *testing.T) {
 	})
 }
 func TestDeleteKey(t *testing.T) {
-	initialize()
 	createNet()
 	network, err := logic.GetNetwork("skynet")
 	assert.Nil(t, err)
@@ -183,7 +199,6 @@ func TestDeleteKey(t *testing.T) {
 func TestSecurityCheck(t *testing.T) {
 	//these seem to work but not sure it the tests are really testing the functionality
 
-	initialize()
 	os.Setenv("MASTER_KEY", "secretkey")
 	t.Run("NoNetwork", func(t *testing.T) {
 		networks, username, err := logic.UserPermissions(false, "", "Bearer secretkey")
@@ -214,7 +229,6 @@ func TestValidateNetwork(t *testing.T) {
 	//t.Skip()
 	//This functions is not called by anyone
 	//it panics as validation function 'display_name_valid' is not defined
-	initialize()
 	//yes := true
 	//no := false
 	//deleteNet(t)
@@ -291,7 +305,6 @@ func TestValidateNetwork(t *testing.T) {
 func TestIpv6Network(t *testing.T) {
 	//these seem to work but not sure it the tests are really testing the functionality
 
-	initialize()
 	os.Setenv("MASTER_KEY", "secretkey")
 	deleteAllNetworks()
 	createNet()
@@ -318,21 +331,6 @@ func deleteAllNetworks() {
 	}
 }
 
-func initialize() {
-	database.InitializeDatabase()
-	createAdminUser()
-}
-
-func createAdminUser() {
-	logic.CreateAdmin(&models.User{
-		UserName: "admin",
-		Password: "password",
-		IsAdmin:  true,
-		Networks: []string{},
-		Groups:   []string{},
-	})
-}
-
 func createNet() {
 	var network models.Network
 	network.NetID = "skynet"

+ 0 - 3
controllers/node_test.go

@@ -21,7 +21,6 @@ func TestCreateEgressGateway(t *testing.T) {
 	var gateway models.EgressGatewayRequest
 	gateway.Ranges = []string{"10.100.100.0/24"}
 	gateway.NetID = "skynet"
-	database.InitializeDatabase()
 	deleteAllNetworks()
 	createNet()
 	t.Run("NoNodes", func(t *testing.T) {
@@ -78,7 +77,6 @@ func TestCreateEgressGateway(t *testing.T) {
 }
 func TestDeleteEgressGateway(t *testing.T) {
 	var gateway models.EgressGatewayRequest
-	database.InitializeDatabase()
 	deleteAllNetworks()
 	createNet()
 	testnode := createTestNode()
@@ -110,7 +108,6 @@ func TestDeleteEgressGateway(t *testing.T) {
 }
 
 func TestGetNetworkNodes(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllNetworks()
 	createNet()
 	t.Run("BadNet", func(t *testing.T) {

+ 1 - 11
controllers/user_test.go

@@ -3,7 +3,6 @@ package controller
 import (
 	"testing"
 
-	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"github.com/stretchr/testify/assert"
@@ -18,7 +17,6 @@ func deleteAllUsers() {
 
 func TestHasAdmin(t *testing.T) {
 	//delete all current users
-	database.InitializeDatabase()
 	users, _ := logic.GetUsers()
 	for _, user := range users {
 		success, err := logic.DeleteUser(user.UserName)
@@ -48,7 +46,7 @@ func TestHasAdmin(t *testing.T) {
 	})
 	t.Run("multiple admins", func(t *testing.T) {
 		var user = models.User{"admin1", "password", nil, true, nil}
-		 err := logic.CreateUser(&user)
+		err := logic.CreateUser(&user)
 		assert.Nil(t, err)
 		found, err := logic.HasAdmin()
 		assert.Nil(t, err)
@@ -57,7 +55,6 @@ func TestHasAdmin(t *testing.T) {
 }
 
 func TestCreateUser(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllUsers()
 	user := models.User{"admin", "password", nil, true, nil}
 	t.Run("NoUser", func(t *testing.T) {
@@ -72,7 +69,6 @@ func TestCreateUser(t *testing.T) {
 }
 
 func TestCreateAdmin(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllUsers()
 	var user models.User
 	t.Run("NoAdmin", func(t *testing.T) {
@@ -90,7 +86,6 @@ func TestCreateAdmin(t *testing.T) {
 }
 
 func TestDeleteUser(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllUsers()
 	t.Run("NonExistent User", func(t *testing.T) {
 		deleted, err := logic.DeleteUser("admin")
@@ -107,7 +102,6 @@ func TestDeleteUser(t *testing.T) {
 }
 
 func TestValidateUser(t *testing.T) {
-	database.InitializeDatabase()
 	var user models.User
 	t.Run("Valid Create", func(t *testing.T) {
 		user.UserName = "admin"
@@ -155,7 +149,6 @@ func TestValidateUser(t *testing.T) {
 }
 
 func TestGetUser(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllUsers()
 	t.Run("NonExistantUser", func(t *testing.T) {
 		admin, err := logic.GetUser("admin")
@@ -172,7 +165,6 @@ func TestGetUser(t *testing.T) {
 }
 
 func TestGetUsers(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllUsers()
 	t.Run("NonExistantUser", func(t *testing.T) {
 		admin, err := logic.GetUsers()
@@ -203,7 +195,6 @@ func TestGetUsers(t *testing.T) {
 }
 
 func TestUpdateUser(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllUsers()
 	user := models.User{"admin", "password", nil, true, nil}
 	newuser := models.User{"hello", "world", []string{"wirecat, netmaker"}, true, []string{}}
@@ -246,7 +237,6 @@ func TestUpdateUser(t *testing.T) {
 // }
 
 func TestVerifyAuthRequest(t *testing.T) {
-	database.InitializeDatabase()
 	deleteAllUsers()
 	var authRequest models.UserAuthParams
 	t.Run("EmptyUserName", func(t *testing.T) {

+ 22 - 8
functions/helpers_test.go

@@ -1,10 +1,12 @@
 package functions
 
 import (
+	"context"
 	"encoding/json"
 	"testing"
 
 	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 )
@@ -19,11 +21,27 @@ var (
 	}
 )
 
+func TestMain(m *testing.M) {
+	database.InitializeDatabase()
+	defer database.CloseDB()
+	logic.CreateAdmin(&models.User{
+		UserName: "admin",
+		Password: "password",
+		IsAdmin:  true,
+		Networks: []string{},
+		Groups:   []string{},
+	})
+	peerUpdate := make(chan *models.Node)
+	go logic.ManageZombies(context.Background(), peerUpdate)
+	go func() {
+		for update := range peerUpdate {
+			//do nothing
+			logger.Log(3, "received node update", update.Action)
+		}
+	}()
+}
+
 func TestNetworkExists(t *testing.T) {
-	err := database.InitializeDatabase()
-	if err != nil {
-		t.Fatalf("error initilizing database: %s", err)
-	}
 	database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID)
 	defer database.CloseDB()
 	exists, err := logic.NetworkExists(testNetwork.NetID)
@@ -53,10 +71,6 @@ func TestNetworkExists(t *testing.T) {
 }
 
 func TestGetAllExtClients(t *testing.T) {
-	err := database.InitializeDatabase()
-	if err != nil {
-		t.Fatalf("error initilizing database: %s", err)
-	}
 	defer database.CloseDB()
 	database.DeleteRecord(database.EXT_CLIENT_TABLE_NAME, testExternalClient.ClientID)
 

+ 22 - 1
logic/host_test.go

@@ -1,17 +1,38 @@
 package logic
 
 import (
+	"context"
 	"net"
 	"testing"
 
 	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 	"github.com/matryer/is"
 )
 
-func TestCheckPorts(t *testing.T) {
+func TestMain(m *testing.M) {
 	database.InitializeDatabase()
+	defer database.CloseDB()
+	CreateAdmin(&models.User{
+		UserName: "admin",
+		Password: "password",
+		IsAdmin:  true,
+		Networks: []string{},
+		Groups:   []string{},
+	})
+	peerUpdate := make(chan *models.Node)
+	go ManageZombies(context.Background(), peerUpdate)
+	go func() {
+		for update := range peerUpdate {
+			//do nothing
+			logger.Log(3, "received node update", update.Action)
+		}
+	}()
+}
+
+func TestCheckPorts(t *testing.T) {
 	h := models.Host{
 		ID:              uuid.New(),
 		EndpointIP:      net.ParseIP("192.168.1.1"),

+ 1 - 0
logic/hosts.go

@@ -96,6 +96,7 @@ func CreateHost(h *models.Host) error {
 		return err
 	}
 	h.HostPass = string(hash)
+	checkForZombieHosts(h)
 	return UpsertHost(h)
 }
 

+ 1 - 1
logic/nodes.go

@@ -534,7 +534,7 @@ func createNode(node *models.Node) error {
 	if err != nil {
 		return err
 	}
-	CheckZombies(node, host.MacAddress)
+	CheckZombies(node)
 
 	nodebytes, err := json.Marshal(&node)
 	if err != nil {

+ 5 - 1
logic/pro/networkuser_test.go

@@ -10,8 +10,12 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
-func TestNetworkUserLogic(t *testing.T) {
+func TestMain(m *testing.M) {
 	database.InitializeDatabase()
+	defer database.CloseDB()
+}
+
+func TestNetworkUserLogic(t *testing.T) {
 	networkUser := promodels.NetworkUser{
 		ID: "helloworld",
 	}

+ 0 - 2
logic/pro/usergroups_test.go

@@ -3,13 +3,11 @@ package pro
 import (
 	"testing"
 
-	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/models/promodels"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestUserGroupLogic(t *testing.T) {
-	database.InitializeDatabase()
 
 	t.Run("User Groups initialized successfully", func(t *testing.T) {
 		err := InitializeGroups()

+ 57 - 23
logic/zombie.go

@@ -2,7 +2,6 @@ package logic
 
 import (
 	"context"
-	"net"
 	"time"
 
 	"github.com/google/uuid"
@@ -18,15 +17,16 @@ const (
 )
 
 var (
-	zombies      []uuid.UUID
-	removeZombie chan uuid.UUID = make(chan (uuid.UUID), 10)
-	newZombie    chan uuid.UUID = make(chan (uuid.UUID), 10)
+	zombies       []uuid.UUID
+	hostZombies   []uuid.UUID
+	newZombie     chan uuid.UUID = make(chan (uuid.UUID), 10)
+	newHostZombie chan uuid.UUID = make(chan (uuid.UUID), 10)
 )
 
-// CheckZombies - checks if new node has same macaddress as existing node
+// CheckZombies - checks if new node has same hostid as existing node
 // if so, existing node is added to zombie node quarantine list
 // also cleans up nodes past their expiration date
-func CheckZombies(newnode *models.Node, mac net.HardwareAddr) {
+func CheckZombies(newnode *models.Node) {
 	nodes, err := GetNetworkNodes(newnode.Network)
 	if err != nil {
 		logger.Log(1, "Failed to retrieve network nodes", newnode.Network, err.Error())
@@ -44,6 +44,35 @@ func CheckZombies(newnode *models.Node, mac net.HardwareAddr) {
 	}
 }
 
+// checkForZombieHosts - checks if new host has the same macAddress as an existing host
+// if true, existing host is added to host zombie collection
+func checkForZombieHosts(h *models.Host) {
+	hosts, err := GetAllHosts()
+	if err != nil {
+		logger.Log(3, "errror retrieving all hosts", err.Error())
+	}
+	for _, existing := range hosts {
+		if existing.ID == h.ID {
+			//probably an unnecessary check as new host should not be in database yet, but just in case
+			//skip self
+			continue
+		}
+		if existing.MacAddress.String() == h.MacAddress.String() {
+			//add to hostZombies
+			newHostZombie <- existing.ID
+			//add all nodes belonging to host to zombile list
+			for _, node := range existing.Nodes {
+				id, err := uuid.Parse(node)
+				if err != nil {
+					logger.Log(3, "error parsing uuid from host.Nodes", err.Error())
+					continue
+				}
+				newHostZombie <- id
+			}
+		}
+	}
+}
+
 // ManageZombies - goroutine which adds/removes/deletes nodes from the zombie node quarantine list
 func ManageZombies(ctx context.Context, peerUpdate chan *models.Node) {
 	logger.Log(2, "Zombie management started")
@@ -51,24 +80,12 @@ func ManageZombies(ctx context.Context, peerUpdate chan *models.Node) {
 	for {
 		select {
 		case <-ctx.Done():
+			close(peerUpdate)
 			return
 		case id := <-newZombie:
-			logger.Log(1, "adding", id.String(), "to zombie quaratine list")
 			zombies = append(zombies, id)
-		case id := <-removeZombie:
-			found := false
-			if len(zombies) > 0 {
-				for i := len(zombies) - 1; i >= 0; i-- {
-					if zombies[i] == id {
-						logger.Log(1, "removing zombie from quaratine list", zombies[i].String())
-						zombies = append(zombies[:i], zombies[i+1:]...)
-						found = true
-					}
-				}
-			}
-			if !found {
-				logger.Log(3, "no zombies found")
-			}
+		case id := <-newHostZombie:
+			hostZombies = append(hostZombies, id)
 		case <-time.After(time.Second * ZOMBIE_TIMEOUT):
 			logger.Log(3, "checking for zombie nodes")
 			if len(zombies) > 0 {
@@ -92,6 +109,23 @@ func ManageZombies(ctx context.Context, peerUpdate chan *models.Node) {
 					}
 				}
 			}
+			if len(hostZombies) > 0 {
+				logger.Log(3, "checking host zombies")
+				for i := len(hostZombies) - 1; i >= 0; i-- {
+					host, err := GetHost(hostZombies[i].String())
+					if err != nil {
+						logger.Log(1, "error retrieving zombie host", err.Error())
+						logger.Log(1, "deleting ", host.ID.String(), " from zombie list")
+						zombies = append(zombies[:i], zombies[i+1:]...)
+						continue
+					}
+					if len(host.Nodes) == 0 {
+						if err := RemoveHost(host); err != nil {
+							logger.Log(0, "error deleting zombie host", host.ID.String(), err.Error())
+						}
+					}
+				}
+			}
 		}
 	}
 }
@@ -115,10 +149,10 @@ func InitializeZombies() {
 			}
 			if node.HostID == othernode.HostID {
 				if node.LastCheckIn.After(othernode.LastCheckIn) {
-					zombies = append(zombies, othernode.ID)
+					newZombie <- othernode.ID
 					logger.Log(1, "adding", othernode.ID.String(), "to zombie list")
 				} else {
-					zombies = append(zombies, node.ID)
+					newZombie <- node.ID
 					logger.Log(1, "adding", node.ID.String(), "to zombie list")
 				}
 			}

+ 1 - 1
models/network_test.go

@@ -2,7 +2,7 @@ package models
 
 // moved from controllers need work
 //func TestUpdateNetwork(t *testing.T) {
-//	database.InitializeDatabase()
+//	initialize()
 //	createNet()
 //	network := getNet()
 //	t.Run("NetID", func(t *testing.T) {