Ver código fonte

More correct ipv6 header parsing (#1323)

Nate Brown 5 meses atrás
pai
commit
fbff6a1487
2 arquivos alterados com 545 adições e 81 exclusões
  1. 60 43
      outside.go
  2. 485 38
      outside_test.go

+ 60 - 43
outside.go

@@ -3,7 +3,6 @@ package nebula
 import (
 import (
 	"encoding/binary"
 	"encoding/binary"
 	"errors"
 	"errors"
-	"fmt"
 	"net/netip"
 	"net/netip"
 	"time"
 	"time"
 
 
@@ -271,10 +270,19 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h
 	return true
 	return true
 }
 }
 
 
+var (
+	ErrPacketTooShort          = errors.New("packet is too short")
+	ErrUnknownIPVersion        = errors.New("packet is an unknown ip version")
+	ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length")
+	ErrIPv4PacketTooShort      = errors.New("ipv4 packet is too short")
+	ErrIPv6PacketTooShort      = errors.New("ipv6 packet is too short")
+	ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet")
+)
+
 // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
 // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
 func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 	if len(data) < 1 {
 	if len(data) < 1 {
-		return errors.New("packet too short")
+		return ErrPacketTooShort
 	}
 	}
 
 
 	version := int((data[0] >> 4) & 0x0f)
 	version := int((data[0] >> 4) & 0x0f)
@@ -284,13 +292,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 	case ipv6.Version:
 	case ipv6.Version:
 		return parseV6(data, incoming, fp)
 		return parseV6(data, incoming, fp)
 	}
 	}
-	return fmt.Errorf("packet is an unknown ip version: %v", version)
+	return ErrUnknownIPVersion
 }
 }
 
 
 func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
 func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
 	dataLen := len(data)
 	dataLen := len(data)
 	if dataLen < ipv6.HeaderLen {
 	if dataLen < ipv6.HeaderLen {
-		return fmt.Errorf("ipv6 packet is less than %v bytes", ipv4.HeaderLen)
+		return ErrIPv6PacketTooShort
 	}
 	}
 
 
 	if incoming {
 	if incoming {
@@ -301,11 +309,10 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
 		fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
 		fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
 	}
 	}
 
 
