Browse Source

Tighten up the inside handlers with a bit of DRY

Dave Russell 5 years ago
parent
commit
55d72ac46f
2 changed files with 26 additions and 36 deletions
  1. 19 29
      handler.go
  2. 7 7
      interface.go

+ 19 - 29
handler.go

@@ -1,22 +1,30 @@
 package nebula
 
-func (f *Interface) handleMessagePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
-	if !f.handleEncrypted(ci, addr, header) {
-		return
-	}
+func (f *Interface) encrypted(h InsideHandler) InsideHandler {
+	return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
+		if !f.handleEncrypted(ci, addr, header) {
+			return
+		}
 
-	f.decryptToTun(hostInfo, header.MessageCounter, out, packet, fwPacket, nb)
+		h(hostInfo, ci, addr, header, out, packet, fwPacket, nb)
 
-	f.handleHostRoaming(hostInfo, addr)
-	f.connectionManager.In(hostInfo.hostId)
+		f.handleHostRoaming(hostInfo, addr)
+		f.connectionManager.In(hostInfo.hostId)
+	}
 }
 
-func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
-	f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-	if !f.handleEncrypted(ci, addr, header) {
-		return
+func (f *Interface) rxMetrics(h InsideHandler) InsideHandler {
+	return func(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
+		f.messageMetrics.Rx(header.Type, header.Subtype, 1)
+		h(hostInfo, ci, addr, header, out, packet, fwPacket, nb)
 	}
+}
+
+func (f *Interface) handleMessagePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
+	f.decryptToTun(hostInfo, header.MessageCounter, out, packet, fwPacket, nb)
+}
 
+func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
 	d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb)
 	if err != nil {
 		hostInfo.logger().WithError(err).WithField("udpAddr", addr).
@@ -29,17 +37,9 @@ func (f *Interface) handleLighthousePacket(hostInfo *HostInfo, ci *ConnectionSta
 	}
 
 	f.lightHouse.HandleRequest(addr, hostInfo.hostId, d, hostInfo.GetCert(), f)
-
-	f.handleHostRoaming(hostInfo, addr)
-	f.connectionManager.In(hostInfo.hostId)
 }
 
 func (f *Interface) handleTestPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
-	f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-	if !f.handleEncrypted(ci, addr, header) {
-		return
-	}
-
 	d, err := f.decrypt(hostInfo, header.MessageCounter, out, packet, header, nb)
 	if err != nil {
 		hostInfo.logger().WithError(err).WithField("udpAddr", addr).
@@ -57,28 +57,18 @@ func (f *Interface) handleTestPacket(hostInfo *HostInfo, ci *ConnectionState, ad
 		f.handleHostRoaming(hostInfo, addr)
 		f.send(test, testReply, ci, hostInfo, hostInfo.remote, d, nb, out)
 	}
-
-	f.handleHostRoaming(hostInfo, addr)
-	f.connectionManager.In(hostInfo.hostId)
 }
 
 func (f *Interface) handleHandshakePacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
-	f.messageMetrics.Rx(header.Type, header.Subtype, 1)
 	HandleIncomingHandshake(f, addr, packet, header, hostInfo)
 }
 
 func (f *Interface) handleRecvErrorPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
-	f.messageMetrics.Rx(header.Type, header.Subtype, 1)
 	// TODO: Remove this with recv_error deprecation
 	f.handleRecvError(addr, header)
 }
 
 func (f *Interface) handleCloseTunnelPacket(hostInfo *HostInfo, ci *ConnectionState, addr *udpAddr, header *Header, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte) {
-	f.messageMetrics.Rx(header.Type, header.Subtype, 1)
-	if !f.handleEncrypted(ci, addr, header) {
-		return
-	}
-
 	hostInfo.logger().WithField("udpAddr", addr).
 		Info("Close tunnel received, tearing down.")
 

+ 7 - 7
interface.go

@@ -111,23 +111,23 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
 	ifce.handlers = map[uint8]map[NebulaMessageType]map[NebulaMessageSubType]InsideHandler{
 		Version: {
 			handshake: {
-				handshakeIXPSK0: ifce.handleHandshakePacket,
+				handshakeIXPSK0: ifce.rxMetrics(ifce.handleHandshakePacket),
 			},
 			message: {
-				subTypeNone: ifce.handleMessagePacket,
+				subTypeNone: ifce.encrypted(ifce.handleMessagePacket),
 			},
 			recvError: {
-				subTypeNone: ifce.handleRecvErrorPacket,
+				subTypeNone: ifce.rxMetrics(ifce.handleRecvErrorPacket),
 			},
 			lightHouse: {
-				subTypeNone: ifce.handleLighthousePacket,
+				subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleLighthousePacket)),
 			},
 			test: {
-				testRequest: ifce.handleTestPacket,
-				testReply:   ifce.handleTestPacket,
+				testRequest: ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)),
+				testReply:   ifce.rxMetrics(ifce.encrypted(ifce.handleTestPacket)),
 			},
 			closeTunnel: {
-				subTypeNone: ifce.handleCloseTunnelPacket,
+				subTypeNone: ifce.rxMetrics(ifce.encrypted(ifce.handleCloseTunnelPacket)),
 			},
 		},
 	}