Răsfoiți Sursa

resolve merge conflicts

abhishek9686 9 luni în urmă
părinte
comite
602a8ec92f

+ 1 - 1
compose/docker-compose.yml

@@ -12,7 +12,7 @@ services:
       - sqldata:/root/data
     environment:
       # config-dependant vars
-      - STUN_LIST=stun1.netmaker.io:3478,stun2.netmaker.io:3478,stun1.l.google.com:19302,stun2.l.google.com:19302
+      - STUN_SERVERS=stun1.netmaker.io:3478,stun2.netmaker.io:3478,stun1.l.google.com:19302,stun2.l.google.com:19302
       # The domain/host IP indicating the mq broker address
       - BROKER_ENDPOINT=wss://broker.${NM_DOMAIN} # For EMQX broker use `BROKER_ENDPOINT=wss://broker.${NM_DOMAIN}/mqtt`
       # For EMQX broker (uncomment the two lines below)

+ 2 - 0
config/config.go

@@ -101,6 +101,8 @@ type ServerConfig struct {
 	SmtpPort                   int           `json:"smtp_port"`
 	MetricInterval             string        `yaml:"metric_interval"`
 	ManageDNS                  bool          `yaml:"manage_dns"`
+	Stun                       bool          `yaml:"stun"`
+	StunServers                string        `yaml:"stun_servers"`
 	DefaultDomain              string        `yaml:"default_domain"`
 }
 

+ 1 - 1
controllers/acls.go

@@ -158,7 +158,7 @@ func getAcls(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
-	acls, err := logic.ListAcls(models.NetworkID(netID))
+	acls, err := logic.ListAclsByNetwork(models.NetworkID(netID))
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "failed to get all network acl entries: ", err.Error())
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))

+ 9 - 10
logic/acls.go

