Parcourir la source

fixed failing tests
updated test helper functions to permit easier troubleshooting when tests fail

Matthew R Kasun il y a 3 ans
Parent
commit
db934a1ccd

+ 20 - 20
controllers/dns_test.go

@@ -13,8 +13,8 @@ import (
 func TestGetAllDNS(t *testing.T) {
 	database.InitializeDatabase()
 	deleteAllDNS(t)
-	deleteAllNetworks()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 	t.Run("NoEntries", func(t *testing.T) {
 		entries, err := logic.GetAllDNS()
 		assert.Nil(t, err)
@@ -39,8 +39,8 @@ func TestGetAllDNS(t *testing.T) {
 func TestGetNodeDNS(t *testing.T) {
 	database.InitializeDatabase()
 	deleteAllDNS(t)
-	deleteAllNetworks()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 	t.Run("NoNodes", func(t *testing.T) {
 		dns, err := logic.GetNodeDNS("skynet")
 		assert.EqualError(t, err, "could not find any records")
@@ -64,14 +64,14 @@ func TestGetNodeDNS(t *testing.T) {
 func TestGetCustomDNS(t *testing.T) {
 	database.InitializeDatabase()
 	deleteAllDNS(t)
-	deleteAllNetworks()
+	deleteAllNetworks(t)
 	t.Run("NoNetworks", func(t *testing.T) {
 		dns, err := logic.GetCustomDNS("skynet")
 		assert.EqualError(t, err, "could not find any records")
 		assert.Equal(t, []models.DNSEntry(nil), dns)
 	})
 	t.Run("NoNodes", func(t *testing.T) {
-		createNet()
+		createNet(t)
 		dns, err := logic.GetCustomDNS("skynet")
 		assert.EqualError(t, err, "could not find any records")
 		assert.Equal(t, []models.DNSEntry(nil), dns)
@@ -101,8 +101,8 @@ func TestGetCustomDNS(t *testing.T) {
 func TestGetDNSEntryNum(t *testing.T) {
 	database.InitializeDatabase()
 	deleteAllDNS(t)
-	deleteAllNetworks()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 	t.Run("NoNodes", func(t *testing.T) {
 		num, err := logic.GetDNSEntryNum("myhost", "skynet")
 		assert.Nil(t, err)
@@ -120,8 +120,8 @@ func TestGetDNSEntryNum(t *testing.T) {
 func TestGetDNS(t *testing.T) {
 	database.InitializeDatabase()
 	deleteAllDNS(t)
-	deleteAllNetworks()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 	t.Run("NoEntries", func(t *testing.T) {
 		dns, err := logic.GetDNS("skynet")
 		assert.Nil(t, err)
@@ -163,8 +163,8 @@ func TestGetDNS(t *testing.T) {
 func TestCreateDNS(t *testing.T) {
 	database.InitializeDatabase()
 	deleteAllDNS(t)
-	deleteAllNetworks()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 	entry := models.DNSEntry{"10.0.0.2", "newhost", "skynet"}
 	dns, err := CreateDNS(entry)
 	assert.Nil(t, err)
@@ -174,7 +174,7 @@ func TestCreateDNS(t *testing.T) {
 func TestSetDNS(t *testing.T) {
 	database.InitializeDatabase()
 	deleteAllDNS(t)
-	deleteAllNetworks()
+	deleteAllNetworks(t)
 	t.Run("NoNetworks", func(t *testing.T) {
 		err := logic.SetDNS()
 		assert.Nil(t, err)
@@ -184,7 +184,7 @@ func TestSetDNS(t *testing.T) {
 		assert.Equal(t, int64(0), info.Size())
 	})
 	t.Run("NoEntries", func(t *testing.T) {
-		createNet()
+		createNet(t)
 		err := logic.SetDNS()
 		assert.Nil(t, err)
 		info, err := os.Stat("./config/dnsconfig/netmaker.hosts")
@@ -221,8 +221,8 @@ func TestSetDNS(t *testing.T) {
 func TestGetDNSEntry(t *testing.T) {
 	database.InitializeDatabase()
 	deleteAllDNS(t)
-	deleteAllNetworks()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 	createTestNode()
 	entry := models.DNSEntry{"10.0.0.2", "newhost", "skynet"}
 	CreateDNS(entry)
@@ -278,8 +278,8 @@ func TestGetDNSEntry(t *testing.T) {
 func TestDeleteDNS(t *testing.T) {
 	database.InitializeDatabase()
 	deleteAllDNS(t)
-	deleteAllNetworks()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 	entry := models.DNSEntry{"10.0.0.2", "newhost", "skynet"}
 	CreateDNS(entry)
 	t.Run("EntryExists", func(t *testing.T) {
@@ -300,8 +300,8 @@ func TestDeleteDNS(t *testing.T) {
 func TestValidateDNSUpdate(t *testing.T) {
 	database.InitializeDatabase()
 	deleteAllDNS(t)
-	deleteAllNetworks()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 	entry := models.DNSEntry{"10.0.0.2", "myhost", "skynet"}
 	t.Run("BadNetwork", func(t *testing.T) {
 		change := models.DNSEntry{"10.0.0.2", "myhost", "badnet"}

+ 34 - 19
controllers/network_test.go

@@ -18,7 +18,7 @@ type NetworkValidationTestCase struct {
 
 func TestCreateNetwork(t *testing.T) {
 	database.InitializeDatabase()
-	deleteAllNetworks()
+	deleteAllNetworks(t)
 
 	var network models.Network
 	network.NetID = "skynet"
@@ -30,7 +30,8 @@ func TestCreateNetwork(t *testing.T) {
 }
 func TestGetNetwork(t *testing.T) {
 	database.InitializeDatabase()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 
 	t.Run("GetExistingNetwork", func(t *testing.T) {
 		network, err := logic.GetNetwork("skynet")
@@ -46,10 +47,8 @@ func TestGetNetwork(t *testing.T) {
 
 func TestDeleteNetwork(t *testing.T) {
 	database.InitializeDatabase()
-	createNet()
-	//create nodes
-	t.Run("NetworkwithNodes", func(t *testing.T) {
-	})
+	deleteAllNetworks(t)
+	createNet(t)
 	t.Run("DeleteExistingNetwork", func(t *testing.T) {
 		err := logic.DeleteNetwork("skynet")
 		assert.Nil(t, err)
@@ -58,12 +57,29 @@ func TestDeleteNetwork(t *testing.T) {
 		err := logic.DeleteNetwork("skynet")
 		assert.Nil(t, err)
 	})
+	createNet(t)
+	createTestNode()
+	t.Run("NetworkWithNodes", func(t *testing.T) {
+		err := logic.DeleteNetwork("skynet")
+		assert.Contains(t, err.Error(), "node check failed. All nodes must be deleted before deleting network")
+	})
+	t.Run("NetworkWithNoNodes", func(t *testing.T) {
+		nodes, err := logic.GetAllNodes()
+		assert.Nil(t, err)
+		for _, node := range nodes {
+			err := logic.DeleteNode(&node, true)
+			assert.Nil(t, err)
+		}
+		err = logic.DeleteNetwork("skynet")
+		assert.Nil(t, err)
+	})
+
 }
 
 func TestKeyUpdate(t *testing.T) {
 	t.Skip() //test is failing on last assert  --- not sure why
 	database.InitializeDatabase()
-	createNet()
+	createNet(t)
 	existing, err := logic.GetNetwork("skynet")
 	assert.Nil(t, err)
 	time.Sleep(time.Second * 1)
@@ -76,7 +92,7 @@ func TestKeyUpdate(t *testing.T) {
 
 func TestCreateKey(t *testing.T) {
 	database.InitializeDatabase()
-	createNet()
+	createNet(t)
 	keys, _ := logic.GetKeys("skynet")
 	for _, key := range keys {
 		logic.DeleteKey(key.Name, "skynet")
@@ -148,8 +164,8 @@ func TestCreateKey(t *testing.T) {
 
 func TestGetKeys(t *testing.T) {
 	database.InitializeDatabase()
-	deleteAllNetworks()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 	network, err := logic.GetNetwork("skynet")
 	assert.Nil(t, err)
 	var key models.AccessKey
@@ -171,7 +187,7 @@ func TestGetKeys(t *testing.T) {
 }
 func TestDeleteKey(t *testing.T) {
 	database.InitializeDatabase()
-	createNet()
+	createNet(t)
 	network, err := logic.GetNetwork("skynet")
 	assert.Nil(t, err)
 	var key models.AccessKey
@@ -332,21 +348,20 @@ func TestValidateNetworkUpdate(t *testing.T) {
 	}
 }
 
-func deleteAllNetworks() {
-	deleteAllNodes()
-	nets, _ := logic.GetNetworks()
-	for _, net := range nets {
-		logic.DeleteNetwork(net.NetID)
-	}
+func deleteAllNetworks(t *testing.T) {
+	deleteAllNodes(t)
+	err := database.DeleteAllRecords("networks")
+	assert.Nil(t, err)
 }
 
-func createNet() {
+func createNet(t *testing.T) {
 	var network models.Network
 	network.NetID = "skynet"
 	network.AddressRange = "10.0.0.1/24"
 	network.DisplayName = "mynetwork"
 	_, err := logic.GetNetwork("skynet")
 	if err != nil {
-		logic.CreateNetwork(network)
+		err := logic.CreateNetwork(network)
+		assert.Nil(t, err)
 	}
 }

+ 11 - 13
controllers/node_test.go

@@ -14,8 +14,8 @@ func TestCreateEgressGateway(t *testing.T) {
 	gateway.Interface = "eth0"
 	gateway.Ranges = []string{"10.100.100.0/24"}
 	database.InitializeDatabase()
-	deleteAllNetworks()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 	t.Run("NoNodes", func(t *testing.T) {
 		node, err := logic.CreateEgressGateway(gateway)
 		assert.Equal(t, models.Node{}, node)
@@ -36,8 +36,8 @@ func TestCreateEgressGateway(t *testing.T) {
 func TestDeleteEgressGateway(t *testing.T) {
 	var gateway models.EgressGatewayRequest
 	database.InitializeDatabase()
-	deleteAllNetworks()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 	createTestNode()
 	testnode := createTestNode()
 	gateway.Interface = "eth0"
@@ -79,8 +79,8 @@ func TestDeleteEgressGateway(t *testing.T) {
 
 func TestGetNetworkNodes(t *testing.T) {
 	database.InitializeDatabase()
-	deleteAllNetworks()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 	t.Run("BadNet", func(t *testing.T) {
 		node, err := logic.GetNetworkNodes("badnet")
 		assert.Nil(t, err)
@@ -102,8 +102,8 @@ func TestGetNetworkNodes(t *testing.T) {
 }
 func TestUncordonNode(t *testing.T) {
 	database.InitializeDatabase()
-	deleteAllNetworks()
-	createNet()
+	deleteAllNetworks(t)
+	createNet(t)
 	node := createTestNode()
 	t.Run("BadNet", func(t *testing.T) {
 		resp, err := logic.UncordonNode("badnet", node.MacAddress)
@@ -144,11 +144,9 @@ func TestValidateEgressGateway(t *testing.T) {
 	})
 }
 
-func deleteAllNodes() {
-	nodes, _ := logic.GetAllNodes()
-	for _, node := range nodes {
-		logic.DeleteNode(&node, true)
-	}
+func deleteAllNodes(t *testing.T) {
+	err := database.DeleteAllRecords("nodes")
+	assert.Nil(t, err)
 }
 
 func createTestNode() *models.Node {

+ 5 - 5
logic/gateway.go

@@ -44,10 +44,10 @@ func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, erro
 			postDownCmd = node.PostDown + "; " + postDownCmd
 		}
 	}
-	key, err := GetRecordKey(gateway.NodeID, gateway.NetID)
-	if err != nil {
-		return node, err
-	}
+	//key, err := GetRecordKey(gateway.NodeID, gateway.NetID)
+	//if err != nil {
+	//	return node, err
+	//}
 	node.PostUp = postUpCmd
 	node.PostDown = postDownCmd
 	node.SetLastModified()
@@ -56,7 +56,7 @@ func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, erro
 	if err != nil {
 		return node, err
 	}
-	if err = database.Insert(key, string(nodeData), database.NODES_TABLE_NAME); err != nil {
+	if err = database.Insert(node.ID, string(nodeData), database.NODES_TABLE_NAME); err != nil {
 		return models.Node{}, err
 	}
 	if err = NetworkNodesUpdatePullChanges(node.Network); err != nil {

+ 12 - 14
logic/util.go

@@ -44,29 +44,26 @@ func SetNetworkServerPeers(node *models.Node) {
 // DeleteNode - deletes a node from database or moves into delete nodes table
 func DeleteNode(node *models.Node, exterminate bool) error {
 	var err error
-	node.SetID()
-	var key = node.ID
 	if !exterminate {
-		args := strings.Split(key, "###")
-		node, err := GetNode(args[0], args[1])
-		if err != nil {
-			return err
-		}
 		node.Action = models.NODE_DELETE
 		nodedata, err := json.Marshal(&node)
 		if err != nil {
 			return err
 		}
-		err = database.Insert(key, string(nodedata), database.DELETED_NODES_TABLE_NAME)
+		err = database.Insert(node.ID, string(nodedata), database.DELETED_NODES_TABLE_NAME)
 		if err != nil {
 			return err
 		}
 	} else {
-		if err := database.DeleteRecord(database.DELETED_NODES_TABLE_NAME, key); err != nil {
+		if err := database.DeleteRecord(database.DELETED_NODES_TABLE_NAME, node.ID); err != nil {
 			logger.Log(2, err.Error())
 		}
 	}
-	if err = database.DeleteRecord(database.NODES_TABLE_NAME, key); err != nil {
+	if err = database.DeleteRecord(database.NODES_TABLE_NAME, node.ID); err != nil {
+		return err
+	}
+	macnet := node.MacAddress + "###" + node.Network
+	if err = database.DeleteRecord(database.UUID_MAP_TABLE_NAME, macnet); err != nil {
 		return err
 	}
 	if servercfg.IsDNSMode() {
@@ -118,18 +115,19 @@ func CreateNode(node *models.Node) error {
 	if err != nil {
 		return err
 	}
-	uuid, err := GetUUID(key)
+	node.SetID()
+	nodebytes, err := json.Marshal(&node)
 	if err != nil {
 		return err
 	}
-	nodebytes, err := json.Marshal(&node)
+	uuidbytes, err := json.Marshal(&node.ID)
 	if err != nil {
 		return err
 	}
-	if err = database.Insert(key, uuid, database.UUID_MAP_TABLE_NAME); err != nil {
+	if err = database.Insert(key, string(uuidbytes), database.UUID_MAP_TABLE_NAME); err != nil {
 		return err
 	}
-	if err = database.Insert(uuid, string(nodebytes), database.NODES_TABLE_NAME); err != nil {
+	if err = database.Insert(node.ID, string(nodebytes), database.NODES_TABLE_NAME); err != nil {
 		return err
 	}
 	if node.IsPending != "yes" {

+ 7 - 1
netclient/ncutils/netclientutils.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"io/fs"
 	"log"
 	"math/rand"
 	"net"
@@ -385,7 +386,7 @@ func RunCmds(commands []string, printerr bool) error {
 // FileExists - checks if file exists locally
 func FileExists(f string) bool {
 	info, err := os.Stat(f)
-	if os.IsNotExist(err) {
+	if errors.Is(err, fs.ErrNotExist) {
 		return false
 	}
 	if err != nil && strings.Contains(err.Error(), "not a directory") {
@@ -394,6 +395,11 @@ func FileExists(f string) bool {
 	if err != nil {
 		Log("error reading file: " + f + ", " + err.Error())
 	}
+	//needed to prevent panic accessing info.IsDir if true
+	if errors.Is(err, fs.ErrPermission) {
+		Log("error reading file: " + f + " " + err.Error())
+		return false
+	}
 	return !info.IsDir()
 }