Browse Source

firewall reject packets: cleanup error cases (#957)

Wade Simmons 1 year ago
parent
commit
fe16ea566d
4 changed files with 130 additions and 14 deletions
  1. 20 6
      inside.go
  2. 34 7
      iputil/packet.go
  3. 73 0
      iputil/packet_test.go
  4. 3 1
      outside.go

+ 20 - 6
inside.go

@@ -83,6 +83,10 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
 	}
 
 	out = iputil.CreateRejectPacket(packet, out)
+	if len(out) == 0 {
+		return
+	}
+
 	_, err := f.readers[q].Write(out)
 	if err != nil {
 		f.l.WithError(err).Error("Failed to write to tun")
@@ -94,12 +98,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 		return
 	}
 
-	// Use some out buffer space to build the packet before encryption
-	// Need 40 bytes for the reject packet (20 byte ipv4 header, 20 byte tcp rst packet)
-	// Leave 100 bytes for the encrypted packet (60 byte Nebula header, 40 byte reject packet)
-	out = out[:140]
-	outPacket := iputil.CreateRejectPacket(packet, out[100:])
-	f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, outPacket, nb, out, q)
+	out = iputil.CreateRejectPacket(packet, out)
+	if len(out) == 0 {
+		return
+	}
+
+	if len(out) > iputil.MaxRejectPacketSize {
+		if f.l.GetLevel() >= logrus.InfoLevel {
+			f.l.
+				WithField("packet", packet).
+				WithField("outPacket", out).
+				Info("rejectOutside: packet too big, not sending")
+		}
+		return
+	}
+
+	f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q)
 }
 
 func (f *Interface) Handshake(vpnIp iputil.VpnIp) {

+ 34 - 7
iputil/packet.go

@@ -6,8 +6,19 @@ import (
 	"golang.org/x/net/ipv4"
 )
 
+const (
+	// Need 96 bytes for the largest reject packet:
+	// - 20 byte ipv4 header
+	// - 8 byte icmpv4 header
+	// - 68 byte body (60 byte max orig ipv4 header + 8 byte orig icmpv4 header)
+	MaxRejectPacketSize = ipv4.HeaderLen + 8 + 60 + 8
+)
+
 func CreateRejectPacket(packet []byte, out []byte) []byte {
-	// TODO ipv4 only, need to fix when inside supports ipv6
+	if len(packet) < ipv4.HeaderLen || int(packet[0]>>4) != ipv4.Version {
+		return nil
+	}
+
 	switch packet[9] {
 	case 6: // tcp
 		return ipv4CreateRejectTCPPacket(packet, out)
@@ -19,20 +30,28 @@ func CreateRejectPacket(packet []byte, out []byte) []byte {
 func ipv4CreateRejectICMPPacket(packet []byte, out []byte) []byte {
 	ihl := int(packet[0]&0x0f) << 2
 
-	// ICMP reply includes header and first 8 bytes of the packet
+	if len(packet) < ihl {
+		// We need at least this many bytes for this to be a valid packet
+		return nil
+	}
+
+	// ICMP reply includes original header and first 8 bytes of the packet
 	packetLen := len(packet)
 	if packetLen > ihl+8 {
 		packetLen = ihl + 8
 	}
 
 	outLen := ipv4.HeaderLen + 8 + packetLen
+	if outLen > cap(out) {
+		return nil
+	}
 
-	out = out[:(outLen)]
+	out = out[:outLen]
 
 	ipHdr := out[0:ipv4.HeaderLen]
-	ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2)                        // version, ihl
-	ipHdr[1] = 0                                                              // DSCP, ECN
-	binary.BigEndian.PutUint16(ipHdr[2:], uint16(ipv4.HeaderLen+8+packetLen)) // Total Length
+	ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2)    // version, ihl
+	ipHdr[1] = 0                                          // DSCP, ECN
+	binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length
 
 	ipHdr[4] = 0  // id
 	ipHdr[5] = 0  //  .
@@ -76,7 +95,15 @@ func ipv4CreateRejectTCPPacket(packet []byte, out []byte) []byte {
 	ihl := int(packet[0]&0x0f) << 2
 	outLen := ipv4.HeaderLen + tcpLen
 
-	out = out[:(outLen)]
+	if len(packet) < ihl+tcpLen {
+		// We need at least this many bytes for this to be a valid packet
+		return nil
+	}
+	if outLen > cap(out) {
+		return nil
+	}
+
+	out = out[:outLen]
 
 	ipHdr := out[0:ipv4.HeaderLen]
 	ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2)    // version, ihl

+ 73 - 0
iputil/packet_test.go

@@ -0,0 +1,73 @@
+package iputil
+
+import (
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"golang.org/x/net/ipv4"
+)
+
+func Test_CreateRejectPacket(t *testing.T) {
+	h := ipv4.Header{
+		Len:      20,
+		Src:      net.IPv4(10, 0, 0, 1),
+		Dst:      net.IPv4(10, 0, 0, 2),
+		Protocol: 1, // ICMP
+	}
+
+	b, err := h.Marshal()
+	if err != nil {
+		t.Fatalf("h.Marhshal: %v", err)
+	}
+	b = append(b, []byte{0, 3, 0, 4}...)
+
+	expectedLen := ipv4.HeaderLen + 8 + h.Len + 4
+	out := make([]byte, expectedLen)
+	rejectPacket := CreateRejectPacket(b, out)
+	assert.NotNil(t, rejectPacket)
+	assert.Len(t, rejectPacket, expectedLen)
+
+	// ICMP with max header len
+	h = ipv4.Header{
+		Len:      60,
+		Src:      net.IPv4(10, 0, 0, 1),
+		Dst:      net.IPv4(10, 0, 0, 2),
+		Protocol: 1, // ICMP
+		Options:  make([]byte, 40),
+	}
+
+	b, err = h.Marshal()
+	if err != nil {
+		t.Fatalf("h.Marhshal: %v", err)
+	}
+	b = append(b, []byte{0, 3, 0, 4, 0, 0, 0, 0}...)
+
+	expectedLen = MaxRejectPacketSize
+	out = make([]byte, MaxRejectPacketSize)
+	rejectPacket = CreateRejectPacket(b, out)
+	assert.NotNil(t, rejectPacket)
+	assert.Len(t, rejectPacket, expectedLen)
+
+	// TCP with max header len
+	h = ipv4.Header{
+		Len:      60,
+		Src:      net.IPv4(10, 0, 0, 1),
+		Dst:      net.IPv4(10, 0, 0, 2),
+		Protocol: 6, // TCP
+		Options:  make([]byte, 40),
+	}
+
+	b, err = h.Marshal()
+	if err != nil {
+		t.Fatalf("h.Marhshal: %v", err)
+	}
+	b = append(b, []byte{0, 3, 0, 4}...)
+	b = append(b, make([]byte, 16)...)
+
+	expectedLen = ipv4.HeaderLen + 20
+	out = make([]byte, expectedLen)
+	rejectPacket = CreateRejectPacket(b, out)
+	assert.NotNil(t, rejectPacket)
+	assert.Len(t, rejectPacket, expectedLen)
+}

+ 3 - 1
outside.go

@@ -406,7 +406,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 
 	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason != nil {
-		f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, out, q)
+		// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
+		// This gives us a buffer to build the reject packet in
+		f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
 		if f.l.Level >= logrus.DebugLevel {
 			hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
 				WithField("reason", dropReason).