enrollmentkey.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. package logic
  2. import (
  3. b64 "encoding/base64"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "time"
  8. "github.com/google/uuid"
  9. "github.com/gravitl/netmaker/database"
  10. "github.com/gravitl/netmaker/models"
  11. "golang.org/x/exp/slices"
  12. )
  13. // EnrollmentErrors - struct for holding EnrollmentKey error messages
  14. var EnrollmentErrors = struct {
  15. InvalidCreate error
  16. NoKeyFound error
  17. InvalidKey error
  18. NoUsesRemaining error
  19. FailedToTokenize error
  20. FailedToDeTokenize error
  21. }{
  22. InvalidCreate: fmt.Errorf("failed to create enrollment key. paramters invalid"),
  23. NoKeyFound: fmt.Errorf("no enrollmentkey found"),
  24. InvalidKey: fmt.Errorf("invalid key provided"),
  25. NoUsesRemaining: fmt.Errorf("no uses remaining"),
  26. FailedToTokenize: fmt.Errorf("failed to tokenize"),
  27. FailedToDeTokenize: fmt.Errorf("failed to detokenize"),
  28. }
  29. // CreateEnrollmentKey - creates a new enrollment key in db
  30. func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) {
  31. newKeyID, err := getUniqueEnrollmentID()
  32. if err != nil {
  33. return nil, err
  34. }
  35. k := &models.EnrollmentKey{
  36. Value: newKeyID,
  37. Expiration: time.Time{},
  38. UsesRemaining: 0,
  39. Unlimited: unlimited,
  40. Networks: []string{},
  41. Tags: []string{},
  42. Type: models.Undefined,
  43. Relay: relay,
  44. }
  45. if uses > 0 {
  46. k.UsesRemaining = uses
  47. k.Type = models.Uses
  48. } else if !expiration.IsZero() {
  49. k.Expiration = expiration
  50. k.Type = models.TimeExpiration
  51. } else if k.Unlimited {
  52. k.Type = models.Unlimited
  53. }
  54. if len(networks) > 0 {
  55. k.Networks = networks
  56. }
  57. if len(tags) > 0 {
  58. k.Tags = tags
  59. }
  60. if err := k.Validate(); err != nil {
  61. return nil, err
  62. }
  63. if relay != uuid.Nil {
  64. relayNode, err := GetNodeByID(relay.String())
  65. if err != nil {
  66. return nil, err
  67. }
  68. if !slices.Contains(k.Networks, relayNode.Network) {
  69. return nil, errors.New("relay node not in key's networks")
  70. }
  71. if !relayNode.IsRelay {
  72. return nil, errors.New("relay node is not a relay")
  73. }
  74. }
  75. if err = upsertEnrollmentKey(k); err != nil {
  76. return nil, err
  77. }
  78. return k, nil
  79. }
  80. // UpdateEnrollmentKey - updates an existing enrollment key's associated relay
  81. func UpdateEnrollmentKey(keyId string, relayId uuid.UUID) (*models.EnrollmentKey, error) {
  82. key, err := GetEnrollmentKey(keyId)
  83. if err != nil {
  84. return nil, err
  85. }
  86. if relayId != uuid.Nil {
  87. relayNode, err := GetNodeByID(relayId.String())
  88. if err != nil {
  89. return nil, err
  90. }
  91. if !slices.Contains(key.Networks, relayNode.Network) {
  92. return nil, errors.New("relay node not in key's networks")
  93. }
  94. if !relayNode.IsRelay {
  95. return nil, errors.New("relay node is not a relay")
  96. }
  97. }
  98. key.Relay = relayId
  99. if err = upsertEnrollmentKey(key); err != nil {
  100. return nil, err
  101. }
  102. return key, nil
  103. }
  104. // GetAllEnrollmentKeys - fetches all enrollment keys from DB
  105. // TODO drop double pointer
  106. func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) {
  107. currentKeys, err := getEnrollmentKeysMap()
  108. if err != nil {
  109. return nil, err
  110. }
  111. var currentKeysList = []*models.EnrollmentKey{}
  112. for k := range currentKeys {
  113. currentKeysList = append(currentKeysList, currentKeys[k])
  114. }
  115. return currentKeysList, nil
  116. }
  117. // GetEnrollmentKey - fetches a single enrollment key
  118. // returns nil and error if not found
  119. func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) {
  120. currentKeys, err := getEnrollmentKeysMap()
  121. if err != nil {
  122. return nil, err
  123. }
  124. if key, ok := currentKeys[value]; ok {
  125. return key, nil
  126. }
  127. return nil, EnrollmentErrors.NoKeyFound
  128. }
  129. // DeleteEnrollmentKey - delete's a given enrollment key by value
  130. func DeleteEnrollmentKey(value string) error {
  131. _, err := GetEnrollmentKey(value)
  132. if err != nil {
  133. return err
  134. }
  135. return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
  136. }
  137. // TryToUseEnrollmentKey - checks first if key can be decremented
  138. // returns true if it is decremented or isvalid
  139. func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool {
  140. key, err := decrementEnrollmentKey(k.Value)
  141. if err != nil {
  142. if errors.Is(err, EnrollmentErrors.NoUsesRemaining) {
  143. return k.IsValid()
  144. }
  145. } else {
  146. k.UsesRemaining = key.UsesRemaining
  147. return true
  148. }
  149. return false
  150. }
  151. // Tokenize - tokenizes an enrollment key to be used via registration
  152. // and attaches it to the Token field on the struct
  153. func Tokenize(k *models.EnrollmentKey, serverAddr string) error {
  154. if len(serverAddr) == 0 || k == nil {
  155. return EnrollmentErrors.FailedToTokenize
  156. }
  157. newToken := models.EnrollmentToken{
  158. Server: serverAddr,
  159. Value: k.Value,
  160. }
  161. data, err := json.Marshal(&newToken)
  162. if err != nil {
  163. return err
  164. }
  165. k.Token = b64.StdEncoding.EncodeToString(data)
  166. return nil
  167. }
  168. // DeTokenize - detokenizes a base64 encoded string
  169. // and finds the associated enrollment key
  170. func DeTokenize(b64Token string) (*models.EnrollmentKey, error) {
  171. if len(b64Token) == 0 {
  172. return nil, EnrollmentErrors.FailedToDeTokenize
  173. }
  174. tokenData, err := b64.StdEncoding.DecodeString(b64Token)
  175. if err != nil {
  176. return nil, err
  177. }
  178. var newToken models.EnrollmentToken
  179. err = json.Unmarshal(tokenData, &newToken)
  180. if err != nil {
  181. return nil, err
  182. }
  183. k, err := GetEnrollmentKey(newToken.Value)
  184. if err != nil {
  185. return nil, err
  186. }
  187. return k, nil
  188. }
  189. // == private ==
  190. // decrementEnrollmentKey - decrements the uses on a key if above 0 remaining
  191. func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
  192. k, err := GetEnrollmentKey(value)
  193. if err != nil {
  194. return nil, err
  195. }
  196. if k.UsesRemaining == 0 {
  197. return nil, EnrollmentErrors.NoUsesRemaining
  198. }
  199. k.UsesRemaining = k.UsesRemaining - 1
  200. if err = upsertEnrollmentKey(k); err != nil {
  201. return nil, err
  202. }
  203. return k, nil
  204. }
  205. func upsertEnrollmentKey(k *models.EnrollmentKey) error {
  206. if k == nil {
  207. return EnrollmentErrors.InvalidKey
  208. }
  209. data, err := json.Marshal(k)
  210. if err != nil {
  211. return err
  212. }
  213. return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)
  214. }
  215. func getUniqueEnrollmentID() (string, error) {
  216. currentKeys, err := getEnrollmentKeysMap()
  217. if err != nil {
  218. return "", err
  219. }
  220. newID := RandomString(models.EnrollmentKeyLength)
  221. for _, ok := currentKeys[newID]; ok; {
  222. newID = RandomString(models.EnrollmentKeyLength)
  223. }
  224. return newID, nil
  225. }
  226. func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) {
  227. records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
  228. if err != nil {
  229. if !database.IsEmptyRecord(err) {
  230. return nil, err
  231. }
  232. }
  233. if records == nil {
  234. records = make(map[string]string)
  235. }
  236. currentKeys := make(map[string]*models.EnrollmentKey, 0)
  237. if len(records) > 0 {
  238. for k := range records {
  239. var currentKey models.EnrollmentKey
  240. if err = json.Unmarshal([]byte(records[k]), &currentKey); err != nil {
  241. continue
  242. }
  243. currentKeys[k] = &currentKey
  244. }
  245. }
  246. return currentKeys, nil
  247. }