소스 검색

fix failover middleware mgmt

abhishek9686 1 년 전
부모
커밋
2c6403543a
1개의 변경된 파일11개의 추가작업 그리고 9개의 파일을 삭제
  1. 11 9
      controllers/middleware.go

+ 11 - 9
controllers/middleware.go

@@ -7,6 +7,7 @@ import (
 
 	"github.com/gorilla/mux"
 	"github.com/gravitl/netmaker/logger"
+	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 )
 
@@ -20,14 +21,11 @@ func userMiddleWare(handler http.Handler) http.Handler {
 		r.Header.Set("NET_ID", params["network"])
 		if strings.Contains(r.URL.Path, "hosts") || strings.Contains(r.URL.Path, "nodes") {
 			r.Header.Set("TARGET_RSRC", models.HostRsrc.String())
-			r.Header.Set("RSRC_TYPE", models.HostRsrc.String())
 		}
 		if strings.Contains(r.URL.Path, "dns") {
-			r.Header.Set("RSRC_TYPE", models.DnsRsrc.String())
 			r.Header.Set("TARGET_RSRC", models.DnsRsrc.String())
 		}
 		if strings.Contains(r.URL.Path, "users") {
-			r.Header.Set("RSRC_TYPE", models.UserRsrc.String())
 			r.Header.Set("TARGET_RSRC", models.UserRsrc.String())
 		}
 		if strings.Contains(r.URL.Path, "ingress") {
@@ -36,27 +34,23 @@ func userMiddleWare(handler http.Handler) http.Handler {
 		if strings.Contains(r.URL.Path, "createrelay") || strings.Contains(r.URL.Path, "deleterelay") {
 			r.Header.Set("TARGET_RSRC", models.RelayRsrc.String())
 		}
+
 		if strings.Contains(r.URL.Path, "gateway") {
 			r.Header.Set("TARGET_RSRC", models.EgressGwRsrc.String())
 		}
 		if strings.Contains(r.URL.Path, "networks") {
 			r.Header.Set("TARGET_RSRC", models.NetworkRsrc.String())
-			r.Header.Set("RSRC_TYPE", models.NetworkRsrc.String())
 		}
 		if strings.Contains(r.URL.Path, "acls") {
 			r.Header.Set("TARGET_RSRC", models.AclRsrc.String())
-			r.Header.Set("RSRC_TYPE", models.NetworkRsrc.String())
 		}
 		if strings.Contains(r.URL.Path, "extclients") {
 			r.Header.Set("TARGET_RSRC", models.ExtClientsRsrc.String())
-			r.Header.Set("RSRC_TYPE", models.ExtClientsRsrc.String())
 		}
 		if strings.Contains(r.URL.Path, "enrollment-keys") {
 			r.Header.Set("TARGET_RSRC", models.EnrollmentKeysRsrc.String())
-			r.Header.Set("RSRC_TYPE", models.EnrollmentKeysRsrc.String())
 		}
 		if strings.Contains(r.URL.Path, "metrics") {
-			r.Header.Set("RSRC_TYPE", models.MetricRsrc.String())
 			r.Header.Set("TARGET_RSRC", models.MetricRsrc.String())
 		}
 		if keyID, ok := params["keyID"]; ok {
@@ -65,6 +59,13 @@ func userMiddleWare(handler http.Handler) http.Handler {
 		if nodeID, ok := params["nodeid"]; ok && r.Header.Get("TARGET_RSRC") != models.ExtClientsRsrc.String() {
 			r.Header.Set("TARGET_RSRC_ID", nodeID)
 		}
+		if strings.Contains(r.URL.Path, "failover") {
+			r.Header.Set("TARGET_RSRC", models.FailOverRsrc.String())
+			nodeID := r.Header.Get("TARGET_RSRC_ID")
+			node, _ := logic.GetNodeByID(nodeID)
+			r.Header.Set("NET_ID", node.Network)
+
+		}
 		if hostID, ok := params["hostid"]; ok {
 			r.Header.Set("TARGET_RSRC_ID", hostID)
 		}
@@ -86,12 +87,13 @@ func userMiddleWare(handler http.Handler) http.Handler {
 				r.Header.Set("TARGET_RSRC_ID", username)
 			}
 		}
-
 		if r.Header.Get("NET_ID") == "" && (r.Header.Get("TARGET_RSRC_ID") == "" ||
 			r.Header.Get("TARGET_RSRC") == models.EnrollmentKeysRsrc.String() ||
 			r.Header.Get("TARGET_RSRC") == models.UserRsrc.String()) {
 			r.Header.Set("IS_GLOBAL_ACCESS", "yes")
 		}
+
+		r.Header.Set("RSRC_TYPE", r.Header.Get("TARGET_RSRC"))
 		logger.Log(0, "URL ------> ", r.URL.String())
 		handler.ServeHTTP(w, r)
 	})