Sfoglia il codice sorgente

Release v1.2.0 (#3764)

* add safe nil check for legacy acls

* fix import cycle issue

* fix user sync by groups in azure

* remove debug logs

* remove autoenabling rac configs

* preserve extclient DNS field
Abhishek Kondur 1 settimana fa
parent
commit
333da2f053
6 ha cambiato i file con 76 aggiunte e 50 eliminazioni
  1. 0 36
      controllers/user.go
  2. 7 0
      logic/clients.go
  3. 3 0
      logic/dns.go
  4. 0 1
      pro/auth/azure-ad.go
  5. 51 9
      pro/idp/azure/azure.go
  6. 15 4
      pro/logic/ext_acls.go

+ 0 - 36
controllers/user.go

@@ -415,42 +415,6 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
 	response.Header().Set("Content-Type", "application/json")
 	response.Write(successJSONResponse)
 
-	go func() {
-		if servercfg.IsPro {
-			// enable all associeated clients for the user
-			clients, err := logic.GetAllExtClients()
-			if err != nil {
-				slog.Error("error getting clients: ", "error", err)
-				return
-			}
-			for _, client := range clients {
-				if client.OwnerID == username && !client.Enabled {
-					slog.Info(
-						fmt.Sprintf(
-							"enabling ext client %s for user %s due to RAC autodisabling feature",
-							client.ClientID,
-							client.OwnerID,
-						),
-					)
-					if newClient, err := logic.ToggleExtClientConnectivity(&client, true); err != nil {
-						slog.Error(
-							"error enabling ext client in RAC autodisable hook",
-							"error",
-							err,
-						)
-						continue // dont return but try for other clients
-					} else {
-						// publish peer update to ingress gateway
-						if ingressNode, err := logic.GetNodeByID(newClient.IngressGatewayID); err == nil {
-							if err = mq.PublishPeerUpdate(false); err != nil {
-								slog.Error("error updating ext clients on", "ingress", ingressNode.ID.String(), "err", err.Error())
-							}
-						}
-					}
-				}
-			}
-		}
-	}()
 }
 
 // @Summary     Validates a user's identity against it's token. This is used by UI before a user performing a critical operation to validate the user's identity.

+ 7 - 0
logic/clients.go

