| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655 | package logicimport (	"crypto/md5"	"encoding/json"	"errors"	"fmt"	"os"	"reflect"	"sort"	"sync"	"github.com/google/uuid"	"golang.org/x/crypto/bcrypt"	"golang.org/x/exp/slog"	"github.com/gravitl/netmaker/database"	"github.com/gravitl/netmaker/logger"	"github.com/gravitl/netmaker/models"	"github.com/gravitl/netmaker/servercfg")var (	hostCacheMutex = &sync.RWMutex{}	hostsCacheMap  = make(map[string]models.Host))var (	// ErrHostExists error indicating that host exists when trying to create new host	ErrHostExists error = errors.New("host already exists")	// ErrInvalidHostID	ErrInvalidHostID error = errors.New("invalid host id"))var GetHostLocInfo = func(ip, token string) string { return "" }func getHostsFromCache() (hosts []models.Host) {	hostCacheMutex.RLock()	for _, host := range hostsCacheMap {		hosts = append(hosts, host)	}	hostCacheMutex.RUnlock()	return}func getHostsMapFromCache() (hostsMap map[string]models.Host) {	hostCacheMutex.RLock()	hostsMap = hostsCacheMap	hostCacheMutex.RUnlock()	return}func getHostFromCache(hostID string) (host models.Host, ok bool) {	hostCacheMutex.RLock()	host, ok = hostsCacheMap[hostID]	hostCacheMutex.RUnlock()	return}func storeHostInCache(h models.Host) {	hostCacheMutex.Lock()	hostsCacheMap[h.ID.String()] = h	hostCacheMutex.Unlock()}func deleteHostFromCache(hostID string) {	hostCacheMutex.Lock()	delete(hostsCacheMap, hostID)	hostCacheMutex.Unlock()}func loadHostsIntoCache(hMap map[string]models.Host) {	hostCacheMutex.Lock()	hostsCacheMap = hMap	hostCacheMutex.Unlock()}const (	maxPort = 1<<16 - 1	minPort = 1025)// GetAllHosts - returns all hosts in flat list or errorfunc GetAllHosts() ([]models.Host, error) {	var currHosts []models.Host	if servercfg.CacheEnabled() {		currHosts := getHostsFromCache()		if len(currHosts) != 0 {			return currHosts, nil		}	}	records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)	if err != nil && !database.IsEmptyRecord(err) {		return nil, err	}	currHostsMap := make(map[string]models.Host)	if servercfg.CacheEnabled() {		defer loadHostsIntoCache(currHostsMap)	}	for k := range records {		var h models.Host		err = json.Unmarshal([]byte(records[k]), &h)		if err != nil {			return nil, err		}		currHosts = append(currHosts, h)		currHostsMap[h.ID.String()] = h	}	return currHosts, nil}// GetAllHostsWithStatus - returns all hosts with at least one// node with given status.func GetAllHostsWithStatus(status models.NodeStatus) ([]models.Host, error) {	hosts, err := GetAllHosts()	if err != nil {		return nil, err	}	var validHosts []models.Host	for _, host := range hosts {		if len(host.Nodes) == 0 {			continue		}		nodes := GetHostNodes(&host)		for _, node := range nodes {			getNodeCheckInStatus(&node, false)			if node.Status == status {				validHosts = append(validHosts, host)				break			}		}	}	return validHosts, nil}// GetAllHostsAPI - get's all the hosts in an API usable formatfunc GetAllHostsAPI(hosts []models.Host) []models.ApiHost {	apiHosts := []models.ApiHost{}	for i := range hosts {		newApiHost := hosts[i].ConvertNMHostToAPI()		apiHosts = append(apiHosts, *newApiHost)	}	return apiHosts[:]}// GetHostsMap - gets all the current hosts on machine in a mapfunc GetHostsMap() (map[string]models.Host, error) {	if servercfg.CacheEnabled() {		hostsMap := getHostsMapFromCache()		if len(hostsMap) != 0 {			return hostsMap, nil		}	}	records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)	if err != nil && !database.IsEmptyRecord(err) {		return nil, err	}	currHostMap := make(map[string]models.Host)	if servercfg.CacheEnabled() {		defer loadHostsIntoCache(currHostMap)	}	for k := range records {		var h models.Host		err = json.Unmarshal([]byte(records[k]), &h)		if err != nil {			return nil, err		}		currHostMap[h.ID.String()] = h	}	return currHostMap, nil}func DoesHostExistinTheNetworkAlready(h *models.Host, network models.NetworkID) bool {	if len(h.Nodes) > 0 {		for _, nodeID := range h.Nodes {			node, err := GetNodeByID(nodeID)			if err == nil && node.Network == network.String() {				return true			}		}	}	return false}// GetHost - gets a host from db given idfunc GetHost(hostid string) (*models.Host, error) {	if servercfg.CacheEnabled() {		if host, ok := getHostFromCache(hostid); ok {			return &host, nil		}	}	record, err := database.FetchRecord(database.HOSTS_TABLE_NAME, hostid)	if err != nil {		return nil, err	}	var h models.Host	if err = json.Unmarshal([]byte(record), &h); err != nil {		return nil, err	}	if servercfg.CacheEnabled() {		storeHostInCache(h)	}	return &h, nil}// GetHostByPubKey - gets a host from db given pubkeyfunc GetHostByPubKey(hostPubKey string) (*models.Host, error) {	hosts, err := GetAllHosts()	if err != nil {		return nil, err	}	for _, host := range hosts {		if host.PublicKey.String() == hostPubKey {			return &host, nil		}	}	return nil, errors.New("host not found")}// CreateHost - creates a host if not existfunc CreateHost(h *models.Host) error {	hosts, hErr := GetAllHosts()	clients, cErr := GetAllExtClients()	if (hErr != nil && !database.IsEmptyRecord(hErr)) ||		(cErr != nil && !database.IsEmptyRecord(cErr)) ||		len(hosts)+len(clients) >= MachinesLimit {		return errors.New("free tier limits exceeded on machines")	}	_, err := GetHost(h.ID.String())	if (err != nil && !database.IsEmptyRecord(err)) || (err == nil) {		return ErrHostExists	}	// encrypt that password so we never see it	hash, err := bcrypt.GenerateFromPassword([]byte(h.HostPass), 5)	if err != nil {		return err	}	h.HostPass = string(hash)	h.AutoUpdate = AutoUpdateEnabled()	if GetServerSettings().ManageDNS {		h.DNS = "yes"	} else {		h.DNS = "no"	}	if h.EndpointIP != nil {		h.Location = GetHostLocInfo(h.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))	} else if h.EndpointIPv6 != nil {		h.Location = GetHostLocInfo(h.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))	}	checkForZombieHosts(h)	return UpsertHost(h)}// UpdateHost - updates host data by fieldfunc UpdateHost(newHost, currentHost *models.Host) {	// unchangeable fields via API here	newHost.DaemonInstalled = currentHost.DaemonInstalled	newHost.OS = currentHost.OS	newHost.IPForwarding = currentHost.IPForwarding	newHost.HostPass = currentHost.HostPass	newHost.MacAddress = currentHost.MacAddress	newHost.Debug = currentHost.Debug	newHost.Nodes = currentHost.Nodes	newHost.PublicKey = currentHost.PublicKey	newHost.TrafficKeyPublic = currentHost.TrafficKeyPublic	// changeable fields	if len(newHost.Version) == 0 {		newHost.Version = currentHost.Version	}	if len(newHost.Name) == 0 {		newHost.Name = currentHost.Name	}	if newHost.MTU == 0 {		newHost.MTU = currentHost.MTU	}	if newHost.ListenPort == 0 {		newHost.ListenPort = currentHost.ListenPort	}	if newHost.PersistentKeepalive == 0 {		newHost.PersistentKeepalive = currentHost.PersistentKeepalive	}}// UpdateHostFromClient - used for updating host on server with update recieved from clientfunc UpdateHostFromClient(newHost, currHost *models.Host) (sendPeerUpdate bool) {	if newHost.PublicKey != currHost.PublicKey {		currHost.PublicKey = newHost.PublicKey		sendPeerUpdate = true	}	if newHost.ListenPort != 0 && currHost.ListenPort != newHost.ListenPort {		currHost.ListenPort = newHost.ListenPort		sendPeerUpdate = true	}	if newHost.WgPublicListenPort != 0 &&		currHost.WgPublicListenPort != newHost.WgPublicListenPort {		currHost.WgPublicListenPort = newHost.WgPublicListenPort		sendPeerUpdate = true	}	isEndpointChanged := false	if currHost.EndpointIP.String() != newHost.EndpointIP.String() {		currHost.EndpointIP = newHost.EndpointIP		sendPeerUpdate = true		isEndpointChanged = true	}	if currHost.EndpointIPv6.String() != newHost.EndpointIPv6.String() {		currHost.EndpointIPv6 = newHost.EndpointIPv6		sendPeerUpdate = true		isEndpointChanged = true	}	if !reflect.DeepEqual(currHost.Interfaces, newHost.Interfaces) {		currHost.Interfaces = newHost.Interfaces		sendPeerUpdate = true	}	if isEndpointChanged {		for _, nodeID := range currHost.Nodes {			node, err := GetNodeByID(nodeID)			if err != nil {				slog.Error("failed to get node:", "id", node.ID, "error", err)				continue			}			if node.FailedOverBy != uuid.Nil {				ResetFailedOverPeer(&node)			}		}	}	currHost.DaemonInstalled = newHost.DaemonInstalled	currHost.Debug = newHost.Debug	currHost.Verbosity = newHost.Verbosity	currHost.Version = newHost.Version	currHost.IsStaticPort = newHost.IsStaticPort	currHost.IsStatic = newHost.IsStatic	currHost.MTU = newHost.MTU	currHost.Name = newHost.Name	if len(newHost.NatType) > 0 && newHost.NatType != currHost.NatType {		currHost.NatType = newHost.NatType		sendPeerUpdate = true	}	return}// UpsertHost - upserts into DB a given host model, does not check for existence*func UpsertHost(h *models.Host) error {	data, err := json.Marshal(h)	if err != nil {		return err	}	err = database.Insert(h.ID.String(), string(data), database.HOSTS_TABLE_NAME)	if err != nil {		return err	}	if servercfg.CacheEnabled() {		storeHostInCache(*h)	}	return nil}// RemoveHost - removes a given host from serverfunc RemoveHost(h *models.Host, forceDelete bool) error {	if !forceDelete && len(h.Nodes) > 0 {		return fmt.Errorf("host still has associated nodes")	}	if len(h.Nodes) > 0 {		if err := DisassociateAllNodesFromHost(h.ID.String()); err != nil {			return err		}	}	err := database.DeleteRecord(database.HOSTS_TABLE_NAME, h.ID.String())	if err != nil {		return err	}	if servercfg.CacheEnabled() {		deleteHostFromCache(h.ID.String())	}	go func() {		if servercfg.IsDNSMode() {			SetDNS()		}	}()	return nil}// RemoveHostByID - removes a given host by id from serverfunc RemoveHostByID(hostID string) error {	err := database.DeleteRecord(database.HOSTS_TABLE_NAME, hostID)	if err != nil {		return err	}	if servercfg.CacheEnabled() {		deleteHostFromCache(hostID)	}	return nil}// UpdateHostNetwork - adds/deletes host from a networkfunc UpdateHostNetwork(h *models.Host, network string, add bool) (*models.Node, error) {	for _, nodeID := range h.Nodes {		node, err := GetNodeByID(nodeID)		if err != nil || node.PendingDelete {			continue		}		if node.Network == network {			if !add {				return &node, nil			} else {				return &node, errors.New("host already part of network " + network)			}		}	}	if !add {		return nil, errors.New("host not part of the network " + network)	} else {		newNode := models.Node{}		newNode.Server = servercfg.GetServer()		newNode.Network = network		newNode.HostID = h.ID		if err := AssociateNodeToHost(&newNode, h); err != nil {			return nil, err		}		return &newNode, nil	}}// AssociateNodeToHost - associates and creates a node with a given host// should be the only way nodes get created as of 0.18func AssociateNodeToHost(n *models.Node, h *models.Host) error {	if len(h.ID.String()) == 0 || h.ID == uuid.Nil {		return ErrInvalidHostID	}	n.HostID = h.ID	err := createNode(n)	if err != nil {		return err	}	currentHost, err := GetHost(h.ID.String())	if err != nil {		return err	}	h.HostPass = currentHost.HostPass	h.Nodes = append(currentHost.Nodes, n.ID.String())	return UpsertHost(h)}// DissasociateNodeFromHost - deletes a node and removes from host nodes// should be the only way nodes are deleted as of 0.18func DissasociateNodeFromHost(n *models.Node, h *models.Host) error {	if len(h.ID.String()) == 0 || h.ID == uuid.Nil {		return ErrInvalidHostID	}	if n.HostID != h.ID { // check if node actually belongs to host		return fmt.Errorf("node is not associated with host")	}	if len(h.Nodes) == 0 {		return fmt.Errorf("no nodes present in given host")	}	nList := []string{}	for i := range h.Nodes {		if h.Nodes[i] != n.ID.String() {			nList = append(nList, h.Nodes[i])		}	}	h.Nodes = nList	go func() {		if servercfg.IsPro {			if clients, err := GetNetworkExtClients(n.Network); err != nil {				for i := range clients {					AllowClientNodeAccess(&clients[i], n.ID.String())				}			}		}	}()	if err := DeleteNodeByID(n); err != nil {		return err	}	return UpsertHost(h)}// DisassociateAllNodesFromHost - deletes all nodes of the hostfunc DisassociateAllNodesFromHost(hostID string) error {	host, err := GetHost(hostID)	if err != nil {		return err	}	for _, nodeID := range host.Nodes {		node, err := GetNodeByID(nodeID)		if err != nil {			logger.Log(0, "failed to get host node, node id:", nodeID, err.Error())			continue		}		if err := DeleteNode(&node, true); err != nil {			logger.Log(0, "failed to delete node", node.ID.String(), err.Error())			continue		}		logger.Log(3, "deleted node", node.ID.String(), "of host", host.ID.String())	}	host.Nodes = []string{}	return UpsertHost(host)}// GetDefaultHosts - retrieve all hosts marked as default from DBfunc GetDefaultHosts() []models.Host {	defaultHostList := []models.Host{}	hosts, err := GetAllHosts()	if err != nil {		return defaultHostList	}	for i := range hosts {		if hosts[i].IsDefault {			defaultHostList = append(defaultHostList, hosts[i])		}	}	return defaultHostList[:]}// GetHostNetworks - fetches all the networksfunc GetHostNetworks(hostID string) []string {	currHost, err := GetHost(hostID)	if err != nil {		return nil	}	nets := []string{}	for i := range currHost.Nodes {		n, err := GetNodeByID(currHost.Nodes[i])		if err != nil {			return nil		}		nets = append(nets, n.Network)	}	return nets}// GetRelatedHosts - fetches related hosts of a given hostfunc GetRelatedHosts(hostID string) []models.Host {	relatedHosts := []models.Host{}	networks := GetHostNetworks(hostID)	networkMap := make(map[string]struct{})	for _, network := range networks {		networkMap[network] = struct{}{}	}	hosts, err := GetAllHosts()	if err == nil {		for _, host := range hosts {			if host.ID.String() == hostID {				continue			}			networks := GetHostNetworks(host.ID.String())			for _, network := range networks {				if _, ok := networkMap[network]; ok {					relatedHosts = append(relatedHosts, host)					break				}			}		}	}	return relatedHosts}// CheckHostPort checks host endpoints to ensures that hosts on the same server// with the same endpoint have different listen ports// in the case of 64535 hosts or more with same endpoint, ports will not be changedfunc CheckHostPorts(h *models.Host) (changed bool) {	portsInUse := make(map[int]bool, 0)	hosts, err := GetAllHosts()	if err != nil {		return	}	originalPort := h.ListenPort	defer func() {		if originalPort != h.ListenPort {			changed = true		}	}()	if h.EndpointIP == nil {		return	}	for _, host := range hosts {		if host.ID.String() == h.ID.String() {			// skip self			continue		}		if host.EndpointIP == nil {			continue		}		if !host.EndpointIP.Equal(h.EndpointIP) {			continue		}		portsInUse[host.ListenPort] = true	}	// iterate until port is not found or max iteration is reached	for i := 0; portsInUse[h.ListenPort] && i < maxPort-minPort+1; i++ {		if h.ListenPort == 443 {			h.ListenPort = 51821		} else {			h.ListenPort++		}		if h.ListenPort > maxPort {			h.ListenPort = minPort		}	}	return}// HostExists - checks if given host already existsfunc HostExists(h *models.Host) bool {	_, err := GetHost(h.ID.String())	return (err != nil && !database.IsEmptyRecord(err)) || (err == nil)}// GetHostByNodeID - returns a host if found to have a node's ID, else nilfunc GetHostByNodeID(id string) *models.Host {	hosts, err := GetAllHosts()	if err != nil {		return nil	}	for i := range hosts {		h := hosts[i]		if StringSliceContains(h.Nodes, id) {			return &h		}	}	return nil}// ConvHostPassToHash - converts password to md5 hashfunc ConvHostPassToHash(hostPass string) string {	return fmt.Sprintf("%x", md5.Sum([]byte(hostPass)))}// SortApiHosts - Sorts slice of ApiHosts by their ID alphabetically with numbers firstfunc SortApiHosts(unsortedHosts []models.ApiHost) {	sort.Slice(unsortedHosts, func(i, j int) bool {		return unsortedHosts[i].ID < unsortedHosts[j].ID	})}
 |