enrollmentkey.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. package logic
  2. import (
  3. b64 "encoding/base64"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "sync"
  8. "time"
  9. "github.com/google/uuid"
  10. "github.com/gravitl/netmaker/database"
  11. "github.com/gravitl/netmaker/models"
  12. "github.com/gravitl/netmaker/servercfg"
  13. "golang.org/x/exp/slices"
  14. )
  15. // EnrollmentErrors - struct for holding EnrollmentKey error messages
  16. var EnrollmentErrors = struct {
  17. InvalidCreate error
  18. NoKeyFound error
  19. InvalidKey error
  20. NoUsesRemaining error
  21. FailedToTokenize error
  22. FailedToDeTokenize error
  23. }{
  24. InvalidCreate: fmt.Errorf("failed to create enrollment key. paramters invalid"),
  25. NoKeyFound: fmt.Errorf("no enrollmentkey found"),
  26. InvalidKey: fmt.Errorf("invalid key provided"),
  27. NoUsesRemaining: fmt.Errorf("no uses remaining"),
  28. FailedToTokenize: fmt.Errorf("failed to tokenize"),
  29. FailedToDeTokenize: fmt.Errorf("failed to detokenize"),
  30. }
  31. var (
  32. enrollmentkeyCacheMutex = &sync.RWMutex{}
  33. enrollmentkeyCacheMap = make(map[string]models.EnrollmentKey)
  34. )
  35. // CreateEnrollmentKey - creates a new enrollment key in db
  36. func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID, defaultKey bool) (*models.EnrollmentKey, error) {
  37. newKeyID, err := getUniqueEnrollmentID()
  38. if err != nil {
  39. return nil, err
  40. }
  41. k := &models.EnrollmentKey{
  42. Value: newKeyID,
  43. Expiration: time.Time{},
  44. UsesRemaining: 0,
  45. Unlimited: unlimited,
  46. Networks: []string{},
  47. Tags: []string{},
  48. Type: models.Undefined,
  49. Relay: relay,
  50. Groups: groups,
  51. }
  52. if uses > 0 {
  53. k.UsesRemaining = uses
  54. k.Type = models.Uses
  55. } else if !expiration.IsZero() {
  56. k.Expiration = expiration
  57. k.Type = models.TimeExpiration
  58. } else if k.Unlimited {
  59. k.Type = models.Unlimited
  60. }
  61. if len(networks) > 0 {
  62. k.Networks = networks
  63. }
  64. if len(tags) > 0 {
  65. k.Tags = tags
  66. }
  67. if err := k.Validate(); err != nil {
  68. return nil, err
  69. }
  70. if relay != uuid.Nil {
  71. relayNode, err := GetNodeByID(relay.String())
  72. if err != nil {
  73. return nil, err
  74. }
  75. if !slices.Contains(k.Networks, relayNode.Network) {
  76. return nil, errors.New("relay node not in key's networks")
  77. }
  78. if !relayNode.IsRelay {
  79. return nil, errors.New("relay node is not a relay")
  80. }
  81. }
  82. if err = upsertEnrollmentKey(k); err != nil {
  83. return nil, err
  84. }
  85. return k, nil
  86. }
  87. // UpdateEnrollmentKey - updates an existing enrollment key's associated relay
  88. func UpdateEnrollmentKey(keyId string, relayId uuid.UUID, groups []models.TagID) (*models.EnrollmentKey, error) {
  89. key, err := GetEnrollmentKey(keyId)
  90. if err != nil {
  91. return nil, err
  92. }
  93. if relayId != uuid.Nil {
  94. relayNode, err := GetNodeByID(relayId.String())
  95. if err != nil {
  96. return nil, err
  97. }
  98. if !slices.Contains(key.Networks, relayNode.Network) {
  99. return nil, errors.New("relay node not in key's networks")
  100. }
  101. if !relayNode.IsRelay {
  102. return nil, errors.New("relay node is not a relay")
  103. }
  104. }
  105. key.Relay = relayId
  106. key.Groups = groups
  107. if err = upsertEnrollmentKey(&key); err != nil {
  108. return nil, err
  109. }
  110. return &key, nil
  111. }
  112. // GetAllEnrollmentKeys - fetches all enrollment keys from DB
  113. // TODO drop double pointer
  114. func GetAllEnrollmentKeys() ([]models.EnrollmentKey, error) {
  115. currentKeys, err := getEnrollmentKeysMap()
  116. if err != nil {
  117. return nil, err
  118. }
  119. var currentKeysList = []models.EnrollmentKey{}
  120. for k := range currentKeys {
  121. currentKeysList = append(currentKeysList, currentKeys[k])
  122. }
  123. return currentKeysList, nil
  124. }
  125. // GetEnrollmentKey - fetches a single enrollment key
  126. // returns nil and error if not found
  127. func GetEnrollmentKey(value string) (key models.EnrollmentKey, err error) {
  128. currentKeys, err := getEnrollmentKeysMap()
  129. if err != nil {
  130. return key, err
  131. }
  132. if key, ok := currentKeys[value]; ok {
  133. return key, nil
  134. }
  135. return key, EnrollmentErrors.NoKeyFound
  136. }
  137. func deleteEnrollmentkeyFromCache(key string) {
  138. enrollmentkeyCacheMutex.Lock()
  139. delete(enrollmentkeyCacheMap, key)
  140. enrollmentkeyCacheMutex.Unlock()
  141. }
  142. // DeleteEnrollmentKey - delete's a given enrollment key by value
  143. func DeleteEnrollmentKey(value string, force bool) error {
  144. key, err := GetEnrollmentKey(value)
  145. if err != nil {
  146. return err
  147. }
  148. if key.Default && !force {
  149. return errors.New("cannot delete default network key")
  150. }
  151. err = database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
  152. if err == nil {
  153. if servercfg.CacheEnabled() {
  154. deleteEnrollmentkeyFromCache(value)
  155. }
  156. }
  157. return err
  158. }
  159. // TryToUseEnrollmentKey - checks first if key can be decremented
  160. // returns true if it is decremented or isvalid
  161. func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool {
  162. key, err := decrementEnrollmentKey(k.Value)
  163. if err != nil {
  164. if errors.Is(err, EnrollmentErrors.NoUsesRemaining) {
  165. return k.IsValid()
  166. }
  167. } else {
  168. k.UsesRemaining = key.UsesRemaining
  169. return true
  170. }
  171. return false
  172. }
  173. // Tokenize - tokenizes an enrollment key to be used via registration
  174. // and attaches it to the Token field on the struct
  175. func Tokenize(k *models.EnrollmentKey, serverAddr string) error {
  176. if len(serverAddr) == 0 || k == nil {
  177. return EnrollmentErrors.FailedToTokenize
  178. }
  179. newToken := models.EnrollmentToken{
  180. Server: serverAddr,
  181. Value: k.Value,
  182. }
  183. data, err := json.Marshal(&newToken)
  184. if err != nil {
  185. return err
  186. }
  187. k.Token = b64.StdEncoding.EncodeToString(data)
  188. return nil
  189. }
  190. // DeTokenize - detokenizes a base64 encoded string
  191. // and finds the associated enrollment key
  192. func DeTokenize(b64Token string) (*models.EnrollmentKey, error) {
  193. if len(b64Token) == 0 {
  194. return nil, EnrollmentErrors.FailedToDeTokenize
  195. }
  196. tokenData, err := b64.StdEncoding.DecodeString(b64Token)
  197. if err != nil {
  198. return nil, err
  199. }
  200. var newToken models.EnrollmentToken
  201. err = json.Unmarshal(tokenData, &newToken)
  202. if err != nil {
  203. return nil, err
  204. }
  205. k, err := GetEnrollmentKey(newToken.Value)
  206. if err != nil {
  207. return nil, err
  208. }
  209. return &k, nil
  210. }
  211. // == private ==
  212. // decrementEnrollmentKey - decrements the uses on a key if above 0 remaining
  213. func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
  214. k, err := GetEnrollmentKey(value)
  215. if err != nil {
  216. return nil, err
  217. }
  218. if k.UsesRemaining == 0 {
  219. return nil, EnrollmentErrors.NoUsesRemaining
  220. }
  221. k.UsesRemaining = k.UsesRemaining - 1
  222. if err = upsertEnrollmentKey(&k); err != nil {
  223. return nil, err
  224. }
  225. return &k, nil
  226. }
  227. func upsertEnrollmentKey(k *models.EnrollmentKey) error {
  228. if k == nil {
  229. return EnrollmentErrors.InvalidKey
  230. }
  231. data, err := json.Marshal(k)
  232. if err != nil {
  233. return err
  234. }
  235. err = database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)
  236. if err == nil {
  237. if servercfg.CacheEnabled() {
  238. storeEnrollmentkeyInCache(k.Value, *k)
  239. }
  240. }
  241. return nil
  242. }
  243. func getUniqueEnrollmentID() (string, error) {
  244. currentKeys, err := getEnrollmentKeysMap()
  245. if err != nil {
  246. return "", err
  247. }
  248. newID := RandomString(models.EnrollmentKeyLength)
  249. for _, ok := currentKeys[newID]; ok; {
  250. newID = RandomString(models.EnrollmentKeyLength)
  251. }
  252. return newID, nil
  253. }
  254. func getEnrollmentkeysFromCache() map[string]models.EnrollmentKey {
  255. return enrollmentkeyCacheMap
  256. }
  257. func storeEnrollmentkeyInCache(key string, enrollmentkey models.EnrollmentKey) {
  258. enrollmentkeyCacheMutex.Lock()
  259. enrollmentkeyCacheMap[key] = enrollmentkey
  260. enrollmentkeyCacheMutex.Unlock()
  261. }
  262. func getEnrollmentKeysMap() (map[string]models.EnrollmentKey, error) {
  263. if servercfg.CacheEnabled() {
  264. keys := getEnrollmentkeysFromCache()
  265. if len(keys) != 0 {
  266. return keys, nil
  267. }
  268. }
  269. records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
  270. if err != nil {
  271. if !database.IsEmptyRecord(err) {
  272. return nil, err
  273. }
  274. }
  275. if records == nil {
  276. records = make(map[string]string)
  277. }
  278. currentKeys := make(map[string]models.EnrollmentKey, 0)
  279. if len(records) > 0 {
  280. for k := range records {
  281. var currentKey models.EnrollmentKey
  282. if err = json.Unmarshal([]byte(records[k]), &currentKey); err != nil {
  283. continue
  284. }
  285. currentKeys[k] = currentKey
  286. if servercfg.CacheEnabled() {
  287. storeEnrollmentkeyInCache(currentKey.Value, currentKey)
  288. }
  289. }
  290. }
  291. return currentKeys, nil
  292. }
  293. func RemoveTagFromEnrollmentKeys(deletedTagID models.TagID) {
  294. keys, _ := GetAllEnrollmentKeys()
  295. for _, key := range keys {
  296. newTags := []models.TagID{}
  297. update := false
  298. for _, tagID := range key.Groups {
  299. if tagID == deletedTagID {
  300. update = true
  301. continue
  302. }
  303. newTags = append(newTags, tagID)
  304. }
  305. if update {
  306. key.Groups = newTags
  307. upsertEnrollmentKey(&key)
  308. }
  309. }
  310. }