enrollmentkey.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. package logic
  2. import (
  3. b64 "encoding/base64"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "time"
  8. "github.com/gravitl/netmaker/database"
  9. "github.com/gravitl/netmaker/models"
  10. )
  11. // EnrollmentErrors - struct for holding EnrollmentKey error messages
  12. var EnrollmentErrors = struct {
  13. InvalidCreate error
  14. NoKeyFound error
  15. InvalidKey error
  16. NoUsesRemaining error
  17. FailedToTokenize error
  18. FailedToDeTokenize error
  19. }{
  20. InvalidCreate: fmt.Errorf("invalid enrollment key created"),
  21. NoKeyFound: fmt.Errorf("no enrollmentkey found"),
  22. InvalidKey: fmt.Errorf("invalid key provided"),
  23. NoUsesRemaining: fmt.Errorf("no uses remaining"),
  24. FailedToTokenize: fmt.Errorf("failed to tokenize"),
  25. FailedToDeTokenize: fmt.Errorf("failed to detokenize"),
  26. }
  27. // CreateEnrollmentKey - creates a new enrollment key in db
  28. func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) {
  29. newKeyID, err := getUniqueEnrollmentID()
  30. if err != nil {
  31. return nil, err
  32. }
  33. k = &models.EnrollmentKey{
  34. Value: newKeyID,
  35. Expiration: time.Time{},
  36. UsesRemaining: 0,
  37. Unlimited: unlimited,
  38. Networks: []string{},
  39. Tags: []string{},
  40. Type: models.Undefined,
  41. }
  42. if uses > 0 {
  43. k.UsesRemaining = uses
  44. k.Type = models.Uses
  45. } else if !expiration.IsZero() {
  46. k.Expiration = expiration
  47. k.Type = models.TimeExpiration
  48. } else if k.Unlimited {
  49. k.Type = models.Unlimited
  50. }
  51. if len(networks) > 0 {
  52. k.Networks = networks
  53. }
  54. if len(tags) > 0 {
  55. k.Tags = tags
  56. }
  57. if ok := k.Validate(); !ok {
  58. return nil, EnrollmentErrors.InvalidCreate
  59. }
  60. if err = upsertEnrollmentKey(k); err != nil {
  61. return nil, err
  62. }
  63. return
  64. }
  65. // GetAllEnrollmentKeys - fetches all enrollment keys from DB
  66. // TODO drop double pointer
  67. func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) {
  68. currentKeys, err := getEnrollmentKeysMap()
  69. if err != nil {
  70. return nil, err
  71. }
  72. var currentKeysList = []*models.EnrollmentKey{}
  73. for k := range currentKeys {
  74. currentKeysList = append(currentKeysList, currentKeys[k])
  75. }
  76. return currentKeysList, nil
  77. }
  78. // GetEnrollmentKey - fetches a single enrollment key
  79. // returns nil and error if not found
  80. func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) {
  81. currentKeys, err := getEnrollmentKeysMap()
  82. if err != nil {
  83. return nil, err
  84. }
  85. if key, ok := currentKeys[value]; ok {
  86. return key, nil
  87. }
  88. return nil, EnrollmentErrors.NoKeyFound
  89. }
  90. // DeleteEnrollmentKey - delete's a given enrollment key by value
  91. func DeleteEnrollmentKey(value string) error {
  92. _, err := GetEnrollmentKey(value)
  93. if err != nil {
  94. return err
  95. }
  96. return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
  97. }
  98. // TryToUseEnrollmentKey - checks first if key can be decremented
  99. // returns true if it is decremented or isvalid
  100. func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool {
  101. key, err := decrementEnrollmentKey(k.Value)
  102. if err != nil {
  103. if errors.Is(err, EnrollmentErrors.NoUsesRemaining) {
  104. return k.IsValid()
  105. }
  106. } else {
  107. k.UsesRemaining = key.UsesRemaining
  108. return true
  109. }
  110. return false
  111. }
  112. // Tokenize - tokenizes an enrollment key to be used via registration
  113. // and attaches it to the Token field on the struct
  114. func Tokenize(k *models.EnrollmentKey, serverAddr string) error {
  115. if len(serverAddr) == 0 || k == nil {
  116. return EnrollmentErrors.FailedToTokenize
  117. }
  118. newToken := models.EnrollmentToken{
  119. Server: serverAddr,
  120. Value: k.Value,
  121. }
  122. data, err := json.Marshal(&newToken)
  123. if err != nil {
  124. return err
  125. }
  126. k.Token = b64.StdEncoding.EncodeToString(data)
  127. return nil
  128. }
  129. // DeTokenize - detokenizes a base64 encoded string
  130. // and finds the associated enrollment key
  131. func DeTokenize(b64Token string) (*models.EnrollmentKey, error) {
  132. if len(b64Token) == 0 {
  133. return nil, EnrollmentErrors.FailedToDeTokenize
  134. }
  135. tokenData, err := b64.StdEncoding.DecodeString(b64Token)
  136. if err != nil {
  137. return nil, err
  138. }
  139. var newToken models.EnrollmentToken
  140. err = json.Unmarshal(tokenData, &newToken)
  141. if err != nil {
  142. return nil, err
  143. }
  144. k, err := GetEnrollmentKey(newToken.Value)
  145. if err != nil {
  146. return nil, err
  147. }
  148. return k, nil
  149. }
  150. // == private ==
  151. // decrementEnrollmentKey - decrements the uses on a key if above 0 remaining
  152. func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
  153. k, err := GetEnrollmentKey(value)
  154. if err != nil {
  155. return nil, err
  156. }
  157. if k.UsesRemaining == 0 {
  158. return nil, EnrollmentErrors.NoUsesRemaining
  159. }
  160. k.UsesRemaining = k.UsesRemaining - 1
  161. if err = upsertEnrollmentKey(k); err != nil {
  162. return nil, err
  163. }
  164. return k, nil
  165. }
  166. func upsertEnrollmentKey(k *models.EnrollmentKey) error {
  167. if k == nil {
  168. return EnrollmentErrors.InvalidKey
  169. }
  170. data, err := json.Marshal(k)
  171. if err != nil {
  172. return err
  173. }
  174. return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)
  175. }
  176. func getUniqueEnrollmentID() (string, error) {
  177. currentKeys, err := getEnrollmentKeysMap()
  178. if err != nil {
  179. return "", err
  180. }
  181. newID := RandomString(models.EnrollmentKeyLength)
  182. for _, ok := currentKeys[newID]; ok; {
  183. newID = RandomString(models.EnrollmentKeyLength)
  184. }
  185. return newID, nil
  186. }
  187. func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) {
  188. records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
  189. if err != nil {
  190. if !database.IsEmptyRecord(err) {
  191. return nil, err
  192. }
  193. }
  194. if records == nil {
  195. records = make(map[string]string)
  196. }
  197. currentKeys := make(map[string]*models.EnrollmentKey, 0)
  198. if len(records) > 0 {
  199. for k := range records {
  200. var currentKey models.EnrollmentKey
  201. if err = json.Unmarshal([]byte(records[k]), &currentKey); err != nil {
  202. continue
  203. }
  204. currentKeys[k] = &currentKey
  205. }
  206. }
  207. return currentKeys, nil
  208. }