浏览代码

redefine access token api routes, add auto egress option to enrollment keys

abhishek9686 5 月之前
父节点
当前提交
84ba04b6e1
共有 9 个文件被更改,包括 58 次插入38 次删除
  1. 1 0
      controllers/enrollmentkeys.go
  2. 23 15
      controllers/user.go
  3. 14 4
      logic/auth.go
  4. 2 1
      logic/enrollmentkey.go
  5. 13 13
      logic/enrollmentkey_test.go
  6. 1 0
      logic/networks.go
  7. 1 0
      migrate/migrate.go
  8. 1 5
      models/accessToken.go
  9. 2 0
      models/enrollment_key.go

+ 1 - 0
controllers/enrollmentkeys.go

@@ -160,6 +160,7 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
 		enrollmentKeyBody.Unlimited,
 		relayId,
 		false,
+		enrollmentKeyBody.AutoEgress,
 	)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())

+ 23 - 15
controllers/user.go

@@ -37,9 +37,9 @@ func userHandlers(r *mux.Router) {
 	r.HandleFunc("/api/v1/users", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet)
 	r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods(http.MethodGet)
 	r.HandleFunc("/api/v1/users/roles", logic.SecurityCheck(true, http.HandlerFunc(ListRoles))).Methods(http.MethodGet)
-	r.HandleFunc("/api/v1/users/{username}/access_token", logic.SecurityCheck(true, http.HandlerFunc(createUserAccessToken))).Methods(http.MethodPost)
-	r.HandleFunc("/api/v1/users/{username}/access_token", logic.SecurityCheck(true, http.HandlerFunc(getUserAccessTokens))).Methods(http.MethodGet)
-	r.HandleFunc("/api/v1/users/{username}/access_token", logic.SecurityCheck(true, http.HandlerFunc(deleteUserAccessTokens))).Methods(http.MethodDelete)
+	r.HandleFunc("/api/v1/users/access_token", logic.SecurityCheck(true, http.HandlerFunc(createUserAccessToken))).Methods(http.MethodPost)
+	r.HandleFunc("/api/v1/users/access_token", logic.SecurityCheck(true, http.HandlerFunc(getUserAccessTokens))).Methods(http.MethodGet)
+	r.HandleFunc("/api/v1/users/access_token", logic.SecurityCheck(true, http.HandlerFunc(deleteUserAccessTokens))).Methods(http.MethodDelete)
 }
 
 // @Summary     Authenticate a user to retrieve an authorization token