-	//TODO: CERT-V2 whats a reasonable number of extension headers to attempt to parse?
-	//https://www.ietf.org/archive/id/draft-ietf-6man-eh-limits-00.html
-	protoAt := 6
-	offset := 40
-	for i := 0; i < 24; i++ {
+	protoAt := 6             // NextHeader is at 6 bytes into the ipv6 header
+	offset := ipv6.HeaderLen // Start at the end of the ipv6 header
+	next := 0
+	for {
 		if dataLen < offset {
 		if dataLen < offset {
 			break
 			break
 		}
 		}
@@ -313,17 +320,18 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
 		proto := layers.IPProtocol(data[protoAt])
 		proto := layers.IPProtocol(data[protoAt])
 		//fmt.Println(proto, protoAt)
 		//fmt.Println(proto, protoAt)
 		switch proto {
 		switch proto {
-		case layers.IPProtocolICMPv6:
+		case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
 			fp.Protocol = uint8(proto)
 			fp.Protocol = uint8(proto)
 			fp.RemotePort = 0
 			fp.RemotePort = 0
 			fp.LocalPort = 0
 			fp.LocalPort = 0
 			fp.Fragment = false
 			fp.Fragment = false
 			return nil
 			return nil
 
 
-		case layers.IPProtocolTCP:
+		case layers.IPProtocolTCP, layers.IPProtocolUDP:
 			if dataLen < offset+4 {
 			if dataLen < offset+4 {
-				return fmt.Errorf("ipv6 packet was too small")
+				return ErrIPv6PacketTooShort
 			}
 			}
+
 			fp.Protocol = uint8(proto)
 			fp.Protocol = uint8(proto)
 			if incoming {
 			if incoming {
 				fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
 				fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
@@ -332,62 +340,71 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
 				fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
 				fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
 				fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
 				fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
 			}
 			}
+
 			fp.Fragment = false
 			fp.Fragment = false
 			return nil
 			return nil
 
 
-		case layers.IPProtocolUDP:
-			if dataLen < offset+4 {
-				return fmt.Errorf("ipv6 packet was too small")
+		case layers.IPProtocolIPv6Fragment:
+			// Fragment header is 8 bytes, need at least offset+4 to read the offset field
+			if dataLen < offset+8 {
+				return ErrIPv6PacketTooShort
 			}
 			}
-			fp.Protocol = uint8(proto)
-			if incoming {
-				fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
-				fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
-			} else {
-				fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
-				fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
+
+			// Check if this is the first fragment
+			fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits
+			if fragmentOffset != 0 {
+				// Non-first fragment, use what we have now and stop processing
+				fp.Protocol = data[offset]
+				fp.Fragment = true
+				fp.RemotePort = 0
+				fp.LocalPort = 0
+				return nil
 			}
 			}
-			fp.Fragment = false
-			return nil
 
 
-		case layers.IPProtocolIPv6Fragment:
-			//TODO: CERT-V2 can we determine the protocol?
-			fp.RemotePort = 0
-			fp.LocalPort = 0
-			fp.Fragment = true
-			return nil
+			// The next loop should be the transport layer since we are the first fragment
+			next = 8 // Fragment headers are always 8 bytes
 
 
-		default:
+		case layers.IPProtocolAH:
+			// Auth headers, used by IPSec, have a different meaning for header length
 			if dataLen < offset+1 {
 			if dataLen < offset+1 {
 				break
 				break
 			}
 			}
 
 
-			next := int(data[offset+1]) * 8
-			if next == 0 {
-				// each extension is at least 8 bytes
-				next = 8
+			next = int(data[offset+1]+2) << 2
+
+		default:
+			// Normal ipv6 header length processing
+			if dataLen < offset+1 {
+				break
 			}
 			}
 
 
-			protoAt = offset
-			offset = offset + next
+			next = int(data[offset+1]+1) << 3
 		}
 		}
+
+		if next <= 0 {
+			// Safety check, each ipv6 header has to be at least 8 bytes
+			next = 8
+		}
+
+		protoAt = offset
+		offset = offset + next
 	}
 	}
 
 
-	return fmt.Errorf("could not find payload in ipv6 packet")
+	return ErrIPv6CouldNotFindPayload
 }
 }
 
 
 func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
 func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
 	// Do we at least have an ipv4 header worth of data?
 	// Do we at least have an ipv4 header worth of data?
 	if len(data) < ipv4.HeaderLen {
 	if len(data) < ipv4.HeaderLen {
-		return fmt.Errorf("ipv4 packet is less than %v bytes", ipv4.HeaderLen)
+		return ErrIPv4PacketTooShort
 	}
 	}
 
 
 	// Adjust our start position based on the advertised ip header length
 	// Adjust our start position based on the advertised ip header length
 	ihl := int(data[0]&0x0f) << 2
 	ihl := int(data[0]&0x0f) << 2
 
 
-	// Well formed ip header length?
+	// Well-formed ip header length?
 	if ihl < ipv4.HeaderLen {
 	if ihl < ipv4.HeaderLen {
-		return fmt.Errorf("ipv4 packet had an invalid header length: %v", ihl)
+		return ErrIPv4InvalidHeaderLength
 	}
 	}
 
 
 	// Check if this is the second or further fragment of a fragmented packet.
 	// Check if this is the second or further fragment of a fragmented packet.
@@ -403,7 +420,7 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
 		minLen += minFwPacketLen
 		minLen += minFwPacketLen
 	}
 	}
 	if len(data) < minLen {
 	if len(data) < minLen {
-		return fmt.Errorf("ipv4 packet is less than %v bytes, ip header len: %v", minLen, ihl)
+		return ErrIPv4InvalidHeaderLength
 	}
 	}
 
 
 	// Firewall packets are locally oriented
 	// Firewall packets are locally oriented
