enrollmentkey.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. package logic
  2. import (
  3. b64 "encoding/base64"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "golang.org/x/exp/slices"
  8. "time"
  9. "github.com/gravitl/netmaker/database"
  10. "github.com/gravitl/netmaker/models"
  11. )
  12. // EnrollmentErrors - struct for holding EnrollmentKey error messages
  13. var EnrollmentErrors = struct {
  14. InvalidCreate error
  15. NoKeyFound error
  16. InvalidKey error
  17. NoUsesRemaining error
  18. FailedToTokenize error
  19. FailedToDeTokenize error
  20. }{
  21. InvalidCreate: fmt.Errorf("invalid enrollment key created"),
  22. NoKeyFound: fmt.Errorf("no enrollmentkey found"),
  23. InvalidKey: fmt.Errorf("invalid key provided"),
  24. NoUsesRemaining: fmt.Errorf("no uses remaining"),
  25. FailedToTokenize: fmt.Errorf("failed to tokenize"),
  26. FailedToDeTokenize: fmt.Errorf("failed to detokenize"),
  27. }
  28. // CreateEnrollmentKey - creates a new enrollment key in db
  29. func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) {
  30. newKeyID, err := getUniqueEnrollmentID()
  31. if err != nil {
  32. return nil, err
  33. }
  34. k = &models.EnrollmentKey{
  35. Value: newKeyID,
  36. Expiration: time.Time{},
  37. UsesRemaining: 0,
  38. Unlimited: unlimited,
  39. Networks: []string{},
  40. Tags: []string{},
  41. Type: models.Undefined,
  42. }
  43. if uses > 0 {
  44. k.UsesRemaining = uses
  45. k.Type = models.Uses
  46. } else if !expiration.IsZero() {
  47. k.Expiration = expiration
  48. k.Type = models.TimeExpiration
  49. } else if k.Unlimited {
  50. k.Type = models.Unlimited
  51. }
  52. if len(networks) > 0 {
  53. k.Networks = networks
  54. }
  55. if len(tags) > 0 {
  56. k.Tags = tags
  57. }
  58. if ok := k.Validate(); !ok {
  59. return nil, EnrollmentErrors.InvalidCreate
  60. }
  61. if err = upsertEnrollmentKey(k); err != nil {
  62. return nil, err
  63. }
  64. return
  65. }
  66. // GetAllEnrollmentKeys - fetches all enrollment keys from DB
  67. // TODO drop double pointer
  68. func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) {
  69. currentKeys, err := getEnrollmentKeysMap()
  70. if err != nil {
  71. return nil, err
  72. }
  73. var currentKeysList = []*models.EnrollmentKey{}
  74. for k := range currentKeys {
  75. currentKeysList = append(currentKeysList, currentKeys[k])
  76. }
  77. return currentKeysList, nil
  78. }
  79. // GetEnrollmentKey - fetches a single enrollment key
  80. // returns nil and error if not found
  81. func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) {
  82. currentKeys, err := getEnrollmentKeysMap()
  83. if err != nil {
  84. return nil, err
  85. }
  86. if key, ok := currentKeys[value]; ok {
  87. return key, nil
  88. }
  89. return nil, EnrollmentErrors.NoKeyFound
  90. }
  91. // DeleteEnrollmentKey - delete's a given enrollment key by value
  92. func DeleteEnrollmentKey(value string) error {
  93. _, err := GetEnrollmentKey(value)
  94. if err != nil {
  95. return err
  96. }
  97. return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
  98. }
  99. // TryToUseEnrollmentKey - checks first if key can be decremented
  100. // returns true if it is decremented or isvalid
  101. func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool {
  102. key, err := decrementEnrollmentKey(k.Value)
  103. if err != nil {
  104. if errors.Is(err, EnrollmentErrors.NoUsesRemaining) {
  105. return k.IsValid()
  106. }
  107. } else {
  108. k.UsesRemaining = key.UsesRemaining
  109. return true
  110. }
  111. return false
  112. }
  113. // Tokenize - tokenizes an enrollment key to be used via registration
  114. // and attaches it to the Token field on the struct
  115. func Tokenize(k *models.EnrollmentKey, serverAddr string) error {
  116. if len(serverAddr) == 0 || k == nil {
  117. return EnrollmentErrors.FailedToTokenize
  118. }
  119. newToken := models.EnrollmentToken{
  120. Server: serverAddr,
  121. Value: k.Value,
  122. }
  123. data, err := json.Marshal(&newToken)
  124. if err != nil {
  125. return err
  126. }
  127. k.Token = b64.StdEncoding.EncodeToString(data)
  128. return nil
  129. }
  130. // DeTokenize - detokenizes a base64 encoded string
  131. // and finds the associated enrollment key
  132. func DeTokenize(b64Token string) (*models.EnrollmentKey, error) {
  133. if len(b64Token) == 0 {
  134. return nil, EnrollmentErrors.FailedToDeTokenize
  135. }
  136. tokenData, err := b64.StdEncoding.DecodeString(b64Token)
  137. if err != nil {
  138. return nil, err
  139. }
  140. var newToken models.EnrollmentToken
  141. err = json.Unmarshal(tokenData, &newToken)
  142. if err != nil {
  143. return nil, err
  144. }
  145. k, err := GetEnrollmentKey(newToken.Value)
  146. if err != nil {
  147. return nil, err
  148. }
  149. return k, nil
  150. }
  151. // == private ==
  152. // decrementEnrollmentKey - decrements the uses on a key if above 0 remaining
  153. func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
  154. k, err := GetEnrollmentKey(value)
  155. if err != nil {
  156. return nil, err
  157. }
  158. if k.UsesRemaining == 0 {
  159. return nil, EnrollmentErrors.NoUsesRemaining
  160. }
  161. k.UsesRemaining = k.UsesRemaining - 1
  162. if err = upsertEnrollmentKey(k); err != nil {
  163. return nil, err
  164. }
  165. return k, nil
  166. }
  167. func upsertEnrollmentKey(k *models.EnrollmentKey) error {
  168. if k == nil {
  169. return EnrollmentErrors.InvalidKey
  170. }
  171. data, err := json.Marshal(k)
  172. if err != nil {
  173. return err
  174. }
  175. return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)
  176. }
  177. func getUniqueEnrollmentID() (string, error) {
  178. currentKeys, err := getEnrollmentKeysMap()
  179. if err != nil {
  180. return "", err
  181. }
  182. newID := RandomString(models.EnrollmentKeyLength)
  183. for _, ok := currentKeys[newID]; ok; {
  184. newID = RandomString(models.EnrollmentKeyLength)
  185. }
  186. return newID, nil
  187. }
  188. func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) {
  189. records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
  190. if err != nil {
  191. if !database.IsEmptyRecord(err) {
  192. return nil, err
  193. }
  194. }
  195. if records == nil {
  196. records = make(map[string]string)
  197. }
  198. currentKeys := make(map[string]*models.EnrollmentKey, 0)
  199. if len(records) > 0 {
  200. for k := range records {
  201. var currentKey models.EnrollmentKey
  202. if err = json.Unmarshal([]byte(records[k]), &currentKey); err != nil {
  203. continue
  204. }
  205. currentKeys[k] = &currentKey
  206. }
  207. }
  208. return currentKeys, nil
  209. }
  210. // UserHasNetworksAccess - checks if a user `u` has access to all `networks`
  211. func UserHasNetworksAccess(networks []string, u *models.User) bool {
  212. if u.IsAdmin {
  213. return true
  214. }
  215. for _, n := range networks {
  216. if !slices.Contains(u.Networks, n) {
  217. return false
  218. }
  219. }
  220. return true
  221. }