@@ -64,6 +64,14 @@ func createUserAccessToken(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
+	if req.Name == "" {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("name is required"), "badrequest"))
+		return
+	}
+	if req.UserName == "" {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
+		return
+	}
 
 	user, err := logic.GetUser(req.UserName)
 	if err != nil {
@@ -110,8 +118,11 @@ func createUserAccessToken(w http.ResponseWriter, r *http.Request) {
 // @Failure     401 {object} models.ErrorResponse
 // @Failure     500 {object} models.ErrorResponse
 func getUserAccessTokens(w http.ResponseWriter, r *http.Request) {
-	var params = mux.Vars(r)
-	username := params["username"]
+	username := r.URL.Query().Get("username")
+	if username == "" {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
+		return
+	}
 	_, err := logic.GetUser(username)
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
@@ -131,25 +142,22 @@ func getUserAccessTokens(w http.ResponseWriter, r *http.Request) {
 // @Failure     401 {object} models.ErrorResponse
 // @Failure     500 {object} models.ErrorResponse
 func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
-	var req models.AccessToken
-
-	err := json.NewDecoder(r.Body).Decode(&req)
-	if err != nil {
-		logger.Log(0, "error decoding request body: ",
-			err.Error())
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+	id := r.URL.Query().Get("id")
+	if id == "" {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
 		return
 	}
-	err = logic.RevokeAccessToken(req)
+
+	err := logic.RevokeAccessToken(models.AccessToken{ID: id})
 	if err != nil {
 		logic.ReturnErrorResponse(
 			w,
 			r,
-			logic.FormatError(errors.New("error creating access token "+err.Error()), "internal"),
+			logic.FormatError(errors.New("error deleting access token "+err.Error()), "internal"),
 		)
 		return
 	}
-	logic.ReturnSuccessResponseWithJson(w, r, nil, "revoked access token "+req.Name)
+	logic.ReturnSuccessResponseWithJson(w, r, nil, "revoked access token")
 }
 
 // @Summary     Authenticate a user to retrieve an authorization token

+ 14 - 4
logic/auth.go

@@ -8,6 +8,7 @@ import (
 	"time"
 
 	"github.com/go-playground/validator/v10"
+	"github.com/google/uuid"
 	"golang.org/x/crypto/bcrypt"
 	"golang.org/x/exp/slog"
 
@@ -139,7 +140,7 @@ func FetchPassValue(newValue string) (string, error) {
 }
 
 func RevokeAccessToken(a models.AccessToken) error {
-	err := database.DeleteRecord(database.ACCESS_TOKENS_TABLE_NAME, a.GetDBKey())
+	err := database.DeleteRecord(database.ACCESS_TOKENS_TABLE_NAME, a.ID)
 	if err != nil {
 		return err
 	}
@@ -164,8 +165,17 @@ func RevokeAllUserTokens(username string) {
 	}
 }
 
+func GetAccessToken(k string) (a models.AccessToken, err error) {
+	value, err := database.FetchRecord(k, database.ACCESS_TOKENS_TABLE_NAME)
+	if err != nil {
+		return
+	}
+	err = json.Unmarshal([]byte(value), &a)
+	return
+}
+
 func ListAccessTokens(username string) (tokens []models.AccessToken) {
-	collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
+	collection, err := database.FetchRecords(database.ACCESS_TOKENS_TABLE_NAME)
 	if err != nil {
 		return
 	}
@@ -186,13 +196,13 @@ func ListAccessTokens(username string) (tokens []models.AccessToken) {
 }
 
 func CreateAccessToken(a models.AccessToken) error {
-	// connect db
+	a.ID = uuid.New().String()
 	data, err := json.Marshal(a)
 	if err != nil {
 		logger.Log(0, "failed to marshal", err.Error())
 		return err
 	}
-	err = database.Insert(a.GetDBKey(), string(data), database.ACCESS_TOKENS_TABLE_NAME)
+	err = database.Insert(a.ID, string(data), database.ACCESS_TOKENS_TABLE_NAME)
 	if err != nil {
 		logger.Log(0, "failed to insert user", err.Error())
 		return err

+ 2 - 1
logic/enrollmentkey.go

@@ -38,7 +38,7 @@ var (
 )
 
 // CreateEnrollmentKey - creates a new enrollment key in db
-func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID, defaultKey bool) (*models.EnrollmentKey, error) {
+func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID, defaultKey, autoEgress bool) (*models.EnrollmentKey, error) {
 	newKeyID, err := getUniqueEnrollmentID()
 	if err != nil {
 		return nil, err
@@ -54,6 +54,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
 		Relay:         relay,
 		Groups:        groups,
 		Default:       defaultKey,
+		AutoEgress:    autoEgress,
 	}
 	if uses > 0 {
 		k.UsesRemaining = uses

+ 13 - 13
logic/enrollmentkey_test.go

@@ -14,35 +14,35 @@ func TestCreateEnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	t.Run("Can_Not_Create_Key", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
 		assert.Nil(t, newKey)
 		assert.NotNil(t, err)
 		assert.ErrorIs(t, err, models.ErrInvalidEnrollmentKey)
 	})
 	t.Run("Can_Create_Key_Uses", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
+		newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
 		assert.Nil(t, err)
 		assert.Equal(t, 1, newKey.UsesRemaining)
 		assert.True(t, newKey.IsValid())
 	})
 	t.Run("Can_Create_Key_Time", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil, false)
+		newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil, false, false)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 	})
 	t.Run("Can_Create_Key_Unlimited", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false, false)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 	})
 	t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 		assert.True(t, len(newKey.Networks) == 2)
 	})
 	t.Run("Can_Create_Key_WithTags", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil, false)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil, false, false)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 		assert.True(t, len(newKey.Tags) == 2)
