cidr_radix.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. package nebula
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "net"
  6. )
  7. type CIDRNode struct {
  8. left *CIDRNode
  9. right *CIDRNode
  10. parent *CIDRNode
  11. value interface{}
  12. }
  13. type CIDRTree struct {
  14. root *CIDRNode
  15. }
  16. const (
  17. startbit = uint32(0x80000000)
  18. )
  19. func NewCIDRTree() *CIDRTree {
  20. tree := new(CIDRTree)
  21. tree.root = &CIDRNode{}
  22. return tree
  23. }
  24. func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
  25. bit := startbit
  26. node := tree.root
  27. next := tree.root
  28. ip := ip2int(cidr.IP)
  29. mask := ip2int(cidr.Mask)
  30. // Find our last ancestor in the tree
  31. for bit&mask != 0 {
  32. if ip&bit != 0 {
  33. next = node.right
  34. } else {
  35. next = node.left
  36. }
  37. if next == nil {
  38. break
  39. }
  40. bit = bit >> 1
  41. node = next
  42. }
  43. // We already have this range so update the value
  44. if next != nil {
  45. node.value = val
  46. return
  47. }
  48. // Build up the rest of the tree we don't already have
  49. for bit&mask != 0 {
  50. next = &CIDRNode{}
  51. next.parent = node
  52. if ip&bit != 0 {
  53. node.right = next
  54. } else {
  55. node.left = next
  56. }
  57. bit >>= 1
  58. node = next
  59. }
  60. // Final node marks our cidr, set the value
  61. node.value = val
  62. }
  63. // Finds the first match, which may be the least specific
  64. func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
  65. bit := startbit
  66. node := tree.root
  67. for node != nil {
  68. if node.value != nil {
  69. return node.value
  70. }
  71. if ip&bit != 0 {
  72. node = node.right
  73. } else {
  74. node = node.left
  75. }
  76. bit >>= 1
  77. }
  78. return value
  79. }
  80. // Finds the most specific match
  81. func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) {
  82. bit := startbit
  83. node := tree.root
  84. for node != nil {
  85. if node.value != nil {
  86. value = node.value
  87. }
  88. if ip&bit != 0 {
  89. node = node.right
  90. } else {
  91. node = node.left
  92. }
  93. bit >>= 1
  94. }
  95. return value
  96. }
  97. // Finds the most specific match
  98. func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
  99. bit := startbit
  100. node := tree.root
  101. lastNode := node
  102. for node != nil {
  103. lastNode = node
  104. if ip&bit != 0 {
  105. node = node.right
  106. } else {
  107. node = node.left
  108. }
  109. bit >>= 1
  110. }
  111. if bit == 0 && lastNode != nil {
  112. value = lastNode.value
  113. }
  114. return value
  115. }
  116. // A helper type to avoid converting to IP when logging
  117. type IntIp uint32
  118. func (ip IntIp) String() string {
  119. return fmt.Sprintf("%v", int2ip(uint32(ip)))
  120. }
  121. func (ip IntIp) MarshalJSON() ([]byte, error) {
  122. return []byte(fmt.Sprintf("\"%s\"", int2ip(uint32(ip)).String())), nil
  123. }
  124. func ip2int(ip []byte) uint32 {
  125. if len(ip) == 16 {
  126. return binary.BigEndian.Uint32(ip[12:16])
  127. }
  128. return binary.BigEndian.Uint32(ip)
  129. }
  130. func int2ip(nn uint32) net.IP {
  131. ip := make(net.IP, 4)
  132. binary.BigEndian.PutUint32(ip, nn)
  133. return ip
  134. }