Browse Source

optimize usage of bart (#1395)

Use `bart.Lite` and `.Contains` as suggested by the bart maintainer:

- https://github.com/gaissmai/bart/commit/9455952eedcf59a6e755fc28ed16e906fa4f3066#commitcomment-155362580
Wade Simmons 3 months ago
parent
commit
b8ea55eb90
13 changed files with 61 additions and 80 deletions
  1. 1 2
      control.go
  2. 3 3
      dns_server.go
  3. 11 14
      firewall.go
  4. 3 4
      handshake_ix.go
  5. 1 2
      handshake_manager.go
  6. 4 4
      hostmap.go
  7. 3 6
      inside.go
  8. 5 5
      interface.go
  9. 8 15
      lighthouse.go
  10. 10 10
      lighthouse_test.go
  11. 1 2
      outside.go
  12. 9 9
      pki.go
  13. 2 4
      relay_manager.go

+ 1 - 2
control.go

@@ -131,8 +131,7 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 
 // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
 func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
-	_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
-	if found {
+	if c.f.myVpnAddrsTable.Contains(vpnIp) {
 		// Only returning the default certificate since its impossible
 		// for any other host but ourselves to have more than 1
 		return c.f.pki.getCertState().GetDefaultCertificate().Copy()

+ 3 - 3
dns_server.go

@@ -26,7 +26,7 @@ type dnsRecords struct {
 	dnsMap4         map[string]netip.Addr
 	dnsMap6         map[string]netip.Addr
 	hostMap         *HostMap
-	myVpnAddrsTable *bart.Table[struct{}]
+	myVpnAddrsTable *bart.Lite
 }
 
 func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
@@ -112,8 +112,8 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
 		return true
 	}
 
-	_, found := d.myVpnAddrsTable.Lookup(b)
-	return found //if we found it in this table, it's good
+	//if we found it in this table, it's good
+	return d.myVpnAddrsTable.Contains(b)
 }
 
 func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {

+ 11 - 14
firewall.go

@@ -53,7 +53,7 @@ type Firewall struct {
 
 	// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
 	// The vpn addresses are a full bit match while the unsafe networks only match the prefix
-	routableNetworks *bart.Table[struct{}]
+	routableNetworks *bart.Lite
 
 	// assignedNetworks is a list of vpn networks assigned to us in the certificate.
 	assignedNetworks  []netip.Prefix
@@ -125,7 +125,7 @@ type firewallPort map[int32]*FirewallCA
 
 type firewallLocalCIDR struct {
 	Any       bool
-	LocalCIDR *bart.Table[struct{}]
+	LocalCIDR *bart.Lite
 }
 
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
@@ -148,17 +148,17 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 		tmax = defaultTimeout
 	}
 
-	routableNetworks := new(bart.Table[struct{}])
+	routableNetworks := new(bart.Lite)
 	var assignedNetworks []netip.Prefix
 	for _, network := range c.Networks() {
 		nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
-		routableNetworks.Insert(nprefix, struct{}{})
+		routableNetworks.Insert(nprefix)
 		assignedNetworks = append(assignedNetworks, network)
 	}
 
 	hasUnsafeNetworks := false
 	for _, n := range c.UnsafeNetworks() {
-		routableNetworks.Insert(n, struct{}{})
+		routableNetworks.Insert(n)
 		hasUnsafeNetworks = true
 	}
 
@@ -431,8 +431,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 
 	// Make sure remote address matches nebula certificate
 	if h.networks != nil {
-		_, ok := h.networks.Lookup(fp.RemoteAddr)
-		if !ok {
+		if !h.networks.Contains(fp.RemoteAddr) {
 			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
@@ -445,8 +444,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 	}
 
 	// Make sure we are supposed to be handling this local ip address
-	_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
-	if !ok {
+	if !f.routableNetworks.Contains(fp.LocalAddr) {
 		f.metrics(incoming).droppedLocalAddr.Inc(1)
 		return ErrInvalidLocalIP
 	}
@@ -752,7 +750,7 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
 func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
 	flc := func() *firewallLocalCIDR {
 		return &firewallLocalCIDR{
-			LocalCIDR: new(bart.Table[struct{}]),
+			LocalCIDR: new(bart.Lite),
 		}
 	}
 
@@ -879,7 +877,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
 		}
 
 		for _, network := range f.assignedNetworks {
-			flc.LocalCIDR.Insert(network, struct{}{})
+			flc.LocalCIDR.Insert(network)
 		}
 		return nil
 
@@ -888,7 +886,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
 		return nil
 	}
 
-	flc.LocalCIDR.Insert(localIp, struct{}{})
+	flc.LocalCIDR.Insert(localIp)
 	return nil
 }
 
@@ -901,8 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
 		return true
 	}
 
-	_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
-	return ok
+	return flc.LocalCIDR.Contains(p.LocalAddr)
 }
 
 type rule struct {

+ 3 - 4
handshake_ix.go

@@ -192,8 +192,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	for _, network := range remoteCert.Certificate.Networks() {
 		vpnAddr := network.Addr()
-		_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
-		if found {
+		if f.myVpnAddrsTable.Contains(vpnAddr) {
 			f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certVersion", certVersion).
@@ -204,7 +203,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		}
 
 		// vpnAddrs outside our vpn networks are of no use to us, filter them out
-		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+		if !f.myVpnNetworksTable.Contains(vpnAddr) {
 			continue
 		}
 
@@ -579,7 +578,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	for _, network := range vpnNetworks {
 		// vpnAddrs outside our vpn networks are of no use to us, filter them out
 		vpnAddr := network.Addr()
-		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+		if !f.myVpnNetworksTable.Contains(vpnAddr) {
 			continue
 		}
 

+ 1 - 2
handshake_manager.go

@@ -274,8 +274,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 			}
 
 			// Don't relay through the host I'm trying to connect to
-			_, found := hm.f.myVpnAddrsTable.Lookup(relay)
-			if found {
+			if hm.f.myVpnAddrsTable.Contains(relay) {
 				continue
 			}
 

+ 4 - 4
hostmap.go

@@ -223,7 +223,7 @@ type HostInfo struct {
 	recvError atomic.Uint32
 
 	// networks are both all vpn and unsafe networks assigned to this host
-	networks   *bart.Table[struct{}]
+	networks   *bart.Lite
 	relayState RelayState
 
 	// HandshakePacket records the packets used to create this hostinfo
@@ -732,13 +732,13 @@ func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
 		return
 	}
 
-	i.networks = new(bart.Table[struct{}])
+	i.networks = new(bart.Lite)
 	for _, network := range networks {
-		i.networks.Insert(network, struct{}{})
+		i.networks.Insert(network)
 	}
 
 	for _, network := range unsafeNetworks {
-		i.networks.Insert(network, struct{}{})
+		i.networks.Insert(network)
 	}
 }
 

+ 3 - 6
inside.go

@@ -22,14 +22,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 
 	// Ignore local broadcast packets
 	if f.dropLocalBroadcast {
-		_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
-		if found {
+		if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
 			return
 		}
 	}
 
-	_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
-	if found {
+	if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
 		// Immediately forward packets from self to self.
 		// This should only happen on Darwin-based and FreeBSD hosts, which
 		// routes packets from the Nebula addr to the Nebula addr through the Nebula
@@ -130,8 +128,7 @@ func (f *Interface) Handshake(vpnAddr netip.Addr) {
 // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
 func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
-	_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
-	if found {
+	if f.myVpnNetworksTable.Contains(vpnAddr) {
 		return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
 	}
 

+ 5 - 5
interface.go

@@ -61,11 +61,11 @@ type Interface struct {
 	serveDns              bool
 	createTime            time.Time
 	lightHouse            *LightHouse
-	myBroadcastAddrsTable *bart.Table[struct{}]
-	myVpnAddrs            []netip.Addr          // A list of addresses assigned to us via our certificate
-	myVpnAddrsTable       *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
-	myVpnNetworks         []netip.Prefix        // A list of networks assigned to us via our certificate
-	myVpnNetworksTable    *bart.Table[struct{}] // A table of networks assigned to us via our certificate
+	myBroadcastAddrsTable *bart.Lite
+	myVpnAddrs            []netip.Addr // A list of addresses assigned to us via our certificate
+	myVpnAddrsTable       *bart.Lite
+	myVpnNetworks         []netip.Prefix // A list of networks assigned to us via our certificate
+	myVpnNetworksTable    *bart.Lite
 	dropLocalBroadcast    bool
 	dropMulticast         bool
 	routines              int

+ 8 - 15
lighthouse.go

@@ -32,7 +32,7 @@ type LightHouse struct {
 	amLighthouse bool
 
 	myVpnNetworks      []netip.Prefix
-	myVpnNetworksTable *bart.Table[struct{}]
+	myVpnNetworksTable *bart.Lite
 	punchConn          udp.Conn
 	punchy             *Punchy
 
@@ -201,8 +201,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 
 			//TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used
 			addr := addrs[0].Unmap()
-			_, found := lh.myVpnNetworksTable.Lookup(addr)
-			if found {
+			if lh.myVpnNetworksTable.Contains(addr) {
 				lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
 					Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
 				continue
@@ -359,8 +358,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{
 			return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
 		}
 
-		_, found := lh.myVpnNetworksTable.Lookup(addr)
-		if !found {
+		if !lh.myVpnNetworksTable.Contains(addr) {
 			return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
 		}
 		lhMap[addr] = struct{}{}
@@ -431,8 +429,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
 			return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
 		}
 
-		_, found := lh.myVpnNetworksTable.Lookup(vpnAddr)
-		if !found {
+		if !lh.myVpnNetworksTable.Contains(vpnAddr) {
 			return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
 		}
 
@@ -653,8 +650,7 @@ func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
 		return false
 	}
 
-	_, found := lh.myVpnNetworksTable.Lookup(to)
-	if found {
+	if lh.myVpnNetworksTable.Contains(to) {
 		return false
 	}
 
@@ -674,8 +670,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo
 		return false
 	}
 
-	_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
-	if found {
+	if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
 		return false
 	}
 
@@ -695,8 +690,7 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
 		return false
 	}
 
-	_, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr())
-	if found {
+	if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) {
 		return false
 	}
 
@@ -856,8 +850,7 @@ func (lh *LightHouse) SendUpdate() {
 
 	lal := lh.GetLocalAllowList()
 	for _, e := range localAddrs(lh.l, lal) {
-		_, found := lh.myVpnNetworksTable.Lookup(e)
-		if found {
+		if lh.myVpnNetworksTable.Contains(e) {
 			continue
 		}
 

+ 10 - 10
lighthouse_test.go

@@ -31,8 +31,8 @@ func TestOldIPv4Only(t *testing.T) {
 func Test_lhStaticMapping(t *testing.T) {
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
-	nt := new(bart.Table[struct{}])
-	nt.Insert(myVpnNet, struct{}{})
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
 	cs := &CertState{
 		myVpnNetworks:      []netip.Prefix{myVpnNet},
 		myVpnNetworksTable: nt,
@@ -56,8 +56,8 @@ func Test_lhStaticMapping(t *testing.T) {
 func TestReloadLighthouseInterval(t *testing.T) {
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
-	nt := new(bart.Table[struct{}])
-	nt.Insert(myVpnNet, struct{}{})
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
 	cs := &CertState{
 		myVpnNetworks:      []netip.Prefix{myVpnNet},
 		myVpnNetworksTable: nt,
@@ -91,8 +91,8 @@ func TestReloadLighthouseInterval(t *testing.T) {
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
-	nt := new(bart.Table[struct{}])
-	nt.Insert(myVpnNet, struct{}{})
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
 	cs := &CertState{
 		myVpnNetworks:      []netip.Prefix{myVpnNet},
 		myVpnNetworksTable: nt,
@@ -196,8 +196,8 @@ func TestLighthouse_Memory(t *testing.T) {
 	c.Settings["listen"] = map[string]any{"port": 4242}
 
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
-	nt := new(bart.Table[struct{}])
-	nt.Insert(myVpnNet, struct{}{})
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
 	cs := &CertState{
 		myVpnNetworks:      []netip.Prefix{myVpnNet},
 		myVpnNetworksTable: nt,
@@ -281,8 +281,8 @@ func TestLighthouse_reload(t *testing.T) {
 	c.Settings["listen"] = map[string]any{"port": 4242}
 
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
-	nt := new(bart.Table[struct{}])
-	nt.Insert(myVpnNet, struct{}{})
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
 	cs := &CertState{
 		myVpnNetworks:      []netip.Prefix{myVpnNet},
 		myVpnNetworksTable: nt,

+ 1 - 2
outside.go

@@ -31,8 +31,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	//l.Error("in packet ", header, packet[HeaderLen:])
 	if ip.IsValid() {
-		_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
-		if found {
+		if f.myVpnNetworksTable.Contains(ip.Addr()) {
 			if f.l.Level >= logrus.DebugLevel {
 				f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
 			}

+ 9 - 9
pki.go

@@ -39,10 +39,10 @@ type CertState struct {
 	cipher            string
 
 	myVpnNetworks            []netip.Prefix
-	myVpnNetworksTable       *bart.Table[struct{}]
+	myVpnNetworksTable       *bart.Lite
 	myVpnAddrs               []netip.Addr
-	myVpnAddrsTable          *bart.Table[struct{}]
-	myVpnBroadcastAddrsTable *bart.Table[struct{}]
+	myVpnAddrsTable          *bart.Lite
+	myVpnBroadcastAddrsTable *bart.Lite
 }
 
 func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
@@ -345,9 +345,9 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
 	cs := CertState{
 		privateKey:               privateKey,
 		pkcs11Backed:             pkcs11backed,
-		myVpnNetworksTable:       new(bart.Table[struct{}]),
-		myVpnAddrsTable:          new(bart.Table[struct{}]),
-		myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
+		myVpnNetworksTable:       new(bart.Lite),
+		myVpnAddrsTable:          new(bart.Lite),
+		myVpnBroadcastAddrsTable: new(bart.Lite),
 	}
 
 	if v1 != nil && v2 != nil {
@@ -415,16 +415,16 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
 
 	for _, network := range crt.Networks() {
 		cs.myVpnNetworks = append(cs.myVpnNetworks, network)
-		cs.myVpnNetworksTable.Insert(network, struct{}{})
+		cs.myVpnNetworksTable.Insert(network)
 
 		cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
-		cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
+		cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()))
 
 		if network.Addr().Is4() {
 			addr := network.Masked().Addr().As4()
 			mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
 			binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
-			cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
+			cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()))
 		}
 	}
 

+ 2 - 4
relay_manager.go

@@ -241,15 +241,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
 	logMsg.Info("handleCreateRelayRequest")
 	// Is the source of the relay me? This should never happen, but did happen due to
 	// an issue migrating relays over to newly re-handshaked host info objects.
-	_, found := f.myVpnAddrsTable.Lookup(from)
-	if found {
+	if f.myVpnAddrsTable.Contains(from) {
 		logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
 		return
 	}
 
 	// Is the target of the relay me?
-	_, found = f.myVpnAddrsTable.Lookup(target)
-	if found {
+	if f.myVpnAddrsTable.Contains(target) {
 		existingRelay, ok := h.relayState.QueryRelayForByIp(from)
 		if ok {
 			switch existingRelay.State {