Browse Source

refactor validation for node creation

Matthew R Kasun 4 years ago
parent
commit
646f613b93
5 changed files with 375 additions and 255 deletions
  1. 2 48
      controllers/common.go
  2. 144 23
      controllers/common_test.go
  3. 173 183
      controllers/nodeGrpcController.go
  4. 1 1
      controllers/nodeHttpController.go
  5. 55 0
      models/node.go

+ 2 - 48
controllers/common.go

@@ -59,61 +59,15 @@ func GetPeersList(networkName string) ([]models.PeersResponse, error) {
 }
 }
 
 
 func ValidateNodeCreate(networkName string, node models.Node) error {
 func ValidateNodeCreate(networkName string, node models.Node) error {
-
 	v := validator.New()
 	v := validator.New()
-	_ = v.RegisterValidation("address_check", func(fl validator.FieldLevel) bool {
-		isIpv4 := functions.IsIpNet(node.Address)
-		empty := node.Address == ""
-		return (empty || isIpv4)
-	})
-	_ = v.RegisterValidation("address6_check", func(fl validator.FieldLevel) bool {
-		isIpv6 := functions.IsIpNet(node.Address6)
-		empty := node.Address6 == ""
-		return (empty || isIpv6)
-	})
-	_ = v.RegisterValidation("endpoint_check", func(fl validator.FieldLevel) bool {
-		//var isFieldUnique bool = functions.IsFieldUnique(networkName, "endpoint", node.Endpoint)
-		isIp := functions.IsIpNet(node.Endpoint)
-		notEmptyCheck := node.Endpoint != ""
-		return (notEmptyCheck && isIp)
-	})
-	_ = v.RegisterValidation("localaddress_check", func(fl validator.FieldLevel) bool {
-		//var isFieldUnique bool = functions.IsFieldUnique(networkName, "endpoint", node.Endpoint)
-		isIp := functions.IsIpNet(node.LocalAddress)
-		empty := node.LocalAddress == ""
-		return (empty || isIp)
-	})
-
 	_ = v.RegisterValidation("macaddress_unique", func(fl validator.FieldLevel) bool {
 	_ = v.RegisterValidation("macaddress_unique", func(fl validator.FieldLevel) bool {
 		var isFieldUnique bool = functions.IsFieldUnique(networkName, "macaddress", node.MacAddress)
 		var isFieldUnique bool = functions.IsFieldUnique(networkName, "macaddress", node.MacAddress)
 		return isFieldUnique
 		return isFieldUnique
 	})
 	})
-
-	_ = v.RegisterValidation("macaddress_valid", func(fl validator.FieldLevel) bool {
-		_, err := net.ParseMAC(node.MacAddress)
-		return err == nil
-	})
-
-	_ = v.RegisterValidation("name_valid", func(fl validator.FieldLevel) bool {
-		isvalid := functions.NameInNodeCharSet(node.Name)
-		return isvalid
-	})
-
 	_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
 	_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
 		_, err := node.GetNetwork()
 		_, err := node.GetNetwork()
 		return err == nil
 		return err == nil
 	})
 	})
