Browse Source

refactoring for ee

afeiszli 2 years ago
parent
commit
b670755cce

+ 0 - 1
controllers/controller.go

@@ -25,7 +25,6 @@ var HttpHandlers = []interface{}{
 	serverHandlers,
 	extClientHandlers,
 	ipHandlers,
-	metricHandlers,
 	loggerHandlers,
 	userGroupsHandlers,
 	networkUsersHandlers,

+ 17 - 17
controllers/dns.go

@@ -16,13 +16,13 @@ import (
 
 func dnsHandlers(r *mux.Router) {
 
-	r.HandleFunc("/api/dns", securityCheck(true, http.HandlerFunc(getAllDNS))).Methods("GET")
-	r.HandleFunc("/api/dns/adm/{network}/nodes", securityCheck(false, http.HandlerFunc(getNodeDNS))).Methods("GET")
-	r.HandleFunc("/api/dns/adm/{network}/custom", securityCheck(false, http.HandlerFunc(getCustomDNS))).Methods("GET")
-	r.HandleFunc("/api/dns/adm/{network}", securityCheck(false, http.HandlerFunc(getDNS))).Methods("GET")
-	r.HandleFunc("/api/dns/{network}", securityCheck(false, http.HandlerFunc(createDNS))).Methods("POST")
-	r.HandleFunc("/api/dns/adm/pushdns", securityCheck(false, http.HandlerFunc(pushDNS))).Methods("POST")
-	r.HandleFunc("/api/dns/{network}/{domain}", securityCheck(false, http.HandlerFunc(deleteDNS))).Methods("DELETE")
+	r.HandleFunc("/api/dns", logic.SecurityCheck(true, http.HandlerFunc(getAllDNS))).Methods("GET")
+	r.HandleFunc("/api/dns/adm/{network}/nodes", logic.SecurityCheck(false, http.HandlerFunc(getNodeDNS))).Methods("GET")
+	r.HandleFunc("/api/dns/adm/{network}/custom", logic.SecurityCheck(false, http.HandlerFunc(getCustomDNS))).Methods("GET")
+	r.HandleFunc("/api/dns/adm/{network}", logic.SecurityCheck(false, http.HandlerFunc(getDNS))).Methods("GET")
+	r.HandleFunc("/api/dns/{network}", logic.SecurityCheck(false, http.HandlerFunc(createDNS))).Methods("POST")
+	r.HandleFunc("/api/dns/adm/pushdns", logic.SecurityCheck(false, http.HandlerFunc(pushDNS))).Methods("POST")
+	r.HandleFunc("/api/dns/{network}/{domain}", logic.SecurityCheck(false, http.HandlerFunc(deleteDNS))).Methods("DELETE")
 }
 
 // swagger:route GET /api/dns/adm/{network}/nodes dns getNodeDNS
@@ -44,7 +44,7 @@ func getNodeDNS(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to get node DNS entries for network [%s]: %v", network, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	w.WriteHeader(http.StatusOK)
@@ -68,7 +68,7 @@ func getAllDNS(w http.ResponseWriter, r *http.Request) {
 	dns, err := logic.GetAllDNS()
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "failed to get all DNS entries: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	w.WriteHeader(http.StatusOK)
@@ -98,7 +98,7 @@ func getCustomDNS(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to get custom DNS entries for network [%s]: %v", network, err.Error()))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	w.WriteHeader(http.StatusOK)
@@ -128,7 +128,7 @@ func getDNS(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to get all DNS entries for network [%s]: %v", network, err.Error()))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	w.WriteHeader(http.StatusOK)
@@ -160,7 +160,7 @@ func createDNS(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("invalid DNS entry %+v: %v", entry, err))
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
@@ -168,14 +168,14 @@ func createDNS(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("Failed to create DNS entry %+v: %v", entry, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	err = logic.SetDNS()
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("Failed to set DNS entries on file: %v", err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	logger.Log(1, "new DNS record added:", entry.Name)
@@ -221,7 +221,7 @@ func deleteDNS(w http.ResponseWriter, r *http.Request) {
 
 	if err != nil {
 		logger.Log(0, "failed to delete dns entry: ", entrytext)
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	logger.Log(1, "deleted dns entry: ", entrytext)
@@ -229,7 +229,7 @@ func deleteDNS(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("Failed to set DNS entries on file: %v", err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	json.NewEncoder(w).Encode(entrytext + " deleted.")
@@ -287,7 +287,7 @@ func pushDNS(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("Failed to set DNS entries on file: %v", err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	logger.Log(1, r.Header.Get("user"), "pushed DNS updates to nameserver")

+ 33 - 33
controllers/ext_client.go

@@ -21,13 +21,13 @@ import (
 
 func extClientHandlers(r *mux.Router) {
 
-	r.HandleFunc("/api/extclients", securityCheck(false, http.HandlerFunc(getAllExtClients))).Methods("GET")
-	r.HandleFunc("/api/extclients/{network}", securityCheck(false, http.HandlerFunc(getNetworkExtClients))).Methods("GET")
-	r.HandleFunc("/api/extclients/{network}/{clientid}", securityCheck(false, http.HandlerFunc(getExtClient))).Methods("GET")
-	r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", netUserSecurityCheck(false, true, http.HandlerFunc(getExtClientConf))).Methods("GET")
-	r.HandleFunc("/api/extclients/{network}/{clientid}", netUserSecurityCheck(false, true, http.HandlerFunc(updateExtClient))).Methods("PUT")
-	r.HandleFunc("/api/extclients/{network}/{clientid}", netUserSecurityCheck(false, true, http.HandlerFunc(deleteExtClient))).Methods("DELETE")
-	r.HandleFunc("/api/extclients/{network}/{nodeid}", netUserSecurityCheck(false, true, checkFreeTierLimits(clients_l, http.HandlerFunc(createExtClient)))).Methods("POST")
+	r.HandleFunc("/api/extclients", logic.SecurityCheck(false, http.HandlerFunc(getAllExtClients))).Methods("GET")
+	r.HandleFunc("/api/extclients/{network}", logic.SecurityCheck(false, http.HandlerFunc(getNetworkExtClients))).Methods("GET")
+	r.HandleFunc("/api/extclients/{network}/{clientid}", logic.SecurityCheck(false, http.HandlerFunc(getExtClient))).Methods("GET")
+	r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", logic.NetUserSecurityCheck(false, true, http.HandlerFunc(getExtClientConf))).Methods("GET")
+	r.HandleFunc("/api/extclients/{network}/{clientid}", logic.NetUserSecurityCheck(false, true, http.HandlerFunc(updateExtClient))).Methods("PUT")
+	r.HandleFunc("/api/extclients/{network}/{clientid}", logic.NetUserSecurityCheck(false, true, http.HandlerFunc(deleteExtClient))).Methods("DELETE")
+	r.HandleFunc("/api/extclients/{network}/{nodeid}", logic.NetUserSecurityCheck(false, true, checkFreeTierLimits(clients_l, http.HandlerFunc(createExtClient)))).Methods("POST")
 }
 
 func checkIngressExists(nodeID string) bool {
@@ -62,7 +62,7 @@ func getNetworkExtClients(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to get ext clients for network [%s]: %v", network, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -96,16 +96,16 @@ func getAllExtClients(w http.ResponseWriter, r *http.Request) {
 	if marshalErr != nil {
 		logger.Log(0, "error unmarshalling networks: ",
 			marshalErr.Error())
-		returnErrorResponse(w, r, formatError(marshalErr, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(marshalErr, "internal"))
 		return
 	}
 	clients := []models.ExtClient{}
 	var err error
-	if networksSlice[0] == ALL_NETWORK_ACCESS {
+	if networksSlice[0] == logic.ALL_NETWORK_ACCESS {
 		clients, err = functions.GetAllExtClients()
 		if err != nil && !database.IsEmptyRecord(err) {
 			logger.Log(0, "failed to get all extclients: ", err.Error())
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 	} else {
@@ -146,7 +146,7 @@ func getExtClient(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get extclient for [%s] on network [%s]: %v",
 			clientid, network, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -177,7 +177,7 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get extclient for [%s] on network [%s]: %v",
 			clientid, networkid, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -185,14 +185,14 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", client.IngressGatewayID, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
 	network, err := logic.GetParentNetwork(client.Network)
 	if err != nil {
 		logger.Log(1, r.Header.Get("user"), "Could not retrieve Ingress Gateway Network", client.Network)
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -258,7 +258,7 @@ Endpoint = %s
 		bytes, err := qrcode.Encode(config, qrcode.Medium, 220)
 		if err != nil {
 			logger.Log(1, r.Header.Get("user"), "failed to encode qr code: ", err.Error())
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 		w.Header().Set("Content-Type", "image/png")
@@ -266,7 +266,7 @@ Endpoint = %s
 		_, err = w.Write(bytes)
 		if err != nil {
 			logger.Log(1, r.Header.Get("user"), "response writer error (qr) ", err.Error())
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 		return
@@ -280,7 +280,7 @@ Endpoint = %s
 		_, err := fmt.Fprint(w, config)
 		if err != nil {
 			logger.Log(1, r.Header.Get("user"), "response writer error (file) ", err.Error())
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		}
 		return
 	}
@@ -310,7 +310,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 		err := errors.New("ingress does not exist")
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to create extclient on network [%s]: %v", networkName, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -329,7 +329,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", nodeid, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	extclient.IngressGatewayEndpoint = node.Endpoint + ":" + strconv.FormatInt(int64(node.ListenPort), 10)
@@ -345,7 +345,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to create new ext client on network [%s]: %v", networkName, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -355,7 +355,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 		if isAdmin, err = checkProClientAccess(userID, extclient.ClientID, &parentNetwork); err != nil {
 			logger.Log(0, userID, "attempted to create a client on network", networkName, "but they lack access")
 			logic.DeleteExtClient(networkName, extclient.ClientID)
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 		if !isAdmin {
@@ -400,7 +400,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	clientid := params["clientid"]
@@ -410,7 +410,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to get record key for client [%s], network [%s]: %v",
 				clientid, network, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key)
@@ -418,13 +418,13 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to fetch  ext client record key [%s] from db for client [%s], network [%s]: %v",
 				key, clientid, network, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	if err = json.Unmarshal([]byte(data), &oldExtClient); err != nil {
 		logger.Log(0, "error unmarshalling extclient: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -435,7 +435,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
 		userID := r.Header.Get("user")
 		_, doesOwn := doesUserOwnClient(userID, params["clientid"], networkName)
 		if !doesOwn {
-			returnErrorResponse(w, r, formatError(fmt.Errorf("user not permitted"), "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "internal"))
 			return
 		}
 	}
@@ -457,7 +457,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to update ext client [%s], network [%s]: %v",
 				clientid, network, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	logger.Log(0, r.Header.Get("user"), "updated ext client", newExtClient.ClientID)
@@ -497,14 +497,14 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
 		err = errors.New("Could not delete extclient " + params["clientid"])
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to delete extclient [%s],network [%s]: %v", clientid, network, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	ingressnode, err := logic.GetNodeByID(extclient.IngressGatewayID)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to get ingress gateway node [%s] info: %v", extclient.IngressGatewayID, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -513,7 +513,7 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
 		userID, clientID, networkName := r.Header.Get("user"), params["clientid"], params["network"]
 		_, doesOwn := doesUserOwnClient(userID, clientID, networkName)
 		if !doesOwn {
-			returnErrorResponse(w, r, formatError(fmt.Errorf("user not permitted"), "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "internal"))
 			return
 		}
 	}
@@ -531,7 +531,7 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to delete extclient [%s],network [%s]: %v", clientid, network, err))
 		err = errors.New("Could not delete extclient " + params["clientid"])
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -542,7 +542,7 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
 
 	logger.Log(0, r.Header.Get("user"),
 		"Deleted extclient client", params["clientid"], "from network", params["network"])
-	returnSuccessResponse(w, r, params["clientid"]+" deleted.")
+	logic.ReturnSuccessResponse(w, r, params["clientid"]+" deleted.")
 }
 
 func checkProClientAccess(username, clientID string, network *models.Network) (bool, error) {

+ 9 - 10
controllers/limits.go

@@ -4,7 +4,6 @@ import (
 	"net/http"
 
 	"github.com/gravitl/netmaker/database"
-	"github.com/gravitl/netmaker/ee"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 )
@@ -23,32 +22,32 @@ func checkFreeTierLimits(limit_choice int, next http.Handler) http.HandlerFunc {
 			Code: http.StatusUnauthorized, Message: "free tier limits exceeded on networks",
 		}
 
-		if ee.Limits.FreeTier { // check that free tier limits not exceeded
+		if logic.Free_Tier && logic.Is_EE { // check that free tier limits not exceeded
 			if limit_choice == networks_l {
 				currentNetworks, err := logic.GetNetworks()
-				if (err != nil && !database.IsEmptyRecord(err)) || len(currentNetworks) >= ee.Limits.Networks {
-					returnErrorResponse(w, r, errorResponse)
+				if (err != nil && !database.IsEmptyRecord(err)) || len(currentNetworks) >= logic.Networks_Limit {
+					logic.ReturnErrorResponse(w, r, errorResponse)
 					return
 				}
 			} else if limit_choice == node_l {
 				nodes, err := logic.GetAllNodes()
-				if (err != nil && !database.IsEmptyRecord(err)) || len(nodes) >= ee.Limits.Nodes {
+				if (err != nil && !database.IsEmptyRecord(err)) || len(nodes) >= logic.Node_Limit {
 					errorResponse.Message = "free tier limits exceeded on nodes"
-					returnErrorResponse(w, r, errorResponse)
+					logic.ReturnErrorResponse(w, r, errorResponse)
 					return
 				}
 			} else if limit_choice == users_l {
 				users, err := logic.GetUsers()
-				if (err != nil && !database.IsEmptyRecord(err)) || len(users) >= ee.Limits.Users {
+				if (err != nil && !database.IsEmptyRecord(err)) || len(users) >= logic.Users_Limit {
 					errorResponse.Message = "free tier limits exceeded on users"
-					returnErrorResponse(w, r, errorResponse)
+					logic.ReturnErrorResponse(w, r, errorResponse)
 					return
 				}
 			} else if limit_choice == clients_l {
 				clients, err := logic.GetAllExtClients()
-				if (err != nil && !database.IsEmptyRecord(err)) || len(clients) >= ee.Limits.Clients {
+				if (err != nil && !database.IsEmptyRecord(err)) || len(clients) >= logic.Clients_Limit {
 					errorResponse.Message = "free tier limits exceeded on external clients"
-					returnErrorResponse(w, r, errorResponse)
+					logic.ReturnErrorResponse(w, r, errorResponse)
 					return
 				}
 			}

+ 2 - 1
controllers/logger.go

@@ -7,10 +7,11 @@ import (
 
 	"github.com/gorilla/mux"
 	"github.com/gravitl/netmaker/logger"
+	"github.com/gravitl/netmaker/logic"
 )
 
 func loggerHandlers(r *mux.Router) {
-	r.HandleFunc("/api/logs", securityCheck(true, http.HandlerFunc(getLogs))).Methods("GET")
+	r.HandleFunc("/api/logs", logic.SecurityCheck(true, http.HandlerFunc(getLogs))).Methods("GET")
 }
 
 func getLogs(w http.ResponseWriter, r *http.Request) {

+ 43 - 49
controllers/network.go

@@ -17,26 +17,20 @@ import (
 	"github.com/gravitl/netmaker/servercfg"
 )
 
-// ALL_NETWORK_ACCESS - represents all networks
-const ALL_NETWORK_ACCESS = "THIS_USER_HAS_ALL"
-
-// NO_NETWORKS_PRESENT - represents no networks
-const NO_NETWORKS_PRESENT = "THIS_USER_HAS_NONE"
-
 func networkHandlers(r *mux.Router) {
-	r.HandleFunc("/api/networks", securityCheck(false, http.HandlerFunc(getNetworks))).Methods("GET")
-	r.HandleFunc("/api/networks", securityCheck(true, checkFreeTierLimits(networks_l, http.HandlerFunc(createNetwork)))).Methods("POST")
-	r.HandleFunc("/api/networks/{networkname}", securityCheck(false, http.HandlerFunc(getNetwork))).Methods("GET")
-	r.HandleFunc("/api/networks/{networkname}", securityCheck(false, http.HandlerFunc(updateNetwork))).Methods("PUT")
-	r.HandleFunc("/api/networks/{networkname}/nodelimit", securityCheck(true, http.HandlerFunc(updateNetworkNodeLimit))).Methods("PUT")
-	r.HandleFunc("/api/networks/{networkname}", securityCheck(true, http.HandlerFunc(deleteNetwork))).Methods("DELETE")
-	r.HandleFunc("/api/networks/{networkname}/keyupdate", securityCheck(true, http.HandlerFunc(keyUpdate))).Methods("POST")
-	r.HandleFunc("/api/networks/{networkname}/keys", securityCheck(false, http.HandlerFunc(createAccessKey))).Methods("POST")
-	r.HandleFunc("/api/networks/{networkname}/keys", securityCheck(false, http.HandlerFunc(getAccessKeys))).Methods("GET")
-	r.HandleFunc("/api/networks/{networkname}/keys/{name}", securityCheck(false, http.HandlerFunc(deleteAccessKey))).Methods("DELETE")
+	r.HandleFunc("/api/networks", logic.SecurityCheck(false, http.HandlerFunc(getNetworks))).Methods("GET")
+	r.HandleFunc("/api/networks", logic.SecurityCheck(true, checkFreeTierLimits(networks_l, http.HandlerFunc(createNetwork)))).Methods("POST")
+	r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(false, http.HandlerFunc(getNetwork))).Methods("GET")
+	r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(false, http.HandlerFunc(updateNetwork))).Methods("PUT")
+	r.HandleFunc("/api/networks/{networkname}/nodelimit", logic.SecurityCheck(true, http.HandlerFunc(updateNetworkNodeLimit))).Methods("PUT")
+	r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(true, http.HandlerFunc(deleteNetwork))).Methods("DELETE")
+	r.HandleFunc("/api/networks/{networkname}/keyupdate", logic.SecurityCheck(true, http.HandlerFunc(keyUpdate))).Methods("POST")
+	r.HandleFunc("/api/networks/{networkname}/keys", logic.SecurityCheck(false, http.HandlerFunc(createAccessKey))).Methods("POST")
+	r.HandleFunc("/api/networks/{networkname}/keys", logic.SecurityCheck(false, http.HandlerFunc(getAccessKeys))).Methods("GET")
+	r.HandleFunc("/api/networks/{networkname}/keys/{name}", logic.SecurityCheck(false, http.HandlerFunc(deleteAccessKey))).Methods("DELETE")
 	// ACLs
-	r.HandleFunc("/api/networks/{networkname}/acls", securityCheck(true, http.HandlerFunc(updateNetworkACL))).Methods("PUT")
-	r.HandleFunc("/api/networks/{networkname}/acls", securityCheck(true, http.HandlerFunc(getNetworkACL))).Methods("GET")
+	r.HandleFunc("/api/networks/{networkname}/acls", logic.SecurityCheck(true, http.HandlerFunc(updateNetworkACL))).Methods("PUT")
+	r.HandleFunc("/api/networks/{networkname}/acls", logic.SecurityCheck(true, http.HandlerFunc(getNetworkACL))).Methods("GET")
 }
 
 // swagger:route GET /api/networks networks getNetworks
@@ -58,16 +52,16 @@ func getNetworks(w http.ResponseWriter, r *http.Request) {
 	if marshalErr != nil {
 		logger.Log(0, r.Header.Get("user"), "error unmarshalling networks: ",
 			marshalErr.Error())
-		returnErrorResponse(w, r, formatError(marshalErr, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(marshalErr, "badrequest"))
 		return
 	}
 	allnetworks := []models.Network{}
 	var err error
-	if networksSlice[0] == ALL_NETWORK_ACCESS {
+	if networksSlice[0] == logic.ALL_NETWORK_ACCESS {
 		allnetworks, err = logic.GetNetworks()
 		if err != nil && !database.IsEmptyRecord(err) {
 			logger.Log(0, r.Header.Get("user"), "failed to fetch networks: ", err.Error())
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 	} else {
@@ -110,7 +104,7 @@ func getNetwork(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to fetch network [%s] info: %v",
 			netname, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	if !servercfg.IsDisplayKeys() {
@@ -140,7 +134,7 @@ func keyUpdate(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to update keys for network [%s]: %v",
 			netname, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	logger.Log(2, r.Header.Get("user"), "updated key on network", netname)
@@ -182,7 +176,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "failed to get network info: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	var newNetwork models.Network
@@ -190,7 +184,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
@@ -203,7 +197,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "failed to update network: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
@@ -231,7 +225,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
 			logger.Log(0, r.Header.Get("user"),
 				fmt.Sprintf("failed to update network [%s] ipv4 addresses: %v",
 					network.NetID, err.Error()))
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 	}
@@ -241,7 +235,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
 			logger.Log(0, r.Header.Get("user"),
 				fmt.Sprintf("failed to update network [%s] ipv6 addresses: %v",
 					network.NetID, err.Error()))
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 	}
@@ -251,7 +245,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
 			logger.Log(0, r.Header.Get("user"),
 				fmt.Sprintf("failed to update network [%s] local addresses: %v",
 					network.NetID, err.Error()))
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 	}
@@ -261,7 +255,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
 			logger.Log(0, r.Header.Get("user"),
 				fmt.Sprintf("failed to update network [%s] hole punching: %v",
 					network.NetID, err.Error()))
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 	}
@@ -271,7 +265,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
 			logger.Log(0, r.Header.Get("user"),
 				fmt.Sprintf("failed to get network [%s] nodes: %v",
 					network.NetID, err.Error()))
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 		for _, node := range nodes {
@@ -305,7 +299,7 @@ func updateNetworkNodeLimit(w http.ResponseWriter, r *http.Request) {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to get network [%s] nodes: %v",
 				network.NetID, err.Error()))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -315,7 +309,7 @@ func updateNetworkNodeLimit(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	if networkChange.NodeLimit != 0 {
@@ -324,7 +318,7 @@ func updateNetworkNodeLimit(w http.ResponseWriter, r *http.Request) {
 		if err != nil {
 			logger.Log(0, r.Header.Get("user"),
 				"error marshalling resp: ", err.Error())
-			returnErrorResponse(w, r, formatError(err, "badrequest"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 			return
 		}
 		database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME)
@@ -354,21 +348,21 @@ func updateNetworkACL(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	err = json.NewDecoder(r.Body).Decode(&networkACLChange)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	newNetACL, err := networkACLChange.Save(acls.ContainerID(netname))
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to update ACLs for network [%s]: %v", netname, err))
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	logger.Log(1, r.Header.Get("user"), "updated ACLs for network", netname)
@@ -412,7 +406,7 @@ func getNetworkACL(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	logger.Log(2, r.Header.Get("user"), "fetched acl for network", netname)
@@ -445,7 +439,7 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
 		}
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to delete network [%s]: %v", network, err))
-		returnErrorResponse(w, r, formatError(err, errtype))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, errtype))
 		return
 	}
 	logger.Log(1, r.Header.Get("user"), "deleted network", network)
@@ -475,7 +469,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
@@ -483,7 +477,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
 		err := errors.New("IPv4 or IPv6 CIDR required")
 		logger.Log(0, r.Header.Get("user"), "failed to create network: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
@@ -491,7 +485,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "failed to create network: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
@@ -504,7 +498,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
 			}
 			logger.Log(0, r.Header.Get("user"), "failed to create network: ",
 				err.Error())
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 	}
@@ -537,28 +531,28 @@ func createAccessKey(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "failed to get network info: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	err = json.NewDecoder(r.Body).Decode(&accesskey)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	key, err := logic.CreateAccessKey(accesskey, network)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "failed to create access key: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
 	// do not allow access key creations view API with user names
 	if _, err = logic.GetUser(key.Name); err == nil {
 		logger.Log(0, "access key creation with invalid name attempted by", r.Header.Get("user"))
-		returnErrorResponse(w, r, formatError(fmt.Errorf("cannot create access key with user name"), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("cannot create access key with user name"), "badrequest"))
 		logic.DeleteKey(key.Name, network.NetID)
 		return
 	}
@@ -587,7 +581,7 @@ func getAccessKeys(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get keys for network [%s]: %v",
 			network, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	if !servercfg.IsDisplayKeys() {
@@ -621,7 +615,7 @@ func deleteAccessKey(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to delete key [%s] for network [%s]: %v",
 			keyname, netname, err))
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	logger.Log(1, r.Header.Get("user"), "deleted access key", keyname, "on network,", netname)

+ 4 - 4
controllers/network_test.go

@@ -182,24 +182,24 @@ func TestSecurityCheck(t *testing.T) {
 	initialize()
 	os.Setenv("MASTER_KEY", "secretkey")
 	t.Run("NoNetwork", func(t *testing.T) {
-		networks, username, err := SecurityCheck(false, "", "Bearer secretkey")
+		networks, username, err := logic.UserPermissions(false, "", "Bearer secretkey")
 		assert.Nil(t, err)
 		t.Log(networks, username)
 	})
 	t.Run("WithNetwork", func(t *testing.T) {
-		networks, username, err := SecurityCheck(false, "skynet", "Bearer secretkey")
+		networks, username, err := logic.UserPermissions(false, "skynet", "Bearer secretkey")
 		assert.Nil(t, err)
 		t.Log(networks, username)
 	})
 	t.Run("BadNet", func(t *testing.T) {
 		t.Skip()
-		networks, username, err := SecurityCheck(false, "badnet", "Bearer secretkey")
+		networks, username, err := logic.UserPermissions(false, "badnet", "Bearer secretkey")
 		assert.NotNil(t, err)
 		t.Log(err)
 		t.Log(networks, username)
 	})
 	t.Run("BadToken", func(t *testing.T) {
-		networks, username, err := SecurityCheck(false, "skynet", "Bearer badkey")
+		networks, username, err := logic.UserPermissions(false, "skynet", "Bearer badkey")
 		assert.NotNil(t, err)
 		t.Log(err)
 		t.Log(networks, username)

+ 33 - 33
controllers/networkusers.go

@@ -14,13 +14,13 @@ import (
 )
 
 func networkUsersHandlers(r *mux.Router) {
-	r.HandleFunc("/api/networkusers", securityCheck(true, http.HandlerFunc(getAllNetworkUsers))).Methods("GET")
-	r.HandleFunc("/api/networkusers/{network}", securityCheck(true, http.HandlerFunc(getNetworkUsers))).Methods("GET")
-	r.HandleFunc("/api/networkusers/{network}/{networkuser}", securityCheck(true, http.HandlerFunc(getNetworkUser))).Methods("GET")
-	r.HandleFunc("/api/networkusers/{network}", securityCheck(true, http.HandlerFunc(createNetworkUser))).Methods("POST")
-	r.HandleFunc("/api/networkusers/{network}", securityCheck(true, http.HandlerFunc(updateNetworkUser))).Methods("PUT")
-	r.HandleFunc("/api/networkusers/data/{networkuser}/me", netUserSecurityCheck(false, false, http.HandlerFunc(getNetworkUserData))).Methods("GET")
-	r.HandleFunc("/api/networkusers/{network}/{networkuser}", securityCheck(true, http.HandlerFunc(deleteNetworkUser))).Methods("DELETE")
+	r.HandleFunc("/api/networkusers", logic.SecurityCheck(true, http.HandlerFunc(getAllNetworkUsers))).Methods("GET")
+	r.HandleFunc("/api/networkusers/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkUsers))).Methods("GET")
+	r.HandleFunc("/api/networkusers/{network}/{networkuser}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkUser))).Methods("GET")
+	r.HandleFunc("/api/networkusers/{network}", logic.SecurityCheck(true, http.HandlerFunc(createNetworkUser))).Methods("POST")
+	r.HandleFunc("/api/networkusers/{network}", logic.SecurityCheck(true, http.HandlerFunc(updateNetworkUser))).Methods("PUT")
+	r.HandleFunc("/api/networkusers/data/{networkuser}/me", logic.NetUserSecurityCheck(false, false, http.HandlerFunc(getNetworkUserData))).Methods("GET")
+	r.HandleFunc("/api/networkusers/{network}/{networkuser}", logic.SecurityCheck(true, http.HandlerFunc(deleteNetworkUser))).Methods("DELETE")
 }
 
 // == RETURN TYPES ==
@@ -52,18 +52,18 @@ func getNetworkUserData(w http.ResponseWriter, r *http.Request) {
 
 	networks, err := logic.GetNetworks()
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
 	if networkUserName == "" {
-		returnErrorResponse(w, r, formatError(errors.New("netuserToGet"), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("netuserToGet"), "badrequest"))
 		return
 	}
 
 	u, err := logic.GetUser(networkUserName)
 	if err != nil {
-		returnErrorResponse(w, r, formatError(errors.New("could not find user"), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("could not find user"), "badrequest"))
 		return
 	}
 
@@ -151,7 +151,7 @@ func getAllNetworkUsers(w http.ResponseWriter, r *http.Request) {
 
 	networks, err := logic.GetNetworks()
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -160,7 +160,7 @@ func getAllNetworkUsers(w http.ResponseWriter, r *http.Request) {
 	for i := range networks {
 		netusers, err := pro.GetNetworkUsers(networks[i].NetID)
 		if err != nil {
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 		for _, v := range netusers {
@@ -181,13 +181,13 @@ func getNetworkUsers(w http.ResponseWriter, r *http.Request) {
 
 	_, err := logic.GetNetwork(netname)
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
 	netusers, err := pro.GetNetworkUsers(netname)
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	w.WriteHeader(http.StatusOK)
@@ -203,19 +203,19 @@ func getNetworkUser(w http.ResponseWriter, r *http.Request) {
 
 	_, err := logic.GetNetwork(netname)
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
 	netuserToGet := params["networkuser"]
 	if netuserToGet == "" {
-		returnErrorResponse(w, r, formatError(errors.New("netuserToGet"), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("netuserToGet"), "badrequest"))
 		return
 	}
 
 	netuser, err := pro.GetNetworkUser(netname, promodels.NetworkUserID(netuserToGet))
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	w.WriteHeader(http.StatusOK)
@@ -230,7 +230,7 @@ func createNetworkUser(w http.ResponseWriter, r *http.Request) {
 
 	network, err := logic.GetNetwork(netname)
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	var networkuser promodels.NetworkUser
@@ -238,13 +238,13 @@ func createNetworkUser(w http.ResponseWriter, r *http.Request) {
 	// we decode our body request params
 	err = json.NewDecoder(r.Body).Decode(&networkuser)
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
 	err = pro.CreateNetworkUser(&network, &networkuser)
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
@@ -260,7 +260,7 @@ func updateNetworkUser(w http.ResponseWriter, r *http.Request) {
 
 	network, err := logic.GetNetwork(netname)
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	var networkuser promodels.NetworkUser
@@ -268,38 +268,38 @@ func updateNetworkUser(w http.ResponseWriter, r *http.Request) {
 	// we decode our body request params
 	err = json.NewDecoder(r.Body).Decode(&networkuser)
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	if networkuser.ID == "" || !pro.DoesNetworkUserExist(netname, networkuser.ID) {
-		returnErrorResponse(w, r, formatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest"))
 		return
 	}
 	if networkuser.AccessLevel < pro.NET_ADMIN || networkuser.AccessLevel > pro.NO_ACCESS {
-		returnErrorResponse(w, r, formatError(errors.New("invalid user access level provided"), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid user access level provided"), "badrequest"))
 		return
 	}
 
 	if networkuser.ClientLimit < 0 || networkuser.NodeLimit < 0 {
-		returnErrorResponse(w, r, formatError(errors.New("negative user limit provided"), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("negative user limit provided"), "badrequest"))
 		return
 	}
 
 	u, err := logic.GetUser(string(networkuser.ID))
 	if err != nil {
-		returnErrorResponse(w, r, formatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest"))
 		return
 	}
 
 	if !pro.IsUserAllowed(&network, u.UserName, u.Groups) {
-		returnErrorResponse(w, r, formatError(errors.New("user must be in allowed groups or users"), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("user must be in allowed groups or users"), "badrequest"))
 		return
 	}
 
 	if networkuser.AccessLevel == pro.NET_ADMIN {
 		currentUser, err := logic.GetUser(string(networkuser.ID))
 		if err != nil {
-			returnErrorResponse(w, r, formatError(errors.New("user model not found for "+string(networkuser.ID)), "badrequest"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("user model not found for "+string(networkuser.ID)), "badrequest"))
 			return
 		}
 
@@ -316,7 +316,7 @@ func updateNetworkUser(w http.ResponseWriter, r *http.Request) {
 					UserName: currentUser.UserName,
 				},
 			); err != nil {
-				returnErrorResponse(w, r, formatError(errors.New("user model failed net admin update "+string(networkuser.ID)+" (are they an admin?"), "badrequest"))
+				logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("user model failed net admin update "+string(networkuser.ID)+" (are they an admin?"), "badrequest"))
 				return
 			}
 		}
@@ -324,7 +324,7 @@ func updateNetworkUser(w http.ResponseWriter, r *http.Request) {
 
 	err = pro.UpdateNetworkUser(netname, &networkuser)
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
@@ -340,18 +340,18 @@ func deleteNetworkUser(w http.ResponseWriter, r *http.Request) {
 
 	_, err := logic.GetNetwork(netname)
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
 	netuserToDelete := params["networkuser"]
 	if netuserToDelete == "" {
-		returnErrorResponse(w, r, formatError(errors.New("no group name provided"), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no group name provided"), "badrequest"))
 		return
 	}
 
 	if err := pro.DeleteNetworkUser(netname, netuserToDelete); err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 

+ 53 - 54
controllers/node.go

@@ -8,7 +8,6 @@ import (
 
 	"github.com/gorilla/mux"
 	"github.com/gravitl/netmaker/database"
-	"github.com/gravitl/netmaker/functions"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic/pro"
@@ -30,8 +29,8 @@ func nodeHandlers(r *mux.Router) {
 	r.HandleFunc("/api/nodes/{network}/{nodeid}/deleterelay", authorize(false, true, "user", http.HandlerFunc(deleteRelay))).Methods("DELETE")
 	r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", authorize(false, true, "user", http.HandlerFunc(createEgressGateway))).Methods("POST")
 	r.HandleFunc("/api/nodes/{network}/{nodeid}/deletegateway", authorize(false, true, "user", http.HandlerFunc(deleteEgressGateway))).Methods("DELETE")
-	r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", securityCheck(false, http.HandlerFunc(createIngressGateway))).Methods("POST")
-	r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", securityCheck(false, http.HandlerFunc(deleteIngressGateway))).Methods("DELETE")
+	r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", logic.SecurityCheck(false, http.HandlerFunc(createIngressGateway))).Methods("POST")
+	r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", logic.SecurityCheck(false, http.HandlerFunc(deleteIngressGateway))).Methods("DELETE")
 	r.HandleFunc("/api/nodes/{network}/{nodeid}/approve", authorize(false, true, "user", http.HandlerFunc(uncordonNode))).Methods("POST")
 	r.HandleFunc("/api/nodes/{network}", nodeauth(checkFreeTierLimits(node_l, http.HandlerFunc(createNode)))).Methods("POST")
 	r.HandleFunc("/api/nodes/adm/{network}/lastmodified", authorize(false, true, "network", http.HandlerFunc(getLastModified))).Methods("GET")
@@ -66,19 +65,19 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
 		errorResponse.Message = decoderErr.Error()
 		logger.Log(0, request.Header.Get("user"), "error decoding request body: ",
 			decoderErr.Error())
-		returnErrorResponse(response, request, errorResponse)
+		logic.ReturnErrorResponse(response, request, errorResponse)
 		return
 	} else {
 		errorResponse.Code = http.StatusBadRequest
 		if authRequest.ID == "" {
 			errorResponse.Message = "W1R3: ID can't be empty"
 			logger.Log(0, request.Header.Get("user"), errorResponse.Message)
-			returnErrorResponse(response, request, errorResponse)
+			logic.ReturnErrorResponse(response, request, errorResponse)
 			return
 		} else if authRequest.Password == "" {
 			errorResponse.Message = "W1R3: Password can't be empty"
 			logger.Log(0, request.Header.Get("user"), errorResponse.Message)
-			returnErrorResponse(response, request, errorResponse)
+			logic.ReturnErrorResponse(response, request, errorResponse)
 			return
 		} else {
 			var err error
@@ -89,7 +88,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
 				errorResponse.Message = err.Error()
 				logger.Log(0, request.Header.Get("user"),
 					fmt.Sprintf("failed to get node info [%s]: %v", authRequest.ID, err))
-				returnErrorResponse(response, request, errorResponse)
+				logic.ReturnErrorResponse(response, request, errorResponse)
 				return
 			}
 
@@ -99,7 +98,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
 				errorResponse.Message = err.Error()
 				logger.Log(0, request.Header.Get("user"),
 					"error validating user password: ", err.Error())
-				returnErrorResponse(response, request, errorResponse)
+				logic.ReturnErrorResponse(response, request, errorResponse)
 				return
 			} else {
 				tokenString, err := logic.CreateJWT(authRequest.ID, authRequest.MacAddress, result.Network)
@@ -109,7 +108,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
 					errorResponse.Message = "Could not create Token"
 					logger.Log(0, request.Header.Get("user"),
 						fmt.Sprintf("%s: %v", errorResponse.Message, err))
-					returnErrorResponse(response, request, errorResponse)
+					logic.ReturnErrorResponse(response, request, errorResponse)
 					return
 				}
 
@@ -128,7 +127,7 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
 					errorResponse.Message = err.Error()
 					logger.Log(0, request.Header.Get("user"),
 						"error marshalling resp: ", err.Error())
-					returnErrorResponse(response, request, errorResponse)
+					logic.ReturnErrorResponse(response, request, errorResponse)
 					return
 				}
 				response.WriteHeader(http.StatusOK)
@@ -149,7 +148,7 @@ func nodeauth(next http.Handler) http.HandlerFunc {
 			errorResponse := models.ErrorResponse{
 				Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.",
 			}
-			returnErrorResponse(w, r, errorResponse)
+			logic.ReturnErrorResponse(w, r, errorResponse)
 			return
 		} else {
 			token = tokenSplit[1]
@@ -161,7 +160,7 @@ func nodeauth(next http.Handler) http.HandlerFunc {
 			errorResponse := models.ErrorResponse{
 				Code: http.StatusNotFound, Message: "no networks",
 			}
-			returnErrorResponse(w, r, errorResponse)
+			logic.ReturnErrorResponse(w, r, errorResponse)
 			return
 		}
 		for _, network := range networks {
@@ -177,7 +176,7 @@ func nodeauth(next http.Handler) http.HandlerFunc {
 			errorResponse := models.ErrorResponse{
 				Code: http.StatusUnauthorized, Message: "You are unauthorized to access this endpoint.",
 			}
-			returnErrorResponse(w, r, errorResponse)
+			logic.ReturnErrorResponse(w, r, errorResponse)
 			return
 		}
 		next.ServeHTTP(w, r)
@@ -194,16 +193,16 @@ func nodeauth(next http.Handler) http.HandlerFunc {
 func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Handler) http.HandlerFunc {
 	return func(w http.ResponseWriter, r *http.Request) {
 		var errorResponse = models.ErrorResponse{
-			Code: http.StatusUnauthorized, Message: unauthorized_msg,
+			Code: http.StatusUnauthorized, Message: logic.Unauthorized_Msg,
 		}
 
 		var params = mux.Vars(r)
 
-		networkexists, _ := functions.NetworkExists(params["network"])
+		networkexists, _ := logic.NetworkExists(params["network"])
 		//check that the request is for a valid network
 		//if (networkCheck && !networkexists) || err != nil {
 		if networkCheck && !networkexists {
-			returnErrorResponse(w, r, errorResponse)
+			logic.ReturnErrorResponse(w, r, errorResponse)
 			return
 		} else {
 			w.Header().Set("Content-Type", "application/json")
@@ -220,7 +219,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
 			if len(tokenSplit) > 1 {
 				authToken = tokenSplit[1]
 			} else {
-				returnErrorResponse(w, r, errorResponse)
+				logic.ReturnErrorResponse(w, r, errorResponse)
 				return
 			}
 			//check if node instead of user
@@ -236,7 +235,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
 			var nodeID = ""
 			username, networks, isadmin, errN := logic.VerifyUserToken(authToken)
 			if errN != nil {
-				returnErrorResponse(w, r, errorResponse)
+				logic.ReturnErrorResponse(w, r, errorResponse)
 				return
 			}
 
@@ -269,7 +268,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
 					} else {
 						node, err := logic.GetNodeByID(nodeID)
 						if err != nil {
-							returnErrorResponse(w, r, errorResponse)
+							logic.ReturnErrorResponse(w, r, errorResponse)
 							return
 						}
 						isAuthorized = (node.Network == params["network"])
@@ -287,7 +286,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
 				}
 			}
 			if !isAuthorized {
-				returnErrorResponse(w, r, errorResponse)
+				logic.ReturnErrorResponse(w, r, errorResponse)
 				return
 			} else {
 				//If authorized, this function passes along it's request and output to the appropriate route function.
@@ -324,7 +323,7 @@ func getNetworkNodes(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("error fetching nodes on network %s: %v", networkName, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -358,7 +357,7 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) {
 	if err != nil && r.Header.Get("ismasterkey") != "yes" {
 		logger.Log(0, r.Header.Get("user"),
 			"error fetching user info: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	var nodes []models.Node
@@ -366,7 +365,7 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) {
 		nodes, err = logic.GetAllNodes()
 		if err != nil {
 			logger.Log(0, "error fetching all nodes info: ", err.Error())
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 	} else {
@@ -374,7 +373,7 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) {
 		if err != nil {
 			logger.Log(0, r.Header.Get("user"),
 				"error fetching nodes: ", err.Error())
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 	}
@@ -418,7 +417,7 @@ func getNode(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -426,7 +425,7 @@ func getNode(w http.ResponseWriter, r *http.Request) {
 	if err != nil && !database.IsEmptyRecord(err) {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("error fetching wg peers config for node [ %s ]: %v", nodeid, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -470,7 +469,7 @@ func getLastModified(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("error fetching network [%s] info: %v", networkName, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	logger.Log(2, r.Header.Get("user"), "called last modified")
@@ -498,12 +497,12 @@ func createNode(w http.ResponseWriter, r *http.Request) {
 		Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.",
 	}
 	networkName := params["network"]
-	networkexists, err := functions.NetworkExists(networkName)
+	networkexists, err := logic.NetworkExists(networkName)
 
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to fetch network [%s] info: %v", networkName, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	} else if !networkexists {
 		errorResponse = models.ErrorResponse{
@@ -511,7 +510,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
 		}
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("network [%s] does not exist", networkName))
-		returnErrorResponse(w, r, errorResponse)
+		logic.ReturnErrorResponse(w, r, errorResponse)
 		return
 	}
 
@@ -521,7 +520,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
 	err = json.NewDecoder(r.Body).Decode(&node)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
@@ -531,14 +530,14 @@ func createNode(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to get network [%s] info: %v", node.Network, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	node.NetworkSettings, err = logic.GetNetworkSettings(node.Network)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to get network [%s] settings: %v", node.Network, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	keyName, validKey := logic.IsKeyValid(networkName, node.AccessKey)
@@ -554,7 +553,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
 			logger.Log(0, r.Header.Get("user"),
 				fmt.Sprintf("failed to create node on network [%s]: %s",
 					node.Network, errorResponse.Message))
-			returnErrorResponse(w, r, errorResponse)
+			logic.ReturnErrorResponse(w, r, errorResponse)
 			return
 		}
 	}
@@ -569,17 +568,17 @@ func createNode(w http.ResponseWriter, r *http.Request) {
 	key, keyErr := logic.RetrievePublicTrafficKey()
 	if keyErr != nil {
 		logger.Log(0, "error retrieving key: ", keyErr.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	if key == nil {
 		logger.Log(0, "error: server traffic key is nil")
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	if node.TrafficKeys.Mine == nil {
 		logger.Log(0, "error: node traffic key is nil")
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	node.TrafficKeys = models.TrafficKeys{
@@ -592,7 +591,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to create node on network [%s]: %s",
 				node.Network, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -609,7 +608,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
 		if !updatedUserNode { // user was found but not updated, so delete node
 			logger.Log(0, "failed to add node to user", keyName)
 			logic.DeleteNodeByID(&node, true)
-			returnErrorResponse(w, r, formatError(err, "internal"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 	}
@@ -618,7 +617,7 @@ func createNode(w http.ResponseWriter, r *http.Request) {
 	if err != nil && !database.IsEmptyRecord(err) {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("error fetching wg peers config for node [ %s ]: %v", node.ID, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -656,7 +655,7 @@ func uncordonNode(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to uncordon node [%s]: %v", node.Name, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	logger.Log(1, r.Header.Get("user"), "uncordoned node", node.Name)
@@ -686,7 +685,7 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) {
 	err := json.NewDecoder(r.Body).Decode(&gateway)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	gateway.NetID = params["network"]
@@ -696,7 +695,7 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to create egress gateway on node [%s] on network [%s]: %v",
 				gateway.NodeID, gateway.NetID, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -728,7 +727,7 @@ func deleteEgressGateway(w http.ResponseWriter, r *http.Request) {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to delete egress gateway on node [%s] on network [%s]: %v",
 				nodeid, netid, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -762,7 +761,7 @@ func createIngressGateway(w http.ResponseWriter, r *http.Request) {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to create ingress gateway on node [%s] on network [%s]: %v",
 				nodeid, netid, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -794,7 +793,7 @@ func deleteIngressGateway(w http.ResponseWriter, r *http.Request) {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to delete ingress gateway on node [%s] on network [%s]: %v",
 				nodeid, netid, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -828,7 +827,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -837,7 +836,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
 	err = json.NewDecoder(r.Body).Decode(&newNode)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	relayupdate := false
@@ -885,7 +884,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to update node info [ %s ] info: %v", nodeid, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	if relayupdate {
@@ -932,20 +931,20 @@ func deleteNode(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err))
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	if isServer(&node) {
 		err := fmt.Errorf("cannot delete server node")
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to delete node [ %s ]: %v", nodeid, err))
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	if r.Header.Get("ismaster") != "yes" {
 		username := r.Header.Get("user")
 		if username != "" && !doesUserOwnNode(username, params["network"], nodeid) {
-			returnErrorResponse(w, r, formatError(fmt.Errorf("user not permitted"), "badrequest"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "badrequest"))
 			return
 		}
 	}
@@ -954,11 +953,11 @@ func deleteNode(w http.ResponseWriter, r *http.Request) {
 
 	err = logic.DeleteNodeByID(&node, false)
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
-	returnSuccessResponse(w, r, nodeid+" deleted.")
+	logic.ReturnSuccessResponse(w, r, nodeid+" deleted.")
 
 	logger.Log(1, r.Header.Get("user"), "Deleted node", nodeid, "from network", params["network"])
 	runUpdates(&node, false)

+ 3 - 3
controllers/relay.go

@@ -30,7 +30,7 @@ func createRelay(w http.ResponseWriter, r *http.Request) {
 	err := json.NewDecoder(r.Body).Decode(&relay)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	relay.NetID = params["network"]
@@ -39,7 +39,7 @@ func createRelay(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to create relay on node [%s] on network [%s]: %v", relay.NodeID, relay.NetID, err))
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	logger.Log(1, r.Header.Get("user"), "created relay on node", relay.NodeID, "on network", relay.NetID)
@@ -73,7 +73,7 @@ func deleteRelay(w http.ResponseWriter, r *http.Request) {
 	updatenodes, node, err := logic.DeleteRelay(netid, nodeid)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	logger.Log(1, r.Header.Get("user"), "deleted relay server", nodeid, "on network", netid)

+ 4 - 3
controllers/response_test.go

@@ -7,12 +7,13 @@ import (
 	"net/http/httptest"
 	"testing"
 
+	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestFormatError(t *testing.T) {
-	response := formatError(errors.New("this is a sample error"), "badrequest")
+	response := logic.FormatError(errors.New("this is a sample error"), "badrequest")
 	assert.Equal(t, http.StatusBadRequest, response.Code)
 	assert.Equal(t, "this is a sample error", response.Message)
 }
@@ -20,7 +21,7 @@ func TestFormatError(t *testing.T) {
 func TestReturnSuccessResponse(t *testing.T) {
 	var response models.SuccessResponse
 	handler := func(rw http.ResponseWriter, r *http.Request) {
-		returnSuccessResponse(rw, r, "This is a test message")
+		logic.ReturnSuccessResponse(rw, r, "This is a test message")
 	}
 	req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
 	w := httptest.NewRecorder()
@@ -42,7 +43,7 @@ func TestReturnErrorResponse(t *testing.T) {
 	errMessage.Code = http.StatusUnauthorized
 	errMessage.Message = "You are not authorized to access this endpoint"
 	handler := func(rw http.ResponseWriter, r *http.Request) {
-		returnErrorResponse(rw, r, errMessage)
+		logic.ReturnErrorResponse(rw, r, errMessage)
 	}
 	req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
 	w := httptest.NewRecorder()

+ 11 - 59
controllers/server.go

@@ -10,7 +10,6 @@ import (
 	"strings"
 
 	"github.com/gorilla/mux"
-	"github.com/gravitl/netmaker/ee"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
@@ -22,82 +21,35 @@ import (
 
 func serverHandlers(r *mux.Router) {
 	// r.HandleFunc("/api/server/addnetwork/{network}", securityCheckServer(true, http.HandlerFunc(addNetwork))).Methods("POST")
-	r.HandleFunc("/api/server/getconfig", securityCheckServer(false, http.HandlerFunc(getConfig))).Methods("GET")
-	r.HandleFunc("/api/server/removenetwork/{network}", securityCheckServer(true, http.HandlerFunc(removeNetwork))).Methods("DELETE")
+	r.HandleFunc("/api/server/getconfig", allowUsers(http.HandlerFunc(getConfig))).Methods("GET")
 	r.HandleFunc("/api/server/register", authorize(true, false, "node", http.HandlerFunc(register))).Methods("POST")
 	r.HandleFunc("/api/server/getserverinfo", authorize(true, false, "node", http.HandlerFunc(getServerInfo))).Methods("GET")
 }
 
-//Security check is middleware for every function and just checks to make sure that its the master calling
-//Only admin should have access to all these network-level actions
-//or maybe some Users once implemented
-func securityCheckServer(adminonly bool, next http.Handler) http.HandlerFunc {
+// allowUsers - allow all authenticated (valid) users - only used by getConfig, may be able to remove during refactor
+func allowUsers(next http.Handler) http.HandlerFunc {
 	return func(w http.ResponseWriter, r *http.Request) {
 		var errorResponse = models.ErrorResponse{
-			Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.",
+			Code: http.StatusInternalServerError, Message: logic.Unauthorized_Msg,
 		}
-
 		bearerToken := r.Header.Get("Authorization")
-
 		var tokenSplit = strings.Split(bearerToken, " ")
 		var authToken = ""
 		if len(tokenSplit) < 2 {
-			errorResponse = models.ErrorResponse{
-				Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.",
-			}
-			returnErrorResponse(w, r, errorResponse)
+			logic.ReturnErrorResponse(w, r, errorResponse)
 			return
 		} else {
 			authToken = tokenSplit[1]
 		}
-		//all endpoints here require master so not as complicated
-		//still might not be a good  way of doing this
-		user, _, isadmin, err := logic.VerifyUserToken(authToken)
-		errorResponse = models.ErrorResponse{
-			Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.",
-		}
-		if !adminonly && (err != nil || user == "") {
-			returnErrorResponse(w, r, errorResponse)
-			return
-		}
-		if adminonly && !isadmin && !authenticateMaster(authToken) {
-			returnErrorResponse(w, r, errorResponse)
+		user, _, _, err := logic.VerifyUserToken(authToken)
+		if err != nil || user == "" {
+			logic.ReturnErrorResponse(w, r, errorResponse)
 			return
 		}
 		next.ServeHTTP(w, r)
 	}
 }
 
-// swagger:route DELETE /api/server/removenetwork/{network} nodes removeNetwork
-//
-// Remove a network from the server.
-//
-//		Schemes: https
-//
-// 		Security:
-//   		oauth
-//
-//		Responses:
-//			200: stringJSONResponse
-func removeNetwork(w http.ResponseWriter, r *http.Request) {
-	// Set header
-	w.Header().Set("Content-Type", "application/json")
-
-	// get params
-	var params = mux.Vars(r)
-	network := params["network"]
-	err := logic.DeleteNetwork(network)
-	if err != nil {
-		logger.Log(0, r.Header.Get("user"),
-			fmt.Sprintf("failed to delete network [%s]: %v", network, err))
-		json.NewEncoder(w).Encode(fmt.Sprintf("could not remove network %s from server", network))
-		return
-	}
-	logger.Log(1, r.Header.Get("user"),
-		fmt.Sprintf("deleted network [%s]: %v", network, err))
-	json.NewEncoder(w).Encode(fmt.Sprintf("network %s removed from server", network))
-}
-
 // swagger:route GET /api/server/getserverinfo nodes getServerInfo
 //
 // Get the server configuration.
@@ -138,7 +90,7 @@ func getConfig(w http.ResponseWriter, r *http.Request) {
 
 	scfg := servercfg.GetServerConfig()
 	scfg.IsEE = "no"
-	if ee.IsEnterprise() {
+	if logic.Is_EE {
 		scfg.IsEE = "yes"
 	}
 	json.NewEncoder(w).Encode(scfg)
@@ -166,7 +118,7 @@ func register(w http.ResponseWriter, r *http.Request) {
 		errorResponse := models.ErrorResponse{
 			Code: http.StatusBadRequest, Message: err.Error(),
 		}
-		returnErrorResponse(w, r, errorResponse)
+		logic.ReturnErrorResponse(w, r, errorResponse)
 		return
 	}
 	cert, ca, err := genCerts(&request.Key, &request.CommonName)
@@ -175,7 +127,7 @@ func register(w http.ResponseWriter, r *http.Request) {
 		errorResponse := models.ErrorResponse{
 			Code: http.StatusNotFound, Message: err.Error(),
 		}
-		returnErrorResponse(w, r, errorResponse)
+		logic.ReturnErrorResponse(w, r, errorResponse)
 		return
 	}
 	//x509.Certificate.PublicKey is an interface therefore json encoding/decoding result in a string value rather than a []byte

+ 34 - 34
controllers/user.go

@@ -25,13 +25,13 @@ func userHandlers(r *mux.Router) {
 	r.HandleFunc("/api/users/adm/hasadmin", hasAdmin).Methods("GET")
 	r.HandleFunc("/api/users/adm/createadmin", createAdmin).Methods("POST")
 	r.HandleFunc("/api/users/adm/authenticate", authenticateUser).Methods("POST")
-	r.HandleFunc("/api/users/{username}", securityCheck(false, continueIfUserMatch(http.HandlerFunc(updateUser)))).Methods("PUT")
-	r.HandleFunc("/api/users/networks/{username}", securityCheck(true, http.HandlerFunc(updateUserNetworks))).Methods("PUT")
-	r.HandleFunc("/api/users/{username}/adm", securityCheck(true, http.HandlerFunc(updateUserAdm))).Methods("PUT")
-	r.HandleFunc("/api/users/{username}", securityCheck(true, checkFreeTierLimits(users_l, http.HandlerFunc(createUser)))).Methods("POST")
-	r.HandleFunc("/api/users/{username}", securityCheck(true, http.HandlerFunc(deleteUser))).Methods("DELETE")
-	r.HandleFunc("/api/users/{username}", securityCheck(false, continueIfUserMatch(http.HandlerFunc(getUser)))).Methods("GET")
-	r.HandleFunc("/api/users", securityCheck(true, http.HandlerFunc(getUsers))).Methods("GET")
+	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(updateUser)))).Methods("PUT")
+	r.HandleFunc("/api/users/networks/{username}", logic.SecurityCheck(true, http.HandlerFunc(updateUserNetworks))).Methods("PUT")
+	r.HandleFunc("/api/users/{username}/adm", logic.SecurityCheck(true, http.HandlerFunc(updateUserAdm))).Methods("PUT")
+	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, checkFreeTierLimits(users_l, http.HandlerFunc(createUser)))).Methods("POST")
+	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(deleteUser))).Methods("DELETE")
+	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser)))).Methods("GET")
+	r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods("GET")
 	r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods("GET")
 	r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods("GET")
 	r.HandleFunc("/api/oauth/node-handler", socketHandler)
@@ -59,7 +59,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
 	}
 
 	if !servercfg.IsBasicAuthEnabled() {
-		returnErrorResponse(response, request, formatError(fmt.Errorf("basic auth is disabled"), "badrequest"))
+		logic.ReturnErrorResponse(response, request, logic.FormatError(fmt.Errorf("basic auth is disabled"), "badrequest"))
 		return
 	}
 
@@ -69,7 +69,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
 	if decoderErr != nil {
 		logger.Log(0, "error decoding request body: ",
 			decoderErr.Error())
-		returnErrorResponse(response, request, errorResponse)
+		logic.ReturnErrorResponse(response, request, errorResponse)
 		return
 	}
 	username := authRequest.UserName
@@ -77,14 +77,14 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
 	if err != nil {
 		logger.Log(0, username, "user validation failed: ",
 			err.Error())
-		returnErrorResponse(response, request, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(response, request, logic.FormatError(err, "badrequest"))
 		return
 	}
 
 	if jwt == "" {
 		// very unlikely that err is !nil and no jwt returned, but handle it anyways.
 		logger.Log(0, username, "jwt token is empty")
-		returnErrorResponse(response, request, formatError(errors.New("no token returned"), "internal"))
+		logic.ReturnErrorResponse(response, request, logic.FormatError(errors.New("no token returned"), "internal"))
 		return
 	}
 
@@ -102,7 +102,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
 	if jsonError != nil {
 		logger.Log(0, username,
 			"error marshalling resp: ", err.Error())
-		returnErrorResponse(response, request, errorResponse)
+		logic.ReturnErrorResponse(response, request, errorResponse)
 		return
 	}
 	logger.Log(2, username, "was authenticated")
@@ -128,7 +128,7 @@ func hasAdmin(w http.ResponseWriter, r *http.Request) {
 	hasadmin, err := logic.HasAdmin()
 	if err != nil {
 		logger.Log(0, "failed to check for admin: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -171,7 +171,7 @@ func getUser(w http.ResponseWriter, r *http.Request) {
 
 	if err != nil {
 		logger.Log(0, usernameFetched, "failed to fetch user: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	logger.Log(2, r.Header.Get("user"), "fetched user", usernameFetched)
@@ -197,7 +197,7 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
 
 	if err != nil {
 		logger.Log(0, "failed to fetch users: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -226,12 +226,12 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
 
 		logger.Log(0, admin.UserName, "error decoding request body: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
 	if !servercfg.IsBasicAuthEnabled() {
-		returnErrorResponse(w, r, formatError(fmt.Errorf("basic auth is disabled"), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("basic auth is disabled"), "badrequest"))
 		return
 	}
 
@@ -239,7 +239,7 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, admin.UserName, "failed to create admin: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
@@ -266,7 +266,7 @@ func createUser(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, user.UserName, "error decoding request body: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
@@ -274,7 +274,7 @@ func createUser(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, user.UserName, "error creating new user: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	logger.Log(1, user.UserName, "was created")
@@ -302,7 +302,7 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, username,
 			"failed to update user networks: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	var userchange models.User
@@ -311,7 +311,7 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, username, "error decoding request body: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	err = logic.UpdateUserNetworks(userchange.Networks, userchange.Groups, userchange.IsAdmin, &models.ReturnUser{
@@ -324,7 +324,7 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, username,
 			"failed to update user networks: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	logger.Log(1, username, "status was updated")
@@ -352,13 +352,13 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, username,
 			"failed to update user info: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	if auth.IsOauthUser(&user) == nil {
 		err := fmt.Errorf("cannot update user info for oauth user %s", username)
 		logger.Log(0, err.Error())
-		returnErrorResponse(w, r, formatError(err, "forbidden"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden"))
 		return
 	}
 	var userchange models.User
@@ -367,7 +367,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, username, "error decoding request body: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	userchange.Networks = nil
@@ -375,7 +375,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, username,
 			"failed to update user info: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	logger.Log(1, username, "was updated")
@@ -401,13 +401,13 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
 	username := params["username"]
 	user, err := GetUserInternal(username)
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	if auth.IsOauthUser(&user) != nil {
 		err := fmt.Errorf("cannot update user info for oauth user %s", username)
 		logger.Log(0, err.Error())
-		returnErrorResponse(w, r, formatError(err, "forbidden"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden"))
 		return
 	}
 	var userchange models.User
@@ -416,18 +416,18 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, username, "error decoding request body: ",
 			err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	if !user.IsAdmin {
 		logger.Log(0, username, "not an admin user")
-		returnErrorResponse(w, r, formatError(errors.New("not a admin user"), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("not a admin user"), "badrequest"))
 	}
 	user, err = logic.UpdateUser(userchange, user)
 	if err != nil {
 		logger.Log(0, username,
 			"failed to update user (admin) info: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	logger.Log(1, username, "was updated (admin)")
@@ -458,12 +458,12 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		logger.Log(0, username,
 			"failed to delete user: ", err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	} else if !success {
 		err := errors.New("delete unsuccessful")
 		logger.Log(0, username, err.Error())
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 

+ 11 - 9
controllers/usergroups.go

@@ -3,18 +3,20 @@ package controller
 import (
 	"encoding/json"
 	"errors"
-	"github.com/gravitl/netmaker/logger"
 	"net/http"
 
+	"github.com/gravitl/netmaker/logger"
+	"github.com/gravitl/netmaker/logic"
+
 	"github.com/gorilla/mux"
 	"github.com/gravitl/netmaker/logic/pro"
 	"github.com/gravitl/netmaker/models/promodels"
 )
 
 func userGroupsHandlers(r *mux.Router) {
-	r.HandleFunc("/api/usergroups", securityCheck(true, http.HandlerFunc(getUserGroups))).Methods("GET")
-	r.HandleFunc("/api/usergroups/{usergroup}", securityCheck(true, http.HandlerFunc(createUserGroup))).Methods("POST")
-	r.HandleFunc("/api/usergroups/{usergroup}", securityCheck(true, http.HandlerFunc(deleteUserGroup))).Methods("DELETE")
+	r.HandleFunc("/api/usergroups", logic.SecurityCheck(true, http.HandlerFunc(getUserGroups))).Methods("GET")
+	r.HandleFunc("/api/usergroups/{usergroup}", logic.SecurityCheck(true, http.HandlerFunc(createUserGroup))).Methods("POST")
+	r.HandleFunc("/api/usergroups/{usergroup}", logic.SecurityCheck(true, http.HandlerFunc(deleteUserGroup))).Methods("DELETE")
 }
 
 func getUserGroups(w http.ResponseWriter, r *http.Request) {
@@ -23,7 +25,7 @@ func getUserGroups(w http.ResponseWriter, r *http.Request) {
 
 	userGroups, err := pro.GetUserGroups()
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 	// Returns all the groups in JSON format
@@ -39,13 +41,13 @@ func createUserGroup(w http.ResponseWriter, r *http.Request) {
 	logger.Log(1, r.Header.Get("user"), "requested creating user group", newGroup)
 
 	if newGroup == "" {
-		returnErrorResponse(w, r, formatError(errors.New("no group name provided"), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no group name provided"), "badrequest"))
 		return
 	}
 
 	err := pro.InsertUserGroup(promodels.UserGroupName(newGroup))
 	if err != nil {
-		returnErrorResponse(w, r, formatError(err, "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
@@ -58,12 +60,12 @@ func deleteUserGroup(w http.ResponseWriter, r *http.Request) {
 	logger.Log(1, r.Header.Get("user"), "requested deleting user group", groupToDelete)
 
 	if groupToDelete == "" {
-		returnErrorResponse(w, r, formatError(errors.New("no group name provided"), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no group name provided"), "badrequest"))
 		return
 	}
 
 	if err := pro.DeleteUserGroup(promodels.UserGroupName(groupToDelete)); err != nil {
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 

+ 9 - 8
controllers/metrics.go → ee/ee_controllers/metrics.go

@@ -1,4 +1,4 @@
-package controller
+package ee_controllers
 
 import (
 	"encoding/json"
@@ -10,10 +10,11 @@ import (
 	"github.com/gravitl/netmaker/models"
 )
 
-func metricHandlers(r *mux.Router) {
-	r.HandleFunc("/api/metrics/{network}/{nodeid}", securityCheck(true, http.HandlerFunc(getNodeMetrics))).Methods("GET")
-	r.HandleFunc("/api/metrics/{network}", securityCheck(true, http.HandlerFunc(getNetworkNodesMetrics))).Methods("GET")
-	r.HandleFunc("/api/metrics", securityCheck(true, http.HandlerFunc(getAllMetrics))).Methods("GET")
+// MetricHandlers - How we handle EE Metrics
+func MetricHandlers(r *mux.Router) {
+	r.HandleFunc("/api/metrics/{network}/{nodeid}", logic.SecurityCheck(true, http.HandlerFunc(getNodeMetrics))).Methods("GET")
+	r.HandleFunc("/api/metrics/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodesMetrics))).Methods("GET")
+	r.HandleFunc("/api/metrics", logic.SecurityCheck(true, http.HandlerFunc(getAllMetrics))).Methods("GET")
 }
 
 // get the metrics of a given node
@@ -28,7 +29,7 @@ func getNodeMetrics(w http.ResponseWriter, r *http.Request) {
 	metrics, err := logic.GetMetrics(nodeID)
 	if err != nil {
 		logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of node", nodeID, err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -49,7 +50,7 @@ func getNetworkNodesMetrics(w http.ResponseWriter, r *http.Request) {
 	networkNodes, err := logic.GetNetworkNodes(network)
 	if err != nil {
 		logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of all nodes in network", network, err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
@@ -79,7 +80,7 @@ func getAllMetrics(w http.ResponseWriter, r *http.Request) {
 	allNodes, err := logic.GetAllNodes()
 	if err != nil {
 		logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of all nodes on server", err.Error())
-		returnErrorResponse(w, r, formatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 

+ 54 - 0
ee/initialize.go

@@ -0,0 +1,54 @@
+//go:build ee
+// +build ee
+
+package ee
+
+import (
+	controller "github.com/gravitl/netmaker/controllers"
+	"github.com/gravitl/netmaker/ee/ee_controllers"
+	"github.com/gravitl/netmaker/logger"
+	"github.com/gravitl/netmaker/logic"
+	"github.com/gravitl/netmaker/models"
+)
+
+// InitEE - Initialize EE Logic
+func InitEE() {
+	SetIsEnterprise()
+	models.SetLogo(retrieveEELogo())
+	controller.HttpHandlers = append(controller.HttpHandlers, ee_controllers.MetricHandlers)
+	logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
+		// == License Handling ==
+		ValidateLicense()
+		if Limits.FreeTier {
+			logger.Log(0, "proceeding with Free Tier license")
+		} else {
+			logger.Log(0, "proceeding with Paid Tier license")
+		}
+		// == End License Handling ==
+		AddLicenseHooks()
+	})
+}
+
+func setControllerLimits() {
+	logic.Node_Limit = Limits.Nodes
+	logic.Users_Limit = Limits.Users
+	logic.Clients_Limit = Limits.Clients
+	logic.Free_Tier = Limits.FreeTier
+	logic.Is_EE = true
+}
+
+func retrieveEELogo() string {
+	return `              
+ __   __     ______     ______   __    __     ______     __  __     ______     ______    
+/\ "-.\ \   /\  ___\   /\__  _\ /\ "-./  \   /\  __ \   /\ \/ /    /\  ___\   /\  == \   
+\ \ \-.  \  \ \  __\   \/_/\ \/ \ \ \-./\ \  \ \  __ \  \ \  _"-.  \ \  __\   \ \  __<   
+ \ \_\\"\_\  \ \_____\    \ \_\  \ \_\ \ \_\  \ \_\ \_\  \ \_\ \_\  \ \_____\  \ \_\ \_\ 
+  \/_/ \/_/   \/_____/     \/_/   \/_/  \/_/   \/_/\/_/   \/_/\/_/   \/_____/   \/_/ /_/ 
+                                                                                         																							 
+                                   ___    ___   ____                        
+           ____  ____  ____       / _ \  / _ \ / __ \       ____  ____  ____
+          /___/ /___/ /___/      / ___/ / , _// /_/ /      /___/ /___/ /___/
+         /___/ /___/ /___/      /_/    /_/|_| \____/      /___/ /___/ /___/ 
+                                                                            
+`
+}

+ 63 - 28
ee/license.go

@@ -1,7 +1,11 @@
+//go:build ee
+// +build ee
+
 package ee
 
 import (
 	"bytes"
+	"crypto/rand"
 	"encoding/json"
 	"fmt"
 	"io/ioutil"
@@ -11,11 +15,20 @@ import (
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
-	"github.com/gravitl/netmaker/logic/pro"
 	"github.com/gravitl/netmaker/netclient/ncutils"
 	"github.com/gravitl/netmaker/servercfg"
+	"golang.org/x/crypto/nacl/box"
+)
+
+const (
+	db_license_key = "netmaker-id-key-pair"
 )
 
+type apiServerConf struct {
+	PrivateKey []byte `json:"private_key" binding:"required"`
+	PublicKey  []byte `json:"public_key" binding:"required"`
+}
+
 // AddLicenseHooks - adds the validation and cache clear hooks
 func AddLicenseHooks() {
 	logic.AddHook(ValidateLicense)
@@ -39,7 +52,7 @@ func ValidateLicense() error {
 		logger.FatalLog(errValidation.Error())
 	}
 
-	tempPubKey, tempPrivKey, err := pro.FetchApiServerKeys()
+	tempPubKey, tempPrivKey, err := FetchApiServerKeys()
 	if err != nil {
 		logger.FatalLog(errValidation.Error())
 	}
@@ -88,11 +101,59 @@ func ValidateLicense() error {
 	if Limits.FreeTier {
 		Limits.Networks = 3
 	}
+	setControllerLimits()
 
 	logger.Log(0, "License validation succeeded!")
 	return nil
 }
 
+// FetchApiServerKeys - fetches netmaker license keys for identification
+// as well as secure communication with API
+// if none present, it generates a new pair
+func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
+	var returnData = apiServerConf{}
+	currentData, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, db_license_key)
+	if err != nil && !database.IsEmptyRecord(err) {
+		return nil, nil, err
+	} else if database.IsEmptyRecord(err) { // need to generate a new identifier pair
+		pub, priv, err = box.GenerateKey(rand.Reader)
+		if err != nil {
+			return nil, nil, err
+		}
+		pubBytes, err := ncutils.ConvertKeyToBytes(pub)
+		if err != nil {
+			return nil, nil, err
+		}
+		privBytes, err := ncutils.ConvertKeyToBytes(priv)
+		if err != nil {
+			return nil, nil, err
+		}
+		returnData.PrivateKey = privBytes
+		returnData.PublicKey = pubBytes
+		record, err := json.Marshal(&returnData)
+		if err != nil {
+			return nil, nil, err
+		}
+		if err = database.Insert(db_license_key, string(record), database.SERVERCONF_TABLE_NAME); err != nil {
+			return nil, nil, err
+		}
+	} else {
+		if err = json.Unmarshal([]byte(currentData), &returnData); err != nil {
+			return nil, nil, err
+		}
+		priv, err = ncutils.ConvertBytesToKey(returnData.PrivateKey)
+		if err != nil {
+			return nil, nil, err
+		}
+		pub, err = ncutils.ConvertBytesToKey(returnData.PublicKey)
+		if err != nil {
+			return nil, nil, err
+		}
+	}
+
+	return pub, priv, nil
+}
+
 func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {
 	decodedPubKey := base64decode(licensePubKeyEncoded)
 	return ncutils.ConvertBytesToKey(decodedPubKey)
@@ -179,32 +240,6 @@ func ClearLicenseCache() error {
 	return database.DeleteRecord(database.CACHE_TABLE_NAME, license_cache_key)
 }
 
-// AddServerIDIfNotPresent - add's current server ID to DB if not present
-func AddServerIDIfNotPresent() error {
-	currentNodeID := servercfg.GetNodeID()
-	currentServerIDs := serverIDs{}
-
-	record, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, server_id_key)
-	if err != nil && !database.IsEmptyRecord(err) {
-		return err
-	} else if err == nil {
-		if err = json.Unmarshal([]byte(record), &currentServerIDs); err != nil {
-			return err
-		}
-	}
-
-	if !logic.StringSliceContains(currentServerIDs.ServerIDs, currentNodeID) {
-		currentServerIDs.ServerIDs = append(currentServerIDs.ServerIDs, currentNodeID)
-		data, err := json.Marshal(&currentServerIDs)
-		if err != nil {
-			return err
-		}
-		return database.Insert(server_id_key, string(data), database.SERVERCONF_TABLE_NAME)
-	}
-
-	return nil
-}
-
 func getServerCount() int {
 	if record, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, server_id_key); err == nil {
 		currentServerIDs := serverIDs{}

+ 1 - 1
ee/util.go

@@ -49,6 +49,6 @@ func getCurrentServerLimit() (limits LicenseLimits) {
 	if err == nil {
 		limits.Users = len(users)
 	}
-	limits.Servers = getServerCount()
+	limits.Servers = logic.GetServerCount()
 	return
 }

+ 0 - 11
functions/helpers.go

@@ -8,17 +8,6 @@ import (
 	"github.com/gravitl/netmaker/models"
 )
 
-// NetworkExists - check if network exists
-func NetworkExists(name string) (bool, error) {
-
-	var network string
-	var err error
-	if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil {
-		return false, err
-	}
-	return len(network) > 0, nil
-}
-
 // NameInDNSCharSet - name in dns char set
 func NameInDNSCharSet(name string) bool {
 

+ 2 - 2
functions/helpers_test.go

@@ -26,7 +26,7 @@ func TestNetworkExists(t *testing.T) {
 	}
 	database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID)
 	defer database.CloseDB()
-	exists, err := NetworkExists(testNetwork.NetID)
+	exists, err := logic.NetworkExists(testNetwork.NetID)
 	if err == nil {
 		t.Fatalf("expected error, received nil")
 	}
@@ -38,7 +38,7 @@ func TestNetworkExists(t *testing.T) {
 	if err != nil {
 		t.Fatalf("failed to save test network in databse: %s", err)
 	}
-	exists, err = NetworkExists(testNetwork.NetID)
+	exists, err = logic.NetworkExists(testNetwork.NetID)
 	if err != nil {
 		t.Fatalf("expected nil, received err: %s", err)
 	}

+ 1 - 1
logic/auth.go

@@ -99,7 +99,7 @@ func CreateUser(user models.User) (models.User, error) {
 
 	tokenString, _ := CreateProUserJWT(user.UserName, user.Networks, user.Groups, user.IsAdmin)
 	if tokenString == "" {
-		// returnErrorResponse(w, r, errorResponse)
+		// logic.ReturnErrorResponse(w, r, errorResponse)
 		return user, err
 	}
 

+ 7 - 4
controllers/response.go → logic/errors.go

@@ -1,4 +1,4 @@
-package controller
+package logic
 
 import (
 	"encoding/json"
@@ -8,7 +8,8 @@ import (
 	"github.com/gravitl/netmaker/models"
 )
 
-func formatError(err error, errType string) models.ErrorResponse {
+// FormatError - takes ErrorResponse and uses correct code
+func FormatError(err error, errType string) models.ErrorResponse {
 
 	var status = http.StatusInternalServerError
 	switch errType {
@@ -33,7 +34,8 @@ func formatError(err error, errType string) models.ErrorResponse {
 	return response
 }
 
-func returnSuccessResponse(response http.ResponseWriter, request *http.Request, message string) {
+// ReturnSuccessResponse - processes message and adds header
+func ReturnSuccessResponse(response http.ResponseWriter, request *http.Request, message string) {
 	var httpResponse models.SuccessResponse
 	httpResponse.Code = http.StatusOK
 	httpResponse.Message = message
@@ -42,7 +44,8 @@ func returnSuccessResponse(response http.ResponseWriter, request *http.Request,
 	json.NewEncoder(response).Encode(httpResponse)
 }
 
-func returnErrorResponse(response http.ResponseWriter, request *http.Request, errorMessage models.ErrorResponse) {
+// ReturnErrorResponse - processes error and adds header
+func ReturnErrorResponse(response http.ResponseWriter, request *http.Request, errorMessage models.ErrorResponse) {
 	httpResponse := &models.ErrorResponse{Code: errorMessage.Code, Message: errorMessage.Message}
 	jsonResponse, err := json.Marshal(httpResponse)
 	if err != nil {

+ 0 - 0
logic/pro/metrics/metrics.go → logic/metrics/metrics.go


+ 12 - 1
logic/networks.go

@@ -96,7 +96,7 @@ func CreateNetwork(network models.Network) (models.Network, error) {
 
 	err := ValidateNetwork(&network, false)
 	if err != nil {
-		//returnErrorResponse(w, r, formatError(err, "badrequest"))
+		//logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return models.Network{}, err
 	}
 
@@ -656,6 +656,17 @@ func SaveNetwork(network *models.Network) error {
 	return nil
 }
 
+// NetworkExists - check if network exists
+func NetworkExists(name string) (bool, error) {
+
+	var network string
+	var err error
+	if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil {
+		return false, err
+	}
+	return len(network) > 0, nil
+}
+
 // == Private ==
 
 func networkNodesUpdateAction(networkName string, action string) error {

+ 1 - 1
logic/nodes.go

@@ -311,7 +311,7 @@ func CreateNode(node *models.Node) error {
 	//Create a JWT for the node
 	tokenString, _ := CreateJWT(node.ID, node.MacAddress, node.Network)
 	if tokenString == "" {
-		//returnErrorResponse(w, r, errorResponse)
+		//logic.ReturnErrorResponse(w, r, errorResponse)
 		return err
 	}
 	err = ValidateNode(node, false)

+ 0 - 66
logic/pro/license.go

@@ -1,66 +0,0 @@
-package pro
-
-import (
-	"crypto/rand"
-	"encoding/json"
-
-	"github.com/gravitl/netmaker/database"
-	"github.com/gravitl/netmaker/netclient/ncutils"
-	"golang.org/x/crypto/nacl/box"
-)
-
-const (
-	db_license_key = "netmaker-id-key-pair"
-)
-
-type apiServerConf struct {
-	PrivateKey []byte `json:"private_key" binding:"required"`
-	PublicKey  []byte `json:"public_key" binding:"required"`
-}
-
-// FetchApiServerKeys - fetches netmaker license keys for identification
-// as well as secure communication with API
-// if none present, it generates a new pair
-func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
-	var returnData = apiServerConf{}
-	currentData, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, db_license_key)
-	if err != nil && !database.IsEmptyRecord(err) {
-		return nil, nil, err
-	} else if database.IsEmptyRecord(err) { // need to generate a new identifier pair
-		pub, priv, err = box.GenerateKey(rand.Reader)
-		if err != nil {
-			return nil, nil, err
-		}
-		pubBytes, err := ncutils.ConvertKeyToBytes(pub)
-		if err != nil {
-			return nil, nil, err
-		}
-		privBytes, err := ncutils.ConvertKeyToBytes(priv)
-		if err != nil {
-			return nil, nil, err
-		}
-		returnData.PrivateKey = privBytes
-		returnData.PublicKey = pubBytes
-		record, err := json.Marshal(&returnData)
-		if err != nil {
-			return nil, nil, err
-		}
-		if err = database.Insert(db_license_key, string(record), database.SERVERCONF_TABLE_NAME); err != nil {
-			return nil, nil, err
-		}
-	} else {
-		if err = json.Unmarshal([]byte(currentData), &returnData); err != nil {
-			return nil, nil, err
-		}
-		priv, err = ncutils.ConvertBytesToKey(returnData.PrivateKey)
-		if err != nil {
-			return nil, nil, err
-		}
-		pub, err = ncutils.ConvertBytesToKey(returnData.PublicKey)
-		if err != nil {
-			return nil, nil, err
-		}
-	}
-
-	return pub, priv, nil
-}

+ 2 - 2
logic/pro/networks_test.go

@@ -58,7 +58,7 @@ func TestNetworkProSettings(t *testing.T) {
 		}
 		AddProNetDefaults(&network)
 		assert.NotNil(t, network.ProSettings)
-		assert.Nil(t, network.ProSettings.AllowedGroups)
-		assert.Nil(t, network.ProSettings.AllowedUsers)
+		assert.Equal(t, len(network.ProSettings.AllowedGroups), 1)
+		assert.Equal(t, len(network.ProSettings.AllowedUsers), 0)
 	})
 }

+ 33 - 30
controllers/security.go → logic/security.go

@@ -1,4 +1,4 @@
-package controller
+package logic
 
 import (
 	"encoding/json"
@@ -7,8 +7,6 @@ import (
 
 	"github.com/gorilla/mux"
 	"github.com/gravitl/netmaker/database"
-	"github.com/gravitl/netmaker/functions"
-	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic/pro"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models/promodels"
@@ -16,16 +14,20 @@ import (
 )
 
 const (
+	// ALL_NETWORK_ACCESS - represents all networks
+	ALL_NETWORK_ACCESS = "THIS_USER_HAS_ALL"
+
 	master_uname     = "masteradministrator"
-	unauthorized_msg = "unauthorized"
-	unauthorized_err = models.Error(unauthorized_msg)
+	Unauthorized_Msg = "unauthorized"
+	Unauthorized_Err = models.Error(Unauthorized_Msg)
 )
 
-func securityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
+// SecurityCheck - Check if user has appropriate permissions
+func SecurityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
 
 	return func(w http.ResponseWriter, r *http.Request) {
 		var errorResponse = models.ErrorResponse{
-			Code: http.StatusUnauthorized, Message: unauthorized_msg,
+			Code: http.StatusUnauthorized, Message: Unauthorized_Msg,
 		}
 
 		var params = mux.Vars(r)
@@ -44,14 +46,14 @@ func securityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
 		if len(networkName) == 0 {
 			networkName = params["network"]
 		}
-		networks, username, err := SecurityCheck(reqAdmin, networkName, bearerToken)
+		networks, username, err := UserPermissions(reqAdmin, networkName, bearerToken)
 		if err != nil {
-			returnErrorResponse(w, r, errorResponse)
+			ReturnErrorResponse(w, r, errorResponse)
 			return
 		}
 		networksJson, err := json.Marshal(&networks)
 		if err != nil {
-			returnErrorResponse(w, r, errorResponse)
+			ReturnErrorResponse(w, r, errorResponse)
 			return
 		}
 		r.Header.Set("user", username)
@@ -60,7 +62,8 @@ func securityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
 	}
 }
 
-func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.HandlerFunc {
+// NetUserSecurityCheck - Check if network user has appropriate permissions
+func NetUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.HandlerFunc {
 	return func(w http.ResponseWriter, r *http.Request) {
 		var errorResponse = models.ErrorResponse{
 			Code: http.StatusUnauthorized, Message: "unauthorized",
@@ -77,7 +80,7 @@ func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.Handl
 		var authToken = ""
 
 		if len(tokenSplit) < 2 {
-			returnErrorResponse(w, r, errorResponse)
+			ReturnErrorResponse(w, r, errorResponse)
 			return
 		} else {
 			authToken = tokenSplit[1]
@@ -91,9 +94,9 @@ func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.Handl
 			return
 		}
 
-		userName, _, isadmin, err := logic.VerifyUserToken(authToken)
+		userName, _, isadmin, err := VerifyUserToken(authToken)
 		if err != nil {
-			returnErrorResponse(w, r, errorResponse)
+			ReturnErrorResponse(w, r, errorResponse)
 			return
 		}
 		r.Header.Set("user", userName)
@@ -113,15 +116,15 @@ func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.Handl
 			}
 			u, err := pro.GetNetworkUser(network, promodels.NetworkUserID(userName))
 			if err != nil {
-				returnErrorResponse(w, r, errorResponse)
+				ReturnErrorResponse(w, r, errorResponse)
 				return
 			}
 			if u.AccessLevel > necessaryAccess {
-				returnErrorResponse(w, r, errorResponse)
+				ReturnErrorResponse(w, r, errorResponse)
 				return
 			}
 		} else if netUserName != userName {
-			returnErrorResponse(w, r, errorResponse)
+			ReturnErrorResponse(w, r, errorResponse)
 			return
 		}
 
@@ -129,14 +132,14 @@ func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.Handl
 	}
 }
 
-// SecurityCheck - checks token stuff
-func SecurityCheck(reqAdmin bool, netname string, token string) ([]string, string, error) {
+// UserPermissions - checks token stuff
+func UserPermissions(reqAdmin bool, netname string, token string) ([]string, string, error) {
 	var tokenSplit = strings.Split(token, " ")
 	var authToken = ""
 	userNetworks := []string{}
 
 	if len(tokenSplit) < 2 {
-		return userNetworks, "", unauthorized_err
+		return userNetworks, "", Unauthorized_Err
 	} else {
 		authToken = tokenSplit[1]
 	}
@@ -144,12 +147,12 @@ func SecurityCheck(reqAdmin bool, netname string, token string) ([]string, strin
 	if authenticateMaster(authToken) {
 		return []string{ALL_NETWORK_ACCESS}, master_uname, nil
 	}
-	username, networks, isadmin, err := logic.VerifyUserToken(authToken)
+	username, networks, isadmin, err := VerifyUserToken(authToken)
 	if err != nil {
-		return nil, username, unauthorized_err
+		return nil, username, Unauthorized_Err
 	}
 	if !isadmin && reqAdmin {
-		return nil, username, unauthorized_err
+		return nil, username, Unauthorized_Err
 	}
 	userNetworks = networks
 	if isadmin {
@@ -157,10 +160,10 @@ func SecurityCheck(reqAdmin bool, netname string, token string) ([]string, strin
 	}
 	// check network admin access
 	if len(netname) > 0 && (!authenticateNetworkUser(netname, userNetworks) || len(userNetworks) == 0) {
-		return nil, username, unauthorized_err
+		return nil, username, Unauthorized_Err
 	}
 	if !pro.IsUserNetAdmin(netname, username) {
-		return nil, "", unauthorized_err
+		return nil, "", Unauthorized_Err
 	}
 	return userNetworks, username, nil
 }
@@ -171,11 +174,11 @@ func authenticateMaster(tokenString string) bool {
 }
 
 func authenticateNetworkUser(network string, userNetworks []string) bool {
-	networkexists, err := functions.NetworkExists(network)
+	networkexists, err := NetworkExists(network)
 	if (err != nil && !database.IsEmptyRecord(err)) || !networkexists {
 		return false
 	}
-	return logic.StringSliceContains(userNetworks, network)
+	return StringSliceContains(userNetworks, network)
 }
 
 //Consider a more secure way of setting master key
@@ -187,15 +190,15 @@ func authenticateDNSToken(tokenString string) bool {
 	return tokens[1] == servercfg.GetDNSKey()
 }
 
-func continueIfUserMatch(next http.Handler) http.HandlerFunc {
+func ContinueIfUserMatch(next http.Handler) http.HandlerFunc {
 	return func(w http.ResponseWriter, r *http.Request) {
 		var errorResponse = models.ErrorResponse{
-			Code: http.StatusUnauthorized, Message: unauthorized_msg,
+			Code: http.StatusUnauthorized, Message: Unauthorized_Msg,
 		}
 		var params = mux.Vars(r)
 		var requestedUser = params["username"]
 		if requestedUser != r.Header.Get("user") {
-			returnErrorResponse(w, r, errorResponse)
+			ReturnErrorResponse(w, r, errorResponse)
 			return
 		}
 		next.ServeHTTP(w, r)

+ 9 - 0
logic/server.go

@@ -18,6 +18,8 @@ import (
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
 
+var EnterpriseCheckFuncs []interface{}
+
 // == Join, Checkin, and Leave for Server ==
 
 // KUBERNETES_LISTEN_PORT - starting port for Kubernetes in order to use NodePort range
@@ -164,6 +166,13 @@ func ServerJoin(networkSettings *models.Network) (models.Node, error) {
 	return *node, nil
 }
 
+// EnterpriseCheck - Runs enterprise functions if presented
+func EnterpriseCheck() {
+	for _, check := range EnterpriseCheckFuncs {
+		check.(func())()
+	}
+}
+
 // ServerUpdate - updates the server
 // replaces legacy Checkin code
 func ServerUpdate(serverNode *models.Node, ifaceDelta bool) error {

+ 15 - 0
logic/serverconf.go

@@ -6,6 +6,21 @@ import (
 	"github.com/gravitl/netmaker/database"
 )
 
+var (
+	// Node_Limit - dummy var for community
+	Node_Limit = 1000000000
+	// Networks_Limit - dummy var for community
+	Networks_Limit = 1000000000
+	// Users_Limit - dummy var for community
+	Users_Limit = 1000000000
+	// Clients_Limit - dummy var for community
+	Clients_Limit = 1000000000
+	// Free_Tier - specifies if free tier
+	Free_Tier = false
+	// Is_EE - specifies if enterprise
+	Is_EE = false
+)
+
 // constant for database key for storing server ids
 const server_id_key = "nm-server-id"
 

+ 2 - 14
main.go

@@ -20,7 +20,6 @@ import (
 	"github.com/gravitl/netmaker/config"
 	controller "github.com/gravitl/netmaker/controllers"
 	"github.com/gravitl/netmaker/database"
-	"github.com/gravitl/netmaker/ee"
 	"github.com/gravitl/netmaker/functions"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
@@ -76,7 +75,7 @@ func initialize() { // Client Mode Prereq Check
 		logger.FatalLog("Error connecting to database")
 	}
 	logger.Log(0, "database successfully connected")
-	if err = ee.AddServerIDIfNotPresent(); err != nil {
+	if err = logic.AddServerIDIfNotPresent(); err != nil {
 		logger.Log(1, "failed to save server ID")
 	}
 
@@ -91,18 +90,7 @@ func initialize() { // Client Mode Prereq Check
 		logger.Log(1, "Timer error occurred: ", err.Error())
 	}
 
-	if ee.IsEnterprise() {
-		// == License Handling ==
-		ee.ValidateLicense()
-		if ee.Limits.FreeTier {
-			logger.Log(0, "proceeding with Free Tier license")
-		} else {
-			logger.Log(0, "proceeding with Paid Tier license")
-		}
-		// == End License Handling ==
-
-		ee.AddLicenseHooks()
-	}
+	logic.EnterpriseCheck()
 
 	var authProvider = auth.InitializeAuthProvider()
 	if authProvider != "" {

+ 1 - 19
main_ee.go

@@ -5,26 +5,8 @@ package main
 
 import (
 	"github.com/gravitl/netmaker/ee"
-	"github.com/gravitl/netmaker/models"
 )
 
 func init() {
-	ee.SetIsEnterprise()
-	models.SetLogo(retrieveEELogo())
-}
-
-func retrieveEELogo() string {
-	return `              
- __   __     ______     ______   __    __     ______     __  __     ______     ______    
-/\ "-.\ \   /\  ___\   /\__  _\ /\ "-./  \   /\  __ \   /\ \/ /    /\  ___\   /\  == \   
-\ \ \-.  \  \ \  __\   \/_/\ \/ \ \ \-./\ \  \ \  __ \  \ \  _"-.  \ \  __\   \ \  __<   
- \ \_\\"\_\  \ \_____\    \ \_\  \ \_\ \ \_\  \ \_\ \_\  \ \_\ \_\  \ \_____\  \ \_\ \_\ 
-  \/_/ \/_/   \/_____/     \/_/   \/_/  \/_/   \/_/\/_/   \/_/\/_/   \/_____/   \/_/ /_/ 
-                                                                                         																							 
-                                   ___    ___   ____                        
-           ____  ____  ____       / _ \  / _ \ / __ \       ____  ____  ____
-          /___/ /___/ /___/      / ___/ / , _// /_/ /      /___/ /___/ /___/
-         /___/ /___/ /___/      /_/    /_/|_| \____/      /___/ /___/ /___/ 
-                                                                            
-`
+	ee.InitEE()
 }

+ 1 - 2
mq/handlers.go

@@ -7,7 +7,6 @@ import (
 
 	mqtt "github.com/eclipse/paho.mqtt.golang"
 	"github.com/gravitl/netmaker/database"
-	"github.com/gravitl/netmaker/ee"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
@@ -99,7 +98,7 @@ func UpdateNode(client mqtt.Client, msg mqtt.Message) {
 
 // UpdateMetrics  message Handler -- handles updates from client nodes for metrics
 func UpdateMetrics(client mqtt.Client, msg mqtt.Message) {
-	if ee.IsEnterprise() {
+	if logic.Is_EE {
 		go func() {
 			id, err := getID(msg.Topic())
 			if err != nil {

+ 2 - 3
mq/publishers.go

@@ -6,10 +6,9 @@ import (
 	"fmt"
 	"time"
 
-	"github.com/gravitl/netmaker/ee"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
-	"github.com/gravitl/netmaker/logic/pro/metrics"
+	"github.com/gravitl/netmaker/logic/metrics"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/servercfg"
 	"github.com/gravitl/netmaker/serverctl"
@@ -185,7 +184,7 @@ func ServerStartNotify() error {
 
 // function to collect and store metrics for server nodes
 func collectServerMetrics(networks []models.Network) {
-	if !ee.IsEnterprise() {
+	if !logic.Is_EE {
 		return
 	}
 	if len(networks) > 0 {

+ 1 - 1
netclient/functions/mqpublish.go

@@ -15,7 +15,7 @@ import (
 
 	"github.com/cloverstd/tcping/ping"
 	"github.com/gravitl/netmaker/logger"
-	"github.com/gravitl/netmaker/logic/pro/metrics"
+	"github.com/gravitl/netmaker/logic/metrics"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/netclient/auth"
 	"github.com/gravitl/netmaker/netclient/config"