acls.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. package logic
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "sort"
  7. "time"
  8. "github.com/gravitl/netmaker/database"
  9. "github.com/gravitl/netmaker/models"
  10. )
  11. // CreateDefaultAclNetworkPolicies - create default acl network policies
  12. func CreateDefaultAclNetworkPolicies(netID models.NetworkID) {
  13. if netID.String() == "" {
  14. return
  15. }
  16. if !IsAclExists(fmt.Sprintf("%s.%s", netID, "all-nodes")) {
  17. defaultDeviceAcl := models.Acl{
  18. ID: fmt.Sprintf("%s.%s", netID, "all-nodes"),
  19. Name: "All Nodes",
  20. MetaData: "This Policy allows all nodes in the network to communicate with each other",
  21. Default: true,
  22. NetworkID: netID,
  23. RuleType: models.DevicePolicy,
  24. Src: []models.AclPolicyTag{
  25. {
  26. ID: models.DeviceAclID,
  27. Value: "*",
  28. }},
  29. Dst: []models.AclPolicyTag{
  30. {
  31. ID: models.DeviceAclID,
  32. Value: "*",
  33. }},
  34. AllowedDirection: models.TrafficDirectionBi,
  35. Enabled: true,
  36. CreatedBy: "auto",
  37. CreatedAt: time.Now().UTC(),
  38. }
  39. InsertAcl(defaultDeviceAcl)
  40. }
  41. if !IsAclExists(fmt.Sprintf("%s.%s", netID, "all-users")) {
  42. defaultUserAcl := models.Acl{
  43. ID: fmt.Sprintf("%s.%s", netID, "all-users"),
  44. Default: true,
  45. Name: "All Users",
  46. MetaData: "This policy gives access to everything in the network for an user",
  47. NetworkID: netID,
  48. RuleType: models.UserPolicy,
  49. Src: []models.AclPolicyTag{
  50. {
  51. ID: models.UserAclID,
  52. Value: "*",
  53. },
  54. {
  55. ID: models.UserGroupAclID,
  56. Value: "*",
  57. },
  58. },
  59. Dst: []models.AclPolicyTag{{
  60. ID: models.DeviceAclID,
  61. Value: "*",
  62. }},
  63. AllowedDirection: models.TrafficDirectionUni,
  64. Enabled: true,
  65. CreatedBy: "auto",
  66. CreatedAt: time.Now().UTC(),
  67. }
  68. InsertAcl(defaultUserAcl)
  69. }
  70. if !IsAclExists(fmt.Sprintf("%s.%s", netID, "all-remote-access-gws")) {
  71. defaultUserAcl := models.Acl{
  72. ID: fmt.Sprintf("%s.%s", netID, "all-remote-access-gws"),
  73. Default: true,
  74. Name: "All Remote Access Gateways",
  75. NetworkID: netID,
  76. RuleType: models.DevicePolicy,
  77. Src: []models.AclPolicyTag{
  78. {
  79. ID: models.DeviceAclID,
  80. Value: fmt.Sprintf("%s.%s", netID, models.RemoteAccessTagName),
  81. },
  82. },
  83. Dst: []models.AclPolicyTag{
  84. {
  85. ID: models.DeviceAclID,
  86. Value: "*",
  87. },
  88. },
  89. AllowedDirection: models.TrafficDirectionBi,
  90. Enabled: true,
  91. CreatedBy: "auto",
  92. CreatedAt: time.Now().UTC(),
  93. }
  94. InsertAcl(defaultUserAcl)
  95. }
  96. CreateDefaultUserPolicies(netID)
  97. }
  98. // DeleteDefaultNetworkPolicies - deletes all default network acl policies
  99. func DeleteDefaultNetworkPolicies(netId models.NetworkID) {
  100. acls, _ := ListAcls(netId)
  101. for _, acl := range acls {
  102. if acl.NetworkID == netId && acl.Default {
  103. DeleteAcl(acl)
  104. }
  105. }
  106. }
  107. // ValidateCreateAclReq - validates create req for acl
  108. func ValidateCreateAclReq(req models.Acl) error {
  109. // check if acl network exists
  110. _, err := GetNetwork(req.NetworkID.String())
  111. if err != nil {
  112. return errors.New("failed to get network details for " + req.NetworkID.String())
  113. }
  114. // err = CheckIDSyntax(req.Name)
  115. // if err != nil {
  116. // return err
  117. // }
  118. return nil
  119. }
  120. // InsertAcl - creates acl policy
  121. func InsertAcl(a models.Acl) error {
  122. d, err := json.Marshal(a)
  123. if err != nil {
  124. return err
  125. }
  126. return database.Insert(a.ID, string(d), database.ACLS_TABLE_NAME)
  127. }
  128. // GetAcl - gets acl info by id
  129. func GetAcl(aID string) (models.Acl, error) {
  130. a := models.Acl{}
  131. d, err := database.FetchRecord(database.ACLS_TABLE_NAME, aID)
  132. if err != nil {
  133. return a, err
  134. }
  135. err = json.Unmarshal([]byte(d), &a)
  136. if err != nil {
  137. return a, err
  138. }
  139. return a, nil
  140. }
  141. // IsAclExists - checks if acl exists
  142. func IsAclExists(aclID string) bool {
  143. _, err := GetAcl(aclID)
  144. return err == nil
  145. }
  146. // IsAclPolicyValid - validates if acl policy is valid
  147. func IsAclPolicyValid(acl models.Acl) bool {
  148. //check if src and dst are valid
  149. switch acl.RuleType {
  150. case models.UserPolicy:
  151. // src list should only contain users
  152. for _, srcI := range acl.Src {
  153. if srcI.ID == "" || srcI.Value == "" {
  154. return false
  155. }
  156. if srcI.Value == "*" {
  157. continue
  158. }
  159. if srcI.ID != models.UserAclID && srcI.ID != models.UserGroupAclID {
  160. return false
  161. }
  162. // check if user group is valid
  163. if srcI.ID == models.UserAclID {
  164. _, err := GetUser(srcI.Value)
  165. if err != nil {
  166. return false
  167. }
  168. } else if srcI.ID == models.UserGroupAclID {
  169. err := IsGroupValid(models.UserGroupID(srcI.Value))
  170. if err != nil {
  171. return false
  172. }
  173. // check if group belongs to this network
  174. netGrps := GetUserGroupsInNetwork(acl.NetworkID)
  175. if _, ok := netGrps[models.UserGroupID(srcI.Value)]; !ok {
  176. return false
  177. }
  178. }
  179. }
  180. for _, dstI := range acl.Dst {
  181. if dstI.ID == "" || dstI.Value == "" {
  182. return false
  183. }
  184. if dstI.ID != models.DeviceAclID {
  185. return false
  186. }
  187. if dstI.Value == "*" {
  188. continue
  189. }
  190. // check if tag is valid
  191. _, err := GetTag(models.TagID(dstI.Value))
  192. if err != nil {
  193. return false
  194. }
  195. }
  196. case models.DevicePolicy:
  197. for _, srcI := range acl.Src {
  198. if srcI.ID == "" || srcI.Value == "" {
  199. return false
  200. }
  201. if srcI.ID != models.DeviceAclID {
  202. return false
  203. }
  204. if srcI.Value == "*" {
  205. continue
  206. }
  207. // check if tag is valid
  208. _, err := GetTag(models.TagID(srcI.Value))
  209. if err != nil {
  210. return false
  211. }
  212. }
  213. for _, dstI := range acl.Dst {
  214. if dstI.ID == "" || dstI.Value == "" {
  215. return false
  216. }
  217. if dstI.ID != models.DeviceAclID {
  218. return false
  219. }
  220. if dstI.Value == "*" {
  221. continue
  222. }
  223. // check if tag is valid
  224. _, err := GetTag(models.TagID(dstI.Value))
  225. if err != nil {
  226. return false
  227. }
  228. }
  229. }
  230. return true
  231. }
  232. // UpdateAcl - updates allowed fields on acls and commits to DB
  233. func UpdateAcl(newAcl, acl models.Acl) error {
  234. if !acl.Default {
  235. acl.Name = newAcl.Name
  236. acl.Src = newAcl.Src
  237. acl.Dst = newAcl.Dst
  238. }
  239. acl.Enabled = newAcl.Enabled
  240. d, err := json.Marshal(acl)
  241. if err != nil {
  242. return err
  243. }
  244. return database.Insert(acl.ID, string(d), database.ACLS_TABLE_NAME)
  245. }
  246. // UpsertAcl - upserts acl
  247. func UpsertAcl(acl models.Acl) error {
  248. d, err := json.Marshal(acl)
  249. if err != nil {
  250. return err
  251. }
  252. return database.Insert(acl.ID, string(d), database.ACLS_TABLE_NAME)
  253. }
  254. // DeleteAcl - deletes acl policy
  255. func DeleteAcl(a models.Acl) error {
  256. return database.DeleteRecord(database.ACLS_TABLE_NAME, a.ID)
  257. }
  258. // GetDefaultPolicy - fetches default policy in the network by ruleType
  259. func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (models.Acl, error) {
  260. aclID := "all-users"
  261. if ruleType == models.DevicePolicy {
  262. aclID = "all-nodes"
  263. }
  264. acl, err := GetAcl(fmt.Sprintf("%s.%s", netID, aclID))
  265. if err != nil {
  266. return models.Acl{}, errors.New("default rule not found")
  267. }
  268. return acl, nil
  269. }
  270. // ListUserPolicies - lists all acl policies enforced on an user
  271. func ListUserPolicies(u models.User) []models.Acl {
  272. data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
  273. if err != nil && !database.IsEmptyRecord(err) {
  274. return []models.Acl{}
  275. }
  276. acls := []models.Acl{}
  277. for _, dataI := range data {
  278. acl := models.Acl{}
  279. err := json.Unmarshal([]byte(dataI), &acl)
  280. if err != nil {
  281. continue
  282. }
  283. if acl.RuleType == models.UserPolicy {
  284. srcMap := convAclTagToValueMap(acl.Src)
  285. if _, ok := srcMap[u.UserName]; ok {
  286. acls = append(acls, acl)
  287. } else {
  288. // check for user groups
  289. for gID := range u.UserGroups {
  290. if _, ok := srcMap[gID.String()]; ok {
  291. acls = append(acls, acl)
  292. break
  293. }
  294. }
  295. }
  296. }
  297. }
  298. return acls
  299. }
  300. // listPoliciesOfUser - lists all user acl policies applied to user in an network
  301. func listPoliciesOfUser(user models.User, netID models.NetworkID) []models.Acl {
  302. data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
  303. if err != nil && !database.IsEmptyRecord(err) {
  304. return []models.Acl{}
  305. }
  306. acls := []models.Acl{}
  307. for _, dataI := range data {
  308. acl := models.Acl{}
  309. err := json.Unmarshal([]byte(dataI), &acl)
  310. if err != nil {
  311. continue
  312. }
  313. if acl.NetworkID == netID && acl.RuleType == models.UserPolicy {
  314. srcMap := convAclTagToValueMap(acl.Src)
  315. if _, ok := srcMap[user.UserName]; ok {
  316. acls = append(acls, acl)
  317. continue
  318. }
  319. for netRole := range user.NetworkRoles {
  320. if _, ok := srcMap[netRole.String()]; ok {
  321. acls = append(acls, acl)
  322. continue
  323. }
  324. }
  325. for userG := range user.UserGroups {
  326. if _, ok := srcMap[userG.String()]; ok {
  327. acls = append(acls, acl)
  328. continue
  329. }
  330. }
  331. }
  332. }
  333. return acls
  334. }
  335. // listDevicePolicies - lists all device policies in a network
  336. func listDevicePolicies(netID models.NetworkID) []models.Acl {
  337. data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
  338. if err != nil && !database.IsEmptyRecord(err) {
  339. return []models.Acl{}
  340. }
  341. acls := []models.Acl{}
  342. for _, dataI := range data {
  343. acl := models.Acl{}
  344. err := json.Unmarshal([]byte(dataI), &acl)
  345. if err != nil {
  346. continue
  347. }
  348. if acl.NetworkID == netID && acl.RuleType == models.DevicePolicy {
  349. acls = append(acls, acl)
  350. }
  351. }
  352. return acls
  353. }
  354. // ListAcls - lists all acl policies
  355. func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
  356. data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
  357. if err != nil && !database.IsEmptyRecord(err) {
  358. return []models.Acl{}, err
  359. }
  360. acls := []models.Acl{}
  361. for _, dataI := range data {
  362. acl := models.Acl{}
  363. err := json.Unmarshal([]byte(dataI), &acl)
  364. if err != nil {
  365. continue
  366. }
  367. if acl.NetworkID == netID {
  368. acls = append(acls, acl)
  369. }
  370. }
  371. return acls, nil
  372. }
  373. func convAclTagToValueMap(acltags []models.AclPolicyTag) map[string]struct{} {
  374. aclValueMap := make(map[string]struct{})
  375. for _, aclTagI := range acltags {
  376. aclValueMap[aclTagI.Value] = struct{}{}
  377. }
  378. return aclValueMap
  379. }
  380. // IsUserAllowedToCommunicate - check if user is allowed to communicate with peer
  381. func IsUserAllowedToCommunicate(userName string, peer models.Node) bool {
  382. if peer.IsStatic {
  383. peer = peer.StaticNode.ConvertToStaticNode()
  384. }
  385. acl, _ := GetDefaultPolicy(models.NetworkID(peer.Network), models.UserPolicy)
  386. if acl.Enabled {
  387. return true
  388. }
  389. user, err := GetUser(userName)
  390. if err != nil {
  391. return false
  392. }
  393. policies := listPoliciesOfUser(*user, models.NetworkID(peer.Network))
  394. for _, policy := range policies {
  395. if !policy.Enabled {
  396. continue
  397. }
  398. dstMap := convAclTagToValueMap(policy.Dst)
  399. if _, ok := dstMap["*"]; ok {
  400. return true
  401. }
  402. for tagID := range peer.Tags {
  403. if _, ok := dstMap[tagID.String()]; ok {
  404. return true
  405. }
  406. }
  407. }
  408. return false
  409. }
  410. // IsNodeAllowedToCommunicate - check node is allowed to communicate with the peer
  411. func IsNodeAllowedToCommunicate(node, peer models.Node) bool {
  412. if node.IsStatic {
  413. node = node.StaticNode.ConvertToStaticNode()
  414. }
  415. if peer.IsStatic {
  416. peer = peer.StaticNode.ConvertToStaticNode()
  417. }
  418. // check default policy if all allowed return true
  419. defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
  420. if err == nil {
  421. if defaultPolicy.Enabled {
  422. return true
  423. }
  424. }
  425. // list device policies
  426. policies := listDevicePolicies(models.NetworkID(peer.Network))
  427. for _, policy := range policies {
  428. if !policy.Enabled {
  429. continue
  430. }
  431. srcMap := convAclTagToValueMap(policy.Src)
  432. dstMap := convAclTagToValueMap(policy.Dst)
  433. // fmt.Printf("\n======> SRCMAP: %+v\n", srcMap)
  434. // fmt.Printf("\n======> DSTMAP: %+v\n", dstMap)
  435. // fmt.Printf("\n======> node Tags: %+v\n", node.Tags)
  436. // fmt.Printf("\n======> peer Tags: %+v\n", peer.Tags)
  437. if _, ok := srcMap["*"]; ok {
  438. if _, ok := dstMap["*"]; ok {
  439. return true
  440. }
  441. }
  442. for tagID := range node.Tags {
  443. if _, ok := dstMap[tagID.String()]; ok {
  444. if _, ok := srcMap["*"]; ok {
  445. return true
  446. }
  447. for tagID := range peer.Tags {
  448. if _, ok := srcMap[tagID.String()]; ok {
  449. return true
  450. }
  451. }
  452. }
  453. if _, ok := srcMap[tagID.String()]; ok {
  454. if _, ok := dstMap["*"]; ok {
  455. return true
  456. }
  457. for tagID := range peer.Tags {
  458. if _, ok := dstMap[tagID.String()]; ok {
  459. return true
  460. }
  461. }
  462. }
  463. }
  464. for tagID := range peer.Tags {
  465. if _, ok := dstMap[tagID.String()]; ok {
  466. if _, ok := srcMap["*"]; ok {
  467. return true
  468. }
  469. for tagID := range node.Tags {
  470. if _, ok := srcMap[tagID.String()]; ok {
  471. return true
  472. }
  473. }
  474. }
  475. if _, ok := srcMap[tagID.String()]; ok {
  476. if _, ok := dstMap["*"]; ok {
  477. return true
  478. }
  479. for tagID := range node.Tags {
  480. if _, ok := dstMap[tagID.String()]; ok {
  481. return true
  482. }
  483. }
  484. }
  485. }
  486. }
  487. return false
  488. }
  489. // SortTagEntrys - Sorts slice of Tag entries by their id
  490. func SortAclEntrys(acls []models.Acl) {
  491. sort.Slice(acls, func(i, j int) bool {
  492. return acls[i].Name < acls[j].Name
  493. })
  494. }
  495. // UpdateDeviceTag - updates device tag on acl policies
  496. func UpdateDeviceTag(OldID, newID models.TagID, netID models.NetworkID) {
  497. acls := listDevicePolicies(netID)
  498. update := false
  499. for _, acl := range acls {
  500. for i, srcTagI := range acl.Src {
  501. if srcTagI.ID == models.DeviceAclID {
  502. if OldID.String() == srcTagI.Value {
  503. acl.Src[i].Value = newID.String()
  504. update = true
  505. }
  506. }
  507. }
  508. for i, dstTagI := range acl.Dst {
  509. if dstTagI.ID == models.DeviceAclID {
  510. if OldID.String() == dstTagI.Value {
  511. acl.Dst[i].Value = newID.String()
  512. update = true
  513. }
  514. }
  515. }
  516. if update {
  517. UpsertAcl(acl)
  518. }
  519. }
  520. }
  521. // RemoveDeviceTagFromAclPolicies - remove device tag from acl policies
  522. func RemoveDeviceTagFromAclPolicies(tagID models.TagID, netID models.NetworkID) error {
  523. acls := listDevicePolicies(netID)
  524. update := false
  525. for _, acl := range acls {
  526. for i, srcTagI := range acl.Src {
  527. if srcTagI.ID == models.DeviceAclID {
  528. if tagID.String() == srcTagI.Value {
  529. acl.Src = append(acl.Src[:i], acl.Src[i+1:]...)
  530. update = true
  531. }
  532. }
  533. }
  534. for i, dstTagI := range acl.Dst {
  535. if dstTagI.ID == models.DeviceAclID {
  536. if tagID.String() == dstTagI.Value {
  537. acl.Dst = append(acl.Dst[:i], acl.Dst[i+1:]...)
  538. update = true
  539. }
  540. }
  541. }
  542. if update {
  543. UpsertAcl(acl)
  544. }
  545. }
  546. return nil
  547. }