瀏覽代碼

restrict user access to resources on server

Abhishek Kondur 2 年之前
父節點
當前提交
595911bccb
共有 9 個文件被更改,包括 91 次插入66 次删除
  1. 6 6
      controllers/dns.go
  2. 1 1
      controllers/enrollmentkeys.go
  3. 52 22
      controllers/ext_client.go
  4. 1 1
      controllers/hosts.go
  5. 9 24
      controllers/network.go
  6. 3 3
      controllers/node.go
  7. 5 6
      controllers/user.go
  8. 14 0
      logic/gateway.go
  9. 0 3
      logic/security.go

+ 6 - 6
controllers/dns.go

@@ -17,12 +17,12 @@ import (
 func dnsHandlers(r *mux.Router) {
 
 	r.HandleFunc("/api/dns", logic.SecurityCheck(true, http.HandlerFunc(getAllDNS))).Methods(http.MethodGet)
-	r.HandleFunc("/api/dns/adm/{network}/nodes", logic.SecurityCheck(false, http.HandlerFunc(getNodeDNS))).Methods(http.MethodGet)
-	r.HandleFunc("/api/dns/adm/{network}/custom", logic.SecurityCheck(false, http.HandlerFunc(getCustomDNS))).Methods(http.MethodGet)
-	r.HandleFunc("/api/dns/adm/{network}", logic.SecurityCheck(false, http.HandlerFunc(getDNS))).Methods(http.MethodGet)
-	r.HandleFunc("/api/dns/{network}", logic.SecurityCheck(false, http.HandlerFunc(createDNS))).Methods(http.MethodPost)
-	r.HandleFunc("/api/dns/adm/pushdns", logic.SecurityCheck(false, http.HandlerFunc(pushDNS))).Methods(http.MethodPost)
-	r.HandleFunc("/api/dns/{network}/{domain}", logic.SecurityCheck(false, http.HandlerFunc(deleteDNS))).Methods(http.MethodDelete)
+	r.HandleFunc("/api/dns/adm/{network}/nodes", logic.SecurityCheck(true, http.HandlerFunc(getNodeDNS))).Methods(http.MethodGet)
+	r.HandleFunc("/api/dns/adm/{network}/custom", logic.SecurityCheck(true, http.HandlerFunc(getCustomDNS))).Methods(http.MethodGet)
+	r.HandleFunc("/api/dns/adm/{network}", logic.SecurityCheck(true, http.HandlerFunc(getDNS))).Methods(http.MethodGet)
+	r.HandleFunc("/api/dns/{network}", logic.SecurityCheck(true, http.HandlerFunc(createDNS))).Methods(http.MethodPost)
+	r.HandleFunc("/api/dns/adm/pushdns", logic.SecurityCheck(true, http.HandlerFunc(pushDNS))).Methods(http.MethodPost)
+	r.HandleFunc("/api/dns/{network}/{domain}", logic.SecurityCheck(true, http.HandlerFunc(deleteDNS))).Methods(http.MethodDelete)
 }
 
 // swagger:route GET /api/dns/adm/{network}/nodes dns getNodeDNS

+ 1 - 1
controllers/enrollmentkeys.go

@@ -17,7 +17,7 @@ import (
 
 func enrollmentKeyHandlers(r *mux.Router) {
 	r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(createEnrollmentKey))).Methods(http.MethodPost)
-	r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(false, http.HandlerFunc(getEnrollmentKeys))).Methods(http.MethodGet)
+	r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys))).Methods(http.MethodGet)
 	r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey))).Methods(http.MethodDelete)
 	r.HandleFunc("/api/v1/host/register/{token}", http.HandlerFunc(handleHostRegister)).Methods(http.MethodPost)
 }

+ 52 - 22
controllers/ext_client.go