-	_ = v.RegisterValidation("pubkey_check", func(fl validator.FieldLevel) bool {
-		notEmptyCheck := node.PublicKey != ""
-		isBase64 := functions.IsBase64(node.PublicKey)
-		return (notEmptyCheck && isBase64)
-	})
-	_ = v.RegisterValidation("password_check", func(fl validator.FieldLevel) bool {
-		notEmptyCheck := node.Password != ""
-		goodLength := len(node.Password) > 5
-		return (notEmptyCheck && goodLength)
-	})
-
 	err := v.Struct(node)
 	err := v.Struct(node)
 
 
 	if err != nil {
 	if err != nil {
@@ -124,7 +78,7 @@ func ValidateNodeCreate(networkName string, node models.Node) error {
 	return err
 	return err
 }
 }
 
 
-func ValidateNodeUpdate(networkName string, node models.Node) error {
+func ValidateNodeUpdate(networkName string, node models.NodeUpdate) error {
 
 
 	v := validator.New()
 	v := validator.New()
 	_ = v.RegisterValidation("address_check", func(fl validator.FieldLevel) bool {
 	_ = v.RegisterValidation("address_check", func(fl validator.FieldLevel) bool {
@@ -188,7 +142,7 @@ func ValidateNodeUpdate(networkName string, node models.Node) error {
 	return err
 	return err
 }
 }
 
 
-func UpdateNode(nodechange models.Node, node models.Node) (models.Node, error) {
+func UpdateNode(nodechange models.NodeUpdate, node models.Node) (models.Node, error) {
 	//Question: Is there a better way  of doing  this than a bunch of "if" statements? probably...
 	//Question: Is there a better way  of doing  this than a bunch of "if" statements? probably...
 	//Eventually, lets have a better way to check if any of the fields are filled out...
 	//Eventually, lets have a better way to check if any of the fields are filled out...
 	queryMac := node.MacAddress
 	queryMac := node.MacAddress

+ 144 - 23
controllers/common_test.go

@@ -13,6 +13,12 @@ type NodeValidationTC struct {
 	errorMessage string
 	errorMessage string
 }
 }
 
 
+type NodeValidationUpdateTC struct {
+	testname     string
+	node         models.NodeUpdate
+	errorMessage string
+}
+
 func TestCreateNode(t *testing.T) {
 func TestCreateNode(t *testing.T) {
 }
 }
 func TestDeleteNode(t *testing.T) {
 func TestDeleteNode(t *testing.T) {
@@ -43,28 +49,28 @@ func TestValidateNodeCreate(t *testing.T) {
 			node: models.Node{
 			node: models.Node{
 				Address: "256.0.0.1",
 				Address: "256.0.0.1",
 			},
 			},
-			errorMessage: "Field validation for 'Address' failed on the 'address_check' tag",
+			errorMessage: "Field validation for 'Address' failed on the 'ipv4' tag",
 		},
 		},
 		NodeValidationTC{
 		NodeValidationTC{
 			testname: "BadAddress6",
 			testname: "BadAddress6",
 			node: models.Node{
 			node: models.Node{
 				Address6: "2607::abcd:efgh::1",
 				Address6: "2607::abcd:efgh::1",
 			},
 			},
-			errorMessage: "Field validation for 'Address6' failed on the 'address6_check' tag",
+			errorMessage: "Field validation for 'Address6' failed on the 'ipv6' tag",
 		},
 		},
 		NodeValidationTC{
 		NodeValidationTC{
 			testname: "BadLocalAddress",
 			testname: "BadLocalAddress",
 			node: models.Node{
 			node: models.Node{
 				LocalAddress: "10.0.200.300",
 				LocalAddress: "10.0.200.300",
 			},
 			},
-			errorMessage: "Field validation for 'LocalAddress' failed on the 'localaddress_check' tag",
+			errorMessage: "Field validation for 'LocalAddress' failed on the 'ip' tag",
 		},
 		},
 		NodeValidationTC{
 		NodeValidationTC{
 			testname: "InvalidName",
 			testname: "InvalidName",
 			node: models.Node{
 			node: models.Node{
 				Name: "mynode*",
 				Name: "mynode*",
 			},
 			},
-			errorMessage: "Field validation for 'Name' failed on the 'name_valid' tag",
+			errorMessage: "Field validation for 'Name' failed on the 'alphanum' tag",
 		},
 		},
 		NodeValidationTC{
 		NodeValidationTC{
 			testname: "NameTooLong",
 			testname: "NameTooLong",
@@ -88,18 +94,32 @@ func TestValidateNodeCreate(t *testing.T) {
 			errorMessage: "Field validation for 'ListenPort' failed on the 'max' tag",
 			errorMessage: "Field validation for 'ListenPort' failed on the 'max' tag",
 		},
 		},
 		NodeValidationTC{
 		NodeValidationTC{
-			testname: "PublicKeyInvalid",
+			testname: "PublicKeyEmpty",
 			node: models.Node{
 			node: models.Node{
 				PublicKey: "",
 				PublicKey: "",
 			},
 			},
-			errorMessage: "Field validation for 'PublicKey' failed on the 'pubkey_check' tag",
+			errorMessage: "Field validation for 'PublicKey' failed on the 'required' tag",
+		},
+		NodeValidationTC{
+			testname: "PublicKeyInvalid",
+			node: models.Node{
+				PublicKey: "junk%key",
+			},
+			errorMessage: "Field validation for 'PublicKey' failed on the 'base64' tag",
 		},
 		},
 		NodeValidationTC{
 		NodeValidationTC{
 			testname: "EndpointInvalid",
 			testname: "EndpointInvalid",
 			node: models.Node{
 			node: models.Node{
 				Endpoint: "10.2.0.300",
 				Endpoint: "10.2.0.300",
 			},
 			},
-			errorMessage: "Field validation for 'Endpoint' failed on the 'endpoint_check' tag",
+			errorMessage: "Field validation for 'Endpoint' failed on the 'ip' tag",
+		},
+		NodeValidationTC{
+			testname: "EndpointEmpty",
+			node: models.Node{
+				Endpoint: "",
+			},
+			errorMessage: "Field validation for 'Endpoint' failed on the 'required' tag",
 		},
 		},
 		NodeValidationTC{
 		NodeValidationTC{
 			testname: "PersistentKeepaliveMax",
 			testname: "PersistentKeepaliveMax",
@@ -113,7 +133,7 @@ func TestValidateNodeCreate(t *testing.T) {
 			node: models.Node{
 			node: models.Node{
 				MacAddress: "01:02:03:04:05",
 				MacAddress: "01:02:03:04:05",
 			},
 			},
-			errorMessage: "Field validation for 'MacAddress' failed on the 'macaddress_valid' tag",
+			errorMessage: "Field validation for 'MacAddress' failed on the 'mac' tag",
 		},
 		},
 		NodeValidationTC{
 		NodeValidationTC{
 			testname: "MacAddressMissing",
 			testname: "MacAddressMissing",
@@ -127,14 +147,14 @@ func TestValidateNodeCreate(t *testing.T) {
 			node: models.Node{
 			node: models.Node{
 				Password: "",
 				Password: "",
 			},
 			},
-			errorMessage: "Field validation for 'Password' failed on the 'password_check' tag",
+			errorMessage: "Field validation for 'Password' failed on the 'required' tag",
 		},
 		},
 		NodeValidationTC{
 		NodeValidationTC{
 			testname: "ShortPassword",
 			testname: "ShortPassword",
 			node: models.Node{
 			node: models.Node{
 				Password: "1234",
 				Password: "1234",
 			},
 			},
-			errorMessage: "Field validation for 'Password' failed on the 'password_check' tag",
+			errorMessage: "Field validation for 'Password' failed on the 'min' tag",
 		},
 		},
 		NodeValidationTC{
 		NodeValidationTC{
 			testname: "NoNetwork",
 			testname: "NoNetwork",
@@ -170,18 +190,119 @@ func TestValidateNodeCreate(t *testing.T) {
 }
 }
 func TestValidateNodeUpdate(t *testing.T) {
 func TestValidateNodeUpdate(t *testing.T) {
 	//cases
 	//cases
-	t.Run("BlankAddress", func(t *testing.T) {
-	})
-	t.Run("BlankAddress6", func(t *testing.T) {
-	})
-	t.Run("Blank", func(t *testing.T) {
-	})
+	cases := []NodeValidationUpdateTC{
+		NodeValidationUpdateTC{
+			testname: "BadAddress",
+			node: models.NodeUpdate{
+				Address: "256.0.0.1",
+			},
+			errorMessage: "Field validation for 'Address' failed on the 'address_check' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "BadAddress6",
+			node: models.NodeUpdate{
+				Address6: "2607::abcd:efgh::1",
+			},
+			errorMessage: "Field validation for 'Address6' failed on the 'address6_check' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "BadLocalAddress",
+			node: models.NodeUpdate{
+				LocalAddress: "10.0.200.300",
+			},
+			errorMessage: "Field validation for 'LocalAddress' failed on the 'localaddress_check' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "InvalidName",
+			node: models.NodeUpdate{
+				Name: "mynode*",
+			},
+			errorMessage: "Field validation for 'Name' failed on the 'name_valid' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "NameTooLong",
+			node: models.NodeUpdate{
+				Name: "mynodexmynode",
+			},
+			errorMessage: "Field validation for 'Name' failed on the 'max' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "ListenPortMin",
+			node: models.NodeUpdate{
+				ListenPort: 1023,
+			},
+			errorMessage: "Field validation for 'ListenPort' failed on the 'min' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "ListenPortMax",
+			node: models.NodeUpdate{
+				ListenPort: 65536,
+			},
+			errorMessage: "Field validation for 'ListenPort' failed on the 'max' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "PublicKeyInvalid",
+			node: models.NodeUpdate{
+				PublicKey: "",
+			},
+			errorMessage: "Field validation for 'PublicKey' failed on the 'pubkey_check' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "EndpointInvalid",
+			node: models.NodeUpdate{
+				Endpoint: "10.2.0.300",
+			},
+			errorMessage: "Field validation for 'Endpoint' failed on the 'endpoint_check' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "PersistentKeepaliveMax",
+			node: models.NodeUpdate{
+				PersistentKeepalive: 1001,
+			},
+			errorMessage: "Field validation for 'PersistentKeepalive' failed on the 'max' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "MacAddressInvalid",
+			node: models.NodeUpdate{
+				MacAddress: "01:02:03:04:05",
+			},
+			errorMessage: "Field validation for 'MacAddress' failed on the 'macaddress_valid' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "MacAddressMissing",
+			node: models.NodeUpdate{
+				MacAddress: "",
+			},
+			errorMessage: "Field validation for 'MacAddress' failed on the 'required' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "EmptyPassword",
+			node: models.NodeUpdate{
+				Password: "",
+			},
+			errorMessage: "Field validation for 'Password' failed on the 'password_check' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "ShortPassword",
+			node: models.NodeUpdate{
+				Password: "1234",
+			},
+			errorMessage: "Field validation for 'Password' failed on the 'password_check' tag",
+		},
+		NodeValidationUpdateTC{
+			testname: "NoNetwork",
+			node: models.NodeUpdate{
+				Network: "badnet",
+			},
+			errorMessage: "Field validation for 'Network' failed on the 'network_exists' tag",
+		},
+	}
+	for _, tc := range cases {
+		t.Run(tc.testname, func(t *testing.T) {
+			err := ValidateNodeUpdate("skynet", tc.node)
+			assert.NotNil(t, err)
+			assert.Contains(t, err.Error(), tc.errorMessage)
+		})
+	}
 
 
-	//	for _, tc := range cases {
-	//		t.Run(tc.testname, func(t *testing.T) {
-	//			err := ValidateNodeUpdate(tc.node)
-	//			assert.NotNil(t, err)
-	//			assert.Contains(t, err.Error(), tc.errorMessage)
-	//		})
-	//  }
 }
 }

+ 173 - 183
controllers/nodeGrpcController.go

@@ -1,12 +1,13 @@
 package controller
 package controller
 
 
 import (
 import (
-        "context"
+	"context"
 	"fmt"
 	"fmt"
 	"strconv"
 	"strconv"
+
+	"github.com/gravitl/netmaker/functions"
 	nodepb "github.com/gravitl/netmaker/grpc"
 	nodepb "github.com/gravitl/netmaker/grpc"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
-	"github.com/gravitl/netmaker/functions"
 	"go.mongodb.org/mongo-driver/mongo"
 	"go.mongodb.org/mongo-driver/mongo"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
 	"google.golang.org/grpc/status"
@@ -15,12 +16,12 @@ import (
 type NodeServiceServer struct {
 type NodeServiceServer struct {
 	NodeDB *mongo.Collection
 	NodeDB *mongo.Collection
 	nodepb.UnimplementedNodeServiceServer
 	nodepb.UnimplementedNodeServiceServer
-
 }
 }
+
 func (s *NodeServiceServer) ReadNode(ctx context.Context, req *nodepb.ReadNodeReq) (*nodepb.ReadNodeRes, error) {
 func (s *NodeServiceServer) ReadNode(ctx context.Context, req *nodepb.ReadNodeReq) (*nodepb.ReadNodeRes, error) {
 	// convert string id (from proto) to mongoDB ObjectId
 	// convert string id (from proto) to mongoDB ObjectId
 	macaddress := req.GetMacaddress()
 	macaddress := req.GetMacaddress()
-        networkName := req.GetNetwork()
+	networkName := req.GetNetwork()
 	network, _ := functions.GetParentNetwork(networkName)
 	network, _ := functions.GetParentNetwork(networkName)
 
 
 	node, err := GetNode(macaddress, networkName)
 	node, err := GetNode(macaddress, networkName)
@@ -30,31 +31,30 @@ func (s *NodeServiceServer) ReadNode(ctx context.Context, req *nodepb.ReadNodeRe
 	}
 	}
 
 
 	/*
 	/*
-	if node == nil {
-		return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not find node with Mac Address %s: %v", req.GetMacaddress(), err))
-	}
+		if node == nil {
+			return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not find node with Mac Address %s: %v", req.GetMacaddress(), err))
+		}
 	*/
 	*/
 	// Cast to ReadNodeRes type
 	// Cast to ReadNodeRes type
 	response := &nodepb.ReadNodeRes{
 	response := &nodepb.ReadNodeRes{
 		Node: &nodepb.Node{
 		Node: &nodepb.Node{
-			Macaddress: node.MacAddress,
-			Name:    node.Name,
-			Address:  node.Address,
-			Endpoint:  node.Endpoint,
-			Password:  node.Password,
-			Nodenetwork:  node.Network,
-			Interface:  node.Interface,
-			Localaddress:  node.LocalAddress,
-			Postdown:  node.PostDown,
-			Postup:  node.PostUp,
-			Checkininterval:  node.CheckInInterval,
-			Ispending:  node.IsPending,
-			Publickey:  node.PublicKey,
-			Listenport:  node.ListenPort,
-			Keepalive:  node.PersistentKeepalive,
-                        Islocal:  *network.IsLocal,
-                        Localrange:  network.LocalRange,
-
+			Macaddress:      node.MacAddress,
+			Name:            node.Name,
+			Address:         node.Address,
+			Endpoint:        node.Endpoint,
+			Password:        node.Password,
+			Nodenetwork:     node.Network,
+			Interface:       node.Interface,
+			Localaddress:    node.LocalAddress,
+			Postdown:        node.PostDown,
+			Postup:          node.PostUp,
+			Checkininterval: node.CheckInInterval,
+			Ispending:       node.IsPending,
+			Publickey:       node.PublicKey,
+			Listenport:      node.ListenPort,
+			Keepalive:       node.PersistentKeepalive,
+			Islocal:         *network.IsLocal,
+			Localrange:      network.LocalRange,
 		},
 		},
 	}
 	}
 	return response, nil
 	return response, nil
@@ -67,54 +67,52 @@ func (s *NodeServiceServer) CreateNode(ctx context.Context, req *nodepb.CreateNo
 	// Now we have to convert this into a NodeItem type to convert into BSON
 	// Now we have to convert this into a NodeItem type to convert into BSON
 	node := models.Node{
 	node := models.Node{
 		// ID:       primitive.NilObjectID,
 		// ID:       primitive.NilObjectID,
-                        MacAddress: data.GetMacaddress(),
-                        LocalAddress: data.GetLocaladdress(),
-                        Name:    data.GetName(),
-                        Address:  data.GetAddress(),
-                        AccessKey:  data.GetAccesskey(),
-                        Endpoint:  data.GetEndpoint(),
-                        PersistentKeepalive:  data.GetKeepalive(),
-                        Password:  data.GetPassword(),
-                        Interface:  data.GetInterface(),
-                        Network:  data.GetNodenetwork(),
-                        IsPending:  data.GetIspending(),
-                        PublicKey:  data.GetPublickey(),
-                        ListenPort:  data.GetListenport(),
+		MacAddress:          data.GetMacaddress(),
+		LocalAddress:        data.GetLocaladdress(),
+		Name:                data.GetName(),
+		Address:             data.GetAddress(),
+		AccessKey:           data.GetAccesskey(),
+		Endpoint:            data.GetEndpoint(),
+		PersistentKeepalive: data.GetKeepalive(),
+		Password:            data.GetPassword(),
+		Interface:           data.GetInterface(),
+		Network:             data.GetNodenetwork(),
+		IsPending:           data.GetIspending(),
+		PublicKey:           data.GetPublickey(),
+		ListenPort:          data.GetListenport(),
 	}
 	}
 
 
-        err := ValidateNodeCreate(node.Network, node)
+	err := ValidateNodeCreate(node.Network, node)
 
 
-        if err != nil {
-                // return internal gRPC error to be handled later
-                return nil, err
-        }
+	if err != nil {
+		// return internal gRPC error to be handled later
+		return nil, err
+	}
 
 
-        //Check to see if key is valid
-        //TODO: Triple inefficient!!! This is the third call to the DB we make for networks
-        validKey := functions.IsKeyValid(node.Network, node.AccessKey)
-        network, err := functions.GetParentNetwork(node.Network)
-        if err != nil {
-                return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not find network: %v", err))
-        } else {
+	//Check to see if key is valid
+	//TODO: Triple inefficient!!! This is the third call to the DB we make for networks
+	validKey := functions.IsKeyValid(node.Network, node.AccessKey)
+	network, err := functions.GetParentNetwork(node.Network)
+	if err != nil {
+		return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not find network: %v", err))
+	} else {
 		fmt.Println("Creating node in network " + network.NetID)
 		fmt.Println("Creating node in network " + network.NetID)
 		fmt.Println("Network is local? " + strconv.FormatBool(*network.IsLocal))
 		fmt.Println("Network is local? " + strconv.FormatBool(*network.IsLocal))
 		fmt.Println("Range if local: " + network.LocalRange)
 		fmt.Println("Range if local: " + network.LocalRange)
 	}
 	}
 
 
-
-
-        if !validKey {
-                //Check to see if network will allow manual sign up
-                //may want to switch this up with the valid key check and avoid a DB call that way.
-                if *network.AllowManualSignUp {
-                        node.IsPending = true
-                } else  {
-	                return nil, status.Errorf(
-		                codes.Internal,
+	if !validKey {
+		//Check to see if network will allow manual sign up
+		//may want to switch this up with the valid key check and avoid a DB call that way.
+		if *network.AllowManualSignUp {
+			node.IsPending = true
+		} else {
+			return nil, status.Errorf(
+				codes.Internal,
 				fmt.Sprintf("Invalid key, and network does not allow no-key signups"),
 				fmt.Sprintf("Invalid key, and network does not allow no-key signups"),
 			)
 			)
-                }
-        }
+		}
+	}
 
 
 	node, err = CreateNode(node, node.Network)
 	node, err = CreateNode(node, node.Network)
 
 
@@ -128,118 +126,114 @@ func (s *NodeServiceServer) CreateNode(ctx context.Context, req *nodepb.CreateNo
 	// return the node in a CreateNodeRes type
 	// return the node in a CreateNodeRes type
 	response := &nodepb.CreateNodeRes{
 	response := &nodepb.CreateNodeRes{
 		Node: &nodepb.Node{
 		Node: &nodepb.Node{
-                        Macaddress: node.MacAddress,
-                        Localaddress: node.LocalAddress,
-                        Name:    node.Name,
-                        Address:  node.Address,
-                        Endpoint:  node.Endpoint,
-                        Password:  node.Password,
-                        Interface:  node.Interface,
-                        Nodenetwork:  node.Network,
-                        Ispending:  node.IsPending,
-                        Publickey:  node.PublicKey,
-                        Listenport:  node.ListenPort,
-                        Keepalive:  node.PersistentKeepalive,
-                        Islocal:  *network.IsLocal,
-                        Localrange:  network.LocalRange,
+			Macaddress:   node.MacAddress,
+			Localaddress: node.LocalAddress,
+			Name:         node.Name,
+			Address:      node.Address,
+			Endpoint:     node.Endpoint,
+			Password:     node.Password,
+			Interface:    node.Interface,
+			Nodenetwork:  node.Network,
+			Ispending:    node.IsPending,
+			Publickey:    node.PublicKey,
+			Listenport:   node.ListenPort,
+			Keepalive:    node.PersistentKeepalive,
+			Islocal:      *network.IsLocal,
+			Localrange:   network.LocalRange,
 		},
 		},
 	}
 	}
-        err = SetNetworkNodesLastModified(node.Network)
-        if err != nil {
-                return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not update network last modified date: %v", err))
-        }
+	err = SetNetworkNodesLastModified(node.Network)
+	if err != nil {
+		return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not update network last modified date: %v", err))
+	}
 
 
 	return response, nil
 	return response, nil
 }
 }
 
 
 func (s *NodeServiceServer) CheckIn(ctx context.Context, req *nodepb.CheckInReq) (*nodepb.CheckInRes, error) {
 func (s *NodeServiceServer) CheckIn(ctx context.Context, req *nodepb.CheckInReq) (*nodepb.CheckInRes, error) {
 	// Get the protobuf node type from the protobuf request type
 	// Get the protobuf node type from the protobuf request type
-        // Essentially doing req.Node to access the struct with a nil check
+	// Essentially doing req.Node to access the struct with a nil check
 	data := req.GetNode()
 	data := req.GetNode()
 	//postchanges := req.GetPostchanges()
 	//postchanges := req.GetPostchanges()
 	// Now we have to convert this into a NodeItem type to convert into BSON
 	// Now we have to convert this into a NodeItem type to convert into BSON
-        node := models.Node{
-                // ID:       primitive.NilObjectID,
-                        MacAddress: data.GetMacaddress(),
-                        Address:  data.GetAddress(),
-                        Endpoint:  data.GetEndpoint(),
-                        Network:  data.GetNodenetwork(),
-                        Password:  data.GetPassword(),
-                        LocalAddress:  data.GetLocaladdress(),
-                        ListenPort:  data.GetListenport(),
-                        PersistentKeepalive:  data.GetKeepalive(),
-                        PublicKey:  data.GetPublickey(),
-        }
+	node := models.Node{
+		// ID:       primitive.NilObjectID,
+		MacAddress:          data.GetMacaddress(),
+		Address:             data.GetAddress(),
+		Endpoint:            data.GetEndpoint(),
+		Network:             data.GetNodenetwork(),
+		Password:            data.GetPassword(),
+		LocalAddress:        data.GetLocaladdress(),
+		ListenPort:          data.GetListenport(),
+		PersistentKeepalive: data.GetKeepalive(),
+		PublicKey:           data.GetPublickey(),
+	}
 
 
 	checkinresponse, err := NodeCheckIn(node, node.Network)
 	checkinresponse, err := NodeCheckIn(node, node.Network)
 
 
-        if err != nil {
-                // return internal gRPC error to be handled later
+	if err != nil {
+		// return internal gRPC error to be handled later
 		if checkinresponse == (models.CheckInResponse{}) || !checkinresponse.IsPending {
 		if checkinresponse == (models.CheckInResponse{}) || !checkinresponse.IsPending {
-                return nil, status.Errorf(
-                        codes.Internal,
-                        fmt.Sprintf("Internal error: %v", err),
-                )
+			return nil, status.Errorf(
+				codes.Internal,
+				fmt.Sprintf("Internal error: %v", err),
+			)
 		}
 		}
-        }
-        // return the node in a CreateNodeRes type
-        response := &nodepb.CheckInRes{
-                Checkinresponse: &nodepb.CheckInResponse{
-                        Success:  checkinresponse.Success,
-                        Needpeerupdate:  checkinresponse.NeedPeerUpdate,
-                        Needdelete:  checkinresponse.NeedDelete,
-                        Needconfigupdate:  checkinresponse.NeedConfigUpdate,
-                        Needkeyupdate:  checkinresponse.NeedKeyUpdate,
-                        Nodemessage:  checkinresponse.NodeMessage,
-                        Ispending:  checkinresponse.IsPending,
-                },
-        }
-        return response, nil
+	}
+	// return the node in a CreateNodeRes type
+	response := &nodepb.CheckInRes{
+		Checkinresponse: &nodepb.CheckInResponse{
+			Success:          checkinresponse.Success,
+			Needpeerupdate:   checkinresponse.NeedPeerUpdate,
+			Needdelete:       checkinresponse.NeedDelete,
+			Needconfigupdate: checkinresponse.NeedConfigUpdate,
+			Needkeyupdate:    checkinresponse.NeedKeyUpdate,
+			Nodemessage:      checkinresponse.NodeMessage,
+			Ispending:        checkinresponse.IsPending,
+		},
+	}
+	return response, nil
 }
 }
 
 
-
 func (s *NodeServiceServer) UpdateNode(ctx context.Context, req *nodepb.UpdateNodeReq) (*nodepb.UpdateNodeRes, error) {
 func (s *NodeServiceServer) UpdateNode(ctx context.Context, req *nodepb.UpdateNodeReq) (*nodepb.UpdateNodeRes, error) {
 	// Get the node data from the request
 	// Get the node data from the request
-        data := req.GetNode()
-        // Now we have to convert this into a NodeItem type to convert into BSON
-        nodechange := models.Node{
-                // ID:       primitive.NilObjectID,
-                        MacAddress: data.GetMacaddress(),
-                        Name:    data.GetName(),
-                        Address:  data.GetAddress(),
-                        LocalAddress:  data.GetLocaladdress(),
-                        Endpoint:  data.GetEndpoint(),
-                        Password:  data.GetPassword(),
-                        PersistentKeepalive:  data.GetKeepalive(),
-                        Network:  data.GetNodenetwork(),
-                        Interface:  data.GetInterface(),
-                        PostDown:  data.GetPostdown(),
-                        PostUp:  data.GetPostup(),
-                        IsPending:  data.GetIspending(),
-                        PublicKey:  data.GetPublickey(),
-                        ListenPort:  data.GetListenport(),
-        }
-
+	data := req.GetNode()
+	// Now we have to convert this into a NodeItem type to convert into BSON
+	nodechange := models.NodeUpdate{
+		// ID:       primitive.NilObjectID,
+		MacAddress:          data.GetMacaddress(),
+		Name:                data.GetName(),
+		Address:             data.GetAddress(),
+		LocalAddress:        data.GetLocaladdress(),
+		Endpoint:            data.GetEndpoint(),
+		Password:            data.GetPassword(),
+		PersistentKeepalive: data.GetKeepalive(),
+		Network:             data.GetNodenetwork(),
+		Interface:           data.GetInterface(),
+		PostDown:            data.GetPostdown(),
+		PostUp:              data.GetPostup(),
+		IsPending:           data.GetIspending(),
+		PublicKey:           data.GetPublickey(),
+		ListenPort:          data.GetListenport(),
+	}
 
 
 	// Convert the Id string to a MongoDB ObjectId
 	// Convert the Id string to a MongoDB ObjectId
 	macaddress := nodechange.MacAddress
 	macaddress := nodechange.MacAddress
 	networkName := nodechange.Network
 	networkName := nodechange.Network
-        network, _ := functions.GetParentNetwork(networkName)
-
+	network, _ := functions.GetParentNetwork(networkName)
 
 
 	err := ValidateNodeUpdate(networkName, nodechange)
 	err := ValidateNodeUpdate(networkName, nodechange)
-        if err != nil {
-                return nil, err
-        }
-
-        node, err := functions.GetNodeByMacAddress(networkName, macaddress)
-        if err != nil {
-               return nil, status.Errorf(
-                        codes.NotFound,
-                        fmt.Sprintf("Could not find node with supplied Mac Address: %v", err),
-                )
+	if err != nil {
+		return nil, err
 	}
 	}
 
 
+	node, err := functions.GetNodeByMacAddress(networkName, macaddress)
+	if err != nil {
+		return nil, status.Errorf(
+			codes.NotFound,
+			fmt.Sprintf("Could not find node with supplied Mac Address: %v", err),
+		)
+	}
 
 
 	newnode, err := UpdateNode(nodechange, node)
 	newnode, err := UpdateNode(nodechange, node)
 
 
@@ -251,23 +245,22 @@ func (s *NodeServiceServer) UpdateNode(ctx context.Context, req *nodepb.UpdateNo
 	}
 	}
 	return &nodepb.UpdateNodeRes{
 	return &nodepb.UpdateNodeRes{
 		Node: &nodepb.Node{
 		Node: &nodepb.Node{
-                        Macaddress: newnode.MacAddress,
-                        Localaddress: newnode.LocalAddress,
-                        Name:    newnode.Name,
-                        Address:  newnode.Address,
-                        Endpoint:  newnode.Endpoint,
-                        Password:  newnode.Password,
-                        Interface:  newnode.Interface,
-                        Postdown:  newnode.PostDown,
-                        Postup:  newnode.PostUp,
-                        Nodenetwork:  newnode.Network,
-                        Ispending:  newnode.IsPending,
-                        Publickey:  newnode.PublicKey,
-                        Listenport:  newnode.ListenPort,
-                        Keepalive:  newnode.PersistentKeepalive,
-                        Islocal:  *network.IsLocal,
-                        Localrange:  network.LocalRange,
-
+			Macaddress:   newnode.MacAddress,
+			Localaddress: newnode.LocalAddress,
+			Name:         newnode.Name,
+			Address:      newnode.Address,
+			Endpoint:     newnode.Endpoint,
+			Password:     newnode.Password,
+			Interface:    newnode.Interface,
+			Postdown:     newnode.PostDown,
+			Postup:       newnode.PostUp,
+			Nodenetwork:  newnode.Network,
+			Ispending:    newnode.IsPending,
+			Publickey:    newnode.PublicKey,
+			Listenport:   newnode.ListenPort,
+			Keepalive:    newnode.PersistentKeepalive,
+			Islocal:      *network.IsLocal,
+			Localrange:   network.LocalRange,
 		},
 		},
 	}, nil
 	}, nil
 }
 }
@@ -287,12 +280,11 @@ func (s *NodeServiceServer) DeleteNode(ctx context.Context, req *nodepb.DeleteNo
 
 
 	fmt.Println("updating network last modified of " + req.GetNetworkName())
 	fmt.Println("updating network last modified of " + req.GetNetworkName())
 	err = SetNetworkNodesLastModified(req.GetNetworkName())
 	err = SetNetworkNodesLastModified(req.GetNetworkName())
-        if err != nil {
+	if err != nil {
 		fmt.Println("Error updating Network")
 		fmt.Println("Error updating Network")
 		fmt.Println(err)
 		fmt.Println(err)
 		return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not update network last modified date: %v", err))
 		return nil, status.Errorf(codes.NotFound, fmt.Sprintf("Could not update network last modified date: %v", err))
-        }
-
+	}
 
 
 	return &nodepb.DeleteNodeRes{
 	return &nodepb.DeleteNodeRes{
 		Success: true,
 		Success: true,
@@ -310,34 +302,32 @@ func (s *NodeServiceServer) GetPeers(req *nodepb.GetPeersReq, stream nodepb.Node
 		return status.Errorf(codes.Internal, fmt.Sprintf("Unknown internal error: %v", err))
 		return status.Errorf(codes.Internal, fmt.Sprintf("Unknown internal error: %v", err))
 	}
 	}
 	// cursor.Next() returns a boolean, if false there are no more items and loop will break
 	// cursor.Next() returns a boolean, if false there are no more items and loop will break
-        for i := 0; i < len(peers); i++ {
+	for i := 0; i < len(peers); i++ {
 
 
 		// If no error is found send node over stream
 		// If no error is found send node over stream
 		stream.Send(&nodepb.GetPeersRes{
 		stream.Send(&nodepb.GetPeersRes{
 			Peers: &nodepb.PeersResponse{
 			Peers: &nodepb.PeersResponse{
-                            Address:  peers[i].Address,
-                            Endpoint:  peers[i].Endpoint,
-                            Gatewayrange:  peers[i].GatewayRange,
-                            Isgateway:  peers[i].IsGateway,
-                            Publickey:  peers[i].PublicKey,
-                            Keepalive:  peers[i].KeepAlive,
-                            Listenport:  peers[i].ListenPort,
-                            Localaddress:  peers[i].LocalAddress,
+				Address:      peers[i].Address,
+				Endpoint:     peers[i].Endpoint,
+				Gatewayrange: peers[i].GatewayRange,
+				Isgateway:    peers[i].IsGateway,
+				Publickey:    peers[i].PublicKey,
+				Keepalive:    peers[i].KeepAlive,
+				Listenport:   peers[i].ListenPort,
+				Localaddress: peers[i].LocalAddress,
 			},
 			},
 		})
 		})
 	}
 	}
 
 
 	node, err := functions.GetNodeByMacAddress(req.GetNetwork(), req.GetMacaddress())
 	node, err := functions.GetNodeByMacAddress(req.GetNetwork(), req.GetMacaddress())
-       if err != nil {
-                return status.Errorf(codes.Internal, fmt.Sprintf("Could not get node: %v", err))
-        }
-
+	if err != nil {
+		return status.Errorf(codes.Internal, fmt.Sprintf("Could not get node: %v", err))
+	}
 
 
 	err = TimestampNode(node, false, true, false)
 	err = TimestampNode(node, false, true, false)
-        if err != nil {
-                return status.Errorf(codes.Internal, fmt.Sprintf("Internal error occurred: %v", err))
-        }
-
+	if err != nil {
+		return status.Errorf(codes.Internal, fmt.Sprintf("Internal error occurred: %v", err))
+	}
 
 
 	return nil
 	return nil
 }
 }

+ 1 - 1
controllers/nodeHttpController.go

@@ -689,7 +689,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
 		return
 		return
 	}
 	}
 
 
-	var nodechange models.Node
+	var nodechange models.NodeUpdate
 
 
 	// we decode our body request params
 	// we decode our body request params
 	_ = json.NewDecoder(r.Body).Decode(&nodechange)
 	_ = json.NewDecoder(r.Body).Decode(&nodechange)

+ 55 - 0
models/node.go

@@ -18,6 +18,38 @@ var seededRand *rand.Rand = rand.New(
 
 
 //node struct
 //node struct
 type Node struct {
 type Node struct {
+	ID                  primitive.ObjectID `json:"_id,omitempty" bson:"_id,omitempty"`
+	Address             string             `json:"address" bson:"address" validate:"omitempty,ipv4"`
+	Address6            string             `json:"address6" bson:"address6" validate:"omitempty,ipv6"`
+	LocalAddress        string             `json:"localaddress" bson:"localaddress" validate:"omitempty,ip"`
+	Name                string             `json:"name" bson:"name" validate:"omitempty,alphanum,max=12"`
+	ListenPort          int32              `json:"listenport" bson:"listenport" validate:"omitempty,numeric,min=1024,max=65535"`
+	PublicKey           string             `json:"publickey" bson:"publickey" validate:"required,base64"`
+	Endpoint            string             `json:"endpoint" bson:"endpoint" validate:"required,ip"`
+	PostUp              string             `json:"postup" bson:"postup"`
+	PostDown            string             `json:"postdown" bson:"postdown"`
+	AllowedIPs          string             `json:"allowedips" bson:"allowedips"`
+	PersistentKeepalive int32              `json:"persistentkeepalive" bson:"persistentkeepalive" validate:"omitempty,numeric,max=1000"`
+	SaveConfig          *bool              `json:"saveconfig" bson:"saveconfig"`
+	AccessKey           string             `json:"accesskey" bson:"accesskey"`
+	Interface           string             `json:"interface" bson:"interface"`
+	LastModified        int64              `json:"lastmodified" bson:"lastmodified"`
+	KeyUpdateTimeStamp  int64              `json:"keyupdatetimestamp" bson:"keyupdatetimestamp"`
+	ExpirationDateTime  int64              `json:"expdatetime" bson:"expdatetime"`
+	LastPeerUpdate      int64              `json:"lastpeerupdate" bson:"lastpeerupdate"`
+	LastCheckIn         int64              `json:"lastcheckin" bson:"lastcheckin"`
+	MacAddress          string             `json:"macaddress" bson:"macaddress" validate:"required,mac,macaddress_unique"`
+	CheckInInterval     int32              `json:"checkininterval" bson:"checkininterval"`
+	Password            string             `json:"password" bson:"password" validate:"required,min=6"`
+	Network             string             `json:"network" bson:"network" validate:"network_exists"`
+	IsPending           bool               `json:"ispending" bson:"ispending"`
+	IsGateway           bool               `json:"isgateway" bson:"isgateway"`
+	GatewayRange        string             `json:"gatewayrange" bson:"gatewayrange"`
+	PostChanges         string             `json:"postchanges" bson:"postchanges"`
+}
+
+//node update struct --- only validations are different
+type NodeUpdate struct {
 	ID                  primitive.ObjectID `json:"_id,omitempty" bson:"_id,omitempty"`
 	ID                  primitive.ObjectID `json:"_id,omitempty" bson:"_id,omitempty"`
 	Address             string             `json:"address" bson:"address" validate:"address_check"`
 	Address             string             `json:"address" bson:"address" validate:"address_check"`
 	Address6            string             `json:"address6" bson:"address6" validate:"address6_check"`
 	Address6            string             `json:"address6" bson:"address6" validate:"address6_check"`
@@ -48,6 +80,29 @@ type Node struct {
 	PostChanges         string             `json:"postchanges" bson:"postchanges"`
 	PostChanges         string             `json:"postchanges" bson:"postchanges"`
 }
 }
 
 
+//Duplicated function for NodeUpdates
+func (node *NodeUpdate) GetNetwork() (Network, error) {
+
+	var network Network
+
+	collection := mongoconn.NetworkDB
+	//collection := mongoconn.Client.Database("netmaker").Collection("networks")
+
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+
+	filter := bson.M{"netid": node.Network}
+	err := collection.FindOne(ctx, filter).Decode(&network)
+
+	defer cancel()
+
+	if err != nil {
+		//log.Fatal(err)
+		return network, err
+	}
+
+	return network, err
+}
+
 //TODO: Contains a fatal error return. Need to change
 //TODO: Contains a fatal error return. Need to change
 //Used in contexts where it's not the Parent network.
 //Used in contexts where it's not the Parent network.
 func (node *Node) GetNetwork() (Network, error) {
 func (node *Node) GetNetwork() (Network, error) {