wireguard_tun_linux.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. //go:build linux && !android && !e2e_testing
  2. package overlay
  3. import (
  4. "fmt"
  5. "sync"
  6. wgtun "github.com/slackhq/nebula/wgstack/tun"
  7. )
  8. type wireguardTunIO struct {
  9. dev wgtun.Device
  10. mtu int
  11. batchSize int
  12. readMu sync.Mutex
  13. readBuffers [][]byte
  14. readLens []int
  15. legacyBuf []byte
  16. writeMu sync.Mutex
  17. writeBuf []byte
  18. writeWrap [][]byte
  19. writeBuffers [][]byte
  20. }
  21. func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
  22. batch := dev.BatchSize()
  23. if batch <= 0 {
  24. batch = 1
  25. }
  26. if mtu <= 0 {
  27. mtu = DefaultMTU
  28. }
  29. return &wireguardTunIO{
  30. dev: dev,
  31. mtu: mtu,
  32. batchSize: batch,
  33. readLens: make([]int, batch),
  34. legacyBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
  35. writeBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
  36. writeWrap: make([][]byte, 1),
  37. }
  38. }
  39. func (w *wireguardTunIO) Read(p []byte) (int, error) {
  40. w.readMu.Lock()
  41. defer w.readMu.Unlock()
  42. bufs := w.readBuffers
  43. if len(bufs) == 0 {
  44. bufs = [][]byte{w.legacyBuf}
  45. w.readBuffers = bufs
  46. }
  47. n, err := w.dev.Read(bufs[:1], w.readLens[:1], wgtun.VirtioNetHdrLen)
  48. if err != nil {
  49. return 0, err
  50. }
  51. if n == 0 {
  52. return 0, nil
  53. }
  54. length := w.readLens[0]
  55. copy(p, w.legacyBuf[wgtun.VirtioNetHdrLen:wgtun.VirtioNetHdrLen+length])
  56. return length, nil
  57. }
  58. func (w *wireguardTunIO) Write(p []byte) (int, error) {
  59. if len(p) > w.mtu {
  60. return 0, fmt.Errorf("wireguard tun: payload exceeds MTU (%d > %d)", len(p), w.mtu)
  61. }
  62. w.writeMu.Lock()
  63. defer w.writeMu.Unlock()
  64. buf := w.writeBuf[:wgtun.VirtioNetHdrLen+len(p)]
  65. for i := 0; i < wgtun.VirtioNetHdrLen; i++ {
  66. buf[i] = 0
  67. }
  68. copy(buf[wgtun.VirtioNetHdrLen:], p)
  69. w.writeWrap[0] = buf
  70. n, err := w.dev.Write(w.writeWrap, wgtun.VirtioNetHdrLen)
  71. if err != nil {
  72. return n, err
  73. }
  74. return len(p), nil
  75. }
  76. func (w *wireguardTunIO) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
  77. if pool == nil {
  78. return nil, fmt.Errorf("wireguard tun: packet pool is nil")
  79. }
  80. w.readMu.Lock()
  81. defer w.readMu.Unlock()
  82. if len(w.readBuffers) < w.batchSize {
  83. w.readBuffers = make([][]byte, w.batchSize)
  84. }
  85. if len(w.readLens) < w.batchSize {
  86. w.readLens = make([]int, w.batchSize)
  87. }
  88. packets := make([]*Packet, w.batchSize)
  89. requiredHeadroom := w.BatchHeadroom()
  90. requiredPayload := w.BatchPayloadCap()
  91. headroom := 0
  92. for i := 0; i < w.batchSize; i++ {
  93. pkt := pool.Get()
  94. if pkt == nil {
  95. releasePackets(packets[:i])
  96. return nil, fmt.Errorf("wireguard tun: packet pool returned nil packet")
  97. }
  98. if pkt.Capacity() < requiredPayload {
  99. pkt.Release()
  100. releasePackets(packets[:i])
  101. return nil, fmt.Errorf("wireguard tun: packet capacity %d below required %d", pkt.Capacity(), requiredPayload)
  102. }
  103. if i == 0 {
  104. headroom = pkt.Offset
  105. if headroom < requiredHeadroom {
  106. pkt.Release()
  107. releasePackets(packets[:i])
  108. return nil, fmt.Errorf("wireguard tun: packet headroom %d below virtio requirement %d", headroom, requiredHeadroom)
  109. }
  110. } else if pkt.Offset != headroom {
  111. pkt.Release()
  112. releasePackets(packets[:i])
  113. return nil, fmt.Errorf("wireguard tun: inconsistent packet headroom (%d != %d)", pkt.Offset, headroom)
  114. }
  115. packets[i] = pkt
  116. w.readBuffers[i] = pkt.Buf
  117. }
  118. n, err := w.dev.Read(w.readBuffers[:w.batchSize], w.readLens[:w.batchSize], headroom)
  119. if err != nil {
  120. releasePackets(packets)
  121. return nil, err
  122. }
  123. if n == 0 {
  124. releasePackets(packets)
  125. return nil, nil
  126. }
  127. for i := 0; i < n; i++ {
  128. packets[i].Len = w.readLens[i]
  129. }
  130. for i := n; i < w.batchSize; i++ {
  131. packets[i].Release()
  132. packets[i] = nil
  133. }
  134. return packets[:n], nil
  135. }
  136. func (w *wireguardTunIO) WriteBatch(packets []*Packet) (int, error) {
  137. if len(packets) == 0 {
  138. return 0, nil
  139. }
  140. requiredHeadroom := w.BatchHeadroom()
  141. offset := packets[0].Offset
  142. if offset < requiredHeadroom {
  143. releasePackets(packets)
  144. return 0, fmt.Errorf("wireguard tun: packet offset %d smaller than required headroom %d", offset, requiredHeadroom)
  145. }
  146. for _, pkt := range packets {
  147. if pkt == nil {
  148. continue
  149. }
  150. if pkt.Offset != offset {
  151. releasePackets(packets)
  152. return 0, fmt.Errorf("wireguard tun: mixed packet offsets not supported")
  153. }
  154. limit := pkt.Offset + pkt.Len
  155. if limit > len(pkt.Buf) {
  156. releasePackets(packets)
  157. return 0, fmt.Errorf("wireguard tun: packet length %d exceeds buffer capacity %d", pkt.Len, len(pkt.Buf)-pkt.Offset)
  158. }
  159. }
  160. w.writeMu.Lock()
  161. defer w.writeMu.Unlock()
  162. if len(w.writeBuffers) < len(packets) {
  163. w.writeBuffers = make([][]byte, len(packets))
  164. }
  165. for i, pkt := range packets {
  166. if pkt == nil {
  167. w.writeBuffers[i] = nil
  168. continue
  169. }
  170. limit := pkt.Offset + pkt.Len
  171. w.writeBuffers[i] = pkt.Buf[:limit]
  172. }
  173. n, err := w.dev.Write(w.writeBuffers[:len(packets)], offset)
  174. if err != nil {
  175. return n, err
  176. }
  177. releasePackets(packets)
  178. return n, nil
  179. }
  180. func (w *wireguardTunIO) BatchHeadroom() int {
  181. return wgtun.VirtioNetHdrLen
  182. }
  183. func (w *wireguardTunIO) BatchPayloadCap() int {
  184. return w.mtu
  185. }
  186. func (w *wireguardTunIO) BatchSize() int {
  187. return w.batchSize
  188. }
  189. func (w *wireguardTunIO) Close() error {
  190. return nil
  191. }
  192. func releasePackets(pkts []*Packet) {
  193. for _, pkt := range pkts {
  194. if pkt != nil {
  195. pkt.Release()
  196. }
  197. }
  198. }