peers.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. package logic
  2. import (
  3. "log"
  4. "net"
  5. "strconv"
  6. "strings"
  7. "time"
  8. "github.com/gravitl/netmaker/models"
  9. "github.com/gravitl/netmaker/netclient/ncutils"
  10. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  11. )
  12. // GetPeerUpdate - gets a wireguard peer config for each peer of a node
  13. func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
  14. var peerUpdate models.PeerUpdate
  15. var peers []wgtypes.PeerConfig
  16. networkNodes, err := GetNetworkNodes(node.Network)
  17. if err != nil {
  18. return models.PeerUpdate{}, err
  19. }
  20. var serverNodeAddresses = []models.ServerAddr{}
  21. for _, peer := range networkNodes {
  22. if peer.ID == node.ID {
  23. //skip yourself
  24. continue
  25. }
  26. pubkey, err := wgtypes.ParseKey(peer.PublicKey)
  27. if err != nil {
  28. return models.PeerUpdate{}, err
  29. }
  30. if node.Endpoint == peer.Endpoint {
  31. //peer is on same network
  32. if node.LocalAddress != peer.LocalAddress && peer.LocalAddress != "" {
  33. peer.Endpoint = peer.LocalAddress
  34. } else {
  35. continue
  36. }
  37. }
  38. endpoint := peer.Endpoint + ":" + strconv.FormatInt(int64(peer.ListenPort), 10)
  39. address, err := net.ResolveUDPAddr("udp", endpoint)
  40. if err != nil {
  41. return models.PeerUpdate{}, err
  42. }
  43. allowedips := GetAllowedIPs(node, &peer)
  44. var keepalive time.Duration
  45. if node.PersistentKeepalive != 0 {
  46. keepalive, _ = time.ParseDuration(strconv.FormatInt(int64(node.PersistentKeepalive), 10) + "s")
  47. }
  48. var peerData = wgtypes.PeerConfig{
  49. PublicKey: pubkey,
  50. Endpoint: address,
  51. ReplaceAllowedIPs: true,
  52. AllowedIPs: allowedips,
  53. PersistentKeepaliveInterval: &keepalive,
  54. }
  55. peers = append(peers, peerData)
  56. if peer.IsServer == "yes" {
  57. serverNodeAddresses = append(serverNodeAddresses, models.ServerAddr{IsLeader: IsLeader(&peer), Address: peer.Address})
  58. }
  59. }
  60. peerUpdate.Network = node.Network
  61. peerUpdate.Peers = peers
  62. peerUpdate.ServerAddrs = serverNodeAddresses
  63. return peerUpdate, nil
  64. }
  65. // GetAllowedIPs - calculates the wireguard allowedip field for a peer of a node based on the peer and node settings
  66. func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
  67. var allowedips []net.IPNet
  68. var gateways []string
  69. var peeraddr = net.IPNet{
  70. IP: net.ParseIP(peer.Address),
  71. Mask: net.CIDRMask(32, 32),
  72. }
  73. dualstack := false
  74. allowedips = append(allowedips, peeraddr)
  75. // handle manually set peers
  76. for _, allowedIp := range node.AllowedIPs {
  77. if _, ipnet, err := net.ParseCIDR(allowedIp); err == nil {
  78. nodeEndpointArr := strings.Split(node.Endpoint, ":")
  79. if !ipnet.Contains(net.IP(nodeEndpointArr[0])) && ipnet.IP.String() != node.Address { // don't need to add an allowed ip that already exists..
  80. allowedips = append(allowedips, *ipnet)
  81. }
  82. } else if appendip := net.ParseIP(allowedIp); appendip != nil && allowedIp != node.Address {
  83. ipnet := net.IPNet{
  84. IP: net.ParseIP(allowedIp),
  85. Mask: net.CIDRMask(32, 32),
  86. }
  87. allowedips = append(allowedips, ipnet)
  88. }
  89. }
  90. // handle egress gateway peers
  91. if node.IsEgressGateway == "yes" {
  92. //hasGateway = true
  93. ranges := node.EgressGatewayRanges
  94. for _, iprange := range ranges { // go through each cidr for egress gateway
  95. _, ipnet, err := net.ParseCIDR(iprange) // confirming it's valid cidr
  96. if err != nil {
  97. ncutils.PrintLog("could not parse gateway IP range. Not adding "+iprange, 1)
  98. continue // if can't parse CIDR
  99. }
  100. nodeEndpointArr := strings.Split(node.Endpoint, ":") // getting the public ip of node
  101. if ipnet.Contains(net.ParseIP(nodeEndpointArr[0])) { // ensuring egress gateway range does not contain public ip of node
  102. ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.Endpoint+", omitting", 2)
  103. continue // skip adding egress range if overlaps with node's ip
  104. }
  105. if ipnet.Contains(net.ParseIP(node.LocalAddress)) { // ensuring egress gateway range does not contain public ip of node
  106. ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.LocalAddress+", omitting", 2)
  107. continue // skip adding egress range if overlaps with node's local ip
  108. }
  109. gateways = append(gateways, iprange)
  110. if err != nil {
  111. log.Println("ERROR ENCOUNTERED SETTING GATEWAY")
  112. } else {
  113. allowedips = append(allowedips, *ipnet)
  114. }
  115. }
  116. }
  117. if node.Address6 != "" && dualstack {
  118. var addr6 = net.IPNet{
  119. IP: net.ParseIP(node.Address6),
  120. Mask: net.CIDRMask(128, 128),
  121. }
  122. allowedips = append(allowedips, addr6)
  123. }
  124. return allowedips
  125. }