cidr6_radix.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. package nebula
  2. import (
  3. "encoding/binary"
  4. "net"
  5. )
  6. type CIDR6Tree struct {
  7. root4 *CIDRNode
  8. root6 *CIDRNode
  9. }
  10. func NewCIDR6Tree() *CIDR6Tree {
  11. tree := new(CIDR6Tree)
  12. tree.root4 = &CIDRNode{}
  13. tree.root6 = &CIDRNode{}
  14. return tree
  15. }
  16. func (tree *CIDR6Tree) AddCIDR(cidr *net.IPNet, val interface{}) {
  17. var node, next *CIDRNode
  18. cidrIP, ipv4 := isIPV4(cidr.IP)
  19. if ipv4 {
  20. node = tree.root4
  21. next = tree.root4
  22. } else {
  23. node = tree.root6
  24. next = tree.root6
  25. }
  26. for i := 0; i < len(cidrIP); i += 4 {
  27. ip := binary.BigEndian.Uint32(cidrIP[i : i+4])
  28. mask := binary.BigEndian.Uint32(cidr.Mask[i : i+4])
  29. bit := startbit
  30. // Find our last ancestor in the tree
  31. for bit&mask != 0 {
  32. if ip&bit != 0 {
  33. next = node.right
  34. } else {
  35. next = node.left
  36. }
  37. if next == nil {
  38. break
  39. }
  40. bit = bit >> 1
  41. node = next
  42. }
  43. // Build up the rest of the tree we don't already have
  44. for bit&mask != 0 {
  45. next = &CIDRNode{}
  46. next.parent = node
  47. if ip&bit != 0 {
  48. node.right = next
  49. } else {
  50. node.left = next
  51. }
  52. bit >>= 1
  53. node = next
  54. }
  55. }
  56. // Final node marks our cidr, set the value
  57. node.value = val
  58. }
  59. // Finds the first match, which may be the least specific
  60. func (tree *CIDR6Tree) Contains(ip net.IP) (value interface{}) {
  61. var node *CIDRNode
  62. wholeIP, ipv4 := isIPV4(ip)
  63. if ipv4 {
  64. node = tree.root4
  65. } else {
  66. node = tree.root6
  67. }
  68. for i := 0; i < len(wholeIP); i += 4 {
  69. ip := ip2int(wholeIP[i : i+4])
  70. bit := startbit
  71. for node != nil {
  72. if node.value != nil {
  73. return node.value
  74. }
  75. // Check if we have reached the end and the above return did not trigger, move to the next uint32 if available
  76. if bit == 0 {
  77. break
  78. }
  79. if ip&bit != 0 {
  80. node = node.right
  81. } else {
  82. node = node.left
  83. }
  84. bit >>= 1
  85. }
  86. }
  87. // Nothing found
  88. return
  89. }
  90. // Finds the most specific match
  91. func (tree *CIDR6Tree) MostSpecificContains(ip net.IP) (value interface{}) {
  92. var node *CIDRNode
  93. wholeIP, ipv4 := isIPV4(ip)
  94. if ipv4 {
  95. node = tree.root4
  96. } else {
  97. node = tree.root6
  98. }
  99. for i := 0; i < len(wholeIP); i += 4 {
  100. ip := ip2int(wholeIP[i : i+4])
  101. bit := startbit
  102. for node != nil {
  103. if node.value != nil {
  104. value = node.value
  105. }
  106. if bit == 0 {
  107. break
  108. }
  109. if ip&bit != 0 {
  110. node = node.right
  111. } else {
  112. node = node.left
  113. }
  114. bit >>= 1
  115. }
  116. }
  117. return value
  118. }
  119. // Finds the most specific match
  120. func (tree *CIDR6Tree) Match(ip net.IP) (value interface{}) {
  121. var node *CIDRNode
  122. var bit uint32
  123. wholeIP, ipv4 := isIPV4(ip)
  124. if ipv4 {
  125. node = tree.root4
  126. } else {
  127. node = tree.root6
  128. }
  129. for i := 0; i < len(wholeIP); i += 4 {
  130. ip := ip2int(wholeIP[i : i+4])
  131. bit = startbit
  132. for node != nil && bit > 0 {
  133. if ip&bit != 0 {
  134. node = node.right
  135. } else {
  136. node = node.left
  137. }
  138. bit >>= 1
  139. }
  140. }
  141. if bit == 0 && node != nil {
  142. value = node.value
  143. }
  144. return value
  145. }
  146. func isIPV4(ip net.IP) (net.IP, bool) {
  147. if len(ip) == net.IPv4len {
  148. return ip, true
  149. }
  150. if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
  151. return ip[12:16], true
  152. }
  153. return ip, false
  154. }
  155. func isZeros(p net.IP) bool {
  156. for i := 0; i < len(p); i++ {
  157. if p[i] != 0 {
  158. return false
  159. }
  160. }
  161. return true
  162. }