@@ -57,7 +57,7 @@ func CreateDefaultAclNetworkPolicies(netID models.NetworkID) {
 	if netID.String() == "" {
 		return
 	}
-	_, _ = ListAcls(netID)
+	_, _ = ListAclsByNetwork(netID)
 	if !IsAclExists(fmt.Sprintf("%s.%s", netID, "all-nodes")) {
 		defaultDeviceAcl := models.Acl{
 			ID:        fmt.Sprintf("%s.%s", netID, "all-nodes"),
@@ -146,7 +146,7 @@ func CreateDefaultAclNetworkPolicies(netID models.NetworkID) {
 
 // DeleteDefaultNetworkPolicies - deletes all default network acl policies
 func DeleteDefaultNetworkPolicies(netId models.NetworkID) {
-	acls, _ := ListAcls(netId)
+	acls, _ := ListAclsByNetwork(netId)
 	for _, acl := range acls {
 		if acl.NetworkID == netId && acl.Default {
 			DeleteAcl(acl)
@@ -394,7 +394,7 @@ func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (mo
 		return acl, nil
 	}
 	// check if there are any custom all policies
-	policies, _ := ListAcls(netID)
+	policies, _ := ListAclsByNetwork(netID)
 	for _, policy := range policies {
 		if !policy.Enabled {
 			continue
@@ -414,8 +414,7 @@ func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (mo
 	return acl, nil
 }
 
-func listAcls() (acls []models.Acl) {
-
+func ListAcls() (acls []models.Acl) {
 	if servercfg.CacheEnabled() && len(aclCacheMap) > 0 {
 		return listAclFromCache()
 	}
@@ -440,7 +439,7 @@ func listAcls() (acls []models.Acl) {
 
 // ListUserPolicies - lists all acl policies enforced on an user
 func ListUserPolicies(u models.User) []models.Acl {
-	allAcls := listAcls()
+	allAcls := ListAcls()
 	userAcls := []models.Acl{}
 	for _, acl := range allAcls {
 
@@ -465,7 +464,7 @@ func ListUserPolicies(u models.User) []models.Acl {
 
 // listPoliciesOfUser - lists all user acl policies applied to user in an network
 func listPoliciesOfUser(user models.User, netID models.NetworkID) []models.Acl {
-	allAcls := listAcls()
+	allAcls := ListAcls()
 	userAcls := []models.Acl{}
 	for _, acl := range allAcls {
 		if acl.NetworkID == netID && acl.RuleType == models.UserPolicy {
@@ -494,7 +493,7 @@ func listPoliciesOfUser(user models.User, netID models.NetworkID) []models.Acl {
 
 // listDevicePolicies - lists all device policies in a network
 func listDevicePolicies(netID models.NetworkID) []models.Acl {
-	allAcls := listAcls()
+	allAcls := ListAcls()
 	deviceAcls := []models.Acl{}
 	for _, acl := range allAcls {
 		if acl.NetworkID == netID && acl.RuleType == models.DevicePolicy {
@@ -505,9 +504,9 @@ func listDevicePolicies(netID models.NetworkID) []models.Acl {
 }
 
 // ListAcls - lists all acl policies
-func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
+func ListAclsByNetwork(netID models.NetworkID) ([]models.Acl, error) {
 
-	allAcls := listAcls()
+	allAcls := ListAcls()
 	netAcls := []models.Acl{}
 	for _, acl := range allAcls {
 		if acl.NetworkID == netID {

+ 2 - 0
logic/peers.go

@@ -444,6 +444,8 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 	}
 
 	hostPeerUpdate.ManageDNS = servercfg.GetManageDNS()
+	hostPeerUpdate.Stun = servercfg.IsStunEnabled()
+	hostPeerUpdate.StunServers = servercfg.GetStunServers()
 	return hostPeerUpdate, nil
 }
 

+ 9 - 0
main.go

@@ -99,6 +99,15 @@ func initialize() { // Client Mode Prereq Check
 		logger.FatalLog("Error connecting to database: ", err.Error())
 	}
 	logger.Log(0, "database successfully connected")
+
+	//initialize cache
+	_, _ = logic.GetNetworks()
+	_, _ = logic.GetAllNodes()
+	_, _ = logic.GetAllHosts()
+	_, _ = logic.GetAllExtClients()
+	_ = logic.ListAcls()
+	_, _ = logic.GetAllEnrollmentKeys()
+
 	migrate.Run()
 
 	logic.SetJWTSecret()

+ 0 - 2
migrate/migrate.go

@@ -20,8 +20,6 @@ import (
 
 // Run - runs all migrations
 func Run() {
-	_, _ = logic.GetAllNodes()
-	_, _ = logic.GetAllHosts()
 	updateEnrollmentKeys()
 	assignSuperAdmin()
 	createDefaultTagsAndPolicies()

+ 2 - 0
models/mqtt.go

@@ -25,6 +25,8 @@ type HostPeerUpdate struct {
 	ReplacePeers      bool                  `json:"replace_peers"`
 	EndpointDetection bool                  `json:"endpoint_detection"`
 	ManageDNS         bool                  `yaml:"manage_dns"`
+	Stun              bool                  `yaml:"stun"`
+	StunServers       string                `yaml:"stun_servers"`
 }
 
 type FwRule struct {

+ 2 - 0
models/structs.go

@@ -267,6 +267,8 @@ type ServerConfig struct {
 	TrafficKey     []byte `yaml:"traffickey"`
 	MetricInterval string `yaml:"metric_interval"`
 	ManageDNS      bool   `yaml:"manage_dns"`
+	Stun           bool   `yaml:"stun"`
+	StunServers    string `yaml:"stun_servers"`
 	DefaultDomain  string `yaml:"default_domain"`
 }
 

+ 2 - 0
scripts/netmaker.default.env

@@ -94,3 +94,5 @@ PEER_UPDATE_BATCH_SIZE=50
 DEFAULT_DOMAIN=netmaker.hosted
 # managed dns setting, set to true to resolve dns entries on netmaker network
 MANAGE_DNS=false
+# if STUN is set to true, hole punch is called
+STUN=true

+ 17 - 0
servercfg/serverconf.go

@@ -94,6 +94,8 @@ func GetServerConfig() config.ServerConfig {
 	cfg.RacAutoDisable = GetRacAutoDisable()
 	cfg.MetricInterval = GetMetricInterval()
 	cfg.ManageDNS = GetManageDNS()
+	cfg.Stun = IsStunEnabled()
+	cfg.StunServers = GetStunServers()
 	cfg.DefaultDomain = GetDefaultDomain()
 	return cfg
 }
@@ -140,6 +142,8 @@ func GetServerInfo() models.ServerConfig {
 	cfg.IsPro = IsPro
 	cfg.MetricInterval = GetMetricInterval()
 	cfg.ManageDNS = GetManageDNS()
+	cfg.Stun = IsStunEnabled()
+	cfg.StunServers = GetStunServers()
 	cfg.DefaultDomain = GetDefaultDomain()
 	return cfg
 }
@@ -805,6 +809,19 @@ func IsEndpointDetectionEnabled() bool {
 	return enabled
 }
 
+// IsStunEnabled - returns true if STUN set to on
+func IsStunEnabled() bool {
+	var enabled = true
+	if os.Getenv("STUN") != "" {
+		enabled = os.Getenv("STUN") == "true"
+	}
+	return enabled
+}
+
+func GetStunServers() string {
+	return os.Getenv("STUN_SERVERS")
+}
+
 // GetEnvironment returns the environment the server is running in (e.g. dev, staging, prod...)
 func GetEnvironment() string {
 	if env := os.Getenv("ENVIRONMENT"); env != "" {