@@ -23,13 +23,13 @@ import (
 
 func extClientHandlers(r *mux.Router) {
 
-	r.HandleFunc("/api/extclients", logic.SecurityCheck(false, http.HandlerFunc(getAllExtClients))).Methods(http.MethodGet)
-	r.HandleFunc("/api/extclients/{network}", logic.SecurityCheck(false, http.HandlerFunc(getNetworkExtClients))).Methods(http.MethodGet)
+	r.HandleFunc("/api/extclients", logic.SecurityCheck(true, http.HandlerFunc(getAllExtClients))).Methods(http.MethodGet)
+	r.HandleFunc("/api/extclients/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkExtClients))).Methods(http.MethodGet)
 	r.HandleFunc("/api/extclients/{network}/{clientid}", logic.SecurityCheck(false, http.HandlerFunc(getExtClient))).Methods(http.MethodGet)
-	r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", getExtClientConf).Methods(http.MethodGet)
-	r.HandleFunc("/api/extclients/{network}/{clientid}", updateExtClient).Methods(http.MethodPut)
-	r.HandleFunc("/api/extclients/{network}/{clientid}", deleteExtClient).Methods(http.MethodDelete)
-	r.HandleFunc("/api/extclients/{network}/{nodeid}", createExtClient).Methods(http.MethodPost)
+	r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", logic.SecurityCheck(false, http.HandlerFunc(getExtClientConf))).Methods(http.MethodGet)
+	r.HandleFunc("/api/extclients/{network}/{clientid}", logic.SecurityCheck(false, http.HandlerFunc(updateExtClient))).Methods(http.MethodPut)
+	r.HandleFunc("/api/extclients/{network}/{clientid}", logic.SecurityCheck(false, http.HandlerFunc(deleteExtClient))).Methods(http.MethodDelete)
+	r.HandleFunc("/api/extclients/{network}/{nodeid}", logic.SecurityCheck(false, http.HandlerFunc(createExtClient))).Methods(http.MethodPost)
 }
 
 func checkIngressExists(nodeID string) bool {
@@ -99,24 +99,14 @@ func getAllExtClients(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(marshalErr, "internal"))
 		return
 	}
-	clients := []models.ExtClient{}
+
 	var err error
-	if len(networksSlice) > 0 && networksSlice[0] == logic.ALL_NETWORK_ACCESS {
-		clients, err = logic.GetAllExtClients()
-		if err != nil && !database.IsEmptyRecord(err) {
-			logger.Log(0, "failed to get all extclients: ", err.Error())
-			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
-			return
-		}
-	} else {
-		for _, network := range networksSlice {
-			extclients, err := logic.GetNetworkExtClients(network)
-			if err == nil {
-				clients = append(clients, extclients...)
-			}
-		}
+	clients, err := logic.GetAllExtClients()
+	if err != nil && !database.IsEmptyRecord(err) {
+		logger.Log(0, "failed to get all extclients: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
 	}
-
 	//Return all the extclients in JSON format
 	logic.SortExtClient(clients[:])
 	w.WriteHeader(http.StatusOK)
@@ -149,6 +139,14 @@ func getExtClient(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
+	if !logic.IsUserAllowedAccessToExtClient(r.Header.Get("user"), client) {
+		// check if user has access to extclient
+		slog.Error("failed to get extclient", "network", network, "clientID",
+			clientid, "error", errors.New("access is denied"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("access is denied"), "forbidden"))
+		return
+
+	}
 
 	w.WriteHeader(http.StatusOK)
 	json.NewEncoder(w).Encode(client)
@@ -179,6 +177,12 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
+	if !logic.IsUserAllowedAccessToExtClient(r.Header.Get("user"), client) {
+		slog.Error("failed to get extclient", "network", networkid, "clientID",
+			clientid, "error", errors.New("access is denied"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("access is denied"), "forbidden"))
+		return
+	}
 
 	gwnode, err := logic.GetNodeByID(client.IngressGatewayID)
 	if err != nil {
@@ -323,6 +327,17 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
+	caller, err := logic.GetUser(r.Header.Get("user"))
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+	}
+	if !caller.IsAdmin || !caller.IsSuperAdmin {
+		if _, ok := caller.RemoteGwIDs[nodeid]; !ok {
+			slog.Error("failed to create extclient", "error", errors.New("no permission"))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no permission"), "forbidden"))
+			return
+		}
+	}
 
 	var customExtClient models.CustomExtClient
 
@@ -411,12 +426,21 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	clientid := params["clientid"]
+	network := params["network"]
 	oldExtClient, err := logic.GetExtClientByName(clientid)
 	if err != nil {
 		slog.Error("failed to retrieve extclient", "user", r.Header.Get("user"), "id", clientid, "error", err)
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
+	if !logic.IsUserAllowedAccessToExtClient(r.Header.Get("user"), oldExtClient) {
+		// check if user has access to extclient
+		slog.Error("failed to get extclient", "network", network, "clientID",
+			clientid, "error", errors.New("access is denied"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("access is denied"), "forbidden"))
+		return
+
+	}
 	if oldExtClient.ClientID == update.ClientID {
 		if err := validateCustomExtClient(&update, false); err != nil {
 			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
@@ -496,6 +520,12 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
+	if !logic.IsUserAllowedAccessToExtClient(r.Header.Get("user"), extclient) {
+		slog.Error("failed to get extclient", "network", network, "clientID",
+			clientid, "error", errors.New("access is denied"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("access is denied"), "forbidden"))
+		return
+	}
 	ingressnode, err := logic.GetNodeByID(extclient.IngressGatewayID)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),

+ 1 - 1
controllers/hosts.go

@@ -17,7 +17,7 @@ import (
 )
 
 func hostHandlers(r *mux.Router) {
-	r.HandleFunc("/api/hosts", logic.SecurityCheck(false, http.HandlerFunc(getHosts))).Methods(http.MethodGet)
+	r.HandleFunc("/api/hosts", logic.SecurityCheck(true, http.HandlerFunc(getHosts))).Methods(http.MethodGet)
 	r.HandleFunc("/api/hosts/keys", logic.SecurityCheck(true, http.HandlerFunc(updateAllKeys))).Methods(http.MethodPut)
 	r.HandleFunc("/api/hosts/{hostid}/keys", logic.SecurityCheck(true, http.HandlerFunc(updateKeys))).Methods(http.MethodPut)
 	r.HandleFunc("/api/hosts/{hostid}/sync", logic.SecurityCheck(true, http.HandlerFunc(syncHost))).Methods(http.MethodPost)

+ 9 - 24
controllers/network.go

@@ -20,9 +20,9 @@ import (
 )
 
 func networkHandlers(r *mux.Router) {
-	r.HandleFunc("/api/networks", logic.SecurityCheck(false, http.HandlerFunc(getNetworks))).Methods(http.MethodGet)
+	r.HandleFunc("/api/networks", logic.SecurityCheck(true, http.HandlerFunc(getNetworks))).Methods(http.MethodGet)
 	r.HandleFunc("/api/networks", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceNetworks, http.HandlerFunc(createNetwork)))).Methods(http.MethodPost)
-	r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(false, http.HandlerFunc(getNetwork))).Methods(http.MethodGet)
+	r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(true, http.HandlerFunc(getNetwork))).Methods(http.MethodGet)
 	r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(true, http.HandlerFunc(deleteNetwork))).Methods(http.MethodDelete)
 	r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(true, http.HandlerFunc(updateNetwork))).Methods(http.MethodPut)
 	// ACLs
@@ -42,29 +42,14 @@ func networkHandlers(r *mux.Router) {
 //			Responses:
 //				200: getNetworksSliceResponse
 func getNetworks(w http.ResponseWriter, r *http.Request) {
-	networksSlice, marshalErr := getHeaderNetworks(r)
-	if marshalErr != nil {
-		logger.Log(0, r.Header.Get("user"), "error unmarshalling networks: ",
-			marshalErr.Error())
-		logic.ReturnErrorResponse(w, r, logic.FormatError(marshalErr, "badrequest"))
-		return
-	}
-	allnetworks := []models.Network{}
+
 	var err error
-	if len(networksSlice) > 0 && networksSlice[0] == logic.ALL_NETWORK_ACCESS {
-		allnetworks, err = logic.GetNetworks()
-		if err != nil && !database.IsEmptyRecord(err) {
-			logger.Log(0, r.Header.Get("user"), "failed to fetch networks: ", err.Error())
-			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
-			return
-		}
-	} else {
-		for _, network := range networksSlice {
-			netObject, parentErr := logic.GetParentNetwork(network)
-			if parentErr == nil {
-				allnetworks = append(allnetworks, netObject)
-			}
-		}
+
+	allnetworks, err := logic.GetNetworks()
+	if err != nil && !database.IsEmptyRecord(err) {
+		logger.Log(0, r.Header.Get("user"), "failed to fetch networks: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
 	}
 
 	logger.Log(2, r.Header.Get("user"), "fetched networks.")

+ 3 - 3
controllers/node.go

@@ -28,9 +28,9 @@ func nodeHandlers(r *mux.Router) {
 	r.HandleFunc("/api/nodes/{network}/{nodeid}", Authorize(true, true, "node", http.HandlerFunc(deleteNode))).Methods(http.MethodDelete)
 	r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", Authorize(false, true, "user", checkFreeTierLimits(limitChoiceEgress, http.HandlerFunc(createEgressGateway)))).Methods(http.MethodPost)
 	r.HandleFunc("/api/nodes/{network}/{nodeid}/deletegateway", Authorize(false, true, "user", http.HandlerFunc(deleteEgressGateway))).Methods(http.MethodDelete)
-	r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", logic.SecurityCheck(false, checkFreeTierLimits(limitChoiceIngress, http.HandlerFunc(createIngressGateway)))).Methods(http.MethodPost)
-	r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", logic.SecurityCheck(false, http.HandlerFunc(deleteIngressGateway))).Methods(http.MethodDelete)
-	r.HandleFunc("/api/nodes/{network}/{nodeid}/ingress/users", logic.SecurityCheck(false, http.HandlerFunc(IngressGatewayUsers))).Methods(http.MethodGet)
+	r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceIngress, http.HandlerFunc(createIngressGateway)))).Methods(http.MethodPost)
+	r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", logic.SecurityCheck(true, http.HandlerFunc(deleteIngressGateway))).Methods(http.MethodDelete)
+	r.HandleFunc("/api/nodes/{network}/{nodeid}/ingress/users", logic.SecurityCheck(true, http.HandlerFunc(IngressGatewayUsers))).Methods(http.MethodGet)
 	r.HandleFunc("/api/nodes/{network}/{nodeid}", Authorize(true, true, "node", http.HandlerFunc(updateNode))).Methods(http.MethodPost)
 	r.HandleFunc("/api/nodes/adm/{network}/authenticate", authenticate).Methods(http.MethodPost)
 	r.HandleFunc("/api/v1/nodes/migrate", migrate).Methods(http.MethodPost)

+ 5 - 6
controllers/user.go

@@ -544,22 +544,21 @@ func createUser(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	if !caller.IsSuperAdmin && user.IsAdmin {
-		slog.Error("error creating new user: ", "user", user.UserName, "error",
-			errors.New("only superadmin can create admin users"))
+		err = errors.New("only superadmin can create admin users")
+		slog.Error("error creating new user: ", "user", user.UserName, "error", err)
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden"))
 		return
 	}
 	if user.IsSuperAdmin {
-		slog.Error("error creating new user: ", "user", user.UserName, "error",
-			errors.New("additional superadmins cannot be created"))
+		err = errors.New("additional superadmins cannot be created")
+		slog.Error("error creating new user: ", "user", user.UserName, "error", err)
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden"))
 		return
 	}
 
 	err = logic.CreateUser(&user)
 	if err != nil {
-		logger.Log(0, user.UserName, "error creating new user: ",
-			err.Error())
+		logger.Log(0, user.UserName, "error creating new user: ", err.Error())
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}

+ 14 - 0
logic/gateway.go

@@ -229,3 +229,17 @@ func DeleteGatewayExtClients(gatewayID string, networkName string) error {
 	}
 	return nil
 }
+
+// IsUserAllowedAccessToExtClient - checks if user has permission to access extclient
+func IsUserAllowedAccessToExtClient(username string, client models.ExtClient) bool {
+	user, err := GetUser(username)
+	if err != nil {
+		return false
+	}
+	if !user.IsAdmin && !user.IsSuperAdmin {
+		if user.UserName != client.ClientID {
+			return false
+		}
+	}
+	return true
+}

+ 0 - 3
logic/security.go

@@ -10,9 +10,6 @@ import (
 )
 
 const (
-	// ALL_NETWORK_ACCESS - represents all networks
-	ALL_NETWORK_ACCESS = "THIS_USER_HAS_ALL"
-
 	master_uname     = "masteradministrator"
 	Forbidden_Msg    = "forbidden"
 	Forbidden_Err    = models.Error(Forbidden_Msg)