accesskeys.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. package logic
  2. import (
  3. "encoding/base64"
  4. "encoding/json"
  5. "errors"
  6. "math/rand"
  7. "time"
  8. "github.com/go-playground/validator/v10"
  9. "github.com/gravitl/netmaker/database"
  10. "github.com/gravitl/netmaker/logger"
  11. "github.com/gravitl/netmaker/models"
  12. "github.com/gravitl/netmaker/servercfg"
  13. )
  14. const (
  15. charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
  16. )
  17. // CreateAccessKey - create access key
  18. func CreateAccessKey(accesskey models.AccessKey, network models.Network) (models.AccessKey, error) {
  19. if accesskey.Name == "" {
  20. accesskey.Name = genKeyName()
  21. }
  22. if accesskey.Value == "" {
  23. accesskey.Value = genKey()
  24. }
  25. if accesskey.Uses == 0 {
  26. accesskey.Uses = 1
  27. }
  28. checkkeys, err := GetKeys(network.NetID)
  29. if err != nil {
  30. return models.AccessKey{}, errors.New("could not retrieve network keys")
  31. }
  32. for _, key := range checkkeys {
  33. if key.Name == accesskey.Name {
  34. return models.AccessKey{}, errors.New("duplicate AccessKey Name")
  35. }
  36. }
  37. privAddr := ""
  38. if network.IsLocal != "" {
  39. privAddr = network.LocalRange
  40. }
  41. netID := network.NetID
  42. commsNetID, err := FetchCommsNetID()
  43. if err != nil {
  44. return models.AccessKey{}, errors.New("could not retrieve comms netid")
  45. }
  46. var accessToken models.AccessToken
  47. s := servercfg.GetServerConfig()
  48. servervals := models.ServerConfig{
  49. GRPCConnString: s.GRPCConnString,
  50. GRPCSSL: s.GRPCSSL,
  51. CommsNetwork: commsNetID,
  52. }
  53. accessToken.ServerConfig = servervals
  54. accessToken.ClientConfig.Network = netID
  55. accessToken.ClientConfig.Key = accesskey.Value
  56. accessToken.ClientConfig.LocalRange = privAddr
  57. tokenjson, err := json.Marshal(accessToken)
  58. if err != nil {
  59. return accesskey, err
  60. }
  61. accesskey.AccessString = base64.StdEncoding.EncodeToString([]byte(tokenjson))
  62. //validate accesskey
  63. v := validator.New()
  64. err = v.Struct(accesskey)
  65. if err != nil {
  66. for _, e := range err.(validator.ValidationErrors) {
  67. logger.Log(1, "validator", e.Error())
  68. }
  69. return models.AccessKey{}, err
  70. }
  71. network.AccessKeys = append(network.AccessKeys, accesskey)
  72. data, err := json.Marshal(&network)
  73. if err != nil {
  74. return models.AccessKey{}, err
  75. }
  76. if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
  77. return models.AccessKey{}, err
  78. }
  79. return accesskey, nil
  80. }
  81. // DeleteKey - deletes a key
  82. func DeleteKey(keyname, netname string) error {
  83. network, err := GetParentNetwork(netname)
  84. if err != nil {
  85. return err
  86. }
  87. //basically, turn the list of access keys into the list of access keys before and after the item
  88. //have not done any error handling for if there's like...1 item. I think it works? need to test.
  89. found := false
  90. var updatedKeys []models.AccessKey
  91. for _, currentkey := range network.AccessKeys {
  92. if currentkey.Name == keyname {
  93. found = true
  94. } else {
  95. updatedKeys = append(updatedKeys, currentkey)
  96. }
  97. }
  98. if !found {
  99. return errors.New("key " + keyname + " does not exist")
  100. }
  101. network.AccessKeys = updatedKeys
  102. data, err := json.Marshal(&network)
  103. if err != nil {
  104. return err
  105. }
  106. if err := database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
  107. return err
  108. }
  109. return nil
  110. }
  111. // GetKeys - fetches keys for network
  112. func GetKeys(net string) ([]models.AccessKey, error) {
  113. record, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, net)
  114. if err != nil {
  115. return []models.AccessKey{}, err
  116. }
  117. network, err := ParseNetwork(record)
  118. if err != nil {
  119. return []models.AccessKey{}, err
  120. }
  121. return network.AccessKeys, nil
  122. }
  123. // DecrimentKey - decriments key uses
  124. func DecrimentKey(networkName string, keyvalue string) {
  125. var network models.Network
  126. network, err := GetParentNetwork(networkName)
  127. if err != nil || network.IsComms == "yes" {
  128. return
  129. }
  130. for i := len(network.AccessKeys) - 1; i >= 0; i-- {
  131. currentkey := network.AccessKeys[i]
  132. if currentkey.Value == keyvalue {
  133. network.AccessKeys[i].Uses--
  134. if network.AccessKeys[i].Uses < 1 {
  135. network.AccessKeys = append(network.AccessKeys[:i],
  136. network.AccessKeys[i+1:]...)
  137. break
  138. }
  139. }
  140. }
  141. if newNetworkData, err := json.Marshal(&network); err != nil {
  142. logger.Log(2, "failed to decrement key")
  143. return
  144. } else {
  145. database.Insert(network.NetID, string(newNetworkData), database.NETWORKS_TABLE_NAME)
  146. }
  147. }
  148. // IsKeyValid - check if key is valid
  149. func IsKeyValid(networkname string, keyvalue string) bool {
  150. network, err := GetParentNetwork(networkname)
  151. if err != nil {
  152. return false
  153. }
  154. accesskeys := network.AccessKeys
  155. if network.IsComms == "yes" {
  156. accesskeys = getAllAccessKeys()
  157. }
  158. var key models.AccessKey
  159. foundkey := false
  160. isvalid := false
  161. for i := len(accesskeys) - 1; i >= 0; i-- {
  162. currentkey := accesskeys[i]
  163. if currentkey.Value == keyvalue {
  164. key = currentkey
  165. foundkey = true
  166. }
  167. }
  168. if foundkey {
  169. if key.Uses > 0 {
  170. isvalid = true
  171. }
  172. }
  173. return isvalid
  174. }
  175. // RemoveKeySensitiveInfo - remove sensitive key info
  176. func RemoveKeySensitiveInfo(keys []models.AccessKey) []models.AccessKey {
  177. var returnKeys []models.AccessKey
  178. for _, key := range keys {
  179. key.Value = models.PLACEHOLDER_KEY_TEXT
  180. key.AccessString = models.PLACEHOLDER_TOKEN_TEXT
  181. returnKeys = append(returnKeys, key)
  182. }
  183. return returnKeys
  184. }
  185. // == private methods ==
  186. func genKeyName() string {
  187. var seededRand *rand.Rand = rand.New(
  188. rand.NewSource(time.Now().UnixNano()))
  189. length := 5
  190. b := make([]byte, length)
  191. for i := range b {
  192. b[i] = charset[seededRand.Intn(len(charset))]
  193. }
  194. return "key" + string(b)
  195. }
  196. func genKey() string {
  197. var seededRand *rand.Rand = rand.New(
  198. rand.NewSource(time.Now().UnixNano()))
  199. length := 16
  200. b := make([]byte, length)
  201. for i := range b {
  202. b[i] = charset[seededRand.Intn(len(charset))]
  203. }
  204. return string(b)
  205. }
  206. func getAllAccessKeys() []models.AccessKey {
  207. var accesskeys = make([]models.AccessKey, 0)
  208. networks, err := GetNetworks()
  209. if err != nil {
  210. return accesskeys
  211. }
  212. for i := range networks {
  213. accesskeys = append(accesskeys, networks[i].AccessKeys...)
  214. }
  215. return accesskeys
  216. }