Browse Source

Use a list for relay IPs instead of a map (#1423)

* Use a list for relay IPs instead of a map

* linter
brad-defined 3 weeks ago
parent
commit
b158eb0c4c
5 changed files with 46 additions and 11 deletions
  1. 2 2
      control_test.go
  2. 1 1
      handshake_ix.go
  3. 1 1
      handshake_manager.go
  4. 13 7
      hostmap.go
  5. 29 0
      hostmap_test.go

+ 2 - 2
control_test.go

@@ -53,7 +53,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		localIndexId:  201,
 		vpnAddrs:      []netip.Addr{vpnIp},
 		relayState: RelayState{
-			relays:         map[netip.Addr]struct{}{},
+			relays:         nil,
 			relayForByAddr: map[netip.Addr]*Relay{},
 			relayForByIdx:  map[uint32]*Relay{},
 		},
@@ -72,7 +72,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		localIndexId:  201,
 		vpnAddrs:      []netip.Addr{vpnIp2},
 		relayState: RelayState{
-			relays:         map[netip.Addr]struct{}{},
+			relays:         nil,
 			relayForByAddr: map[netip.Addr]*Relay{},
 			relayForByIdx:  map[uint32]*Relay{},
 		},

+ 1 - 1
handshake_ix.go

@@ -249,7 +249,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		lastHandshakeTime: hs.Details.Time,
 		relayState: RelayState{
-			relays:         map[netip.Addr]struct{}{},
+			relays:         nil,
 			relayForByAddr: map[netip.Addr]*Relay{},
 			relayForByIdx:  map[uint32]*Relay{},
 		},

+ 1 - 1
handshake_manager.go

@@ -450,7 +450,7 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han
 		vpnAddrs:        []netip.Addr{vpnAddr},
 		HandshakePacket: make(map[uint8][]byte, 0),
 		relayState: RelayState{
-			relays:         map[netip.Addr]struct{}{},
+			relays:         nil,
 			relayForByAddr: map[netip.Addr]*Relay{},
 			relayForByIdx:  map[uint32]*Relay{},
 		},

+ 13 - 7
hostmap.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"net"
 	"net/netip"
+	"slices"
 	"sync"
 	"sync/atomic"
 	"time"
@@ -68,7 +69,7 @@ type HostMap struct {
 type RelayState struct {
 	sync.RWMutex
 
-	relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
+	relays []netip.Addr // Ordered set of VpnAddrs of Hosts to use as relays to access this peer
 	// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
 	// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
 	// the RelayState Lock held)
@@ -79,7 +80,12 @@ type RelayState struct {
 func (rs *RelayState) DeleteRelay(ip netip.Addr) {
 	rs.Lock()
 	defer rs.Unlock()
-	delete(rs.relays, ip)
+	for idx, val := range rs.relays {
+		if val == ip {
+			rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...)
+			return
+		}
+	}
 }
 
 func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
@@ -124,16 +130,16 @@ func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
 func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
 	rs.Lock()
 	defer rs.Unlock()
-	rs.relays[ip] = struct{}{}
+	if !slices.Contains(rs.relays, ip) {
+		rs.relays = append(rs.relays, ip)
+	}
 }
 
 func (rs *RelayState) CopyRelayIps() []netip.Addr {
+	ret := make([]netip.Addr, len(rs.relays))
 	rs.RLock()
 	defer rs.RUnlock()
-	ret := make([]netip.Addr, 0, len(rs.relays))
-	for ip := range rs.relays {
-		ret = append(ret, ip)
-	}
+	copy(ret, rs.relays)
 	return ret
 }
 

+ 29 - 0
hostmap_test.go

@@ -7,6 +7,7 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestHostMap_MakePrimary(t *testing.T) {
@@ -215,3 +216,31 @@ func TestHostMap_reload(t *testing.T) {
 	c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
 	assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
 }
+
+func TestHostMap_RelayState(t *testing.T) {
+	h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
+	a1 := netip.MustParseAddr("::1")
+	a2 := netip.MustParseAddr("2001::1")
+
+	h1.relayState.InsertRelayTo(a1)
+	assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
+	h1.relayState.InsertRelayTo(a2)
+	assert.Equal(t, []netip.Addr{a1, a2}, h1.relayState.relays)
+	// Ensure that the first relay added is the first one returned in the copy
+	currentRelays := h1.relayState.CopyRelayIps()
+	require.Len(t, currentRelays, 2)
+	assert.Equal(t, a1, currentRelays[0])
+
+	// Deleting the last one in the list works ok
+	h1.relayState.DeleteRelay(a2)
+	assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
+
+	// Deleting an element not in the list works ok
+	h1.relayState.DeleteRelay(a2)
+	assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays)
+
+	// Deleting the only element in the list works ok
+	h1.relayState.DeleteRelay(a1)
+	assert.Equal(t, []netip.Addr{}, h1.relayState.relays)
+
+}