peers.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. package logic
  2. import (
  3. "log"
  4. "net"
  5. "strconv"
  6. "strings"
  7. "time"
  8. "github.com/gravitl/netmaker/models"
  9. "github.com/gravitl/netmaker/netclient/ncutils"
  10. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  11. )
  12. func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
  13. var peerUpdate models.PeerUpdate
  14. var peers []wgtypes.PeerConfig
  15. networkNodes, err := GetNetworkNodes(node.Network)
  16. if err != nil {
  17. return models.PeerUpdate{}, err
  18. }
  19. for _, peer := range networkNodes {
  20. if peer.ID == node.ID {
  21. //skip yourself
  22. continue
  23. }
  24. pubkey, err := wgtypes.ParseKey(peer.PublicKey)
  25. if err != nil {
  26. return models.PeerUpdate{}, err
  27. }
  28. if node.Endpoint == peer.Endpoint {
  29. //peer is on same network
  30. if node.LocalAddress != peer.LocalAddress && peer.LocalAddress != "" {
  31. peer.Endpoint = peer.LocalAddress
  32. } else {
  33. continue
  34. }
  35. }
  36. endpoint := peer.Endpoint + ":" + strconv.FormatInt(int64(peer.ListenPort), 10)
  37. address, err := net.ResolveUDPAddr("udp", endpoint)
  38. if err != nil {
  39. return models.PeerUpdate{}, err
  40. }
  41. allowedips := GetAllowedIPs(node, &peer)
  42. var keepalive time.Duration
  43. if node.PersistentKeepalive != 0 {
  44. keepalive, _ = time.ParseDuration(strconv.FormatInt(int64(node.PersistentKeepalive), 10) + "s")
  45. }
  46. var peerData = wgtypes.PeerConfig{
  47. PublicKey: pubkey,
  48. Endpoint: address,
  49. ReplaceAllowedIPs: true,
  50. AllowedIPs: allowedips,
  51. PersistentKeepaliveInterval: &keepalive,
  52. }
  53. peers = append(peers, peerData)
  54. }
  55. peerUpdate.Network = node.Network
  56. peerUpdate.Peers = peers
  57. return peerUpdate, nil
  58. }
  59. func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
  60. var allowedips []net.IPNet
  61. var gateways []string
  62. var peeraddr = net.IPNet{
  63. IP: net.ParseIP(peer.Address),
  64. Mask: net.CIDRMask(32, 32),
  65. }
  66. dualstack := false
  67. allowedips = append(allowedips, peeraddr)
  68. // handle manually set peers
  69. for _, allowedIp := range node.AllowedIPs {
  70. if _, ipnet, err := net.ParseCIDR(allowedIp); err == nil {
  71. nodeEndpointArr := strings.Split(node.Endpoint, ":")
  72. if !ipnet.Contains(net.IP(nodeEndpointArr[0])) && ipnet.IP.String() != node.Address { // don't need to add an allowed ip that already exists..
  73. allowedips = append(allowedips, *ipnet)
  74. }
  75. } else if appendip := net.ParseIP(allowedIp); appendip != nil && allowedIp != node.Address {
  76. ipnet := net.IPNet{
  77. IP: net.ParseIP(allowedIp),
  78. Mask: net.CIDRMask(32, 32),
  79. }
  80. allowedips = append(allowedips, ipnet)
  81. }
  82. }
  83. // handle egress gateway peers
  84. if node.IsEgressGateway == "yes" {
  85. //hasGateway = true
  86. ranges := node.EgressGatewayRanges
  87. for _, iprange := range ranges { // go through each cidr for egress gateway
  88. _, ipnet, err := net.ParseCIDR(iprange) // confirming it's valid cidr
  89. if err != nil {
  90. ncutils.PrintLog("could not parse gateway IP range. Not adding "+iprange, 1)
  91. continue // if can't parse CIDR
  92. }
  93. nodeEndpointArr := strings.Split(node.Endpoint, ":") // getting the public ip of node
  94. if ipnet.Contains(net.ParseIP(nodeEndpointArr[0])) { // ensuring egress gateway range does not contain public ip of node
  95. ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.Endpoint+", omitting", 2)
  96. continue // skip adding egress range if overlaps with node's ip
  97. }
  98. if ipnet.Contains(net.ParseIP(node.LocalAddress)) { // ensuring egress gateway range does not contain public ip of node
  99. ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.LocalAddress+", omitting", 2)
  100. continue // skip adding egress range if overlaps with node's local ip
  101. }
  102. gateways = append(gateways, iprange)
  103. if err != nil {
  104. log.Println("ERROR ENCOUNTERED SETTING GATEWAY")
  105. } else {
  106. allowedips = append(allowedips, *ipnet)
  107. }
  108. }
  109. }
  110. if node.Address6 != "" && dualstack {
  111. var addr6 = net.IPNet{
  112. IP: net.ParseIP(node.Address6),
  113. Mask: net.CIDRMask(128, 128),
  114. }
  115. allowedips = append(allowedips, addr6)
  116. }
  117. return allowedips
  118. }