split_virtqueue.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. package virtqueue
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "os"
  7. "github.com/slackhq/nebula/overlay/eventfd"
  8. "golang.org/x/sys/unix"
  9. )
  10. // SplitQueue is a virtqueue that consists of several parts, where each part is
  11. // writeable by either the driver or the device, but not both.
  12. type SplitQueue struct {
  13. // size is the size of the queue.
  14. size int
  15. // buf is the underlying memory used for the queue.
  16. buf []byte
  17. descriptorTable *DescriptorTable
  18. availableRing *AvailableRing
  19. usedRing *UsedRing
  20. // kickEventFD is used to signal the device when descriptor chains were
  21. // added to the available ring.
  22. kickEventFD eventfd.EventFD
  23. // callEventFD is used by the device to signal when it has used descriptor
  24. // chains and put them in the used ring.
  25. callEventFD eventfd.EventFD
  26. // stop is used by [SplitQueue.Close] to cancel the goroutine that handles
  27. // used buffer notifications. It blocks until the goroutine ended.
  28. stop func() error
  29. itemSize int
  30. epoll eventfd.Epoll
  31. more int
  32. }
  33. // NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
  34. // specifies the number of entries/buffers the queue can hold. This also affects
  35. // the memory consumption.
  36. func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
  37. if err = CheckQueueSize(queueSize); err != nil {
  38. return nil, err
  39. }
  40. if itemSize%os.Getpagesize() != 0 {
  41. return nil, errors.New("split queue size must be multiple of os.Getpagesize()")
  42. }
  43. sq := SplitQueue{
  44. size: queueSize,
  45. itemSize: itemSize,
  46. }
  47. // Clean up a partially initialized queue when something fails.
  48. defer func() {
  49. if err != nil {
  50. _ = sq.Close()
  51. }
  52. }()
  53. // There are multiple ways for how the memory for the virtqueue could be
  54. // allocated. We could use Go native structs with arrays inside them, but
  55. // this wouldn't allow us to make the queue size configurable. And including
  56. // a slice in the Go structs wouldn't work, because this would just put the
  57. // Go slice descriptor into the memory region which the virtio device will
  58. // not understand.
  59. // Additionally, Go does not allow us to ensure a correct alignment of the
  60. // parts of the virtqueue, as it is required by the virtio specification.
  61. //
  62. // To resolve this, let's just allocate the memory manually by allocating
  63. // one or more memory pages, depending on the queue size. Making the
  64. // virtqueue start at the beginning of a page is not strictly necessary, as
  65. // the virtio specification does not require it to be continuous in the
  66. // physical memory of the host (e.g. the vhost implementation in the kernel
  67. // always uses copy_from_user to access it), but this makes it very easy to
  68. // guarantee the alignment. Also, it is not required for the virtqueue parts
  69. // to be in the same memory region, as we pass separate pointers to them to
  70. // the device, but this design just makes things easier to implement.
  71. //
  72. // One added benefit of allocating the memory manually is, that we have full
  73. // control over its lifetime and don't risk the garbage collector to collect
  74. // our valuable structures while the device still works with them.
  75. // The descriptor table is at the start of the page, so alignment is not an
  76. // issue here.
  77. descriptorTableStart := 0
  78. descriptorTableEnd := descriptorTableStart + descriptorTableSize(queueSize)
  79. availableRingStart := align(descriptorTableEnd, availableRingAlignment)
  80. availableRingEnd := availableRingStart + availableRingSize(queueSize)
  81. usedRingStart := align(availableRingEnd, usedRingAlignment)
  82. usedRingEnd := usedRingStart + usedRingSize(queueSize)
  83. sq.buf, err = unix.Mmap(-1, 0, usedRingEnd,
  84. unix.PROT_READ|unix.PROT_WRITE,
  85. unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
  86. if err != nil {
  87. return nil, fmt.Errorf("allocate virtqueue buffer: %w", err)
  88. }
  89. sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd], sq.itemSize)
  90. sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd])
  91. sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd])
  92. sq.kickEventFD, err = eventfd.New()
  93. if err != nil {
  94. return nil, fmt.Errorf("create kick event file descriptor: %w", err)
  95. }
  96. sq.callEventFD, err = eventfd.New()
  97. if err != nil {
  98. return nil, fmt.Errorf("create call event file descriptor: %w", err)
  99. }
  100. if err = sq.descriptorTable.initializeDescriptors(); err != nil {
  101. return nil, fmt.Errorf("initialize descriptors: %w", err)
  102. }
  103. sq.epoll, err = eventfd.NewEpoll()
  104. if err != nil {
  105. return nil, err
  106. }
  107. err = sq.epoll.AddEvent(sq.callEventFD.FD())
  108. if err != nil {
  109. return nil, err
  110. }
  111. sq.stop = sq.kickSelfToExit()
  112. return &sq, nil
  113. }
  114. // Size returns the size of this queue, which is the number of entries/buffers
  115. // this queue can hold.
  116. func (sq *SplitQueue) Size() int {
  117. return sq.size
  118. }
  119. // DescriptorTable returns the [DescriptorTable] behind this queue.
  120. func (sq *SplitQueue) DescriptorTable() *DescriptorTable {
  121. return sq.descriptorTable
  122. }
  123. // AvailableRing returns the [AvailableRing] behind this queue.
  124. func (sq *SplitQueue) AvailableRing() *AvailableRing {
  125. return sq.availableRing
  126. }
  127. // UsedRing returns the [UsedRing] behind this queue.
  128. func (sq *SplitQueue) UsedRing() *UsedRing {
  129. return sq.usedRing
  130. }
  131. // KickEventFD returns the kick event file descriptor behind this queue.
  132. // The returned file descriptor should be used with great care to not interfere
  133. // with this implementation.
  134. func (sq *SplitQueue) KickEventFD() int {
  135. return sq.kickEventFD.FD()
  136. }
  137. // CallEventFD returns the call event file descriptor behind this queue.
  138. // The returned file descriptor should be used with great care to not interfere
  139. // with this implementation.
  140. func (sq *SplitQueue) CallEventFD() int {
  141. return sq.callEventFD.FD()
  142. }
  143. func (sq *SplitQueue) kickSelfToExit() func() error {
  144. return func() error {
  145. // The goroutine blocks until it receives a signal on the event file
  146. // descriptor, so it will never notice the context being canceled.
  147. // To resolve this, we can just produce a fake-signal ourselves to wake
  148. // it up.
  149. if err := sq.callEventFD.Kick(); err != nil {
  150. return fmt.Errorf("wake up goroutine: %w", err)
  151. }
  152. return nil
  153. }
  154. }
  155. func (sq *SplitQueue) TakeSingleIndex(ctx context.Context) (uint16, error) {
  156. element, err := sq.TakeSingle(ctx)
  157. if err != nil {
  158. return 0xffff, err
  159. }
  160. return element.GetHead(), nil
  161. }
  162. func (sq *SplitQueue) TakeSingle(ctx context.Context) (UsedElement, error) {
  163. var n int
  164. var err error
  165. for ctx.Err() == nil {
  166. out, ok := sq.usedRing.takeOne()
  167. if ok {
  168. return out, nil
  169. }
  170. // Wait for a signal from the device.
  171. if n, err = sq.epoll.Block(); err != nil {
  172. return UsedElement{}, fmt.Errorf("wait: %w", err)
  173. }
  174. if n > 0 {
  175. out, ok = sq.usedRing.takeOne()
  176. if ok {
  177. _ = sq.epoll.Clear() //???
  178. return out, nil
  179. } else {
  180. continue //???
  181. }
  182. }
  183. }
  184. return UsedElement{}, ctx.Err()
  185. }
  186. func (sq *SplitQueue) TakeSingleNoBlock() (UsedElement, bool) {
  187. return sq.usedRing.takeOne()
  188. }
  189. func (sq *SplitQueue) WaitForUsedElements(ctx context.Context) error {
  190. if sq.usedRing.availableToTake() != 0 {
  191. return nil
  192. }
  193. for ctx.Err() == nil {
  194. // Wait for a signal from the device.
  195. n, err := sq.epoll.Block()
  196. if err != nil {
  197. return fmt.Errorf("wait: %w", err)
  198. }
  199. if n > 0 {
  200. _ = sq.epoll.Clear()
  201. if sq.usedRing.availableToTake() != 0 {
  202. return nil
  203. }
  204. }
  205. }
  206. return ctx.Err()
  207. }
  208. func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
  209. var n int
  210. var err error
  211. for ctx.Err() == nil {
  212. //we have leftovers in the fridge
  213. if sq.more > 0 {
  214. stillNeedToTake, out := sq.usedRing.take(maxToTake)
  215. sq.more = stillNeedToTake
  216. return out, nil
  217. }
  218. //look inside the fridge
  219. stillNeedToTake, out := sq.usedRing.take(maxToTake)
  220. if len(out) > 0 {
  221. sq.more = stillNeedToTake
  222. return out, nil
  223. }
  224. //fridge is empty I guess
  225. // Wait for a signal from the device.
  226. if n, err = sq.epoll.Block(); err != nil {
  227. return nil, fmt.Errorf("wait: %w", err)
  228. }
  229. if n > 0 {
  230. _ = sq.epoll.Clear()
  231. stillNeedToTake, out = sq.usedRing.take(maxToTake)
  232. sq.more = stillNeedToTake
  233. return out, nil
  234. }
  235. }
  236. return nil, ctx.Err()
  237. }
  238. // OfferDescriptorChain offers a descriptor chain to the device which contains a
  239. // number of device-readable buffers (out buffers) and device-writable buffers
  240. // (in buffers).
  241. //
  242. // All buffers in the outBuffers slice will be concatenated by chaining
  243. // descriptors, one for each buffer in the slice. When a buffer is too large to
  244. // fit into a single descriptor (limited by the system's page size), it will be
  245. // split up into multiple descriptors within the chain.
  246. // When numInBuffers is greater than zero, the given number of device-writable
  247. // descriptors will be appended to the end of the chain, each referencing a
  248. // whole memory page (see [os.Getpagesize]).
  249. //
  250. // When the queue is full and no more descriptor chains can be added, a wrapped
  251. // [ErrNotEnoughFreeDescriptors] will be returned. If you set waitFree to true,
  252. // this method will handle this error and will block instead until there are
  253. // enough free descriptors again.
  254. //
  255. // After defining the descriptor chain in the [DescriptorTable], the index of
  256. // the head of the chain will be made available to the device using the
  257. // [AvailableRing] and will be returned by this method.
  258. // Callers should read from the [SplitQueue.UsedDescriptorChains] channel to be
  259. // notified when the descriptor chain was used by the device and should free the
  260. // used descriptor chains again using [SplitQueue.FreeDescriptorChain] when
  261. // they're done with them. When this does not happen, the queue will run full
  262. // and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
  263. func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
  264. var (
  265. head uint16
  266. err error
  267. )
  268. for {
  269. head, err = sq.descriptorTable.createDescriptorForInputs()
  270. if err == nil {
  271. break
  272. }
  273. // I don't wanna use errors.Is, it's slow
  274. //goland:noinspection GoDirectComparisonOfErrors
  275. if err == ErrNotEnoughFreeDescriptors {
  276. return 0, err
  277. } else {
  278. return 0, fmt.Errorf("create descriptor chain: %w", err)
  279. }
  280. }
  281. // Make the descriptor chain available to the device.
  282. sq.availableRing.offerSingle(head)
  283. // Notify the device to make it process the updated available ring.
  284. if err = sq.kickEventFD.Kick(); err != nil {
  285. return head, fmt.Errorf("notify device: %w", err)
  286. }
  287. return head, nil
  288. }
  289. // GetDescriptorItem returns the buffer of a given index
  290. // The head index must be one that was returned by a previous call to
  291. // [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
  292. // freed yet.
  293. //
  294. // Be careful to only access the returned buffer slices when the device is no
  295. // longer using them. They must not be accessed after
  296. // [SplitQueue.FreeDescriptorChain] has been called.
  297. func (sq *SplitQueue) GetDescriptorItem(head uint16) []byte {
  298. sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
  299. return sq.descriptorTable.getDescriptorItem(head)
  300. }
  301. func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
  302. //not called under lock
  303. sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
  304. }
  305. func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
  306. // Make the descriptor chain available to the device.
  307. sq.availableRing.offer(chains)
  308. // Notify the device to make it process the updated available ring.
  309. if kick {
  310. return sq.Kick()
  311. }
  312. return nil
  313. }
  314. func (sq *SplitQueue) Kick() error {
  315. if err := sq.kickEventFD.Kick(); err != nil {
  316. return fmt.Errorf("notify device: %w", err)
  317. }
  318. return nil
  319. }
  320. // Close releases all resources used for this queue.
  321. // The implementation will try to release as many resources as possible and
  322. // collect potential errors before returning them.
  323. func (sq *SplitQueue) Close() error {
  324. var errs []error
  325. if sq.stop != nil {
  326. // This has to happen before the event file descriptors may be closed.
  327. if err := sq.stop(); err != nil {
  328. errs = append(errs, fmt.Errorf("stop consume used ring: %w", err))
  329. }
  330. // Make sure that this code block is executed only once.
  331. sq.stop = nil
  332. }
  333. if err := sq.kickEventFD.Close(); err != nil {
  334. errs = append(errs, fmt.Errorf("close kick event file descriptor: %w", err))
  335. }
  336. if err := sq.callEventFD.Close(); err != nil {
  337. errs = append(errs, fmt.Errorf("close call event file descriptor: %w", err))
  338. }
  339. if err := sq.descriptorTable.releaseBuffers(); err != nil {
  340. errs = append(errs, fmt.Errorf("release descriptor buffers: %w", err))
  341. }
  342. if sq.buf != nil {
  343. if err := unix.Munmap(sq.buf); err == nil {
  344. sq.buf = nil
  345. } else {
  346. errs = append(errs, fmt.Errorf("unmap virtqueue buffer: %w", err))
  347. }
  348. }
  349. return errors.Join(errs...)
  350. }
  351. func align(index, alignment int) int {
  352. remainder := index % alignment
  353. if remainder == 0 {
  354. return index
  355. }
  356. return index + alignment - remainder
  357. }