peers.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package logic
  2. import (
  3. "log"
  4. "net"
  5. "strconv"
  6. "strings"
  7. "time"
  8. "github.com/gravitl/netmaker/models"
  9. "github.com/gravitl/netmaker/netclient/ncutils"
  10. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  11. )
  12. // GetPeerUpdate - gets a wireguard peer config for each peer of a node
  13. func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
  14. var peerUpdate models.PeerUpdate
  15. var peers []wgtypes.PeerConfig
  16. networkNodes, err := GetNetworkNodes(node.Network)
  17. if err != nil {
  18. return models.PeerUpdate{}, err
  19. }
  20. for _, peer := range networkNodes {
  21. if peer.ID == node.ID {
  22. //skip yourself
  23. continue
  24. }
  25. pubkey, err := wgtypes.ParseKey(peer.PublicKey)
  26. if err != nil {
  27. return models.PeerUpdate{}, err
  28. }
  29. if node.Endpoint == peer.Endpoint {
  30. //peer is on same network
  31. if node.LocalAddress != peer.LocalAddress && peer.LocalAddress != "" {
  32. peer.Endpoint = peer.LocalAddress
  33. } else {
  34. continue
  35. }
  36. }
  37. endpoint := peer.Endpoint + ":" + strconv.FormatInt(int64(peer.ListenPort), 10)
  38. address, err := net.ResolveUDPAddr("udp", endpoint)
  39. if err != nil {
  40. return models.PeerUpdate{}, err
  41. }
  42. allowedips := GetAllowedIPs(node, &peer)
  43. var keepalive time.Duration
  44. if node.PersistentKeepalive != 0 {
  45. keepalive, _ = time.ParseDuration(strconv.FormatInt(int64(node.PersistentKeepalive), 10) + "s")
  46. }
  47. var peerData = wgtypes.PeerConfig{
  48. PublicKey: pubkey,
  49. Endpoint: address,
  50. ReplaceAllowedIPs: true,
  51. AllowedIPs: allowedips,
  52. PersistentKeepaliveInterval: &keepalive,
  53. }
  54. peers = append(peers, peerData)
  55. }
  56. peerUpdate.Network = node.Network
  57. peerUpdate.Peers = peers
  58. return peerUpdate, nil
  59. }
  60. // GetAllowedIPs - calculates the wireguard allowedip field for a peer of a node based on the peer and node settings
  61. func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
  62. var allowedips []net.IPNet
  63. var gateways []string
  64. var peeraddr = net.IPNet{
  65. IP: net.ParseIP(peer.Address),
  66. Mask: net.CIDRMask(32, 32),
  67. }
  68. dualstack := false
  69. allowedips = append(allowedips, peeraddr)
  70. // handle manually set peers
  71. for _, allowedIp := range node.AllowedIPs {
  72. if _, ipnet, err := net.ParseCIDR(allowedIp); err == nil {
  73. nodeEndpointArr := strings.Split(node.Endpoint, ":")
  74. if !ipnet.Contains(net.IP(nodeEndpointArr[0])) && ipnet.IP.String() != node.Address { // don't need to add an allowed ip that already exists..
  75. allowedips = append(allowedips, *ipnet)
  76. }
  77. } else if appendip := net.ParseIP(allowedIp); appendip != nil && allowedIp != node.Address {
  78. ipnet := net.IPNet{
  79. IP: net.ParseIP(allowedIp),
  80. Mask: net.CIDRMask(32, 32),
  81. }
  82. allowedips = append(allowedips, ipnet)
  83. }
  84. }
  85. // handle egress gateway peers
  86. if node.IsEgressGateway == "yes" {
  87. //hasGateway = true
  88. ranges := node.EgressGatewayRanges
  89. for _, iprange := range ranges { // go through each cidr for egress gateway
  90. _, ipnet, err := net.ParseCIDR(iprange) // confirming it's valid cidr
  91. if err != nil {
  92. ncutils.PrintLog("could not parse gateway IP range. Not adding "+iprange, 1)
  93. continue // if can't parse CIDR
  94. }
  95. nodeEndpointArr := strings.Split(node.Endpoint, ":") // getting the public ip of node
  96. if ipnet.Contains(net.ParseIP(nodeEndpointArr[0])) { // ensuring egress gateway range does not contain public ip of node
  97. ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.Endpoint+", omitting", 2)
  98. continue // skip adding egress range if overlaps with node's ip
  99. }
  100. if ipnet.Contains(net.ParseIP(node.LocalAddress)) { // ensuring egress gateway range does not contain public ip of node
  101. ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.LocalAddress+", omitting", 2)
  102. continue // skip adding egress range if overlaps with node's local ip
  103. }
  104. gateways = append(gateways, iprange)
  105. if err != nil {
  106. log.Println("ERROR ENCOUNTERED SETTING GATEWAY")
  107. } else {
  108. allowedips = append(allowedips, *ipnet)
  109. }
  110. }
  111. }
  112. if node.Address6 != "" && dualstack {
  113. var addr6 = net.IPNet{
  114. IP: net.ParseIP(node.Address6),
  115. Mask: net.CIDRMask(128, 128),
  116. }
  117. allowedips = append(allowedips, addr6)
  118. }
  119. return allowedips
  120. }