@@ -26,6 +26,10 @@ var (
 	}
 	SetClientDefaultACLs = func(ec *models.ExtClient) error {
 		// allow all on CE
+		if !GetServerSettings().OldAClsSupport {
+			ec.DeniedACLs = make(map[string]struct{})
+			return nil
+		}
 		networkAcls := acls.ACLContainer{}
 		networkAcls, err := networkAcls.Get(acls.ContainerID(ec.Network))
 		if err != nil {
@@ -34,6 +38,9 @@ var (
 		}
 		networkAcls[acls.AclID(ec.ClientID)] = make(acls.ACL)
 		for objId := range networkAcls {
+			if networkAcls[objId] == nil {
+				networkAcls[objId] = make(acls.ACL)
+			}
 			networkAcls[objId][acls.AclID(ec.ClientID)] = acls.Allowed
 			networkAcls[acls.AclID(ec.ClientID)][objId] = acls.Allowed
 		}

+ 3 - 0
logic/dns.go

@@ -226,6 +226,9 @@ func GetGwDNS(node *models.Node) string {
 }
 
 func SetDNSOnWgConfig(gwNode *models.Node, extclient *models.ExtClient) {
+	if extclient.DNS != "" {
+		return
+	}
 	extclient.DNS = GetGwDNS(gwNode)
 }
 

+ 0 - 1
pro/auth/azure-ad.go

@@ -62,7 +62,6 @@ func handleAzureLogin(w http.ResponseWriter, r *http.Request) {
 
 func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
 	var rState, rCode = getStateAndCode(r)
-
 	state, err := logic.GetState(rState)
 	if err != nil {
 		handleOauthNotValid(w)

+ 51 - 9
pro/idp/azure/azure.go

@@ -149,7 +149,7 @@ func (a *Client) GetGroups(filters []string) ([]idp.Group, error) {
 	}
 
 	client := &http.Client{}
-	getGroupsURL := "https://graph.microsoft.com/v1.0/groups?$select=id,displayName&$expand=members($select=id)"
+	getGroupsURL := "https://graph.microsoft.com/v1.0/groups?$select=id,displayName"
 	if len(filters) > 0 {
 		getGroupsURL += "&" + buildPrefixFilter("displayName", filters)
 	}
@@ -176,16 +176,19 @@ func (a *Client) GetGroups(filters []string) ([]idp.Group, error) {
 			return nil, err
 		}
 
+		// Fetch members for each group separately to handle pagination
 		for _, group := range groups.Value {
-			retvalMembers := make([]string, len(group.Members))
-			for j, member := range group.Members {
-				retvalMembers[j] = member.Id
+			members, err := a.getGroupMembers(accessToken, group.Id)
+			if err != nil {
+				// Continue with empty members list if error occurs
+				// This allows sync to continue for other groups
+				members = []string{}
 			}
 
 			retval = append(retval, idp.Group{
 				ID:      group.Id,
 				Name:    group.DisplayName,
-				Members: retvalMembers,
+				Members: members,
 			})
 		}
 
@@ -195,6 +198,49 @@ func (a *Client) GetGroups(filters []string) ([]idp.Group, error) {
 	return retval, nil
 }
 
+// getGroupMembers fetches all members of a group with pagination support
+func (a *Client) getGroupMembers(accessToken, groupID string) ([]string, error) {
+	client := &http.Client{}
+	getMembersURL := fmt.Sprintf("https://graph.microsoft.com/v1.0/groups/%s/members?$select=id", groupID)
+
+	var allMembers []string
+	for getMembersURL != "" {
+		req, err := http.NewRequest("GET", getMembersURL, nil)
+		if err != nil {
+			return nil, err
+		}
+
+		req.Header.Add("Authorization", "Bearer "+accessToken)
+		req.Header.Add("Accept", "application/json")
+
+		resp, err := client.Do(req)
+		if err != nil {
+			return nil, err
+		}
+
+		var membersResponse struct {
+			Value []struct {
+				Id string `json:"id"`
+			} `json:"value"`
+			NextLink string `json:"@odata.nextLink"`
+		}
+
+		err = json.NewDecoder(resp.Body).Decode(&membersResponse)
+		_ = resp.Body.Close()
+		if err != nil {
+			return nil, err
+		}
+
+		for _, member := range membersResponse.Value {
+			allMembers = append(allMembers, member.Id)
+		}
+
+		getMembersURL = membersResponse.NextLink
+	}
+
+	return allMembers, nil
+}
+
 func (a *Client) getAccessToken() (string, error) {
 	tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", a.tenantID)
 
@@ -259,10 +305,6 @@ type getGroupsResponse struct {
 	Value        []struct {
 		Id          string `json:"id"`
 		DisplayName string `json:"displayName"`
-		Members     []struct {
-			OdataType string `json:"@odata.type"`
-			Id        string `json:"id"`
-		} `json:"members"`
 	} `json:"value"`
 	NextLink string `json:"@odata.nextLink"`
 }

+ 15 - 4
pro/logic/ext_acls.go

@@ -48,6 +48,10 @@ func RemoveDeniedNodeFromClient(ec *models.ExtClient, clientOrNodeID string) boo
 
 // SetClientDefaultACLs - set's a client's default ACLs based on network and nodes in network
 func SetClientDefaultACLs(ec *models.ExtClient) error {
+	if !logic.GetServerSettings().OldAClsSupport {
+		ec.DeniedACLs = make(map[string]struct{})
+		return nil
+	}
 	networkNodes, err := logic.GetNetworkNodes(ec.Network)
 	if err != nil {
 		return err
@@ -65,14 +69,18 @@ func SetClientDefaultACLs(ec *models.ExtClient) error {
 	networkAcls[acls.AclID(ec.ClientID)] = make(acls.ACL)
 	for i := range networkNodes {
 		currNode := networkNodes[i]
+		nodeID := acls.AclID(currNode.ID.String())
+		if networkAcls[nodeID] == nil {
+			networkAcls[nodeID] = make(acls.ACL)
+		}
 		if network.DefaultACL == "no" || currNode.DefaultACL == "no" {
 			DenyClientNode(ec, currNode.ID.String())
-			networkAcls[acls.AclID(ec.ClientID)][acls.AclID(currNode.ID.String())] = acls.NotAllowed
-			networkAcls[acls.AclID(currNode.ID.String())][acls.AclID(ec.ClientID)] = acls.NotAllowed
+			networkAcls[acls.AclID(ec.ClientID)][nodeID] = acls.NotAllowed
+			networkAcls[nodeID][acls.AclID(ec.ClientID)] = acls.NotAllowed
 		} else {
 			RemoveDeniedNodeFromClient(ec, currNode.ID.String())
-			networkAcls[acls.AclID(ec.ClientID)][acls.AclID(currNode.ID.String())] = acls.Allowed
-			networkAcls[acls.AclID(currNode.ID.String())][acls.AclID(ec.ClientID)] = acls.Allowed
+			networkAcls[acls.AclID(ec.ClientID)][nodeID] = acls.Allowed
+			networkAcls[nodeID][acls.AclID(ec.ClientID)] = acls.Allowed
 		}
 	}
 	networkClients, err := logic.GetNetworkExtClients(ec.Network)
@@ -82,6 +90,9 @@ func SetClientDefaultACLs(ec *models.ExtClient) error {
 	}
 	for _, client := range networkClients {
 		// TODO: revisit when client-client acls are supported
+		if networkAcls[acls.AclID(client.ClientID)] == nil {
+			networkAcls[acls.AclID(client.ClientID)] = make(acls.ACL)
+		}
 		networkAcls[acls.AclID(ec.ClientID)][acls.AclID(client.ClientID)] = acls.Allowed
 		networkAcls[acls.AclID(client.ClientID)][acls.AclID(ec.ClientID)] = acls.Allowed
 	}