cidr_radix.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. package nebula
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "net"
  6. )
  7. type CIDRNode struct {
  8. left *CIDRNode
  9. right *CIDRNode
  10. parent *CIDRNode
  11. value interface{}
  12. }
  13. type CIDRTree struct {
  14. root *CIDRNode
  15. }
  16. const (
  17. startbit = uint32(0x80000000)
  18. )
  19. func NewCIDRTree() *CIDRTree {
  20. tree := new(CIDRTree)
  21. tree.root = &CIDRNode{}
  22. return tree
  23. }
  24. func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
  25. bit := startbit
  26. node := tree.root
  27. next := tree.root
  28. ip := ip2int(cidr.IP)
  29. mask := ip2int(cidr.Mask)
  30. // Find our last ancestor in the tree
  31. for bit&mask != 0 {
  32. if ip&bit != 0 {
  33. next = node.right
  34. } else {
  35. next = node.left
  36. }
  37. if next == nil {
  38. break
  39. }
  40. bit = bit >> 1
  41. node = next
  42. }
  43. // We already have this range so update the value
  44. if next != nil {
  45. node.value = val
  46. return
  47. }
  48. // Build up the rest of the tree we don't already have
  49. for bit&mask != 0 {
  50. next = &CIDRNode{}
  51. next.parent = node
  52. if ip&bit != 0 {
  53. node.right = next
  54. } else {
  55. node.left = next
  56. }
  57. bit >>= 1
  58. node = next
  59. }
  60. // Final node marks our cidr, set the value
  61. node.value = val
  62. }
  63. // Finds the first match, which way be the least specific
  64. func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
  65. bit := startbit
  66. node := tree.root
  67. for node != nil {
  68. if node.value != nil {
  69. return node.value
  70. }
  71. if ip&bit != 0 {
  72. node = node.right
  73. } else {
  74. node = node.left
  75. }
  76. bit >>= 1
  77. }
  78. return value
  79. }
  80. // Finds the most specific match
  81. func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
  82. bit := startbit
  83. node := tree.root
  84. lastNode := node
  85. for node != nil {
  86. lastNode = node
  87. if ip&bit != 0 {
  88. node = node.right
  89. } else {
  90. node = node.left
  91. }
  92. bit >>= 1
  93. }
  94. if bit == 0 && lastNode != nil {
  95. value = lastNode.value
  96. }
  97. return value
  98. }
  99. // A helper type to avoid converting to IP when logging
  100. type IntIp uint32
  101. func (ip IntIp) String() string {
  102. return fmt.Sprintf("%v", int2ip(uint32(ip)))
  103. }
  104. func (ip IntIp) MarshalJSON() ([]byte, error) {
  105. return []byte(fmt.Sprintf("\"%s\"", int2ip(uint32(ip)).String())), nil
  106. }
  107. func ip2int(ip []byte) uint32 {
  108. if len(ip) == 16 {
  109. return binary.BigEndian.Uint32(ip[12:16])
  110. }
  111. return binary.BigEndian.Uint32(ip)
  112. }
  113. func int2ip(nn uint32) net.IP {
  114. ip := make(net.IP, 4)
  115. binary.BigEndian.PutUint32(ip, nn)
  116. return ip
  117. }