ソースを参照

merge from develop

abhishek9686 3 ヶ月 前
コミット
b36cdd7ade

+ 1 - 1
logic/user_mgmt.go

@@ -50,7 +50,7 @@ var MigrateUserRoleAndGroups = func(u models.User) {
 
 }
 
-var MigrateGroups = func() {}
+var MigrateToUUIDs = func() {}
 
 var UpdateUserGwAccess = func(currentUser, changeUser models.User) {}
 

+ 1 - 0
logic/users.go

@@ -199,6 +199,7 @@ func ListUserInvites() ([]models.UserInvite, error) {
 func DeleteUserInvite(email string) error {
 	return database.DeleteRecord(database.USER_INVITES_TABLE_NAME, email)
 }
+
 func ValidateAndApproveUserInvite(email, code string) error {
 	in, err := GetUserInvite(email)
 	if err != nil {

+ 3 - 3
migrate/migrate.go

@@ -29,7 +29,7 @@ func Run() {
 	assignSuperAdmin()
 	createDefaultTagsAndPolicies()
 	removeOldUserGrps()
-	syncGroups()
+	migrateToUUIDs()
 	syncUsers()
 	updateHosts()
 	updateNodes()
@@ -401,8 +401,8 @@ func MigrateEmqx() {
 
 }
 
-func syncGroups() {
-	logic.MigrateGroups()
+func migrateToUUIDs() {
+	logic.MigrateToUUIDs()
 }
 
 func syncUsers() {

+ 34 - 6
pro/auth/auth.go

@@ -1,9 +1,11 @@
 package auth
 
 import (
+	"encoding/json"
 	"errors"
 	"fmt"
 	"net/http"
+	"strconv"
 	"strings"
 	"time"
 
@@ -34,12 +36,38 @@ const (
 
 // OAuthUser - generic OAuth strategy user
 type OAuthUser struct {
-	ID                string `json:"id" bson:"id"`
-	Name              string `json:"name" bson:"name"`
-	Email             string `json:"email" bson:"email"`
-	Login             string `json:"login" bson:"login"`
-	UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"`
-	AccessToken       string `json:"accesstoken" bson:"accesstoken"`
+	ID                StringOrInt `json:"id" bson:"id"`
+	Name              string      `json:"name" bson:"name"`
+	Email             string      `json:"email" bson:"email"`
+	Login             string      `json:"login" bson:"login"`
+	UserPrincipalName string      `json:"userPrincipalName" bson:"userPrincipalName"`
+	AccessToken       string      `json:"accesstoken" bson:"accesstoken"`
+}
+
+// TODO: this is a very poor solution.
+// We should not return the same OAuthUser for different
+// IdPs. They should have the user that their APIs return.
+// But that's a very big change. So, making do with this
+// for now.
+
+type StringOrInt string
+
+func (s *StringOrInt) UnmarshalJSON(data []byte) error {
+	// Try to unmarshal as string directly
+	var strVal string
+	if err := json.Unmarshal(data, &strVal); err == nil {
+		*s = StringOrInt(strVal)
+		return nil
+	}
+
+	// Try to unmarshal as int and convert to string
+	var intVal int
+	if err := json.Unmarshal(data, &intVal); err == nil {
+		*s = StringOrInt(strconv.Itoa(intVal))
+		return nil
+	}
+
+	return fmt.Errorf("cannot unmarshal %s into StringOrInt", string(data))
 }
 
 var (

+ 2 - 3
pro/auth/azure-ad.go

@@ -111,7 +111,7 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
 					logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 					return
 				}
-				user.ExternalIdentityProviderID = content.ID
+				user.ExternalIdentityProviderID = string(content.ID)
 				if err = logic.CreateUser(&user); err != nil {
 					handleSomethingWentWrong(w)
 					return
@@ -125,7 +125,7 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
 				}
 				err = logic.InsertPendingUser(&models.User{
 					UserName:                   content.Email,
-					ExternalIdentityProviderID: content.ID,
+					ExternalIdentityProviderID: string(content.ID),
 					AuthType:                   models.OAuth,
 				})
 				if err != nil {
@@ -243,7 +243,6 @@ func getAzureUserInfo(state string, code string) (*OAuthUser, error) {
 	}
 	if userInfo.Email == "" && userInfo.UserPrincipalName != "" {
 		userInfo.Email = userInfo.UserPrincipalName
-
 	}
 	if userInfo.Email == "" {
 		err = errors.New("failed to fetch user email from SSO state")

+ 2 - 2
pro/auth/github.go

@@ -111,7 +111,7 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
 					logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 					return
 				}
-				user.ExternalIdentityProviderID = content.ID
+				user.ExternalIdentityProviderID = string(content.ID)
 				if err = logic.CreateUser(&user); err != nil {
 					handleSomethingWentWrong(w)
 					return
@@ -125,7 +125,7 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
 				}
 				err = logic.InsertPendingUser(&models.User{
 					UserName:                   content.Email,
-					ExternalIdentityProviderID: content.ID,
+					ExternalIdentityProviderID: string(content.ID),
 					AuthType:                   models.OAuth,
 				})
 				if err != nil {

+ 1 - 1
pro/auth/google.go

@@ -106,7 +106,7 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
 				}
 				err = logic.InsertPendingUser(&models.User{
 					UserName:                   content.Email,
-					ExternalIdentityProviderID: content.ID,
+					ExternalIdentityProviderID: string(content.ID),
 					AuthType:                   models.OAuth,
 				})
 				if err != nil {

+ 1 - 1
pro/auth/headless_callback.go

@@ -65,7 +65,7 @@ func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) {
 		if database.IsEmptyRecord(err) { // user must not exist, so try to make one
 			err = logic.InsertPendingUser(&models.User{
 				UserName:                   userClaims.getUserName(),
-				ExternalIdentityProviderID: userClaims.ID,
+				ExternalIdentityProviderID: string(userClaims.ID),
 				AuthType:                   models.OAuth,
 			})
 			if err != nil {

+ 3 - 3
pro/auth/oidc.go

@@ -102,7 +102,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
 					logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 					return
 				}
-				user.ExternalIdentityProviderID = content.ID
+				user.ExternalIdentityProviderID = string(content.ID)
 				if err = logic.CreateUser(&user); err != nil {
 					handleSomethingWentWrong(w)
 					return
@@ -116,7 +116,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
 				}
 				err = logic.InsertPendingUser(&models.User{
 					UserName:                   content.Email,
-					ExternalIdentityProviderID: content.ID,
+					ExternalIdentityProviderID: string(content.ID),
 					AuthType:                   models.OAuth,
 				})
 				if err != nil {
@@ -232,7 +232,7 @@ func getOIDCUserInfo(state string, code string) (u *OAuthUser, e error) {
 		e = fmt.Errorf("error when claiming OIDCUser: \"%s\"", err.Error())
 	}
 
-	u.ID = idToken.Subject
+	u.ID = StringOrInt(idToken.Subject)
 
 	return
 }

+ 1 - 1
pro/initialize.go

@@ -131,7 +131,7 @@ func InitPro() {
 	logic.UpdateUserGwAccess = proLogic.UpdateUserGwAccess
 	logic.CreateDefaultUserPolicies = proLogic.CreateDefaultUserPolicies
 	logic.MigrateUserRoleAndGroups = proLogic.MigrateUserRoleAndGroups
-	logic.MigrateGroups = proLogic.MigrateGroups
+	logic.MigrateToUUIDs = proLogic.MigrateToUUIDs
 	logic.IntialiseGroups = proLogic.UserGroupsInit
 	logic.AddGlobalNetRolesToAdmins = proLogic.AddGlobalNetRolesToAdmins
 	logic.GetUserGroupsInNetwork = proLogic.GetUserGroupsInNetwork

+ 137 - 5
pro/logic/migrate.go

@@ -11,13 +11,51 @@ import (
 	"github.com/gravitl/netmaker/models"
 )
 
-func MigrateGroups() {
+func MigrateToUUIDs() {
+	roles, err := ListNetworkRoles()
+	if err != nil {
+		return
+	}
+
+	rolesMapping := make(map[models.UserRoleID]models.UserRoleID)
+
+	for _, role := range roles {
+		if role.Default {
+			continue
+		}
+
+		_, err := uuid.Parse(string(role.ID))
+		if err == nil {
+			// role id is already an uuid, so no need to update
+			continue
+		}
+
+		oldRoleID := role.ID
+		role.ID = models.UserRoleID(uuid.NewString())
+		rolesMapping[oldRoleID] = role.ID
+
+		roleBytes, err := json.Marshal(role)
+		if err != nil {
+			continue
+		}
+
+		err = database.Insert(role.ID.String(), string(roleBytes), database.USER_PERMISSIONS_TABLE_NAME)
+		if err != nil {
+			continue
+		}
+
+		err = database.DeleteRecord(database.USER_PERMISSIONS_TABLE_NAME, oldRoleID.String())
+		if err != nil {
+			continue
+		}
+	}
+
 	groups, err := ListUserGroups()
 	if err != nil {
 		return
 	}
 
-	groupMapping := make(map[models.UserGroupID]models.UserGroupID)
+	groupsMapping := make(map[models.UserGroupID]models.UserGroupID)
 
 	for _, group := range groups {
 		if group.Default {
@@ -32,7 +70,22 @@ func MigrateGroups() {
 
 		oldGroupID := group.ID
 		group.ID = models.UserGroupID(uuid.NewString())
-		groupMapping[oldGroupID] = group.ID
+		groupsMapping[oldGroupID] = group.ID
+
+		var groupPermissions = make(map[models.NetworkID]map[models.UserRoleID]struct{})
+		for networkID, networkRoles := range group.NetworkRoles {
+			groupPermissions[networkID] = make(map[models.UserRoleID]struct{})
+			for roleID := range networkRoles {
+				newRoleID, ok := rolesMapping[roleID]
+				if !ok {
+					groupPermissions[networkID][roleID] = struct{}{}
+				} else {
+					groupPermissions[networkID][newRoleID] = struct{}{}
+				}
+			}
+		}
+
+		group.NetworkRoles = groupPermissions
 
 		groupBytes, err := json.Marshal(group)
 		if err != nil {
@@ -50,6 +103,11 @@ func MigrateGroups() {
 		}
 	}
 
+	// if no changes were made, there are no references to be updated.
+	if len(rolesMapping) == 0 && len(groupsMapping) == 0 {
+		return
+	}
+
 	users, err := logic.GetUsersDB()
 	if err != nil {
 		return
@@ -58,7 +116,7 @@ func MigrateGroups() {
 	for _, user := range users {
 		userGroups := make(map[models.UserGroupID]struct{})
 		for groupID := range user.UserGroups {
-			newGroupID, ok := groupMapping[groupID]
+			newGroupID, ok := groupsMapping[groupID]
 			if !ok {
 				userGroups[groupID] = struct{}{}
 			} else {
@@ -67,7 +125,81 @@ func MigrateGroups() {
 		}
 
 		user.UserGroups = userGroups
-		logic.UpsertUser(user)
+		err = logic.UpsertUser(user)
+		if err != nil {
+			continue
+		}
+	}
+
+	for _, acl := range logic.ListAcls() {
+		srcList := make([]models.AclPolicyTag, len(acl.Src))
+		for i, src := range acl.Src {
+			if src.ID == models.UserGroupAclID {
+				newGroupID, ok := groupsMapping[models.UserGroupID(src.Value)]
+				if ok {
+					src.Value = newGroupID.String()
+				}
+			}
+
+			srcList[i] = src
+		}
+
+		dstList := make([]models.AclPolicyTag, len(acl.Dst))
+		for i, dst := range acl.Dst {
+			if dst.ID == models.UserGroupAclID {
+				newGroupID, ok := groupsMapping[models.UserGroupID(dst.Value)]
+				if ok {
+					dst.Value = newGroupID.String()
+				}
+			}
+
+			dstList[i] = dst
+		}
+
+		err = logic.UpsertAcl(acl)
+		if err != nil {
+			continue
+		}
+	}
+
+	invites, err := logic.ListUserInvites()
+	if err != nil {
+		return
+	}
+
+	for _, invite := range invites {
+		userGroups := make(map[models.UserGroupID]struct{})
+		for groupID := range invite.UserGroups {
+			newGroupID, ok := groupsMapping[groupID]
+			if !ok {
+				invite.UserGroups[groupID] = struct{}{}
+			} else {
+				invite.UserGroups[newGroupID] = struct{}{}
+			}
+		}
+
+		invite.UserGroups = userGroups
+
+		userPermissions := make(map[models.NetworkID]map[models.UserRoleID]struct{})
+
+		for networkID, networkRoles := range invite.NetworkRoles {
+			userPermissions[networkID] = make(map[models.UserRoleID]struct{})
+			for roleID := range networkRoles {
+				newRoleID, ok := rolesMapping[roleID]
+				if !ok {
+					userPermissions[networkID][roleID] = struct{}{}
+				} else {
+					userPermissions[networkID][newRoleID] = struct{}{}
+				}
+			}
+		}
+
+		invite.NetworkRoles = userPermissions
+
+		err = logic.InsertUserInvite(invite)
+		if err != nil {
+			continue
+		}
 	}
 }
 

+ 42 - 14
pro/logic/user_mgmt.go

@@ -19,6 +19,8 @@ import (
 var (
 	globalNetworksAdminGroupID = models.UserGroupID(fmt.Sprintf("global-%s-grp", models.NetworkAdmin))
 	globalNetworksUserGroupID  = models.UserGroupID(fmt.Sprintf("global-%s-grp", models.NetworkUser))
+	globalNetworksAdminRoleID  = models.UserRoleID(fmt.Sprintf("global-%s", models.NetworkAdmin))
+	globalNetworksUserRoleID   = models.UserRoleID(fmt.Sprintf("global-%s", models.NetworkUser))
 )
 
 var ServiceUserPermissionTemplate = models.UserRolePermissionTemplate{
@@ -35,7 +37,7 @@ var PlatformUserUserPermissionTemplate = models.UserRolePermissionTemplate{
 }
 
 var NetworkAdminAllPermissionTemplate = models.UserRolePermissionTemplate{
-	ID:         models.UserRoleID(fmt.Sprintf("global-%s", models.NetworkAdmin)),
+	ID:         globalNetworksAdminRoleID,
 	Name:       "Network Admins",
 	MetaData:   "can manage configuration of all networks",
 	Default:    true,
@@ -44,7 +46,7 @@ var NetworkAdminAllPermissionTemplate = models.UserRolePermissionTemplate{
 }
 
 var NetworkUserAllPermissionTemplate = models.UserRolePermissionTemplate{
-	ID:         models.UserRoleID(fmt.Sprintf("global-%s", models.NetworkUser)),
+	ID:         globalNetworksUserRoleID,
 	Name:       "Network Users",
 	MetaData:   "Can connect to nodes in your networks via Netmaker Desktop App.",
 	Default:    true,
@@ -124,7 +126,7 @@ func UserGroupsInit() {
 		MetaData: "can manage configuration of all networks",
 		NetworkRoles: map[models.NetworkID]map[models.UserRoleID]struct{}{
 			models.AllNetworks: {
-				models.UserRoleID(fmt.Sprintf("global-%s", models.NetworkAdmin)): {},
+				globalNetworksAdminRoleID: {},
 			},
 		},
 	}
@@ -134,7 +136,7 @@ func UserGroupsInit() {
 		Default: true,
 		NetworkRoles: map[models.NetworkID]map[models.UserRoleID]struct{}{
 			models.AllNetworks: {
-				models.UserRoleID(fmt.Sprintf("global-%s", models.NetworkUser)): {},
+				globalNetworksUserRoleID: {},
 			},
 		},
 		MetaData: "Provides read-only dashboard access to platform users and allows connection to network nodes via the Netmaker Desktop App.",
@@ -150,7 +152,7 @@ func CreateDefaultNetworkRolesAndGroups(netID models.NetworkID) {
 		return
 	}
 	var NetworkAdminPermissionTemplate = models.UserRolePermissionTemplate{
-		ID:                 models.UserRoleID(fmt.Sprintf("%s-%s", netID, models.NetworkAdmin)),
+		ID:                 GetDefaultNetworkAdminRoleID(netID),
 		Name:               fmt.Sprintf("%s Admin", netID),
 		MetaData:           fmt.Sprintf("can manage your network `%s` configuration.", netID),
 		Default:            true,
@@ -160,7 +162,7 @@ func CreateDefaultNetworkRolesAndGroups(netID models.NetworkID) {
 	}
 
 	var NetworkUserPermissionTemplate = models.UserRolePermissionTemplate{
-		ID:                  models.UserRoleID(fmt.Sprintf("%s-%s", netID, models.NetworkUser)),
+		ID:                  GetDefaultNetworkUserRoleID(netID),
 		Name:                fmt.Sprintf("%s User", netID),
 		MetaData:            fmt.Sprintf("Can connect to nodes in your network `%s` via Netmaker Desktop App.", netID),
 		Default:             true,
@@ -227,7 +229,7 @@ func CreateDefaultNetworkRolesAndGroups(netID models.NetworkID) {
 		Default: true,
 		NetworkRoles: map[models.NetworkID]map[models.UserRoleID]struct{}{
 			netID: {
-				models.UserRoleID(fmt.Sprintf("%s-%s", netID, models.NetworkAdmin)): {},
+				GetDefaultNetworkAdminRoleID(netID): {},
 			},
 		},
 		MetaData: fmt.Sprintf("can manage your network `%s` configuration including adding and removing devices.", netID),
@@ -238,7 +240,7 @@ func CreateDefaultNetworkRolesAndGroups(netID models.NetworkID) {
 		Default: true,
 		NetworkRoles: map[models.NetworkID]map[models.UserRoleID]struct{}{
 			netID: {
-				models.UserRoleID(fmt.Sprintf("%s-%s", netID, models.NetworkUser)): {},
+				GetDefaultNetworkUserRoleID(netID): {},
 			},
 		},
 		MetaData: fmt.Sprintf("Can connect to nodes in your network `%s` via Netmaker Desktop App. Platform users will have read-only access to the the dashboard.", netID),
@@ -403,14 +405,32 @@ func ValidateUpdateRoleReq(userRole *models.UserRolePermissionTemplate) error {
 
 // CreateRole - inserts new role into DB
 func CreateRole(r models.UserRolePermissionTemplate) error {
-	// check if role already exists
-	if r.ID.String() == "" {
-		return errors.New("role id cannot be empty")
+	// default roles are currently created directly in the db.
+	// this check is only to prevent future errors.
+	if r.Default && r.ID == "" {
+		return errors.New("role id cannot be empty for default role")
 	}
-	_, err := database.FetchRecord(database.USER_PERMISSIONS_TABLE_NAME, r.ID.String())
-	if err == nil {
-		return errors.New("role already exists")
+
+	if !r.Default {
+		r.ID = models.UserRoleID(uuid.NewString())
 	}
+
+	// check if the role already exists
+	if r.Name == "" {
+		return errors.New("role name cannot be empty")
+	}
+
+	roles, err := ListNetworkRoles()
+	if err != nil {
+		return err
+	}
+
+	for _, role := range roles {
+		if role.Name == r.Name {
+			return errors.New("role already exists")
+		}
+	}
+
 	d, err := json.Marshal(r)
 	if err != nil {
 		return err
@@ -586,6 +606,14 @@ func GetDefaultNetworkUserGroupID(networkID models.NetworkID) models.UserGroupID
 	return models.UserGroupID(fmt.Sprintf("%s-%s-grp", networkID, models.NetworkUser))
 }
 
+func GetDefaultNetworkAdminRoleID(networkID models.NetworkID) models.UserRoleID {
+	return models.UserRoleID(fmt.Sprintf("%s-%s", networkID, models.NetworkAdmin))
+}
+
+func GetDefaultNetworkUserRoleID(networkID models.NetworkID) models.UserRoleID {
+	return models.UserRoleID(fmt.Sprintf("%s-%s", networkID, models.NetworkUser))
+}
+
 // ListUserGroups - lists user groups
 func ListUserGroups() ([]models.UserGroup, error) {
 	data, err := database.FetchRecords(database.USER_GROUPS_TABLE_NAME)