Browse Source

Merge branch 'develop' of https://github.com/gravitl/netmaker into NET-1932-acl-ce

abhishek9686 3 months ago
parent
commit
749849e185
9 changed files with 65 additions and 37 deletions
  1. 1 1
      auth/host_session.go
  2. 8 3
      controllers/enrollmentkeys.go
  3. 1 0
      controllers/user.go
  4. 2 2
      go.mod
  5. 4 4
      go.sum
  6. 41 23
      pro/auth/sync.go
  7. 2 1
      pro/controllers/users.go
  8. 1 1
      pro/initialize.go
  9. 5 2
      pro/logic/user_mgmt.go

+ 1 - 1
auth/host_session.go

@@ -271,7 +271,7 @@ func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uui
 						slog.Error("failed to relay node. maybe specified relay node is actually not a relay? Or the relayed node is not in the same network with relay?", "err", err)
 					}
 				}
-				if strings.Contains(err.Error(), "host already part of network") {
+				if err != nil && strings.Contains(err.Error(), "host already part of network") {
 					continue
 				}
 			} else {

+ 8 - 3
controllers/enrollmentkeys.go

@@ -402,12 +402,17 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 	}
+	host, err := logic.GetHost(newHost.ID.String())
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
 	// ready the response
 	server := logic.GetServerInfo()
 	server.TrafficKey = key
 	response := models.RegisterResponse{
 		ServerConf:    server,
-		RequestedHost: newHost,
+		RequestedHost: *host,
 	}
 	for _, netID := range enrollmentKey.Networks {
 		logic.LogEvent(&models.Event{
@@ -428,9 +433,9 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
 		})
 	}
 
-	logger.Log(0, newHost.Name, newHost.ID.String(), "registered with Netmaker")
+	logger.Log(0, host.Name, host.ID.String(), "registered with Netmaker")
 	w.WriteHeader(http.StatusOK)
 	json.NewEncoder(w).Encode(&response)
 	// notify host of changes, peer and node updates
-	go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, &newHost, enrollmentKey.Relay, enrollmentKey.Groups)
+	go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, host, enrollmentKey.Relay, enrollmentKey.Groups)
 }

+ 1 - 0
controllers/user.go

@@ -759,6 +759,7 @@ func createUser(w http.ResponseWriter, r *http.Request) {
 			ID:   user.UserName,
 			Name: user.UserName,
 			Type: models.UserSub,
+			Info: user,
 		},
 		Origin: models.Dashboard,
 	})

+ 2 - 2
go.mod

