posture_check.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. package logic
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "slices"
  7. "strconv"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/biter777/countries"
  12. "github.com/gravitl/netmaker/db"
  13. "github.com/gravitl/netmaker/logic"
  14. "github.com/gravitl/netmaker/models"
  15. "github.com/gravitl/netmaker/schema"
  16. "gorm.io/datatypes"
  17. )
  18. var postureCheckMutex = &sync.Mutex{}
  19. func AddPostureCheckHook() {
  20. settings := logic.GetServerSettings()
  21. interval := time.Hour
  22. i, err := strconv.Atoi(settings.PostureCheckInterval)
  23. if err == nil {
  24. interval = time.Minute * time.Duration(i)
  25. }
  26. logic.HookManagerCh <- models.HookDetails{
  27. Hook: logic.WrapHook(RunPostureChecks),
  28. Interval: interval,
  29. }
  30. }
  31. func RemoveTagFromPostureChecks(tagID models.TagID, netID models.NetworkID) {
  32. pcLi, err := (&schema.PostureCheck{NetworkID: netID}).ListByNetwork(db.WithContext(context.TODO()))
  33. if err != nil || len(pcLi) == 0 {
  34. return
  35. }
  36. for _, pcI := range pcLi {
  37. if _, ok := pcI.Tags[tagID.String()]; ok {
  38. delete(pcI.Tags, tagID.String())
  39. pcI.Update(db.WithContext(context.TODO()))
  40. }
  41. }
  42. }
  43. func RemoveUserGroupFromPostureChecks(grpID models.UserGroupID, netID models.NetworkID) {
  44. pcLi, err := (&schema.PostureCheck{NetworkID: netID}).ListByNetwork(db.WithContext(context.TODO()))
  45. if err != nil || len(pcLi) == 0 {
  46. return
  47. }
  48. for _, pcI := range pcLi {
  49. if _, ok := pcI.UserGroups[grpID.String()]; ok {
  50. delete(pcI.UserGroups, grpID.String())
  51. pcI.Update(db.WithContext(context.TODO()))
  52. }
  53. }
  54. }
  55. func RunPostureChecks() error {
  56. if !GetFeatureFlags().EnablePostureChecks {
  57. return nil
  58. }
  59. postureCheckMutex.Lock()
  60. defer postureCheckMutex.Unlock()
  61. nets, err := logic.GetNetworks()
  62. if err != nil {
  63. return err
  64. }
  65. nodes, err := logic.GetAllNodes()
  66. if err != nil {
  67. return err
  68. }
  69. for _, netI := range nets {
  70. networkNodes := logic.GetNetworkNodesMemory(nodes, netI.NetID)
  71. if len(networkNodes) == 0 {
  72. continue
  73. }
  74. networkNodes = logic.AddStaticNodestoList(networkNodes)
  75. pcLi, err := (&schema.PostureCheck{NetworkID: models.NetworkID(netI.NetID)}).ListByNetwork(db.WithContext(context.TODO()))
  76. if err != nil || len(pcLi) == 0 {
  77. continue
  78. }
  79. for _, nodeI := range networkNodes {
  80. if nodeI.IsStatic && !nodeI.IsUserNode {
  81. continue
  82. }
  83. postureChecksViolations, postureCheckVolationSeverityLevel := GetPostureCheckViolations(pcLi, logic.GetPostureCheckDeviceInfoByNode(&nodeI))
  84. if nodeI.IsUserNode {
  85. extclient, err := logic.GetExtClient(nodeI.StaticNode.ClientID, nodeI.StaticNode.Network)
  86. if err == nil {
  87. extclient.PostureChecksViolations = postureChecksViolations
  88. extclient.PostureCheckVolationSeverityLevel = postureCheckVolationSeverityLevel
  89. extclient.LastEvaluatedAt = time.Now().UTC()
  90. logic.SaveExtClient(&extclient)
  91. }
  92. } else {
  93. nodeI.PostureChecksViolations, nodeI.PostureCheckVolationSeverityLevel = postureChecksViolations,
  94. postureCheckVolationSeverityLevel
  95. nodeI.LastEvaluatedAt = time.Now().UTC()
  96. logic.UpsertNode(&nodeI)
  97. }
  98. }
  99. }
  100. return nil
  101. }
  102. func CheckPostureViolations(d models.PostureCheckDeviceInfo, network models.NetworkID) ([]models.Violation, models.Severity) {
  103. if !GetFeatureFlags().EnablePostureChecks {
  104. return []models.Violation{}, models.SeverityUnknown
  105. }
  106. pcLi, err := (&schema.PostureCheck{NetworkID: network}).ListByNetwork(db.WithContext(context.TODO()))
  107. if err != nil || len(pcLi) == 0 {
  108. return []models.Violation{}, models.SeverityUnknown
  109. }
  110. violations, level := GetPostureCheckViolations(pcLi, d)
  111. return violations, level
  112. }
  113. func GetPostureCheckViolations(checks []schema.PostureCheck, d models.PostureCheckDeviceInfo) ([]models.Violation, models.Severity) {
  114. if !GetFeatureFlags().EnablePostureChecks {
  115. return []models.Violation{}, models.SeverityUnknown
  116. }
  117. var violations []models.Violation
  118. highest := models.SeverityUnknown
  119. // Group checks by attribute
  120. checksByAttribute := make(map[schema.Attribute][]schema.PostureCheck)
  121. for _, c := range checks {
  122. // skip disabled checks
  123. if !c.Status {
  124. continue
  125. }
  126. if d.IsUser && c.Attribute == schema.AutoUpdate {
  127. continue
  128. }
  129. // Check if tags match
  130. if !d.IsUser {
  131. // Check if posture check has wildcard tag - applies to all devices
  132. if _, hasWildcard := c.Tags["*"]; hasWildcard {
  133. // Wildcard tag matches all devices, continue to evaluate the check
  134. } else if len(c.Tags) > 0 {
  135. // Check has specific tags - device must have at least one matching tag
  136. if len(d.Tags) == 0 {
  137. // Device has no tags and check doesn't have wildcard, skip
  138. continue
  139. }
  140. exists := false
  141. for tagID := range c.Tags {
  142. if _, ok := d.Tags[models.TagID(tagID)]; ok {
  143. exists = true
  144. break
  145. }
  146. }
  147. if !exists {
  148. continue
  149. }
  150. } else {
  151. // Check has no tags configured, skip
  152. continue
  153. }
  154. } else if d.IsUser {
  155. // Check if posture check has wildcard user group - applies to all users
  156. if _, hasWildcard := c.UserGroups["*"]; hasWildcard {
  157. // Wildcard user group matches all users, continue to evaluate the check
  158. } else if len(c.UserGroups) > 0 {
  159. // Check has specific user groups - user must have at least one matching group
  160. if len(d.UserGroups) == 0 {
  161. // User has no groups and check doesn't have wildcard, skip
  162. continue
  163. }
  164. exists := false
  165. for userG := range c.UserGroups {
  166. if _, ok := d.UserGroups[models.UserGroupID(userG)]; ok {
  167. exists = true
  168. break
  169. }
  170. }
  171. if !exists {
  172. continue
  173. }
  174. } else {
  175. // Check has no user groups configured, skip
  176. continue
  177. }
  178. }
  179. checksByAttribute[c.Attribute] = append(checksByAttribute[c.Attribute], c)
  180. }
  181. // Handle OS and OSFamily together with OR logic since they are related
  182. osChecks := checksByAttribute[schema.OS]
  183. osFamilyChecks := checksByAttribute[schema.OSFamily]
  184. if len(osChecks) > 0 || len(osFamilyChecks) > 0 {
  185. osAllowed := evaluateAttributeChecks(osChecks, d)
  186. osFamilyAllowed := evaluateAttributeChecks(osFamilyChecks, d)
  187. // OR condition: if either OS or OSFamily passes, both are considered passed
  188. if !osAllowed && !osFamilyAllowed {
  189. // Both failed, add violations for both
  190. osDenied := getDeniedChecks(osChecks, d)
  191. osFamilyDenied := getDeniedChecks(osFamilyChecks, d)
  192. for _, denied := range osDenied {
  193. sev := denied.check.Severity
  194. if sev > highest {
  195. highest = sev
  196. }
  197. v := models.Violation{
  198. CheckID: denied.check.ID,
  199. Name: denied.check.Name,
  200. Attribute: string(denied.check.Attribute),
  201. Message: denied.reason,
  202. Severity: sev,
  203. }
  204. violations = append(violations, v)
  205. }
  206. for _, denied := range osFamilyDenied {
  207. sev := denied.check.Severity
  208. if sev > highest {
  209. highest = sev
  210. }
  211. v := models.Violation{
  212. CheckID: denied.check.ID,
  213. Name: denied.check.Name,
  214. Attribute: string(denied.check.Attribute),
  215. Message: denied.reason,
  216. Severity: sev,
  217. }
  218. violations = append(violations, v)
  219. }
  220. }
  221. }
  222. // For all other attributes, check if ANY check allows it
  223. for attr, attrChecks := range checksByAttribute {
  224. // Skip OS and OSFamily as they are handled above
  225. if attr == schema.OS || attr == schema.OSFamily {
  226. continue
  227. }
  228. // Check if any check for this attribute allows the device
  229. allowed := false
  230. var deniedChecks []struct {
  231. check schema.PostureCheck
  232. reason string
  233. }
  234. for _, c := range attrChecks {
  235. violated, reason := evaluatePostureCheck(&c, d)
  236. if !violated {
  237. // At least one check allows it
  238. allowed = true
  239. break
  240. }
  241. // Track denied checks with their reasons for violation reporting
  242. deniedChecks = append(deniedChecks, struct {
  243. check schema.PostureCheck
  244. reason string
  245. }{check: c, reason: reason})
  246. }
  247. // If no check allows it, add violations for all denied checks
  248. if !allowed {
  249. for _, denied := range deniedChecks {
  250. sev := denied.check.Severity
  251. if sev > highest {
  252. highest = sev
  253. }
  254. v := models.Violation{
  255. CheckID: denied.check.ID,
  256. Name: denied.check.Name,
  257. Attribute: string(denied.check.Attribute),
  258. Message: denied.reason,
  259. Severity: sev,
  260. }
  261. violations = append(violations, v)
  262. }
  263. }
  264. }
  265. return violations, highest
  266. }
  267. // GetPostureCheckDeviceInfoByNode retrieves PostureCheckDeviceInfo for a given node
  268. func GetPostureCheckDeviceInfoByNode(node *models.Node) models.PostureCheckDeviceInfo {
  269. var deviceInfo models.PostureCheckDeviceInfo
  270. if !node.IsStatic {
  271. h, err := logic.GetHost(node.HostID.String())
  272. if err != nil {
  273. return deviceInfo
  274. }
  275. deviceInfo = models.PostureCheckDeviceInfo{
  276. ClientLocation: h.CountryCode,
  277. ClientVersion: h.Version,
  278. OS: h.OS,
  279. OSVersion: h.OSVersion,
  280. OSFamily: h.OSFamily,
  281. KernelVersion: h.KernelVersion,
  282. AutoUpdate: h.AutoUpdate,
  283. Tags: node.Tags,
  284. }
  285. } else if node.IsUserNode {
  286. deviceInfo = models.PostureCheckDeviceInfo{
  287. ClientLocation: node.StaticNode.Country,
  288. ClientVersion: node.StaticNode.ClientVersion,
  289. OS: node.StaticNode.OS,
  290. OSVersion: node.StaticNode.OSVersion,
  291. OSFamily: node.StaticNode.OSFamily,
  292. KernelVersion: node.StaticNode.KernelVersion,
  293. Tags: make(map[models.TagID]struct{}),
  294. IsUser: true,
  295. UserGroups: make(map[models.UserGroupID]struct{}),
  296. }
  297. // get user groups
  298. if node.StaticNode.OwnerID != "" {
  299. user, err := logic.GetUser(node.StaticNode.OwnerID)
  300. if err == nil && len(user.UserGroups) > 0 {
  301. deviceInfo.UserGroups = user.UserGroups
  302. if user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole {
  303. deviceInfo.UserGroups[GetDefaultNetworkAdminGroupID(models.NetworkID(node.Network))] = struct{}{}
  304. deviceInfo.UserGroups[GetDefaultGlobalAdminGroupID()] = struct{}{}
  305. } else if _, ok := user.UserGroups[GetDefaultGlobalAdminGroupID()]; ok {
  306. deviceInfo.UserGroups[GetDefaultNetworkAdminGroupID(models.NetworkID(node.Network))] = struct{}{}
  307. } else if _, ok := user.UserGroups[GetDefaultGlobalUserGroupID()]; ok {
  308. deviceInfo.UserGroups[GetDefaultNetworkUserGroupID(models.NetworkID(node.Network))] = struct{}{}
  309. }
  310. }
  311. }
  312. }
  313. return deviceInfo
  314. }
  315. // evaluateAttributeChecks evaluates checks for a specific attribute and returns true if any check allows the device
  316. func evaluateAttributeChecks(attrChecks []schema.PostureCheck, d models.PostureCheckDeviceInfo) bool {
  317. for _, c := range attrChecks {
  318. violated, _ := evaluatePostureCheck(&c, d)
  319. if !violated {
  320. // At least one check allows it
  321. return true
  322. }
  323. }
  324. return false
  325. }
  326. // getDeniedChecks returns all checks that denied the device for a specific attribute
  327. func getDeniedChecks(attrChecks []schema.PostureCheck, d models.PostureCheckDeviceInfo) []struct {
  328. check schema.PostureCheck
  329. reason string
  330. } {
  331. var deniedChecks []struct {
  332. check schema.PostureCheck
  333. reason string
  334. }
  335. for _, c := range attrChecks {
  336. violated, reason := evaluatePostureCheck(&c, d)
  337. if violated {
  338. deniedChecks = append(deniedChecks, struct {
  339. check schema.PostureCheck
  340. reason string
  341. }{check: c, reason: reason})
  342. }
  343. }
  344. return deniedChecks
  345. }
  346. func evaluatePostureCheck(check *schema.PostureCheck, d models.PostureCheckDeviceInfo) (violated bool, reason string) {
  347. switch check.Attribute {
  348. // ------------------------
  349. // 1. Geographic check
  350. // ------------------------
  351. case schema.ClientLocation:
  352. if !slices.Contains(check.Values, strings.ToUpper(d.ClientLocation)) {
  353. return true, fmt.Sprintf("client location '%s' not allowed", CountryNameFromISO(d.ClientLocation))
  354. }
  355. // ------------------------
  356. // 2. Client version check
  357. // Single value representing minimum required version
  358. // ------------------------
  359. case schema.ClientVersion:
  360. if len(check.Values) == 0 {
  361. return false, ""
  362. }
  363. minVersion := check.Values[0]
  364. cmp := compareVersions(cleanVersion(d.ClientVersion), cleanVersion(minVersion))
  365. if cmp < 0 {
  366. return true, fmt.Sprintf("client version '%s' is below minimum required version '%s'", d.ClientVersion, minVersion)
  367. }
  368. // ------------------------
  369. // 3. OS check
  370. // ("windows", "mac", "linux", etc.)
  371. // ------------------------
  372. case schema.OS:
  373. if !slices.Contains(check.Values, d.OS) {
  374. return true, fmt.Sprintf("client os '%s' not allowed", d.OS)
  375. }
  376. case schema.OSFamily:
  377. if !slices.Contains(check.Values, d.OSFamily) {
  378. return true, fmt.Sprintf("os family '%s' not allowed", d.OSFamily)
  379. }
  380. // ------------------------
  381. // 4. OS version check
  382. // Single value representing minimum required version
  383. // ------------------------
  384. case schema.OSVersion:
  385. if len(check.Values) == 0 {
  386. return false, ""
  387. }
  388. minVersion := check.Values[0]
  389. cmp := compareVersions(cleanVersion(d.OSVersion), cleanVersion(minVersion))
  390. if cmp < 0 {
  391. return true, fmt.Sprintf("os version '%s' is below minimum required version '%s'", d.OSVersion, minVersion)
  392. }
  393. case schema.KernelVersion:
  394. if len(check.Values) == 0 {
  395. return false, ""
  396. }
  397. minVersion := check.Values[0]
  398. cmp := compareVersions(cleanVersion(d.KernelVersion), cleanVersion(minVersion))
  399. if cmp < 0 {
  400. return true, fmt.Sprintf("kernel version '%s' is below minimum required version '%s'", d.KernelVersion, minVersion)
  401. }
  402. // ------------------------
  403. // 5. Auto-update check
  404. // Values: ["true"] or ["false"]
  405. // ------------------------
  406. case schema.AutoUpdate:
  407. required := len(check.Values) > 0 && strings.ToLower(check.Values[0]) == "true"
  408. if required && !d.AutoUpdate {
  409. return true, "auto update must be enabled"
  410. }
  411. if !required && d.AutoUpdate {
  412. return true, "auto update must be disabled"
  413. }
  414. }
  415. return false, ""
  416. }
  417. func cleanVersion(v string) string {
  418. v = strings.TrimSpace(v)
  419. v = strings.TrimPrefix(v, "v")
  420. v = strings.TrimPrefix(v, "V")
  421. v = strings.TrimSuffix(v, ",")
  422. v = strings.TrimSpace(v)
  423. return v
  424. }
  425. func compareVersions(a, b string) int {
  426. pa := strings.Split(a, ".")
  427. pb := strings.Split(b, ".")
  428. n := len(pa)
  429. if len(pb) > n {
  430. n = len(pb)
  431. }
  432. for i := 0; i < n; i++ {
  433. ai, bi := 0, 0
  434. if i < len(pa) {
  435. ai, _ = strconv.Atoi(pa[i])
  436. }
  437. if i < len(pb) {
  438. bi, _ = strconv.Atoi(pb[i])
  439. }
  440. if ai > bi {
  441. return 1
  442. }
  443. if ai < bi {
  444. return -1
  445. }
  446. }
  447. return 0
  448. }
  449. func ValidatePostureCheck(pc *schema.PostureCheck) error {
  450. if pc.Name == "" {
  451. return errors.New("name cannot be empty")
  452. }
  453. _, err := logic.GetNetwork(pc.NetworkID.String())
  454. if err != nil {
  455. return errors.New("invalid network")
  456. }
  457. allowedAttrvaluesMap, ok := schema.PostureCheckAttrValuesMap[pc.Attribute]
  458. if !ok {
  459. return errors.New("unkown attribute")
  460. }
  461. if len(pc.Values) == 0 {
  462. return errors.New("attribute value cannot be empty")
  463. }
  464. for i, valueI := range pc.Values {
  465. pc.Values[i] = strings.ToLower(valueI)
  466. }
  467. if pc.Attribute == schema.ClientLocation {
  468. for i, loc := range pc.Values {
  469. if countries.ByName(loc) == countries.Unknown {
  470. return errors.New("invalid country code")
  471. }
  472. pc.Values[i] = strings.ToUpper(loc)
  473. }
  474. }
  475. if pc.Attribute == schema.AutoUpdate || pc.Attribute == schema.OS ||
  476. pc.Attribute == schema.OSFamily {
  477. for _, valueI := range pc.Values {
  478. if _, ok := allowedAttrvaluesMap[valueI]; !ok {
  479. return errors.New("invalid attribute value")
  480. }
  481. }
  482. }
  483. if pc.Attribute == schema.ClientVersion || pc.Attribute == schema.OSVersion ||
  484. pc.Attribute == schema.KernelVersion {
  485. if len(pc.Values) != 1 {
  486. return errors.New("version attribute must have exactly one value (minimum version)")
  487. }
  488. if !logic.IsValidVersion(pc.Values[0]) {
  489. return errors.New("invalid attribute version value")
  490. }
  491. pc.Values[0] = logic.CleanVersion(pc.Values[0])
  492. }
  493. if len(pc.Tags) > 0 {
  494. for tagID := range pc.Tags {
  495. if tagID == "*" {
  496. continue
  497. }
  498. _, err := GetTag(models.TagID(tagID))
  499. if err != nil {
  500. return errors.New("unknown tag")
  501. }
  502. }
  503. } else {
  504. pc.Tags = make(datatypes.JSONMap)
  505. }
  506. if len(pc.UserGroups) > 0 {
  507. for userGrpID := range pc.UserGroups {
  508. if userGrpID == "*" {
  509. continue
  510. }
  511. _, err := GetUserGroup(models.UserGroupID(userGrpID))
  512. if err != nil {
  513. return errors.New("unknown tag")
  514. }
  515. }
  516. } else {
  517. pc.UserGroups = make(datatypes.JSONMap)
  518. }
  519. return nil
  520. }
  521. func CountryNameFromISO(code string) string {
  522. c := countries.ByName(code) // works with ISO2, ISO3, full name
  523. if c == countries.Unknown {
  524. return ""
  525. }
  526. return c.Info().Name
  527. }