allow_list.go 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. package nebula
  2. import (
  3. "fmt"
  4. "net"
  5. "regexp"
  6. )
  7. type AllowList struct {
  8. // The values of this cidrTree are `bool`, signifying allow/deny
  9. cidrTree *CIDR6Tree
  10. // To avoid ambiguity, all rules must be true, or all rules must be false.
  11. nameRules []AllowListNameRule
  12. }
  13. type AllowListNameRule struct {
  14. Name *regexp.Regexp
  15. Allow bool
  16. }
  17. func (al *AllowList) Allow(ip net.IP) bool {
  18. if al == nil {
  19. return true
  20. }
  21. result := al.cidrTree.MostSpecificContains(ip)
  22. switch v := result.(type) {
  23. case bool:
  24. return v
  25. default:
  26. panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
  27. }
  28. }
  29. func (al *AllowList) AllowIpV4(ip uint32) bool {
  30. if al == nil {
  31. return true
  32. }
  33. result := al.cidrTree.MostSpecificContainsIpV4(ip)
  34. switch v := result.(type) {
  35. case bool:
  36. return v
  37. default:
  38. panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
  39. }
  40. }
  41. func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
  42. if al == nil {
  43. return true
  44. }
  45. result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
  46. switch v := result.(type) {
  47. case bool:
  48. return v
  49. default:
  50. panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
  51. }
  52. }
  53. func (al *AllowList) AllowName(name string) bool {
  54. if al == nil || len(al.nameRules) == 0 {
  55. return true
  56. }
  57. for _, rule := range al.nameRules {
  58. if rule.Name.MatchString(name) {
  59. return rule.Allow
  60. }
  61. }
  62. // If no rules match, return the default, which is the inverse of the rules
  63. return !al.nameRules[0].Allow
  64. }