Parcourir la source

Store relay states in a slice for consistent ordering (#1422)

brad-defined il y a 1 mois
Parent
commit
a1498ca8f8
5 fichiers modifiés avec 48 ajouts et 13 suppressions
  1. 2 2
      control_test.go
  2. 1 1
      handshake_ix.go
  3. 1 1
      handshake_manager.go
  4. 15 9
      hostmap.go
  5. 29 0
      hostmap_test.go

+ 2 - 2
control_test.go

@@ -66,7 +66,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		localIndexId:  201,
 		vpnIp:         vpnIp,
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
+			relays:        nil,
 			relayForByIp:  map[netip.Addr]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},
@@ -85,7 +85,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		localIndexId:  201,
 		vpnIp:         vpnIp2,
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
+			relays:        nil,
 			relayForByIp:  map[netip.Addr]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},

+ 1 - 1
handshake_ix.go

@@ -151,7 +151,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		lastHandshakeTime: hs.Details.Time,
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
+			relays:        nil,
 			relayForByIp:  map[netip.Addr]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},

+ 1 - 1
handshake_manager.go

@@ -403,7 +403,7 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 		vpnIp:           vpnIp,
 		HandshakePacket: make(map[uint8][]byte, 0),
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
+			relays:        nil,
 			relayForByIp:  map[netip.Addr]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},

+ 15 - 9
hostmap.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"net"
 	"net/netip"
+	"slices"
 	"sync"
 	"sync/atomic"
 	"time"
@@ -69,15 +70,20 @@ type HostMap struct {
 type RelayState struct {
 	sync.RWMutex
 
-	relays        map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
-	relayForByIp  map[netip.Addr]*Relay   // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
-	relayForByIdx map[uint32]*Relay       // Maps a local index to some Relay info
+	relays        []netip.Addr          // Ordered set of VpnIp's of Hosts to use as relays to access this peer
+	relayForByIp  map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
+	relayForByIdx map[uint32]*Relay     // Maps a local index to some Relay info
 }
 
 func (rs *RelayState) DeleteRelay(ip netip.Addr) {
 	rs.Lock()
 	defer rs.Unlock()
-	delete(rs.relays, ip)
+	for idx, val := range rs.relays {
+		if val == ip {
+			rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...)
+			return
+		}
+	}
 }
 
 func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
@@ -122,16 +128,16 @@ func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
 func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
 	rs.Lock()
 	defer rs.Unlock()
-	rs.relays[ip] = struct{}{}
+	if !slices.Contains(rs.relays, ip) {
+		rs.relays = append(rs.relays, ip)
+	}
 }
 
 func (rs *RelayState) CopyRelayIps() []netip.Addr {
+	ret := make([]netip.Addr, len(rs.relays))
 	rs.RLock()
 	defer rs.RUnlock()
-	ret := make([]netip.Addr, 0, len(rs.relays))
-	for ip := range rs.relays {
-		ret = append(ret, ip)
-	}
+	copy(ret, rs.relays)
 	return ret
 }
 

+ 29 - 0
hostmap_test.go

@@ -7,6 +7,7 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestHostMap_MakePrimary(t *testing.T) {
@@ -225,3 +226,31 @@ func TestHostMap_reload(t *testing.T) {
 	c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
 	assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
 }
+
+func TestHostMap_RelayState(t *testing.T) {
+	h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
+	a1 := netip.MustParseAddr("::1")
+	a2 := netip.MustParseAddr("2001::1")
+
+	h1.relayState.InsertRelayTo(a1)
+	assert.Equal(t, h1.relayState.relays, []netip.Addr{a1})
+	h1.relayState.InsertRelayTo(a2)
+	assert.Equal(t, h1.relayState.relays, []netip.Addr{a1, a2})
+	// Ensure that the first relay added is the first one returned in the copy
+	currentRelays := h1.relayState.CopyRelayIps()
+	require.Len(t, currentRelays, 2)
+	assert.Equal(t, currentRelays[0], a1)
+
+	// Deleting the last one in the list works ok
+	h1.relayState.DeleteRelay(a2)
+	assert.Equal(t, h1.relayState.relays, []netip.Addr{a1})
+
+	// Deleting an element not in the list works ok
+	h1.relayState.DeleteRelay(a2)
+	assert.Equal(t, h1.relayState.relays, []netip.Addr{a1})
+
+	// Deleting the only element in the list works ok
+	h1.relayState.DeleteRelay(a1)
+	assert.Equal(t, h1.relayState.relays, []netip.Addr{})
+
+}