outside_test.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. package nebula
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "net"
  6. "net/netip"
  7. "testing"
  8. "github.com/google/gopacket"
  9. "github.com/google/gopacket/layers"
  10. "github.com/slackhq/nebula/firewall"
  11. "github.com/stretchr/testify/assert"
  12. "golang.org/x/net/ipv4"
  13. )
  14. func Test_newPacket(t *testing.T) {
  15. p := &firewall.Packet{}
  16. // length fails
  17. err := newPacket([]byte{}, true, p)
  18. assert.ErrorIs(t, err, ErrPacketTooShort)
  19. err = newPacket([]byte{0x40}, true, p)
  20. assert.ErrorIs(t, err, ErrIPv4PacketTooShort)
  21. err = newPacket([]byte{0x60}, true, p)
  22. assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
  23. // length fail with ip options
  24. h := ipv4.Header{
  25. Version: 1,
  26. Len: 100,
  27. Src: net.IPv4(10, 0, 0, 1),
  28. Dst: net.IPv4(10, 0, 0, 2),
  29. Options: []byte{0, 1, 0, 2},
  30. }
  31. b, _ := h.Marshal()
  32. err = newPacket(b, true, p)
  33. assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
  34. // not an ipv4 packet
  35. err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
  36. assert.ErrorIs(t, err, ErrUnknownIPVersion)
  37. // invalid ihl
  38. err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
  39. assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
  40. // account for variable ip header length - incoming
  41. h = ipv4.Header{
  42. Version: 1,
  43. Len: 100,
  44. Src: net.IPv4(10, 0, 0, 1),
  45. Dst: net.IPv4(10, 0, 0, 2),
  46. Options: []byte{0, 1, 0, 2},
  47. Protocol: firewall.ProtoTCP,
  48. }
  49. b, _ = h.Marshal()
  50. b = append(b, []byte{0, 3, 0, 4}...)
  51. err = newPacket(b, true, p)
  52. assert.Nil(t, err)
  53. assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
  54. assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
  55. assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
  56. assert.Equal(t, uint16(3), p.RemotePort)
  57. assert.Equal(t, uint16(4), p.LocalPort)
  58. assert.False(t, p.Fragment)
  59. // account for variable ip header length - outgoing
  60. h = ipv4.Header{
  61. Version: 1,
  62. Protocol: 2,
  63. Len: 100,
  64. Src: net.IPv4(10, 0, 0, 1),
  65. Dst: net.IPv4(10, 0, 0, 2),
  66. Options: []byte{0, 1, 0, 2},
  67. }
  68. b, _ = h.Marshal()
  69. b = append(b, []byte{0, 5, 0, 6}...)
  70. err = newPacket(b, false, p)
  71. assert.Nil(t, err)
  72. assert.Equal(t, uint8(2), p.Protocol)
  73. assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
  74. assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
  75. assert.Equal(t, uint16(6), p.RemotePort)
  76. assert.Equal(t, uint16(5), p.LocalPort)
  77. assert.False(t, p.Fragment)
  78. }
  79. func Test_newPacket_v6(t *testing.T) {
  80. p := &firewall.Packet{}
  81. // invalid ipv6
  82. ip := layers.IPv6{
  83. Version: 6,
  84. HopLimit: 128,
  85. SrcIP: net.IPv6linklocalallrouters,
  86. DstIP: net.IPv6linklocalallnodes,
  87. }
  88. buffer := gopacket.NewSerializeBuffer()
  89. opt := gopacket.SerializeOptions{
  90. ComputeChecksums: false,
  91. FixLengths: false,
  92. }
  93. err := gopacket.SerializeLayers(buffer, opt, &ip)
  94. assert.NoError(t, err)
  95. err = newPacket(buffer.Bytes(), true, p)
  96. assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
  97. // A good ICMP packet
  98. ip = layers.IPv6{
  99. Version: 6,
  100. NextHeader: layers.IPProtocolICMPv6,
  101. HopLimit: 128,
  102. SrcIP: net.IPv6linklocalallrouters,
  103. DstIP: net.IPv6linklocalallnodes,
  104. }
  105. icmp := layers.ICMPv6{}
  106. buffer.Clear()
  107. err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
  108. if err != nil {
  109. panic(err)
  110. }
  111. err = newPacket(buffer.Bytes(), true, p)
  112. assert.Nil(t, err)
  113. assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
  114. assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
  115. assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
  116. assert.Equal(t, uint16(0), p.RemotePort)
  117. assert.Equal(t, uint16(0), p.LocalPort)
  118. assert.False(t, p.Fragment)
  119. // A good ESP packet
  120. b := buffer.Bytes()
  121. b[6] = byte(layers.IPProtocolESP)
  122. err = newPacket(b, true, p)
  123. assert.Nil(t, err)
  124. assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
  125. assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
  126. assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
  127. assert.Equal(t, uint16(0), p.RemotePort)
  128. assert.Equal(t, uint16(0), p.LocalPort)
  129. assert.False(t, p.Fragment)
  130. // A good None packet
  131. b = buffer.Bytes()
  132. b[6] = byte(layers.IPProtocolNoNextHeader)
  133. err = newPacket(b, true, p)
  134. assert.Nil(t, err)
  135. assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
  136. assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
  137. assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
  138. assert.Equal(t, uint16(0), p.RemotePort)
  139. assert.Equal(t, uint16(0), p.LocalPort)
  140. assert.False(t, p.Fragment)
  141. // An unknown protocol packet
  142. b = buffer.Bytes()
  143. b[6] = 255 // 255 is a reserved protocol number
  144. err = newPacket(b, true, p)
  145. assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
  146. // A good UDP packet
  147. ip = layers.IPv6{
  148. Version: 6,
  149. NextHeader: firewall.ProtoUDP,
  150. HopLimit: 128,
  151. SrcIP: net.IPv6linklocalallrouters,
  152. DstIP: net.IPv6linklocalallnodes,
  153. }
  154. udp := layers.UDP{
  155. SrcPort: layers.UDPPort(36123),
  156. DstPort: layers.UDPPort(22),
  157. }
  158. err = udp.SetNetworkLayerForChecksum(&ip)
  159. assert.NoError(t, err)
  160. buffer.Clear()
  161. err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
  162. if err != nil {
  163. panic(err)
  164. }
  165. b = buffer.Bytes()
  166. // incoming
  167. err = newPacket(b, true, p)
  168. assert.Nil(t, err)
  169. assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
  170. assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
  171. assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
  172. assert.Equal(t, uint16(36123), p.RemotePort)
  173. assert.Equal(t, uint16(22), p.LocalPort)
  174. assert.False(t, p.Fragment)
  175. // outgoing
  176. err = newPacket(b, false, p)
  177. assert.Nil(t, err)
  178. assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
  179. assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
  180. assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
  181. assert.Equal(t, uint16(36123), p.LocalPort)
  182. assert.Equal(t, uint16(22), p.RemotePort)
  183. assert.False(t, p.Fragment)
  184. // Too short UDP packet
  185. err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
  186. assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
  187. // A good TCP packet
  188. b[6] = byte(layers.IPProtocolTCP)
  189. // incoming
  190. err = newPacket(b, true, p)
  191. assert.Nil(t, err)
  192. assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
  193. assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
  194. assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
  195. assert.Equal(t, uint16(36123), p.RemotePort)
  196. assert.Equal(t, uint16(22), p.LocalPort)
  197. assert.False(t, p.Fragment)
  198. // outgoing
  199. err = newPacket(b, false, p)
  200. assert.Nil(t, err)
  201. assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
  202. assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
  203. assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
  204. assert.Equal(t, uint16(36123), p.LocalPort)
  205. assert.Equal(t, uint16(22), p.RemotePort)
  206. assert.False(t, p.Fragment)
  207. // Too short TCP packet
  208. err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
  209. assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
  210. // A good UDP packet with an AH header
  211. ip = layers.IPv6{
  212. Version: 6,
  213. NextHeader: layers.IPProtocolAH,
  214. HopLimit: 128,
  215. SrcIP: net.IPv6linklocalallrouters,
  216. DstIP: net.IPv6linklocalallnodes,
  217. }
  218. ah := layers.IPSecAH{
  219. AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef},
  220. }
  221. ah.NextHeader = layers.IPProtocolUDP
  222. udpHeader := []byte{
  223. 0x8d, 0x1b, // Source port 36123
  224. 0x00, 0x16, // Destination port 22
  225. 0x00, 0x00, // Length
  226. 0x00, 0x00, // Checksum
  227. }
  228. buffer.Clear()
  229. err = ip.SerializeTo(buffer, opt)
  230. if err != nil {
  231. panic(err)
  232. }
  233. b = buffer.Bytes()
  234. ahb := serializeAH(&ah)
  235. b = append(b, ahb...)
  236. b = append(b, udpHeader...)
  237. err = newPacket(b, true, p)
  238. assert.Nil(t, err)
  239. assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
  240. assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
  241. assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
  242. assert.Equal(t, uint16(36123), p.RemotePort)
  243. assert.Equal(t, uint16(22), p.LocalPort)
  244. assert.False(t, p.Fragment)
  245. // Invalid AH header
  246. b = buffer.Bytes()
  247. err = newPacket(b, true, p)
  248. assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
  249. }
  250. func Test_newPacket_ipv6Fragment(t *testing.T) {
  251. p := &firewall.Packet{}
  252. ip := &layers.IPv6{
  253. Version: 6,
  254. NextHeader: layers.IPProtocolIPv6Fragment,
  255. HopLimit: 64,
  256. SrcIP: net.IPv6linklocalallrouters,
  257. DstIP: net.IPv6linklocalallnodes,
  258. }
  259. // First fragment
  260. fragHeader1 := []byte{
  261. uint8(layers.IPProtocolUDP), // Next Header (UDP)
  262. 0x00, // Reserved
  263. 0x00, // Fragment Offset high byte (0)
  264. 0x01, // Fragment Offset low byte & flags (M=1)
  265. 0x00, 0x00, 0x00, 0x01, // Identification
  266. }
  267. udpHeader := []byte{
  268. 0x8d, 0x1b, // Source port 36123
  269. 0x00, 0x16, // Destination port 22
  270. 0x00, 0x00, // Length
  271. 0x00, 0x00, // Checksum
  272. }
  273. buffer := gopacket.NewSerializeBuffer()
  274. opts := gopacket.SerializeOptions{
  275. ComputeChecksums: true,
  276. FixLengths: true,
  277. }
  278. err := ip.SerializeTo(buffer, opts)
  279. if err != nil {
  280. t.Fatal(err)
  281. }
  282. firstFrag := buffer.Bytes()
  283. firstFrag = append(firstFrag, fragHeader1...)
  284. firstFrag = append(firstFrag, udpHeader...)
  285. firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
  286. // Test first fragment incoming
  287. err = newPacket(firstFrag, true, p)
  288. assert.NoError(t, err)
  289. assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
  290. assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
  291. assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
  292. assert.Equal(t, uint16(36123), p.RemotePort)
  293. assert.Equal(t, uint16(22), p.LocalPort)
  294. assert.False(t, p.Fragment)
  295. // Test first fragment outgoing
  296. err = newPacket(firstFrag, false, p)
  297. assert.NoError(t, err)
  298. assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
  299. assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
  300. assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
  301. assert.Equal(t, uint16(36123), p.LocalPort)
  302. assert.Equal(t, uint16(22), p.RemotePort)
  303. assert.False(t, p.Fragment)
  304. // Second fragment
  305. fragHeader2 := []byte{
  306. uint8(layers.IPProtocolUDP), // Next Header (UDP)
  307. 0x00, // Reserved
  308. 0xb9, // Fragment Offset high byte (185)
  309. 0x01, // Fragment Offset low byte & flags (M=1)
  310. 0x00, 0x00, 0x00, 0x01, // Identification
  311. }
  312. buffer.Clear()
  313. err = ip.SerializeTo(buffer, opts)
  314. if err != nil {
  315. t.Fatal(err)
  316. }
  317. secondFrag := buffer.Bytes()
  318. secondFrag = append(secondFrag, fragHeader2...)
  319. secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
  320. // Test second fragment incoming
  321. err = newPacket(secondFrag, true, p)
  322. assert.NoError(t, err)
  323. assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
  324. assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
  325. assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
  326. assert.Equal(t, uint16(0), p.RemotePort)
  327. assert.Equal(t, uint16(0), p.LocalPort)
  328. assert.True(t, p.Fragment)
  329. // Test second fragment outgoing
  330. err = newPacket(secondFrag, false, p)
  331. assert.NoError(t, err)
  332. assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
  333. assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
  334. assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
  335. assert.Equal(t, uint16(0), p.LocalPort)
  336. assert.Equal(t, uint16(0), p.RemotePort)
  337. assert.True(t, p.Fragment)
  338. // Too short of a fragment packet
  339. err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
  340. assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
  341. }
  342. func BenchmarkParseV6(b *testing.B) {
  343. // Regular UDP packet
  344. ip := &layers.IPv6{
  345. Version: 6,
  346. NextHeader: layers.IPProtocolUDP,
  347. HopLimit: 64,
  348. SrcIP: net.IPv6linklocalallrouters,
  349. DstIP: net.IPv6linklocalallnodes,
  350. }
  351. udp := &layers.UDP{
  352. SrcPort: layers.UDPPort(36123),
  353. DstPort: layers.UDPPort(22),
  354. }
  355. buffer := gopacket.NewSerializeBuffer()
  356. opts := gopacket.SerializeOptions{
  357. ComputeChecksums: false,
  358. FixLengths: true,
  359. }
  360. err := gopacket.SerializeLayers(buffer, opts, ip, udp)
  361. if err != nil {
  362. b.Fatal(err)
  363. }
  364. normalPacket := buffer.Bytes()
  365. // First Fragment packet
  366. ipFrag := &layers.IPv6{
  367. Version: 6,
  368. NextHeader: layers.IPProtocolIPv6Fragment,
  369. HopLimit: 64,
  370. SrcIP: net.IPv6linklocalallrouters,
  371. DstIP: net.IPv6linklocalallnodes,
  372. }
  373. fragHeader := []byte{
  374. uint8(layers.IPProtocolUDP), // Next Header (UDP)
  375. 0x00, // Reserved
  376. 0x00, // Fragment Offset high byte (0)
  377. 0x01, // Fragment Offset low byte & flags (M=1)
  378. 0x00, 0x00, 0x00, 0x01, // Identification
  379. }
  380. udpHeader := []byte{
  381. 0x8d, 0x7b, // Source port 36123
  382. 0x00, 0x16, // Destination port 22
  383. 0x00, 0x00, // Length
  384. 0x00, 0x00, // Checksum
  385. }
  386. buffer.Clear()
  387. err = ipFrag.SerializeTo(buffer, opts)
  388. if err != nil {
  389. b.Fatal(err)
  390. }
  391. firstFrag := buffer.Bytes()
  392. firstFrag = append(firstFrag, fragHeader...)
  393. firstFrag = append(firstFrag, udpHeader...)
  394. firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
  395. // Second Fragment packet
  396. fragHeader[2] = 0xb9 // offset 185
  397. buffer.Clear()
  398. err = ipFrag.SerializeTo(buffer, opts)
  399. if err != nil {
  400. b.Fatal(err)
  401. }
  402. secondFrag := buffer.Bytes()
  403. secondFrag = append(secondFrag, fragHeader...)
  404. secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
  405. fp := &firewall.Packet{}
  406. b.Run("Normal", func(b *testing.B) {
  407. for i := 0; i < b.N; i++ {
  408. if err = parseV6(normalPacket, true, fp); err != nil {
  409. b.Fatal(err)
  410. }
  411. }
  412. })
  413. b.Run("FirstFragment", func(b *testing.B) {
  414. for i := 0; i < b.N; i++ {
  415. if err = parseV6(firstFrag, true, fp); err != nil {
  416. b.Fatal(err)
  417. }
  418. }
  419. })
  420. b.Run("SecondFragment", func(b *testing.B) {
  421. for i := 0; i < b.N; i++ {
  422. if err = parseV6(secondFrag, true, fp); err != nil {
  423. b.Fatal(err)
  424. }
  425. }
  426. })
  427. // Evil packet
  428. evilPacket := &layers.IPv6{
  429. Version: 6,
  430. NextHeader: layers.IPProtocolIPv6HopByHop,
  431. HopLimit: 64,
  432. SrcIP: net.IPv6linklocalallrouters,
  433. DstIP: net.IPv6linklocalallnodes,
  434. }
  435. hopHeader := []byte{
  436. uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop)
  437. 0x00, // Length
  438. 0x00, 0x00, // Options and padding
  439. 0x00, 0x00, 0x00, 0x00, // More options and padding
  440. }
  441. lastHopHeader := []byte{
  442. uint8(layers.IPProtocolUDP), // Next Header (UDP)
  443. 0x00, // Length
  444. 0x00, 0x00, // Options and padding
  445. 0x00, 0x00, 0x00, 0x00, // More options and padding
  446. }
  447. buffer.Clear()
  448. err = evilPacket.SerializeTo(buffer, opts)
  449. if err != nil {
  450. b.Fatal(err)
  451. }
  452. evilBytes := buffer.Bytes()
  453. for i := 0; i < 200; i++ {
  454. evilBytes = append(evilBytes, hopHeader...)
  455. }
  456. evilBytes = append(evilBytes, lastHopHeader...)
  457. evilBytes = append(evilBytes, udpHeader...)
  458. evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...)
  459. b.Run("200 HopByHop headers", func(b *testing.B) {
  460. for i := 0; i < b.N; i++ {
  461. if err = parseV6(evilBytes, false, fp); err != nil {
  462. b.Fatal(err)
  463. }
  464. }
  465. })
  466. }
  467. // Ensure authentication data is a multiple of 8 bytes by padding if necessary
  468. func padAuthData(authData []byte) []byte {
  469. // Length of Authentication Data must be a multiple of 8 bytes
  470. paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary
  471. if paddingLength > 0 {
  472. authData = append(authData, make([]byte, paddingLength)...)
  473. }
  474. return authData
  475. }
  476. // Custom function to manually serialize IPSecAH for both IPv4 and IPv6
  477. func serializeAH(ah *layers.IPSecAH) []byte {
  478. buf := new(bytes.Buffer)
  479. // Ensure Authentication Data is a multiple of 8 bytes
  480. ah.AuthenticationData = padAuthData(ah.AuthenticationData)
  481. // Calculate Payload Length (in 32-bit words, minus 2)
  482. payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2
  483. // Serialize fields
  484. if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil {
  485. panic(err)
  486. }
  487. if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil {
  488. panic(err)
  489. }
  490. if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil {
  491. panic(err)
  492. }
  493. if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil {
  494. panic(err)
  495. }
  496. if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil {
  497. panic(err)
  498. }
  499. if len(ah.AuthenticationData) > 0 {
  500. if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil {
  501. panic(err)
  502. }
  503. }
  504. return buf.Bytes()
  505. }