Browse Source

Merge pull request #3750 from gravitl/v1.2.0-listen-port-fix

Add debug logs
Abhishek Kondur 2 weeks ago
parent
commit
a47eb22a09
3 changed files with 104 additions and 11 deletions
  1. 9 1
      controllers/hosts.go
  2. 87 7
      logic/hosts.go
  3. 8 3
      utils/utils.go

+ 9 - 1
controllers/hosts.go

@@ -232,7 +232,15 @@ func pull(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
-	_ = logic.CheckHostPorts(host)
+
+	portChanged := logic.CheckHostPorts(host)
+	if portChanged {
+		// Save the port change to database immediately to prevent conflicts
+		if err := logic.UpsertHost(host); err != nil {
+			slog.Error("failed to save host port change", "host", host.Name, "error", err)
+		}
+	}
+
 	response := models.HostPull{
 		Host:              *host,
 		Nodes:             logic.GetHostNodes(host),

+ 87 - 7
logic/hosts.go

@@ -24,6 +24,7 @@ import (
 var (
 	hostCacheMutex = &sync.RWMutex{}
 	hostsCacheMap  = make(map[string]models.Host)
+	hostPortMutex  = &sync.Mutex{}
 )
 
 var (
@@ -617,20 +618,43 @@ func GetRelatedHosts(hostID string) []models.Host {
 // with the same endpoint have different listen ports
 // in the case of 64535 hosts or more with same endpoint, ports will not be changed
 func CheckHostPorts(h *models.Host) (changed bool) {
-	portsInUse := make(map[int]bool, 0)
-	hosts, err := GetAllHosts()
-	if err != nil {
+	if h.IsStaticPort {
+		return false
+	}
+	if h.EndpointIP == nil {
 		return
 	}
+
+	// Get the current host from database to check if it already has a valid port assigned
+	// This check happens before the mutex to avoid unnecessary locking
+	currentHost, err := GetHost(h.ID.String())
+	if err == nil && currentHost.ListenPort > 0 {
+		// If the host already has a port in the database, use that instead of the incoming port
+		// This prevents the host from being reassigned when the client sends the old port
+		if currentHost.ListenPort != h.ListenPort {
+			h.ListenPort = currentHost.ListenPort
+		}
+	}
+
+	// Only acquire mutex when we need to check for port conflicts
+	// This reduces contention for the common case where ports are already valid
+	hostPortMutex.Lock()
+	defer hostPortMutex.Unlock()
+
 	originalPort := h.ListenPort
 	defer func() {
 		if originalPort != h.ListenPort {
 			changed = true
 		}
 	}()
-	if h.EndpointIP == nil {
+
+	hosts, err := GetAllHosts()
+	if err != nil {
 		return
 	}
+
+	// Build map of ports in use by other hosts with the same endpoint
+	portsInUse := make(map[int]bool)
 	for _, host := range hosts {
 		if host.ID.String() == h.ID.String() {
 			// skip self
@@ -642,19 +666,75 @@ func CheckHostPorts(h *models.Host) (changed bool) {
 		if !host.EndpointIP.Equal(h.EndpointIP) {
 			continue
 		}
-		portsInUse[host.ListenPort] = true
+		if host.ListenPort > 0 {
+			portsInUse[host.ListenPort] = true
+		}
 	}
-	// iterate until port is not found or max iteration is reached
-	for i := 0; portsInUse[h.ListenPort] && i < maxPort-minPort+1; i++ {
+
+	// If current port is not in use, no change needed
+	if !portsInUse[h.ListenPort] {
+		return
+	}
+
+	// Find an available port
+	maxIterations := maxPort - minPort + 1
+	checkedPorts := make(map[int]bool)
+	initialPort := h.ListenPort
+
+	for i := 0; i < maxIterations; i++ {
+		// Special case: skip port 443 by jumping to 51821
 		if h.ListenPort == 443 {
 			h.ListenPort = 51821
 		} else {
 			h.ListenPort++
 		}
+
+		// Wrap around if we exceed maxPort
 		if h.ListenPort > maxPort {
 			h.ListenPort = minPort
 		}
+
+		// Avoid infinite loop - if we've checked this port before, we've cycled through all
+		if checkedPorts[h.ListenPort] {
+			// All ports are in use, keep original port
+			h.ListenPort = originalPort
+			break
+		}
+		checkedPorts[h.ListenPort] = true
+
+		// Re-read hosts to get the latest state (in case another host just changed its port)
+		// This is important to avoid conflicts when multiple hosts are being processed
+		latestHosts, err := GetAllHosts()
+		if err == nil {
+			// Update portsInUse with latest state
+			for _, host := range latestHosts {
+				if host.ID.String() == h.ID.String() {
+					continue
+				}
+				if host.EndpointIP == nil {
+					continue
+				}
+				if !host.EndpointIP.Equal(h.EndpointIP) {
+					continue
+				}
+				if host.ListenPort > 0 {
+					portsInUse[host.ListenPort] = true
+				}
+			}
+		}
+
+		// If this port is not in use, we found an available port
+		if !portsInUse[h.ListenPort] {
+			break
+		}
+
+		// If we've wrapped back to the initial port, all ports are in use
+		if h.ListenPort == initialPort && i > 0 {
+			h.ListenPort = originalPort
+			break
+		}
 	}
+
 	return
 }
 

+ 8 - 3
utils/utils.go

@@ -57,13 +57,18 @@ func TraceCaller() {
 		slog.Debug("Unable to get caller information")
 		return
 	}
-
+	tracePc, _, _, ok := runtime.Caller(1)
+	if !ok {
+		slog.Debug("Unable to get caller information")
+		return
+	}
+	traceFuncName := runtime.FuncForPC(tracePc).Name()
 	// Get function name from the program counter (pc)
 	funcName := runtime.FuncForPC(pc).Name()
 
 	// Print trace details
-	slog.Debug("Called from function: %s\n", "func", funcName)
-	slog.Debug("File: %s, Line: %d\n", "file", file, "line", line)
+	slog.Debug("## TRACE -> Called from function: ", "tracing-func-name", traceFuncName, "caller-func-name", funcName)
+	slog.Debug("## TRACE -> Caller File Info", "file", file, "line-no", line)
 }
 
 // NoEmptyStringToCsv takes a bunch of strings, filters out empty ones and returns a csv version of the string