@@ -501,7 +518,7 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
 	f.messageMetrics.Tx(header.RecvError, 0, 1)
 	f.messageMetrics.Tx(header.RecvError, 0, 1)
 
 
 	b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
 	b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
-	f.outside.WriteTo(b, endpoint)
+	_ = f.outside.WriteTo(b, endpoint)
 	if f.l.Level >= logrus.DebugLevel {
 	if f.l.Level >= logrus.DebugLevel {
 		f.l.WithField("index", index).
 		f.l.WithField("index", index).
 			WithField("udpAddr", endpoint).
 			WithField("udpAddr", endpoint).

+ 485 - 38
outside_test.go

@@ -1,6 +1,8 @@
 package nebula
 package nebula
 
 
 import (
 import (
+	"bytes"
+	"encoding/binary"
 	"net"
 	"net"
 	"net/netip"
 	"net/netip"
 	"testing"
 	"testing"
@@ -18,13 +20,13 @@ func Test_newPacket(t *testing.T) {
 
 
 	// length fails
 	// length fails
 	err := newPacket([]byte{}, true, p)
 	err := newPacket([]byte{}, true, p)
-	assert.EqualError(t, err, "packet too short")
+	assert.ErrorIs(t, err, ErrPacketTooShort)
 
 
 	err = newPacket([]byte{0x40}, true, p)
 	err = newPacket([]byte{0x40}, true, p)
-	assert.EqualError(t, err, "ipv4 packet is less than 20 bytes")
+	assert.ErrorIs(t, err, ErrIPv4PacketTooShort)
 
 
 	err = newPacket([]byte{0x60}, true, p)
 	err = newPacket([]byte{0x60}, true, p)
-	assert.EqualError(t, err, "ipv6 packet is less than 20 bytes")
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
 
 
 	// length fail with ip options
 	// length fail with ip options
 	h := ipv4.Header{
 	h := ipv4.Header{
@@ -37,16 +39,15 @@ func Test_newPacket(t *testing.T) {
 
 
 	b, _ := h.Marshal()
 	b, _ := h.Marshal()
 	err = newPacket(b, true, p)
 	err = newPacket(b, true, p)
-
-	assert.EqualError(t, err, "ipv4 packet is less than 28 bytes, ip header len: 24")
+	assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
 
 
 	// not an ipv4 packet
 	// not an ipv4 packet
 	err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
 	err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
-	assert.EqualError(t, err, "packet is an unknown ip version: 0")
+	assert.ErrorIs(t, err, ErrUnknownIPVersion)
 
 
 	// invalid ihl
 	// invalid ihl
 	err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
 	err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
-	assert.EqualError(t, err, "ipv4 packet had an invalid header length: 8")
+	assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
 
 
 	// account for variable ip header length - incoming
 	// account for variable ip header length - incoming
 	h = ipv4.Header{
 	h = ipv4.Header{
@@ -63,11 +64,12 @@ func Test_newPacket(t *testing.T) {
 	err = newPacket(b, true, p)
 	err = newPacket(b, true, p)
 
 
 	assert.Nil(t, err)
 	assert.Nil(t, err)
-	assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
-	assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.2"))
-	assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.1"))
-	assert.Equal(t, p.RemotePort, uint16(3))
-	assert.Equal(t, p.LocalPort, uint16(4))
+	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
+	assert.Equal(t, uint16(3), p.RemotePort)
+	assert.Equal(t, uint16(4), p.LocalPort)
+	assert.False(t, p.Fragment)
 
 
 	// account for variable ip header length - outgoing
 	// account for variable ip header length - outgoing
 	h = ipv4.Header{
 	h = ipv4.Header{
@@ -84,17 +86,94 @@ func Test_newPacket(t *testing.T) {
 	err = newPacket(b, false, p)
 	err = newPacket(b, false, p)
 
 
 	assert.Nil(t, err)
 	assert.Nil(t, err)
-	assert.Equal(t, p.Protocol, uint8(2))
-	assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.1"))
-	assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.2"))
-	assert.Equal(t, p.RemotePort, uint16(6))
-	assert.Equal(t, p.LocalPort, uint16(5))
+	assert.Equal(t, uint8(2), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
+	assert.Equal(t, uint16(6), p.RemotePort)
+	assert.Equal(t, uint16(5), p.LocalPort)
+	assert.False(t, p.Fragment)
 }
 }
 
 
 func Test_newPacket_v6(t *testing.T) {
 func Test_newPacket_v6(t *testing.T) {
 	p := &firewall.Packet{}
 	p := &firewall.Packet{}
 
 
+	// invalid ipv6
 	ip := layers.IPv6{
 	ip := layers.IPv6{
+		Version:  6,
+		HopLimit: 128,
+		SrcIP:    net.IPv6linklocalallrouters,
+		DstIP:    net.IPv6linklocalallnodes,
+	}
+
+	buffer := gopacket.NewSerializeBuffer()
+	opt := gopacket.SerializeOptions{
+		ComputeChecksums: false,
+		FixLengths:       false,
+	}
+	err := gopacket.SerializeLayers(buffer, opt, &ip)
+	assert.NoError(t, err)
+
+	err = newPacket(buffer.Bytes(), true, p)
+	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+
+	// A good ICMP packet
+	ip = layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolICMPv6,
+		HopLimit:   128,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	icmp := layers.ICMPv6{}
+
+	buffer.Clear()
+	err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
+	if err != nil {
+		panic(err)
+	}
+
+	err = newPacket(buffer.Bytes(), true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// A good ESP packet
+	b := buffer.Bytes()
+	b[6] = byte(layers.IPProtocolESP)
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// A good None packet
+	b = buffer.Bytes()
+	b[6] = byte(layers.IPProtocolNoNextHeader)
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// An unknown protocol packet
+	b = buffer.Bytes()
+	b[6] = 255 // 255 is a reserved protocol number
+	err = newPacket(b, true, p)
+	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+
+	// A good UDP packet
+	ip = layers.IPv6{
 		Version:    6,
 		Version:    6,
 		NextHeader: firewall.ProtoUDP,
 		NextHeader: firewall.ProtoUDP,
 		HopLimit:   128,
 		HopLimit:   128,
@@ -106,39 +185,407 @@ func Test_newPacket_v6(t *testing.T) {
 		SrcPort: layers.UDPPort(36123),
 		SrcPort: layers.UDPPort(36123),
 		DstPort: layers.UDPPort(22),
 		DstPort: layers.UDPPort(22),
 	}
 	}
-	err := udp.SetNetworkLayerForChecksum(&ip)
+	err = udp.SetNetworkLayerForChecksum(&ip)
+	assert.NoError(t, err)
+
+	buffer.Clear()
+	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
+	if err != nil {
+		panic(err)
+	}
+	b = buffer.Bytes()
+
+	// incoming
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// outgoing
+	err = newPacket(b, false, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint16(36123), p.LocalPort)
+	assert.Equal(t, uint16(22), p.RemotePort)
+	assert.False(t, p.Fragment)
+
+	// Too short UDP packet
+	err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+
+	// A good TCP packet
+	b[6] = byte(layers.IPProtocolTCP)
+
+	// incoming
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// outgoing
+	err = newPacket(b, false, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint16(36123), p.LocalPort)
+	assert.Equal(t, uint16(22), p.RemotePort)
+	assert.False(t, p.Fragment)
+
+	// Too short TCP packet
+	err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+
+	// A good UDP packet with an AH header
+	ip = layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolAH,
+		HopLimit:   128,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	ah := layers.IPSecAH{
+		AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef},
+	}
+	ah.NextHeader = layers.IPProtocolUDP
+
+	udpHeader := []byte{
+		0x8d, 0x1b, // Source port 36123
+		0x00, 0x16, // Destination port 22
+		0x00, 0x00, // Length
+		0x00, 0x00, // Checksum
+	}
+
+	buffer.Clear()
+	err = ip.SerializeTo(buffer, opt)
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
 	}
 	}
 
 
+	b = buffer.Bytes()
+	ahb := serializeAH(&ah)
+	b = append(b, ahb...)
+	b = append(b, udpHeader...)
+
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// Invalid AH header
+	b = buffer.Bytes()
+	err = newPacket(b, true, p)
+	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+}
+
+func Test_newPacket_ipv6Fragment(t *testing.T) {
+	p := &firewall.Packet{}
+
+	ip := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolIPv6Fragment,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	// First fragment
+	fragHeader1 := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Reserved
+		0x00,                        // Fragment Offset high byte (0)
+		0x01,                        // Fragment Offset low byte & flags (M=1)
+		0x00, 0x00, 0x00, 0x01,      // Identification
+	}
+
+	udpHeader := []byte{
+		0x8d, 0x1b, // Source port 36123
+		0x00, 0x16, // Destination port 22
+		0x00, 0x00, // Length
+		0x00, 0x00, // Checksum
+	}
+
 	buffer := gopacket.NewSerializeBuffer()
 	buffer := gopacket.NewSerializeBuffer()
-	opt := gopacket.SerializeOptions{
+	opts := gopacket.SerializeOptions{
 		ComputeChecksums: true,
 		ComputeChecksums: true,
 		FixLengths:       true,
 		FixLengths:       true,
 	}
 	}
-	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
+
+	err := ip.SerializeTo(buffer, opts)
 	if err != nil {
 	if err != nil {
-		panic(err)
+		t.Fatal(err)
 	}
 	}
-	b := buffer.Bytes()
 
 
-	//test incoming
-	err = newPacket(b, true, p)
+	firstFrag := buffer.Bytes()
+	firstFrag = append(firstFrag, fragHeader1...)
+	firstFrag = append(firstFrag, udpHeader...)
+	firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
 
 
-	assert.Nil(t, err)
-	assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
-	assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::2"))
-	assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::1"))
-	assert.Equal(t, p.RemotePort, uint16(36123))
-	assert.Equal(t, p.LocalPort, uint16(22))
+	// Test first fragment incoming
+	err = newPacket(firstFrag, true, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
 
 
-	//test outgoing
-	err = newPacket(b, false, p)
+	// Test first fragment outgoing
+	err = newPacket(firstFrag, false, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(36123), p.LocalPort)
+	assert.Equal(t, uint16(22), p.RemotePort)
+	assert.False(t, p.Fragment)
 
 
-	assert.Nil(t, err)
-	assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
-	assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::2"))
-	assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::1"))
-	assert.Equal(t, p.LocalPort, uint16(36123))
-	assert.Equal(t, p.RemotePort, uint16(22))
+	// Second fragment
+	fragHeader2 := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Reserved
+		0xb9,                        // Fragment Offset high byte (185)
+		0x01,                        // Fragment Offset low byte & flags (M=1)
+		0x00, 0x00, 0x00, 0x01,      // Identification
+	}
+
+	buffer.Clear()
+	err = ip.SerializeTo(buffer, opts)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	secondFrag := buffer.Bytes()
+	secondFrag = append(secondFrag, fragHeader2...)
+	secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	// Test second fragment incoming
+	err = newPacket(secondFrag, true, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.True(t, p.Fragment)
+
+	// Test second fragment outgoing
+	err = newPacket(secondFrag, false, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.True(t, p.Fragment)
+
+	// Too short of a fragment packet
+	err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+}
+
+func BenchmarkParseV6(b *testing.B) {
+	// Regular UDP packet
+	ip := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolUDP,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	udp := &layers.UDP{
+		SrcPort: layers.UDPPort(36123),
+		DstPort: layers.UDPPort(22),
+	}
+
+	buffer := gopacket.NewSerializeBuffer()
+	opts := gopacket.SerializeOptions{
+		ComputeChecksums: false,
+		FixLengths:       true,
+	}
+
+	err := gopacket.SerializeLayers(buffer, opts, ip, udp)
+	if err != nil {
+		b.Fatal(err)
+	}
+	normalPacket := buffer.Bytes()
+
+	// First Fragment packet
+	ipFrag := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolIPv6Fragment,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	fragHeader := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Reserved
+		0x00,                        // Fragment Offset high byte (0)
+		0x01,                        // Fragment Offset low byte & flags (M=1)
+		0x00, 0x00, 0x00, 0x01,      // Identification
+	}
+
+	udpHeader := []byte{
+		0x8d, 0x7b, // Source port 36123
+		0x00, 0x16, // Destination port 22
+		0x00, 0x00, // Length
+		0x00, 0x00, // Checksum
+	}
+
+	buffer.Clear()
+	err = ipFrag.SerializeTo(buffer, opts)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	firstFrag := buffer.Bytes()
+	firstFrag = append(firstFrag, fragHeader...)
+	firstFrag = append(firstFrag, udpHeader...)
+	firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	// Second Fragment packet
+	fragHeader[2] = 0xb9 // offset 185
+	buffer.Clear()
+	err = ipFrag.SerializeTo(buffer, opts)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	secondFrag := buffer.Bytes()
+	secondFrag = append(secondFrag, fragHeader...)
+	secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	fp := &firewall.Packet{}
+
+	b.Run("Normal", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(normalPacket, true, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+
+	b.Run("FirstFragment", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(firstFrag, true, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+
+	b.Run("SecondFragment", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(secondFrag, true, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+
+	// Evil packet
+	evilPacket := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolIPv6HopByHop,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	hopHeader := []byte{
+		uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop)
+		0x00,                                 // Length
+		0x00, 0x00,                           // Options and padding
+		0x00, 0x00, 0x00, 0x00, // More options and padding
+	}
+
+	lastHopHeader := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Length
+		0x00, 0x00,                  // Options and padding
+		0x00, 0x00, 0x00, 0x00, // More options and padding
+	}
+
+	buffer.Clear()
+	err = evilPacket.SerializeTo(buffer, opts)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	evilBytes := buffer.Bytes()
+	for i := 0; i < 200; i++ {
+		evilBytes = append(evilBytes, hopHeader...)
+	}
+	evilBytes = append(evilBytes, lastHopHeader...)
+	evilBytes = append(evilBytes, udpHeader...)
+	evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	b.Run("200 HopByHop headers", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(evilBytes, false, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+}
+
+// Ensure authentication data is a multiple of 8 bytes by padding if necessary
+func padAuthData(authData []byte) []byte {
+	// Length of Authentication Data must be a multiple of 8 bytes
+	paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary
+	if paddingLength > 0 {
+		authData = append(authData, make([]byte, paddingLength)...)
+	}
+	return authData
+}
+
+// Custom function to manually serialize IPSecAH for both IPv4 and IPv6
+func serializeAH(ah *layers.IPSecAH) []byte {
+	buf := new(bytes.Buffer)
+
+	// Ensure Authentication Data is a multiple of 8 bytes
+	ah.AuthenticationData = padAuthData(ah.AuthenticationData)
+	// Calculate Payload Length (in 32-bit words, minus 2)
+	payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2
+
+	// Serialize fields
+	if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil {
+		panic(err)
+	}
+	if len(ah.AuthenticationData) > 0 {
+		if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil {
+			panic(err)
+		}
+	}
+
+	return buf.Bytes()
 }
 }