@@ -22,7 +22,7 @@ require (
 	go.uber.org/automaxprocs v1.6.0
 	golang.org/x/crypto v0.38.0
 	golang.org/x/net v0.39.0 // indirect
-	golang.org/x/oauth2 v0.29.0
+	golang.org/x/oauth2 v0.30.0
 	golang.org/x/sys v0.33.0 // indirect
 	golang.org/x/text v0.25.0 // indirect
 	golang.zx2c4.com/wireguard/wgctrl v0.0.0-20221104135756-97bc4ad4a1cb
@@ -53,7 +53,7 @@ require (
 	gorm.io/datatypes v1.2.5
 	gorm.io/driver/postgres v1.5.11
 	gorm.io/driver/sqlite v1.5.7
-	gorm.io/gorm v1.26.1
+	gorm.io/gorm v1.30.0
 )
 
 require (

+ 4 - 4
go.sum

@@ -161,8 +161,8 @@ golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERs
 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
 golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
 golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
-golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98=
-golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
+golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
+golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
 golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
 golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
 golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
@@ -204,5 +204,5 @@ gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDa
 gorm.io/driver/sqlserver v1.5.4 h1:xA+Y1KDNspv79q43bPyjDMUgHoYHLhXYmdFcYPobg8g=
 gorm.io/driver/sqlserver v1.5.4/go.mod h1:+frZ/qYmuna11zHPlh5oc2O6ZA/lS88Keb0XSH1Zh/g=
 gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
-gorm.io/gorm v1.26.1 h1:ghB2gUI9FkS46luZtn6DLZ0f6ooBJ5IbVej2ENFDjRw=
-gorm.io/gorm v1.26.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
+gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
+gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=

+ 41 - 23
pro/auth/sync.go

@@ -1,6 +1,7 @@
 package auth
 
 import (
+	"context"
 	"fmt"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logger"
@@ -11,29 +12,46 @@ import (
 	"github.com/gravitl/netmaker/pro/idp/google"
 	proLogic "github.com/gravitl/netmaker/pro/logic"
 	"strings"
+	"sync"
 	"time"
 )
 
-var syncTicker *time.Ticker
+var (
+	cancelSyncHook context.CancelFunc
+	hookStopWg     sync.WaitGroup
+)
 
-func StartSyncHook() {
-	syncTicker = time.NewTicker(logic.GetIDPSyncInterval())
+func ResetIDPSyncHook() {
+	if cancelSyncHook != nil {
+		cancelSyncHook()
+		hookStopWg.Wait()
+		cancelSyncHook = nil
+	}
 
-	for range syncTicker.C {
-		err := SyncFromIDP()
-		if err != nil {
-			logger.Log(0, "failed to sync from idp: ", err.Error())
-		} else {
-			logger.Log(0, "sync from idp complete")
-		}
+	if logic.IsSyncEnabled() {
+		ctx, cancel := context.WithCancel(context.Background())
+		cancelSyncHook = cancel
+		hookStopWg.Add(1)
+		go runIDPSyncHook(ctx)
 	}
 }
 
-func ResetIDPSyncHook() {
-	if syncTicker != nil {
-		syncTicker.Stop()
-		if logic.IsSyncEnabled() {
-			go StartSyncHook()
+func runIDPSyncHook(ctx context.Context) {
+	defer hookStopWg.Done()
+	ticker := time.NewTicker(logic.GetIDPSyncInterval())
+	defer ticker.Stop()
+
+	for {
+		select {
+		case <-ctx.Done():
+			logger.Log(0, "idp sync hook stopped")
+			return
+		case <-ticker.C:
+			if err := SyncFromIDP(); err != nil {
+				logger.Log(0, "failed to sync from idp: ", err.Error())
+			} else {
+				logger.Log(0, "sync from idp complete")
+			}
 		}
 	}
 }
@@ -218,11 +236,11 @@ func syncGroups(idpGroups []idp.Group) error {
 
 		dbGroup, ok := dbGroupsMap[group.ID]
 		if !ok {
-			err := proLogic.CreateUserGroup(models.UserGroup{
-				ExternalIdentityProviderID: group.ID,
-				Default:                    false,
-				Name:                       group.Name,
-			})
+			dbGroup.ExternalIdentityProviderID = group.ID
+			dbGroup.Name = group.Name
+			dbGroup.Default = false
+			dbGroup.NetworkRoles = make(map[models.NetworkID]map[models.UserRoleID]struct{})
+			err := proLogic.CreateUserGroup(&dbGroup)
 			if err != nil {
 				return err
 			}
@@ -241,18 +259,18 @@ func syncGroups(idpGroups []idp.Group) error {
 
 		for _, user := range dbUsers {
 			// use dbGroup.Name because the group name may have been changed on idp.
-			_, inNetmakerGroup := user.UserGroups[models.UserGroupID(dbGroup.Name)]
+			_, inNetmakerGroup := user.UserGroups[dbGroup.ID]
 			_, inIDPGroup := groupMembersMap[user.ExternalIdentityProviderID]
 
 			if inNetmakerGroup && !inIDPGroup {
 				// use dbGroup.Name because the group name may have been changed on idp.
-				delete(dbUsersMap[user.ExternalIdentityProviderID].UserGroups, models.UserGroupID(dbGroup.Name))
+				delete(dbUsersMap[user.ExternalIdentityProviderID].UserGroups, dbGroup.ID)
 				modifiedUsers[user.ExternalIdentityProviderID] = struct{}{}
 			}
 
 			if !inNetmakerGroup && inIDPGroup {
 				// use dbGroup.Name because the group name may have been changed on idp.
-				dbUsersMap[user.ExternalIdentityProviderID].UserGroups[models.UserGroupID(dbGroup.Name)] = struct{}{}
+				dbUsersMap[user.ExternalIdentityProviderID].UserGroups[dbGroup.ID] = struct{}{}
 				modifiedUsers[user.ExternalIdentityProviderID] = struct{}{}
 			}
 		}

+ 2 - 1
pro/controllers/users.go

@@ -257,6 +257,7 @@ func inviteUsers(w http.ResponseWriter, r *http.Request) {
 				ID:   callerUserName,
 				Name: callerUserName,
 				Type: models.UserSub,
+				Info: invite,
 			},
 			TriggeredBy: callerUserName,
 			Target: models.Subject{
@@ -458,7 +459,7 @@ func createUserGroup(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
-	err = proLogic.CreateUserGroup(userGroupReq.Group)
+	err = proLogic.CreateUserGroup(&userGroupReq.Group)
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return

+ 1 - 1
pro/initialize.go

@@ -94,7 +94,7 @@ func InitPro() {
 		}
 		proLogic.LoadNodeMetricsToCache()
 		proLogic.InitFailOverCache()
-		go auth.StartSyncHook()
+		auth.ResetIDPSyncHook()
 		email.Init()
 		go proLogic.EventWatcher()
 	})

+ 5 - 2
pro/logic/user_mgmt.go

@@ -531,7 +531,7 @@ func ValidateUpdateGroupReq(g models.UserGroup) error {
 }
 
 // CreateUserGroup - creates new user group
-func CreateUserGroup(g models.UserGroup) error {
+func CreateUserGroup(g *models.UserGroup) error {
 	// default groups are currently created directly in the db.
 	// this check is only to prevent future errors.
 	if g.Default && g.ID == "" {
@@ -1287,7 +1287,10 @@ func AddGlobalNetRolesToAdmins(u *models.User) {
 	if u.PlatformRoleID != models.SuperAdminRole && u.PlatformRoleID != models.AdminRole {
 		return
 	}
-	u.UserGroups = make(map[models.UserGroupID]struct{})
+
+	if len(u.UserGroups) == 0 {
+		u.UserGroups = make(map[models.UserGroupID]struct{})
+	}
 
 	u.UserGroups[globalNetworksAdminGroupID] = struct{}{}
 }