tree6.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. package cidr
  2. import (
  3. "net"
  4. "github.com/slackhq/nebula/iputil"
  5. )
  6. const startbit6 = uint64(1 << 63)
  7. type Tree6[T any] struct {
  8. root4 *Node[T]
  9. root6 *Node[T]
  10. }
  11. func NewTree6[T any]() *Tree6[T] {
  12. tree := new(Tree6[T])
  13. tree.root4 = &Node[T]{}
  14. tree.root6 = &Node[T]{}
  15. return tree
  16. }
  17. func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
  18. var node, next *Node[T]
  19. cidrIP, ipv4 := isIPV4(cidr.IP)
  20. if ipv4 {
  21. node = tree.root4
  22. next = tree.root4
  23. } else {
  24. node = tree.root6
  25. next = tree.root6
  26. }
  27. for i := 0; i < len(cidrIP); i += 4 {
  28. ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
  29. mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
  30. bit := startbit
  31. // Find our last ancestor in the tree
  32. for bit&mask != 0 {
  33. if ip&bit != 0 {
  34. next = node.right
  35. } else {
  36. next = node.left
  37. }
  38. if next == nil {
  39. break
  40. }
  41. bit = bit >> 1
  42. node = next
  43. }
  44. // Build up the rest of the tree we don't already have
  45. for bit&mask != 0 {
  46. next = &Node[T]{}
  47. next.parent = node
  48. if ip&bit != 0 {
  49. node.right = next
  50. } else {
  51. node.left = next
  52. }
  53. bit >>= 1
  54. node = next
  55. }
  56. }
  57. // Final node marks our cidr, set the value
  58. node.value = val
  59. node.hasValue = true
  60. }
  61. // Finds the most specific match
  62. func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
  63. var node *Node[T]
  64. wholeIP, ipv4 := isIPV4(ip)
  65. if ipv4 {
  66. node = tree.root4
  67. } else {
  68. node = tree.root6
  69. }
  70. for i := 0; i < len(wholeIP); i += 4 {
  71. ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
  72. bit := startbit
  73. for node != nil {
  74. if node.hasValue {
  75. value = node.value
  76. ok = true
  77. }
  78. if bit == 0 {
  79. break
  80. }
  81. if ip&bit != 0 {
  82. node = node.right
  83. } else {
  84. node = node.left
  85. }
  86. bit >>= 1
  87. }
  88. }
  89. return ok, value
  90. }
  91. func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
  92. bit := startbit
  93. node := tree.root4
  94. for node != nil {
  95. if node.hasValue {
  96. value = node.value
  97. ok = true
  98. }
  99. if ip&bit != 0 {
  100. node = node.right
  101. } else {
  102. node = node.left
  103. }
  104. bit >>= 1
  105. }
  106. return ok, value
  107. }
  108. func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
  109. ip := hi
  110. node := tree.root6
  111. for i := 0; i < 2; i++ {
  112. bit := startbit6
  113. for node != nil {
  114. if node.hasValue {
  115. value = node.value
  116. ok = true
  117. }
  118. if bit == 0 {
  119. break
  120. }
  121. if ip&bit != 0 {
  122. node = node.right
  123. } else {
  124. node = node.left
  125. }
  126. bit >>= 1
  127. }
  128. ip = lo
  129. }
  130. return ok, value
  131. }
  132. func isIPV4(ip net.IP) (net.IP, bool) {
  133. if len(ip) == net.IPv4len {
  134. return ip, true
  135. }
  136. if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
  137. return ip[12:16], true
  138. }
  139. return ip, false
  140. }
  141. func isZeros(p net.IP) bool {
  142. for i := 0; i < len(p); i++ {
  143. if p[i] != 0 {
  144. return false
  145. }
  146. }
  147. return true
  148. }