acls.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. package logic
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "sort"
  6. "github.com/gravitl/netmaker/database"
  7. "github.com/gravitl/netmaker/models"
  8. )
  9. // InsertAcl - creates acl policy
  10. func InsertAcl(a models.Acl) error {
  11. d, err := json.Marshal(a)
  12. if err != nil {
  13. return err
  14. }
  15. return database.Insert(a.ID.String(), string(d), database.ACLS_TABLE_NAME)
  16. }
  17. func GetAcl(aID string) (models.Acl, error) {
  18. a := models.Acl{}
  19. d, err := database.FetchRecord(database.ACLS_TABLE_NAME, aID)
  20. if err != nil {
  21. return a, err
  22. }
  23. err = json.Unmarshal([]byte(d), &a)
  24. if err != nil {
  25. return a, err
  26. }
  27. return a, nil
  28. }
  29. func IsAclPolicyValid(acl models.Acl) bool {
  30. //check if src and dst are valid
  31. isValid := false
  32. switch acl.RuleType {
  33. case models.UserPolicy:
  34. // src list should only contain users
  35. for _, srcI := range acl.Src {
  36. if srcI.ID == "" || srcI.Value == "" {
  37. break
  38. }
  39. if srcI.ID != models.UserAclID &&
  40. srcI.ID != models.UserGroupAclID {
  41. break
  42. }
  43. // check if user group is valid
  44. if srcI.ID == models.UserAclID {
  45. _, err := GetUser(srcI.Value)
  46. if err != nil {
  47. break
  48. }
  49. } else if srcI.ID == models.UserGroupAclID {
  50. err := IsGroupValid(models.UserGroupID(srcI.Value))
  51. if err != nil {
  52. break
  53. }
  54. }
  55. }
  56. for _, dstI := range acl.Dst {
  57. if dstI.ID == "" || dstI.Value == "" {
  58. break
  59. }
  60. if dstI.ID == models.UserAclID ||
  61. dstI.ID == models.UserGroupAclID {
  62. break
  63. }
  64. if dstI.ID != models.DeviceAclID {
  65. break
  66. }
  67. // check if tag is valid
  68. _, err := GetTag(models.TagID(dstI.Value))
  69. if err != nil {
  70. break
  71. }
  72. }
  73. isValid = true
  74. case models.DevicePolicy:
  75. for _, srcI := range acl.Src {
  76. if srcI.ID == "" || srcI.Value == "" {
  77. break
  78. }
  79. if srcI.ID != models.DeviceAclID {
  80. break
  81. }
  82. // check if tag is valid
  83. _, err := GetTag(models.TagID(srcI.Value))
  84. if err != nil {
  85. break
  86. }
  87. }
  88. for _, dstI := range acl.Dst {
  89. if dstI.ID == "" || dstI.Value == "" {
  90. break
  91. }
  92. if dstI.ID != models.DeviceAclID {
  93. break
  94. }
  95. // check if tag is valid
  96. _, err := GetTag(models.TagID(dstI.Value))
  97. if err != nil {
  98. break
  99. }
  100. }
  101. isValid = true
  102. }
  103. return isValid
  104. }
  105. // UpdateAcl - updates allowed fields on acls and commits to DB
  106. func UpdateAcl(newAcl, acl models.Acl) error {
  107. if newAcl.Name != "" {
  108. acl.Name = newAcl.Name
  109. }
  110. acl.Src = newAcl.Src
  111. acl.Dst = newAcl.Dst
  112. acl.AllowedDirection = newAcl.AllowedDirection
  113. acl.Enabled = newAcl.Enabled
  114. d, err := json.Marshal(acl)
  115. if err != nil {
  116. return err
  117. }
  118. return database.Insert(acl.ID.String(), string(d), database.ACLS_TABLE_NAME)
  119. }
  120. // DeleteAcl - deletes acl policy
  121. func DeleteAcl(a models.Acl) error {
  122. return database.DeleteRecord(database.ACLS_TABLE_NAME, a.ID.String())
  123. }
  124. func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (models.Acl, error) {
  125. acls, _ := ListAcls(netID)
  126. for _, acl := range acls {
  127. if acl.Default && acl.RuleType == ruleType {
  128. return acl, nil
  129. }
  130. }
  131. return models.Acl{}, errors.New("default rule not found")
  132. }
  133. // listDevicePolicies - lists all device policies in a network
  134. func listDevicePolicies(netID models.NetworkID) []models.Acl {
  135. data, err := database.FetchRecords(database.TAG_TABLE_NAME)
  136. if err != nil && !database.IsEmptyRecord(err) {
  137. return []models.Acl{}
  138. }
  139. acls := []models.Acl{}
  140. for _, dataI := range data {
  141. acl := models.Acl{}
  142. err := json.Unmarshal([]byte(dataI), &acl)
  143. if err != nil {
  144. continue
  145. }
  146. if acl.NetworkID == netID && acl.RuleType == models.DevicePolicy {
  147. acls = append(acls, acl)
  148. }
  149. }
  150. return acls
  151. }
  152. // ListAcls - lists all acl policies
  153. func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
  154. data, err := database.FetchRecords(database.TAG_TABLE_NAME)
  155. if err != nil && !database.IsEmptyRecord(err) {
  156. return []models.Acl{}, err
  157. }
  158. acls := []models.Acl{}
  159. for _, dataI := range data {
  160. acl := models.Acl{}
  161. err := json.Unmarshal([]byte(dataI), &acl)
  162. if err != nil {
  163. continue
  164. }
  165. if acl.NetworkID == netID {
  166. acls = append(acls, acl)
  167. }
  168. }
  169. return acls, nil
  170. }
  171. func convAclTagToValueMap(acltags []models.AclPolicyTag) map[string]struct{} {
  172. aclValueMap := make(map[string]struct{})
  173. for _, aclTagI := range acltags {
  174. aclValueMap[aclTagI.ID.String()] = struct{}{}
  175. }
  176. return aclValueMap
  177. }
  178. func IsNodeAllowedToCommunicate(node, peer models.Node) bool {
  179. // check default policy if all allowed return true
  180. defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
  181. if err == nil {
  182. if defaultPolicy.Enabled {
  183. return true
  184. }
  185. }
  186. // list device policies
  187. policies := listDevicePolicies(models.NetworkID(peer.Network))
  188. for _, policy := range policies {
  189. srcMap := convAclTagToValueMap(policy.Src)
  190. dstMap := convAclTagToValueMap(policy.Dst)
  191. for tagID := range peer.Tags {
  192. if _, ok := dstMap[tagID.String()]; ok {
  193. for tagID := range node.Tags {
  194. if _, ok := srcMap[tagID.String()]; ok {
  195. return true
  196. }
  197. }
  198. }
  199. if _, ok := srcMap[tagID.String()]; ok {
  200. for tagID := range node.Tags {
  201. if _, ok := dstMap[tagID.String()]; ok {
  202. return true
  203. }
  204. }
  205. }
  206. }
  207. }
  208. return false
  209. }
  210. // SortTagEntrys - Sorts slice of Tag entries by their id
  211. func SortAclEntrys(acls []models.Acl) {
  212. sort.Slice(acls, func(i, j int) bool {
  213. return acls[i].Name < acls[j].Name
  214. })
  215. }