@@ -62,7 +62,7 @@ func TestCreateEnrollmentKey(t *testing.T) {
 func TestDelete_EnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
 	t.Run("Can_Delete_Key", func(t *testing.T) {
 		assert.True(t, newKey.IsValid())
 		err := DeleteEnrollmentKey(newKey.Value, false)
@@ -83,7 +83,7 @@ func TestDelete_EnrollmentKey(t *testing.T) {
 func TestDecrement_EnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
+	newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
 	t.Run("Check_initial_uses", func(t *testing.T) {
 		assert.True(t, newKey.IsValid())
 		assert.Equal(t, newKey.UsesRemaining, 1)
@@ -107,9 +107,9 @@ func TestDecrement_EnrollmentKey(t *testing.T) {
 func TestUsability_EnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false)
-	key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil, false)
-	key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false)
+	key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
+	key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil, false, false)
+	key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false, false)
 	t.Run("Check if valid use key can be used", func(t *testing.T) {
 		assert.Equal(t, key1.UsesRemaining, 1)
 		ok := TryToUseEnrollmentKey(key1)
@@ -145,7 +145,7 @@ func removeAllEnrollments() {
 func TestTokenize_EnrollmentKeys(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
 	const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5"
 	const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
 	const serverAddr = "api.myserver.com"
@@ -178,7 +178,7 @@ func TestTokenize_EnrollmentKeys(t *testing.T) {
 func TestDeTokenize_EnrollmentKeys(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false)
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
 	const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
 	const serverAddr = "api.myserver.com"
 

+ 1 - 0
logic/networks.go

@@ -302,6 +302,7 @@ func CreateNetwork(network models.Network) (models.Network, error) {
 		true,
 		uuid.Nil,
 		true,
+		false,
 	)
 
 	return network, nil

+ 1 - 0
migrate/migrate.go

@@ -151,6 +151,7 @@ func updateEnrollmentKeys() {
 			true,
 			uuid.Nil,
 			true,
+			false,
 		)
 
 	}

+ 1 - 5
models/accessToken.go

@@ -1,12 +1,12 @@
 package models
 
 import (
-	"fmt"
 	"time"
 )
 
 // AccessToken - token used to access netmaker
 type AccessToken struct {
+	ID        string    `json:"id"`
 	Name      string    `json:"name"`
 	UserName  string    `json:"user_name"`
 	ExpiresAt time.Time `json:"expires_at"`
@@ -14,7 +14,3 @@ type AccessToken struct {
 	CreatedBy time.Time `json:"created_by"`
 	CreatedAt time.Time `json:"created_at"`
 }
-
-func (a AccessToken) GetDBKey() string {
-	return fmt.Sprintf("%s#%s", a.UserName, a.Name)
-}

+ 2 - 0
models/enrollment_key.go

@@ -54,6 +54,7 @@ type EnrollmentKey struct {
 	Relay         uuid.UUID `json:"relay"`
 	Groups        []TagID   `json:"groups"`
 	Default       bool      `json:"default"`
+	AutoEgress    bool      `json:"auto_egress"`
 }
 
 // APIEnrollmentKey - used to create enrollment keys via API
@@ -66,6 +67,7 @@ type APIEnrollmentKey struct {
 	Type          KeyType  `json:"type"`
 	Relay         string   `json:"relay"`
 	Groups        []TagID  `json:"groups"`
+	AutoEgress    bool     `json:"auto_egress"`
 }
 
 // RegisterResponse - the response to a successful enrollment register