Browse Source

update peer action in proxy manager

Abhishek Kondur 2 years ago
parent
commit
104fe8824f

+ 1 - 1
nm-proxy/common/common.go

@@ -70,7 +70,7 @@ type RemotePeer struct {
 
 
 var WgIFaceMap = make(map[string]map[string]*Conn)
 var WgIFaceMap = make(map[string]map[string]*Conn)
 
 
-var RemoteEndpointsMap = make(map[string][]RemotePeer)
+var PeerKeyHashMap = make(map[string]RemotePeer)
 
 
 // RunCmd - runs a local command
 // RunCmd - runs a local command
 func RunCmd(command string, printerr bool) (string, error) {
 func RunCmd(command string, printerr bool) (string, error) {

+ 34 - 16
nm-proxy/manager/manager.go

@@ -1,7 +1,9 @@
 package manager
 package manager
 
 
 import (
 import (
+	"crypto/md5"
 	"errors"
 	"errors"
+	"fmt"
 	"log"
 	"log"
 	"runtime"
 	"runtime"
 
 
@@ -22,6 +24,7 @@ type ManagerPayload struct {
 const (
 const (
 	AddInterface    ProxyAction = "ADD_INTERFACE"
 	AddInterface    ProxyAction = "ADD_INTERFACE"
 	DeleteInterface ProxyAction = "DELETE_INTERFACE"
 	DeleteInterface ProxyAction = "DELETE_INTERFACE"
+	UpdatePeer      ProxyAction = "UPDATE_PEER"
 )
 )
 
 
 type ManagerAction struct {
 type ManagerAction struct {
@@ -37,7 +40,12 @@ func StartProxyManager(manageChan chan *ManagerAction) {
 			log.Printf("-------> PROXY-MANAGER: %+v\n", mI)
 			log.Printf("-------> PROXY-MANAGER: %+v\n", mI)
 			switch mI.Action {
 			switch mI.Action {
 			case AddInterface:
 			case AddInterface:
-				mI.AddInterfaceToProxy()
+				err := mI.AddInterfaceToProxy()
+				if err != nil {
+					log.Printf("failed to add interface: [%s] to proxy: %v\n  ", mI.Payload.InterfaceName, err)
+				}
+			case UpdatePeer:
+				mI.UpdatePeerProxy()
 			}
 			}
 
 
 		}
 		}
@@ -53,6 +61,25 @@ func cleanUp(iface string) {
 	delete(common.WgIFaceMap, iface)
 	delete(common.WgIFaceMap, iface)
 }
 }
 
 
+func (m *ManagerAction) UpdatePeerProxy() error {
+	if len(m.Payload.Peers) == 0 {
+		log.Println("No Peers to add...")
+		return nil
+	}
+	for _, peerI := range m.Payload.Peers {
+		if peers, ok := common.WgIFaceMap[m.Payload.InterfaceName]; ok {
+			if peerConf, ok := peers[peerI.PublicKey.String()]; ok {
+
+				peerConf.Config.RemoteWgPort = peerI.Endpoint.Port
+				peers[peerI.PublicKey.String()] = peerConf
+				common.WgIFaceMap[m.Payload.InterfaceName] = peers
+				log.Printf("---->####### UPdated PEER: %+v\n", peerConf)
+			}
+		}
+	}
+	return nil
+}
+
 func (m *ManagerAction) AddInterfaceToProxy() error {
 func (m *ManagerAction) AddInterfaceToProxy() error {
 	var err error
 	var err error
 	if m.Payload.InterfaceName == "" {
 	if m.Payload.InterfaceName == "" {
@@ -77,23 +104,14 @@ func (m *ManagerAction) AddInterfaceToProxy() error {
 		log.Fatal("Failed init new interface: ", err)
 		log.Fatal("Failed init new interface: ", err)
 	}
 	}
 	log.Printf("wg: %+v\n", wgInterface)
 	log.Printf("wg: %+v\n", wgInterface)
-	for _, peerI := range m.Payload.Peers {
 
 
-		peerpkg.AddNewPeer(wgInterface, &peerI)
-		if val, ok := common.RemoteEndpointsMap[peerI.Endpoint.IP.String()]; ok {
-
-			val = append(val, common.RemotePeer{
-				Interface: ifaceName,
-				PeerKey:   peerI.PublicKey.String(),
-			})
-			common.RemoteEndpointsMap[peerI.Endpoint.IP.String()] = val
-		} else {
-			common.RemoteEndpointsMap[peerI.Endpoint.IP.String()] = []common.RemotePeer{{
-				Interface: ifaceName,
-				PeerKey:   peerI.PublicKey.String(),
-			}}
+	for _, peerI := range m.Payload.Peers {
+		common.PeerKeyHashMap[fmt.Sprintf("%x", md5.Sum([]byte(peerI.PublicKey.String())))] = common.RemotePeer{
+			Interface: ifaceName,
+			PeerKey:   peerI.PublicKey.String(),
 		}
 		}
-
+		peerpkg.AddNewPeer(wgInterface, &peerI)
 	}
 	}
+	log.Printf("------> PEERHASHMAP: %+v\n", common.PeerKeyHashMap)
 	return nil
 	return nil
 }
 }

+ 4 - 1
nm-proxy/nm-proxy.go

@@ -15,12 +15,15 @@ import (
    2. Delete - remove close all conns for the interface,cleanup
    2. Delete - remove close all conns for the interface,cleanup
 
 
 */
 */
-func Start(mgmChan chan *manager.ManagerAction) {
+func Start(mgmChan chan *manager.ManagerAction, isServer bool) {
 	log.Println("Starting Proxy...")
 	log.Println("Starting Proxy...")
 	go manager.StartProxyManager(mgmChan)
 	go manager.StartProxyManager(mgmChan)
 	hInfo := stun.GetHostInfo()
 	hInfo := stun.GetHostInfo()
 	stun.Host = hInfo
 	stun.Host = hInfo
 	log.Printf("HOSTINFO: %+v", hInfo)
 	log.Printf("HOSTINFO: %+v", hInfo)
+	if IsPublicIP(hInfo.PrivIp) {
+		log.Println("Host is public facing!!!")
+	}
 	// start the netclient proxy server
 	// start the netclient proxy server
 	err := server.NmProxyServer.CreateProxyServer(0, 0, hInfo.PrivIp.String())
 	err := server.NmProxyServer.CreateProxyServer(0, 0, hInfo.PrivIp.String())
 	if err != nil {
 	if err != nil {

+ 14 - 6
nm-proxy/packet/packet.go

@@ -2,38 +2,46 @@ package packet
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"crypto/md5"
 	"encoding/binary"
 	"encoding/binary"
+	"fmt"
 	"log"
 	"log"
 )
 )
 
 
 var udpHeaderLen = 8
 var udpHeaderLen = 8
 
 
-func ProcessPacketBeforeSending(buf []byte, n, dstPort int) ([]byte, int, error) {
+func ProcessPacketBeforeSending(buf []byte, srckey string, n, dstPort int) ([]byte, int, error) {
 	log.Println("@###### DST Port: ", dstPort)
 	log.Println("@###### DST Port: ", dstPort)
 	portbuf := new(bytes.Buffer)
 	portbuf := new(bytes.Buffer)
 	binary.Write(portbuf, binary.BigEndian, uint16(dstPort))
 	binary.Write(portbuf, binary.BigEndian, uint16(dstPort))
-	if n > len(buf)-2 {
+	hmd5 := md5.Sum([]byte(srckey))
+	log.Printf("---> HASH: %x ", hmd5)
+	if n > len(buf)-18 {
 		buf = append(buf, portbuf.Bytes()[0])
 		buf = append(buf, portbuf.Bytes()[0])
 		buf = append(buf, portbuf.Bytes()[1])
 		buf = append(buf, portbuf.Bytes()[1])
+		buf = append(buf, hmd5[:]...)
 	} else {
 	} else {
 		buf[n] = portbuf.Bytes()[0]
 		buf[n] = portbuf.Bytes()[0]
 		buf[n+1] = portbuf.Bytes()[1]
 		buf[n+1] = portbuf.Bytes()[1]
+		copy(buf[n+2:n+2+len(hmd5)], hmd5[:])
 	}
 	}
 
 
 	n += 2
 	n += 2
+	n += len(hmd5)
 
 
 	return buf, n, nil
 	return buf, n, nil
 }
 }
 
 
-func ExtractInfo(buffer []byte, n int) (int, int, error) {
+func ExtractInfo(buffer []byte, n int) (int, int, string, error) {
 	data := buffer[:n]
 	data := buffer[:n]
 	var localWgPort uint16
 	var localWgPort uint16
-	portBuf := data[n-2 : n+1]
+	portBuf := data[n-18 : n-18+3]
+	keyHash := data[n-16:]
 	reader := bytes.NewReader(portBuf)
 	reader := bytes.NewReader(portBuf)
 	err := binary.Read(reader, binary.BigEndian, &localWgPort)
 	err := binary.Read(reader, binary.BigEndian, &localWgPort)
 	if err != nil {
 	if err != nil {
 		log.Println("Failed to read port buffer: ", err)
 		log.Println("Failed to read port buffer: ", err)
 	}
 	}
-	n -= 2
-	return int(localWgPort), n, err
+	n -= 18
+	return int(localWgPort), n, fmt.Sprintf("%x", keyHash), err
 }
 }

+ 3 - 1
nm-proxy/peer/peer.go

@@ -45,6 +45,7 @@ func AddNewPeer(wgInterface *wg.WGIface, peer *wgtypes.PeerConfig) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
+	log.Printf("----> Established Remote Conn with RPeer: %s, LAddr: %s ----> RAddr: %s", peer.PublicKey, remoteConn.LocalAddr().String(), remoteConn.RemoteAddr().String())
 	log.Printf("Starting proxy for Peer: %s\n", peer.PublicKey.String())
 	log.Printf("Starting proxy for Peer: %s\n", peer.PublicKey.String())
 	err = p.Start(remoteConn)
 	err = p.Start(remoteConn)
 	if err != nil {
 	if err != nil {
@@ -52,7 +53,8 @@ func AddNewPeer(wgInterface *wg.WGIface, peer *wgtypes.PeerConfig) error {
 	}
 	}
 	connConf := common.ConnConfig{
 	connConf := common.ConnConfig{
 		Key:             peer.PublicKey.String(),
 		Key:             peer.PublicKey.String(),
-		LocalKey:        "",
+		LocalKey:        wgInterface.Device.PublicKey.String(),
+		LocalWgPort:     wgInterface.Device.ListenPort,
 		RemoteProxyIP:   net.ParseIP(peer.Endpoint.IP.String()),
 		RemoteProxyIP:   net.ParseIP(peer.Endpoint.IP.String()),
 		RemoteWgPort:    peer.Endpoint.Port,
 		RemoteWgPort:    peer.Endpoint.Port,
 		RemoteProxyPort: common.NmProxyPort,
 		RemoteProxyPort: common.NmProxyPort,

+ 26 - 0
nm-proxy/proxy/proxy.go

@@ -2,6 +2,8 @@ package proxy
 
 
 import (
 import (
 	"context"
 	"context"
+	"errors"
+	"fmt"
 	"net"
 	"net"
 
 
 	"github.com/gravitl/netmaker/nm-proxy/wg"
 	"github.com/gravitl/netmaker/nm-proxy/wg"
@@ -18,6 +20,7 @@ type Config struct {
 	BodySize     int
 	BodySize     int
 	Addr         string
 	Addr         string
 	RemoteKey    string
 	RemoteKey    string
+	LocalKey     string
 	WgInterface  *wg.WGIface
 	WgInterface  *wg.WGIface
 	AllowedIps   []net.IPNet
 	AllowedIps   []net.IPNet
 	PreSharedKey *wgtypes.Key
 	PreSharedKey *wgtypes.Key
@@ -32,3 +35,26 @@ type Proxy struct {
 	RemoteConn net.Conn
 	RemoteConn net.Conn
 	LocalConn  net.Conn
 	LocalConn  net.Conn
 }
 }
+
+func GetInterfaceIpv4Addr(interfaceName string) (addr string, err error) {
+	var (
+		ief      *net.Interface
+		addrs    []net.Addr
+		ipv4Addr net.IP
+	)
+	if ief, err = net.InterfaceByName(interfaceName); err != nil { // get interface
+		return
+	}
+	if addrs, err = ief.Addrs(); err != nil { // get addresses
+		return
+	}
+	for _, addr := range addrs { // get ipv4 address
+		if ipv4Addr = addr.(*net.IPNet).IP.To4(); ipv4Addr != nil {
+			break
+		}
+	}
+	if ipv4Addr == nil {
+		return "", errors.New(fmt.Sprintf("interface %s don't have an ipv4 address\n", interfaceName))
+	}
+	return ipv4Addr.String(), nil
+}

+ 11 - 11
nm-proxy/proxy/wireguard.go

@@ -62,7 +62,7 @@ func (p *Proxy) ProxyToRemote() {
 			if peerI, ok := peers[p.Config.RemoteKey]; ok {
 			if peerI, ok := peers[p.Config.RemoteKey]; ok {
 				log.Println("PROCESSING PKT BEFORE SENDING")
 				log.Println("PROCESSING PKT BEFORE SENDING")
 
 
-				buf, n, err = packet.ProcessPacketBeforeSending(buf, n, peerI.Config.RemoteWgPort)
+				buf, n, err = packet.ProcessPacketBeforeSending(buf, peerI.Config.LocalKey, n, peerI.Config.RemoteWgPort)
 				if err != nil {
 				if err != nil {
 					log.Println("failed to process pkt before sending: ", err)
 					log.Println("failed to process pkt before sending: ", err)
 				}
 				}
@@ -109,12 +109,11 @@ func (p *Proxy) Start(remoteConn net.Conn) error {
 	// 	log.Println("Failed to get iface: ", p.Config.WgInterface.Name, err)
 	// 	log.Println("Failed to get iface: ", p.Config.WgInterface.Name, err)
 	// 	return err
 	// 	return err
 	// }
 	// }
-	wgPort, err := p.Config.WgInterface.GetListenPort()
-	if err != nil {
-		log.Printf("Failed to get listen port for iface: %s,Err: %v\n", p.Config.WgInterface.Name, err)
-		return err
-	}
-	p.Config.WgInterface.Port = *wgPort
+	// wgAddr, err := GetInterfaceIpv4Addr(p.Config.WgInterface.Name)
+	// if err != nil {
+	// 	log.Println("failed to get interface addr: ", err)
+	// 	return err
+	// }
 	log.Printf("----> WGIFACE: %+v\n", p.Config.WgInterface)
 	log.Printf("----> WGIFACE: %+v\n", p.Config.WgInterface)
 	addr, err := GetFreeIp("127.0.0.1/8", p.Config.WgInterface.Port)
 	addr, err := GetFreeIp("127.0.0.1/8", p.Config.WgInterface.Port)
 	if err != nil {
 	if err != nil {
@@ -179,10 +178,11 @@ func GetFreeIp(cidrAddr string, dstPort int) (string, error) {
 		})
 		})
 		if err != nil {
 		if err != nil {
 			log.Println("----> GetFreeIP ERR: ", err)
 			log.Println("----> GetFreeIP ERR: ", err)
-			if strings.Contains(err.Error(), "can't assign requested address") {
-				newAddrs, err = net4.NextIP(newAddrs)
-				if err != nil {
-					return "", err
+			if strings.Contains(err.Error(), "can't assign requested address") || strings.Contains(err.Error(), "address already in use") {
+				var nErr error
+				newAddrs, nErr = net4.NextIP(newAddrs)
+				if nErr != nil {
+					return "", nErr
 				}
 				}
 			} else {
 			} else {
 				return "", err
 				return "", err

+ 15 - 16
nm-proxy/server/server.go

@@ -44,30 +44,29 @@ func (p *ProxyServer) Listen() {
 			continue
 			continue
 		}
 		}
 		var localWgPort int
 		var localWgPort int
-		localWgPort, n, err = packet.ExtractInfo(buffer, n)
+		var srcPeerKeyHash string
+		localWgPort, n, srcPeerKeyHash, err = packet.ExtractInfo(buffer, n)
 		if err != nil {
 		if err != nil {
 			log.Println("failed to extract info: ", err)
 			log.Println("failed to extract info: ", err)
 			continue
 			continue
 		}
 		}
-		log.Println("--------> RECV PKT: ", source.IP.String(), localWgPort)
-		if val, ok := common.RemoteEndpointsMap[source.IP.String()]; ok {
-			for _, remotePeer := range val {
-				if peers, ok := common.WgIFaceMap[remotePeer.Interface]; ok {
-					if peerI, ok := peers[remotePeer.PeerKey]; ok {
-						if peerI.Config.LocalWgPort == int(localWgPort) {
-							log.Printf("PROXING TO LOCAL!!!---> %s <<<< %s <<<<<<<< %s\n", peerI.Proxy.LocalConn.RemoteAddr(),
-								peerI.Proxy.LocalConn.LocalAddr(), fmt.Sprintf("%s:%d", source.IP.String(), source.Port))
-							_, err = peerI.Proxy.LocalConn.Write(buffer[:n])
-							if err != nil {
-								log.Println("Failed to proxy to Wg local interface: ", err)
-								continue
-							}
-
+		log.Printf("--------> RECV PKT [DSTPORT: %d], [SRCKEYHASH: %s] \n", localWgPort, srcPeerKeyHash)
+		if peerInfo, ok := common.PeerKeyHashMap[srcPeerKeyHash]; ok {
+			if peers, ok := common.WgIFaceMap[peerInfo.Interface]; ok {
+				if peerI, ok := peers[peerInfo.PeerKey]; ok {
+					if peerI.Config.LocalWgPort == int(localWgPort) {
+						log.Printf("PROXING TO LOCAL!!!---> %s <<<< %s <<<<<<<< %s\n", peerI.Proxy.LocalConn.RemoteAddr(),
+							peerI.Proxy.LocalConn.LocalAddr(), fmt.Sprintf("%s:%d", source.IP.String(), source.Port))
+						_, err = peerI.Proxy.LocalConn.Write(buffer[:n])
+						if err != nil {
+							log.Println("Failed to proxy to Wg local interface: ", err)
+							continue
 						}
 						}
+
 					}
 					}
 				}
 				}
-
 			}
 			}
+
 		}
 		}
 
 
 	}
 	}

+ 5 - 2
nm-proxy/wg/wg.go

@@ -22,6 +22,7 @@ type WGIface struct {
 	Name      string
 	Name      string
 	Port      int
 	Port      int
 	MTU       int
 	MTU       int
+	Device    *wgtypes.Device
 	Address   WGAddress
 	Address   WGAddress
 	Interface NetInterface
 	Interface NetInterface
 	mu        sync.Mutex
 	mu        sync.Mutex
@@ -52,7 +53,7 @@ func NewWGIFace(iface string, address string, mtu int) (*WGIface, error) {
 	}
 	}
 
 
 	wgIface.Address = wgAddress
 	wgIface.Address = wgAddress
-
+	wgIface.GetWgIface(iface)
 	return wgIface, nil
 	return wgIface, nil
 }
 }
 
 
@@ -65,8 +66,10 @@ func (w *WGIface) GetWgIface(iface string) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	log.Printf("----> DEVICE: %+v\n", dev)
 
 
+	log.Printf("----> DEVICE: %+v\n", dev)
+	w.Device = dev
+	w.Port = dev.ListenPort
 	return nil
 	return nil
 }
 }