|
@@ -118,7 +118,7 @@ func NewDevice(options ...Option) (*Device, error) {
|
|
|
return nil, fmt.Errorf("set transmit queue backend: %w", err)
|
|
return nil, fmt.Errorf("set transmit queue backend: %w", err)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Fully populate the receive queue with available buffers which the device
|
|
|
|
|
|
|
+ // Fully populate the rx queue with available buffers which the device
|
|
|
// can write new packets into.
|
|
// can write new packets into.
|
|
|
if err = dev.refillReceiveQueue(); err != nil {
|
|
if err = dev.refillReceiveQueue(); err != nil {
|
|
|
return nil, fmt.Errorf("refill receive queue: %w", err)
|
|
return nil, fmt.Errorf("refill receive queue: %w", err)
|
|
@@ -198,11 +198,8 @@ func (dev *Device) Close() error {
|
|
|
// createQueue creates a new virtqueue and registers it with the vhost device
|
|
// createQueue creates a new virtqueue and registers it with the vhost device
|
|
|
// using the given index.
|
|
// using the given index.
|
|
|
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
|
|
func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
|
|
|
- var (
|
|
|
|
|
- queue *virtqueue.SplitQueue
|
|
|
|
|
- err error
|
|
|
|
|
- )
|
|
|
|
|
- if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
|
|
|
|
|
|
|
+ queue, err := virtqueue.NewSplitQueue(queueSize, itemSize)
|
|
|
|
|
+ if err != nil {
|
|
|
return nil, fmt.Errorf("create virtqueue: %w", err)
|
|
return nil, fmt.Errorf("create virtqueue: %w", err)
|
|
|
}
|
|
}
|
|
|
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
|
|
if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
|
|
@@ -218,10 +215,10 @@ func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
|
|
|
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
|
idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
|
|
|
if err == virtqueue.ErrNotEnoughFreeDescriptors {
|
|
if err == virtqueue.ErrNotEnoughFreeDescriptors {
|
|
|
dev.fullTable = true
|
|
dev.fullTable = true
|
|
|
- idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
|
|
|
|
|
|
+ idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
|
|
|
}
|
|
}
|
|
|
} else {
|
|
} else {
|
|
|
- idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
|
|
|
|
|
|
|
+ idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
|
|
|
}
|
|
}
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return 0, nil, fmt.Errorf("transmit queue: %w", err)
|
|
return 0, nil, fmt.Errorf("transmit queue: %w", err)
|
|
@@ -271,18 +268,15 @@ func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
|
|
|
|
|
-func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
|
|
|
|
|
|
|
+// ProcessRxChain processes a single chain to create one packet. The number of processed chains is returned.
|
|
|
|
|
+func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement) (int, error) {
|
|
|
//read first element to see how many descriptors we need:
|
|
//read first element to see how many descriptors we need:
|
|
|
pkt.Reset()
|
|
pkt.Reset()
|
|
|
-
|
|
|
|
|
- err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
|
|
|
|
|
|
|
+ idx := uint16(chain.DescriptorIndex)
|
|
|
|
|
+ buf, err := dev.ReceiveQueue.GetDescriptorItem(idx)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return 0, fmt.Errorf("get descriptor chain: %w", err)
|
|
return 0, fmt.Errorf("get descriptor chain: %w", err)
|
|
|
}
|
|
}
|
|
|
- if len(pkt.ChainRefs) == 0 {
|
|
|
|
|
- return 1, nil
|
|
|
|
|
- }
|
|
|
|
|
|
|
|
|
|
// The specification requires that the first descriptor chain starts
|
|
// The specification requires that the first descriptor chain starts
|
|
|
// with a virtio-net header. It is not clear, whether it is also
|
|
// with a virtio-net header. It is not clear, whether it is also
|
|
@@ -290,7 +284,7 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
|
|
|
// descriptor chain, but it is reasonable to assume that this is
|
|
// descriptor chain, but it is reasonable to assume that this is
|
|
|
// always the case.
|
|
// always the case.
|
|
|
// The decode method already does the buffer length check.
|
|
// The decode method already does the buffer length check.
|
|
|
- if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
|
|
|
|
|
|
|
+ if err = pkt.header.Decode(buf); err != nil {
|
|
|
// The device misbehaved. There is no way we can gracefully
|
|
// The device misbehaved. There is no way we can gracefully
|
|
|
// recover from this, because we don't know how many of the
|
|
// recover from this, because we don't know how many of the
|
|
|
// following descriptor chains belong to this packet.
|
|
// following descriptor chains belong to this packet.
|
|
@@ -298,72 +292,44 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
//we have the header now: what do we need to do?
|
|
//we have the header now: what do we need to do?
|
|
|
- if int(pkt.Header.NumBuffers) > len(chains) {
|
|
|
|
|
- return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
|
|
|
|
|
|
|
+ if int(pkt.header.NumBuffers) > 1 {
|
|
|
|
|
+ return 0, fmt.Errorf("number of buffers is greater than number of chains %d", 1)
|
|
|
}
|
|
}
|
|
|
- if int(pkt.Header.NumBuffers) != 1 {
|
|
|
|
|
- return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
|
|
|
|
|
|
|
+ if int(pkt.header.NumBuffers) != 1 {
|
|
|
|
|
+ return 0, fmt.Errorf("too smol-brain to handle more than one buffer per chain item right now: %d chains, %d bufs", 1, int(pkt.header.NumBuffers))
|
|
|
}
|
|
}
|
|
|
- if chains[0].Length > 16000 {
|
|
|
|
|
|
|
+ if chain.Length > 16000 {
|
|
|
//todo!
|
|
//todo!
|
|
|
- return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
|
|
|
|
|
|
|
+ return 1, fmt.Errorf("too big packet length: %d", chain.Length)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
//shift the buffer out of out:
|
|
//shift the buffer out of out:
|
|
|
- pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
|
|
|
|
|
- pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
|
|
|
|
|
|
|
+ pkt.payload = buf[virtio.NetHdrSize:chain.Length]
|
|
|
|
|
+ pkt.Chains = append(pkt.Chains, idx)
|
|
|
return 1, nil
|
|
return 1, nil
|
|
|
-
|
|
|
|
|
- //cursor := n - virtio.NetHdrSize
|
|
|
|
|
- //
|
|
|
|
|
- //if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
|
|
|
|
|
- // pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
|
|
|
|
|
- // return 1, nil
|
|
|
|
|
- //}
|
|
|
|
|
- //
|
|
|
|
|
- //i := 1
|
|
|
|
|
- //// we used chain 0 already
|
|
|
|
|
- //for i = 1; i < len(chains); i++ {
|
|
|
|
|
- // n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
|
|
|
|
|
- // if err != nil {
|
|
|
|
|
- // // When this fails we may miss to free some descriptor chains. We
|
|
|
|
|
- // // could try to mitigate this by deferring the freeing somehow, but
|
|
|
|
|
- // // it's not worth the hassle. When this method fails, the queue will
|
|
|
|
|
- // // be in a broken state anyway.
|
|
|
|
|
- // return i, fmt.Errorf("get descriptor chain: %w", err)
|
|
|
|
|
- // }
|
|
|
|
|
- // cursor += n
|
|
|
|
|
- //}
|
|
|
|
|
- ////todo this has to be wrong
|
|
|
|
|
- //pkt.Payload = pkt.Payload[:cursor]
|
|
|
|
|
- //return i, nil
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
|
|
|
|
|
- //todo optimize?
|
|
|
|
|
- var chains []virtqueue.UsedElement
|
|
|
|
|
- var err error
|
|
|
|
|
|
|
+type VirtIOPacket struct {
|
|
|
|
|
+ payload []byte
|
|
|
|
|
+ header virtio.NetHdr
|
|
|
|
|
+ Chains []uint16
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
- chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return 0, err
|
|
|
|
|
- }
|
|
|
|
|
- if len(chains) == 0 {
|
|
|
|
|
- return 0, nil
|
|
|
|
|
- }
|
|
|
|
|
|
|
+func NewVIO() *VirtIOPacket {
|
|
|
|
|
+ out := new(VirtIOPacket)
|
|
|
|
|
+ out.payload = nil
|
|
|
|
|
+ out.Chains = make([]uint16, 0, 8)
|
|
|
|
|
+ return out
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
- numPackets := 0
|
|
|
|
|
- chainsIdx := 0
|
|
|
|
|
- for numPackets = 0; chainsIdx < len(chains); numPackets++ {
|
|
|
|
|
- if numPackets >= len(out) {
|
|
|
|
|
- return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
|
|
|
|
|
- }
|
|
|
|
|
- numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return 0, err
|
|
|
|
|
- }
|
|
|
|
|
- chainsIdx += numChains
|
|
|
|
|
- }
|
|
|
|
|
|
|
+func (v *VirtIOPacket) Reset() {
|
|
|
|
|
+ v.payload = nil
|
|
|
|
|
+ v.Chains = v.Chains[:0]
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
- return numPackets, nil
|
|
|
|
|
|
|
+func (v *VirtIOPacket) GetPayload() []byte {
|
|
|
|
|
+ return v.payload
|
|
|
|
|
+}
|
|
|
|
|
+func (v *VirtIOPacket) SetPayload(x []byte) {
|
|
|
|
|
+ v.payload = x //todo?
|
|
|
}
|
|
}
|