tree6.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. package cidr
  2. import (
  3. "net"
  4. "github.com/slackhq/nebula/iputil"
  5. )
  6. const startbit6 = uint64(1 << 63)
  7. type Tree6 struct {
  8. root4 *Node
  9. root6 *Node
  10. }
  11. func NewTree6() *Tree6 {
  12. tree := new(Tree6)
  13. tree.root4 = &Node{}
  14. tree.root6 = &Node{}
  15. return tree
  16. }
  17. func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
  18. var node, next *Node
  19. cidrIP, ipv4 := isIPV4(cidr.IP)
  20. if ipv4 {
  21. node = tree.root4
  22. next = tree.root4
  23. } else {
  24. node = tree.root6
  25. next = tree.root6
  26. }
  27. for i := 0; i < len(cidrIP); i += 4 {
  28. ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
  29. mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
  30. bit := startbit
  31. // Find our last ancestor in the tree
  32. for bit&mask != 0 {
  33. if ip&bit != 0 {
  34. next = node.right
  35. } else {
  36. next = node.left
  37. }
  38. if next == nil {
  39. break
  40. }
  41. bit = bit >> 1
  42. node = next
  43. }
  44. // Build up the rest of the tree we don't already have
  45. for bit&mask != 0 {
  46. next = &Node{}
  47. next.parent = node
  48. if ip&bit != 0 {
  49. node.right = next
  50. } else {
  51. node.left = next
  52. }
  53. bit >>= 1
  54. node = next
  55. }
  56. }
  57. // Final node marks our cidr, set the value
  58. node.value = val
  59. }
  60. // Finds the most specific match
  61. func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
  62. var node *Node
  63. wholeIP, ipv4 := isIPV4(ip)
  64. if ipv4 {
  65. node = tree.root4
  66. } else {
  67. node = tree.root6
  68. }
  69. for i := 0; i < len(wholeIP); i += 4 {
  70. ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
  71. bit := startbit
  72. for node != nil {
  73. if node.value != nil {
  74. value = node.value
  75. }
  76. if bit == 0 {
  77. break
  78. }
  79. if ip&bit != 0 {
  80. node = node.right
  81. } else {
  82. node = node.left
  83. }
  84. bit >>= 1
  85. }
  86. }
  87. return value
  88. }
  89. func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) {
  90. bit := startbit
  91. node := tree.root4
  92. for node != nil {
  93. if node.value != nil {
  94. value = node.value
  95. }
  96. if ip&bit != 0 {
  97. node = node.right
  98. } else {
  99. node = node.left
  100. }
  101. bit >>= 1
  102. }
  103. return value
  104. }
  105. func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
  106. ip := hi
  107. node := tree.root6
  108. for i := 0; i < 2; i++ {
  109. bit := startbit6
  110. for node != nil {
  111. if node.value != nil {
  112. value = node.value
  113. }
  114. if bit == 0 {
  115. break
  116. }
  117. if ip&bit != 0 {
  118. node = node.right
  119. } else {
  120. node = node.left
  121. }
  122. bit >>= 1
  123. }
  124. ip = lo
  125. }
  126. return value
  127. }
  128. func isIPV4(ip net.IP) (net.IP, bool) {
  129. if len(ip) == net.IPv4len {
  130. return ip, true
  131. }
  132. if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
  133. return ip[12:16], true
  134. }
  135. return ip, false
  136. }
  137. func isZeros(p net.IP) bool {
  138. for i := 0; i < len(p); i++ {
  139. if p[i] != 0 {
  140. return false
  141. }
  142. }
  143. return true
  144. }