Browse Source

update EncReader and EncWriter interface function args to have concrete types (#844)

* Update LightHouseHandlerFunc to remove EncWriter param.
* Move EncWriter to interface
* EncReader, too
brad-defined 2 năm trước cách đây
mục cha
commit
9b03053191
14 tập tin đã thay đổi với 69 bổ sung54 xóa
  1. 1 1
      handshake.go
  2. 11 14
      handshake_ix.go
  3. 3 3
      handshake_manager.go
  4. 1 1
      handshake_manager_test.go
  5. 2 5
      inside.go
  6. 14 1
      interface.go
  7. 14 8
      lighthouse.go
  8. 1 1
      lighthouse_test.go
  9. 18 2
      outside.go
  10. 0 1
      udp/conn.go
  11. 1 14
      udp/temp.go
  12. 1 1
      udp/udp_generic.go
  13. 1 1
      udp/udp_linux.go
  14. 1 1
      udp/udp_tester.go

+ 1 - 1
handshake.go

@@ -5,7 +5,7 @@ import (
 	"github.com/slackhq/nebula/udp"
 )
 
-func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H, hostinfo *HostInfo) {
+func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) {
 	// First remote allow list check before we know the vpnIp
 	if addr != nil {
 		if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {

+ 11 - 14
handshake_ix.go

@@ -68,7 +68,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
 	hostinfo.handshakeStart = time.Now()
 }
 
-func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []byte, h *header.H) {
+func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
 	ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
 	// Mark packet 1 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 1)
@@ -240,14 +240,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 				}
 				return
 			} else {
-				via2 := via.(*ViaSender)
-				if via2 == nil {
+				if via == nil {
 					f.l.Error("Handshake send failed: both addr and via are nil.")
 					return
 				}
-				hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
-				f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-				f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via2.relayHI.vpnIp).
+				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+				f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
+				f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					Info("Handshake message sent")
 				return
@@ -315,14 +314,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 				Info("Handshake message sent")
 		}
 	} else {
-		via2 := via.(*ViaSender)
-		if via2 == nil {
+		if via == nil {
 			f.l.Error("Handshake send failed: both addr and via are nil.")
 			return
 		}
-		hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
-		f.SendVia(via2.relayHI, via2.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-		f.l.WithField("vpnIp", vpnIp).WithField("relay", via2.relayHI.vpnIp).
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+		f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
+		f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -338,7 +336,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 	return
 }
 
-func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *HostInfo, packet []byte, h *header.H) bool {
+func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool {
 	if hostinfo == nil {
 		// Nothing here to tear down, got a bogus stage 2 packet
 		return true
@@ -482,8 +480,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *
 	if addr != nil {
 		hostinfo.SetRemote(addr)
 	} else {
-		via2 := via.(*ViaSender)
-		hostinfo.relayState.InsertRelayTo(via2.relayHI.vpnIp)
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
 	}
 
 	// Build up the radix for the firewall if we have subnets in the cert

+ 3 - 3
handshake_manager.go

@@ -73,7 +73,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
 	}
 }
 
-func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
+func (c *HandshakeManager) Run(ctx context.Context, f EncWriter) {
 	clockSource := time.NewTicker(c.config.tryInterval)
 	defer clockSource.Stop()
 
@@ -89,7 +89,7 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
 	}
 }
 
-func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) {
+func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWriter) {
 	c.OutboundHandshakeTimer.Advance(now)
 	for {
 		vpnIp, has := c.OutboundHandshakeTimer.Purge()
@@ -100,7 +100,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E
 	}
 }
 
-func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
+func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) {
 	hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 		return

+ 1 - 1
handshake_manager_test.go

@@ -84,7 +84,7 @@ func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
 	return
 }
 
-func (mw *mockEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) {
+func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
 	return
 }
 

+ 2 - 5
inside.go

@@ -248,16 +248,13 @@ func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *C
 // nb is a buffer used to store the nonce value, re-used for performance reasons.
 // out is a buffer used to store the result of the Encrypt operation
 // q indicates which writer to use to send the packet.
