tree4.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. package cidr
  2. import (
  3. "net"
  4. "github.com/slackhq/nebula/iputil"
  5. )
  6. type Node struct {
  7. left *Node
  8. right *Node
  9. parent *Node
  10. value interface{}
  11. }
  12. type entry struct {
  13. CIDR *net.IPNet
  14. Value *interface{}
  15. }
  16. type Tree4 struct {
  17. root *Node
  18. list []entry
  19. }
  20. const (
  21. startbit = iputil.VpnIp(0x80000000)
  22. )
  23. func NewTree4() *Tree4 {
  24. tree := new(Tree4)
  25. tree.root = &Node{}
  26. tree.list = []entry{}
  27. return tree
  28. }
  29. func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
  30. bit := startbit
  31. node := tree.root
  32. next := tree.root
  33. ip := iputil.Ip2VpnIp(cidr.IP)
  34. mask := iputil.Ip2VpnIp(cidr.Mask)
  35. // Find our last ancestor in the tree
  36. for bit&mask != 0 {
  37. if ip&bit != 0 {
  38. next = node.right
  39. } else {
  40. next = node.left
  41. }
  42. if next == nil {
  43. break
  44. }
  45. bit = bit >> 1
  46. node = next
  47. }
  48. // We already have this range so update the value
  49. if next != nil {
  50. addCIDR := cidr.String()
  51. for i, v := range tree.list {
  52. if addCIDR == v.CIDR.String() {
  53. tree.list = append(tree.list[:i], tree.list[i+1:]...)
  54. break
  55. }
  56. }
  57. tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
  58. node.value = val
  59. return
  60. }
  61. // Build up the rest of the tree we don't already have
  62. for bit&mask != 0 {
  63. next = &Node{}
  64. next.parent = node
  65. if ip&bit != 0 {
  66. node.right = next
  67. } else {
  68. node.left = next
  69. }
  70. bit >>= 1
  71. node = next
  72. }
  73. // Final node marks our cidr, set the value
  74. node.value = val
  75. tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
  76. }
  77. // Contains finds the first match, which may be the least specific
  78. func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
  79. bit := startbit
  80. node := tree.root
  81. for node != nil {
  82. if node.value != nil {
  83. return node.value
  84. }
  85. if ip&bit != 0 {
  86. node = node.right
  87. } else {
  88. node = node.left
  89. }
  90. bit >>= 1
  91. }
  92. return value
  93. }
  94. // MostSpecificContains finds the most specific match
  95. func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
  96. bit := startbit
  97. node := tree.root
  98. for node != nil {
  99. if node.value != nil {
  100. value = node.value
  101. }
  102. if ip&bit != 0 {
  103. node = node.right
  104. } else {
  105. node = node.left
  106. }
  107. bit >>= 1
  108. }
  109. return value
  110. }
  111. // Match finds the most specific match
  112. func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
  113. bit := startbit
  114. node := tree.root
  115. lastNode := node
  116. for node != nil {
  117. lastNode = node
  118. if ip&bit != 0 {
  119. node = node.right
  120. } else {
  121. node = node.left
  122. }
  123. bit >>= 1
  124. }
  125. if bit == 0 && lastNode != nil {
  126. value = lastNode.value
  127. }
  128. return value
  129. }
  130. // List will return all CIDRs and their current values. Do not modify the contents!
  131. func (tree *Tree4) List() []entry {
  132. return tree.list
  133. }