-func (f *Interface) SendVia(viaIfc interface{},
-	relayIfc interface{},
+func (f *Interface) SendVia(via *HostInfo,
+	relay *Relay,
 	ad,
 	nb,
 	out []byte,
 	nocopy bool,
 ) {
-	via := viaIfc.(*HostInfo)
-	relay := relayIfc.(*Relay)
-
 	if noiseutil.EncryptLockNeeded {
 		// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
 		via.ConnectionState.writeLock.Lock()

+ 14 - 1
interface.go

@@ -16,6 +16,7 @@ import (
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
+	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/udp"
@@ -89,6 +90,18 @@ type Interface struct {
 	l *logrus.Logger
 }
 
+type EncWriter interface {
+	SendVia(via *HostInfo,
+		relay *Relay,
+		ad,
+		nb,
+		out []byte,
+		nocopy bool,
+	)
+	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
+	Handshake(vpnIp iputil.VpnIp)
+}
+
 type sendRecvErrorConfig uint8
 
 const (
@@ -238,7 +251,7 @@ func (f *Interface) listenOut(i int) {
 
 	lhh := f.lightHouse.NewRequestHandler()
 	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
-	li.ListenOut(f.readOutsidePackets, lhh.HandleRequest, conntrackCache, i)
+	li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
 }
 
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {

+ 14 - 8
lighthouse.go

@@ -65,7 +65,7 @@ type LightHouse struct {
 	interval        atomic.Int64
 	updateCancel    context.CancelFunc
 	updateParentCtx context.Context
-	updateUdp       udp.EncWriter
+	updateUdp       EncWriter
 	nebulaPort      uint32 // 32 bits because protobuf does not have a uint16
 
 	advertiseAddrs atomic.Pointer[[]netIpAndPort]
@@ -382,7 +382,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 	return nil
 }
 
-func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
+func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList {
 	if !lh.IsLighthouseIP(ip) {
 		lh.QueryServer(ip, f)
 	}
@@ -396,7 +396,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f udp.EncWriter) *RemoteList {
 }
 
 // This is asynchronous so no reply should be expected
-func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f udp.EncWriter) {
+func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) {
 	if lh.amLighthouse {
 		return
 	}
@@ -629,7 +629,7 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
 	return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
 }
 
-func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
+func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f EncWriter) {
 	lh.updateParentCtx = ctx
 	lh.updateUdp = f
 
@@ -655,7 +655,7 @@ func (lh *LightHouse) LhUpdateWorker(ctx context.Context, f udp.EncWriter) {
 	}
 }
 
-func (lh *LightHouse) SendUpdate(f udp.EncWriter) {
+func (lh *LightHouse) SendUpdate(f EncWriter) {
 	var v4 []*Ip4AndPort
 	var v6 []*Ip6AndPort
 
@@ -760,7 +760,13 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
 	return lhh.meta
 }
 
-func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w udp.EncWriter) {
+func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
+	return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) {
+		lhh.HandleRequest(rAddr, vpnIp, p, f)
+	}
+}
+
+func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) {
 	n := lhh.resetMeta()
 	err := n.Unmarshal(p)
 	if err != nil {
@@ -795,7 +801,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp,
 	}
 }
 
-func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w udp.EncWriter) {
+func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) {
 	// Exit if we don't answer queries
 	if !lhh.lh.amLighthouse {
 		if lhh.l.Level >= logrus.DebugLevel {
@@ -928,7 +934,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	am.Unlock()
 }
 
-func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w udp.EncWriter) {
+func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 		return
 	}

+ 1 - 1
lighthouse_test.go

@@ -372,7 +372,7 @@ type testEncWriter struct {
 	metaFilter *NebulaMeta_MessageType
 }
 
-func (tw *testEncWriter) SendVia(via interface{}, relay interface{}, ad, nb, out []byte, nocopy bool) {
+func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
 }
 func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
 }

+ 18 - 2
outside.go

@@ -21,7 +21,23 @@ const (
 	minFwPacketLen = 4
 )
 
-func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
+func readOutsidePackets(f *Interface) udp.EncReader {
+	return func(
+		addr *udp.Addr,
+		out []byte,
+		packet []byte,
+		header *header.H,
+		fwPacket *firewall.Packet,
+		lhh udp.LightHouseHandlerFunc,
+		nb []byte,
+		q int,
+		localCache firewall.ConntrackCache,
+	) {
+		f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
+	}
+}
+
+func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
 	err := h.Parse(packet)
 	if err != nil {
 		// TODO: best if we return this and let caller log
@@ -149,7 +165,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by
 			return
 		}
 
-		lhf(addr, hostinfo.vpnIp, d, f)
+		lhf(addr, hostinfo.vpnIp, d)
 
 		// Fallthrough to the bottom to record incoming traffic
 

+ 0 - 1
udp/conn.go

@@ -9,7 +9,6 @@ const MTU = 9001
 
 type EncReader func(
 	addr *Addr,
-	via interface{},
 	out []byte,
 	packet []byte,
 	header *header.H,

+ 1 - 14
udp/temp.go

@@ -1,22 +1,9 @@
 package udp
 
 import (
-	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 )
 
-type EncWriter interface {
-	SendVia(via interface{},
-		relay interface{},
-		ad,
-		nb,
-		out []byte,
-		nocopy bool,
-	)
-	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
-	Handshake(vpnIp iputil.VpnIp)
-}
-
 //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
 
-type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter)
+type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte)

+ 1 - 1
udp/udp_generic.go

@@ -86,6 +86,6 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
 
 		udpAddr.IP = rua.IP
 		udpAddr.Port = uint16(rua.Port)
-		r(udpAddr, nil, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+		r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
 	}
 }

+ 1 - 1
udp/udp_linux.go

@@ -145,7 +145,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
 		for i := 0; i < n; i++ {
 			udpAddr.IP = names[i][8:24]
 			udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
-			r(udpAddr, nil, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+			r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
 		}
 	}
 }

+ 1 - 1
udp/udp_tester.go

@@ -122,7 +122,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
 		}
 		ua.Port = p.FromPort
 		copy(ua.IP, p.FromIp.To16())
-		r(ua, nil, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
+		r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
 	}
 }