| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664 |
- //go:build linux
- // SPDX-License-Identifier: MIT
- //
- // Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
- package tun
- /* Implementation of the TUN device interface for linux
- */
- import (
- "errors"
- "fmt"
- "os"
- "sync"
- "syscall"
- "time"
- "unsafe"
- wgconn "github.com/slackhq/nebula/wgstack/conn"
- "golang.org/x/sys/unix"
- "golang.zx2c4.com/wireguard/rwcancel"
- )
- const (
- cloneDevicePath = "/dev/net/tun"
- ifReqSize = unix.IFNAMSIZ + 64
- )
- type NativeTun struct {
- tunFile *os.File
- index int32 // if index
- errors chan error // async error handling
- events chan Event // device related events
- netlinkSock int
- netlinkCancel *rwcancel.RWCancel
- hackListenerClosed sync.Mutex
- statusListenersShutdown chan struct{}
- batchSize int
- vnetHdr bool
- closeOnce sync.Once
- nameOnce sync.Once // guards calling initNameCache, which sets following fields
- nameCache string // name of interface
- nameErr error
- readOpMu sync.Mutex // readOpMu guards readBuff
- readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
- writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
- toWrite []int
- tcp4GROTable, tcp6GROTable *tcpGROTable
- }
- func (tun *NativeTun) File() *os.File {
- return tun.tunFile
- }
- func (tun *NativeTun) routineHackListener() {
- defer tun.hackListenerClosed.Unlock()
- /* This is needed for the detection to work across network namespaces
- * If you are reading this and know a better method, please get in touch.
- */
- last := 0
- const (
- up = 1
- down = 2
- )
- for {
- sysconn, err := tun.tunFile.SyscallConn()
- if err != nil {
- return
- }
- err2 := sysconn.Control(func(fd uintptr) {
- _, err = unix.Write(int(fd), nil)
- })
- if err2 != nil {
- return
- }
- switch err {
- case unix.EINVAL:
- if last != up {
- // If the tunnel is up, it reports that write() is
- // allowed but we provided invalid data.
- tun.events <- EventUp
- last = up
- }
- case unix.EIO:
- if last != down {
- // If the tunnel is down, it reports that no I/O
- // is possible, without checking our provided data.
- tun.events <- EventDown
- last = down
- }
- default:
- return
- }
- select {
- case <-time.After(time.Second):
- // nothing
- case <-tun.statusListenersShutdown:
- return
- }
- }
- }
- func createNetlinkSocket() (int, error) {
- sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
- if err != nil {
- return -1, err
- }
- saddr := &unix.SockaddrNetlink{
- Family: unix.AF_NETLINK,
- Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
- }
- err = unix.Bind(sock, saddr)
- if err != nil {
- return -1, err
- }
- return sock, nil
- }
- func (tun *NativeTun) routineNetlinkListener() {
- defer func() {
- unix.Close(tun.netlinkSock)
- tun.hackListenerClosed.Lock()
- close(tun.events)
- tun.netlinkCancel.Close()
- }()
- for msg := make([]byte, 1<<16); ; {
- var err error
- var msgn int
- for {
- msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
- if err == nil || !rwcancel.RetryAfterError(err) {
- break
- }
- if !tun.netlinkCancel.ReadyRead() {
- tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
- return
- }
- }
- if err != nil {
- tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
- return
- }
- select {
- case <-tun.statusListenersShutdown:
- return
- default:
- }
- wasEverUp := false
- for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
- hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
- if int(hdr.Len) > len(remain) {
- break
- }
- switch hdr.Type {
- case unix.NLMSG_DONE:
- remain = []byte{}
- case unix.RTM_NEWLINK:
- info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
- remain = remain[hdr.Len:]
- if info.Index != tun.index {
- // not our interface
- continue
- }
- if info.Flags&unix.IFF_RUNNING != 0 {
- tun.events <- EventUp
- wasEverUp = true
- }
- if info.Flags&unix.IFF_RUNNING == 0 {
- // Don't emit EventDown before we've ever emitted EventUp.
- // This avoids a startup race with HackListener, which
- // might detect Up before we have finished reporting Down.
- if wasEverUp {
- tun.events <- EventDown
- }
- }
- tun.events <- EventMTUUpdate
- default:
- remain = remain[hdr.Len:]
- }
- }
- }
- }
- func getIFIndex(name string) (int32, error) {
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
- 0,
- )
- if err != nil {
- return 0, err
- }
- defer unix.Close(fd)
- var ifr [ifReqSize]byte
- copy(ifr[:], name)
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(unix.SIOCGIFINDEX),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- return 0, errno
- }
- return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil
- }
- func (tun *NativeTun) setMTU(n int) error {
- name, err := tun.Name()
- if err != nil {
- return err
- }
- // open datagram socket
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
- 0,
- )
- if err != nil {
- return err
- }
- defer unix.Close(fd)
- var ifr [ifReqSize]byte
- copy(ifr[:], name)
- *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(unix.SIOCSIFMTU),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- return errno
- }
- return nil
- }
- func (tun *NativeTun) routineNetlinkRead() {
- defer func() {
- unix.Close(tun.netlinkSock)
- tun.hackListenerClosed.Lock()
- close(tun.events)
- tun.netlinkCancel.Close()
- }()
- for msg := make([]byte, 1<<16); ; {
- var err error
- var msgn int
- for {
- msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
- if err == nil || !rwcancel.RetryAfterError(err) {
- break
- }
- if !tun.netlinkCancel.ReadyRead() {
- tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
- return
- }
- }
- if err != nil {
- tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
- return
- }
- wasEverUp := false
- for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
- hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
- if int(hdr.Len) > len(remain) {
- break
- }
- switch hdr.Type {
- case unix.NLMSG_DONE:
- remain = []byte{}
- case unix.RTM_NEWLINK:
- info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
- remain = remain[hdr.Len:]
- if info.Index != tun.index {
- continue
- }
- if info.Flags&unix.IFF_RUNNING != 0 {
- tun.events <- EventUp
- wasEverUp = true
- }
- if info.Flags&unix.IFF_RUNNING == 0 {
- if wasEverUp {
- tun.events <- EventDown
- }
- }
- tun.events <- EventMTUUpdate
- default:
- remain = remain[hdr.Len:]
- }
- }
- }
- }
- func (tun *NativeTun) routineNetlink() {
- var err error
- tun.netlinkSock, err = createNetlinkSocket()
- if err != nil {
- tun.errors <- fmt.Errorf("failed to create netlink socket: %w", err)
- return
- }
- tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
- if err != nil {
- tun.errors <- fmt.Errorf("failed to create netlink cancel: %w", err)
- return
- }
- go tun.routineNetlinkListener()
- }
- func (tun *NativeTun) Close() error {
- var err1, err2 error
- tun.closeOnce.Do(func() {
- if tun.statusListenersShutdown != nil {
- close(tun.statusListenersShutdown)
- if tun.netlinkCancel != nil {
- err1 = tun.netlinkCancel.Cancel()
- }
- } else if tun.events != nil {
- close(tun.events)
- }
- err2 = tun.tunFile.Close()
- })
- if err1 != nil {
- return err1
- }
- return err2
- }
- func (tun *NativeTun) BatchSize() int {
- return tun.batchSize
- }
- const (
- // TODO: support TSO with ECN bits
- tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
- )
- func (tun *NativeTun) initFromFlags(name string) error {
- sc, err := tun.tunFile.SyscallConn()
- if err != nil {
- return err
- }
- if e := sc.Control(func(fd uintptr) {
- var (
- ifr *unix.Ifreq
- )
- ifr, err = unix.NewIfreq(name)
- if err != nil {
- return
- }
- err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
- if err != nil {
- return
- }
- got := ifr.Uint16()
- if got&unix.IFF_VNET_HDR != 0 {
- err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
- if err != nil {
- return
- }
- tun.vnetHdr = true
- tun.batchSize = wgconn.IdealBatchSize
- } else {
- tun.batchSize = 1
- }
- }); e != nil {
- return e
- }
- return err
- }
- // CreateTUN creates a Device with the provided name and MTU.
- func CreateTUN(name string, mtu int) (Device, error) {
- nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
- if err != nil {
- return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
- }
- fd := os.NewFile(uintptr(nfd), cloneDevicePath)
- tun, err := CreateTUNFromFile(fd, mtu)
- if err != nil {
- return nil, err
- }
- if name != "tun" {
- if err := tun.(*NativeTun).initFromFlags(name); err != nil {
- tun.Close()
- return nil, fmt.Errorf("CreateTUN(%q) failed to set flags: %w", name, err)
- }
- }
- return tun, nil
- }
- // CreateTUNFromFile creates a Device from an os.File with the provided MTU.
- func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
- tun := &NativeTun{
- tunFile: file,
- errors: make(chan error, 5),
- events: make(chan Event, 5),
- }
- name, err := tun.Name()
- if err != nil {
- return nil, fmt.Errorf("failed to determine TUN name: %w", err)
- }
- if err := tun.initFromFlags(name); err != nil {
- return nil, fmt.Errorf("failed to query TUN flags: %w", err)
- }
- if tun.batchSize == 0 {
- tun.batchSize = 1
- }
- tun.index, err = getIFIndex(name)
- if err != nil {
- return nil, fmt.Errorf("failed to get TUN index: %w", err)
- }
- if err = tun.setMTU(mtu); err != nil {
- return nil, fmt.Errorf("failed to set MTU: %w", err)
- }
- tun.statusListenersShutdown = make(chan struct{})
- go tun.routineNetlink()
- if tun.batchSize == 0 {
- tun.batchSize = 1
- }
- tun.tcp4GROTable = newTCPGROTable()
- tun.tcp6GROTable = newTCPGROTable()
- return tun, nil
- }
- func (tun *NativeTun) Name() (string, error) {
- tun.nameOnce.Do(tun.initNameCache)
- return tun.nameCache, tun.nameErr
- }
- func (tun *NativeTun) initNameCache() {
- sysconn, err := tun.tunFile.SyscallConn()
- if err != nil {
- tun.nameErr = err
- return
- }
- err = sysconn.Control(func(fd uintptr) {
- var ifr [ifReqSize]byte
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- fd,
- uintptr(unix.TUNGETIFF),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- tun.nameErr = errno
- return
- }
- tun.nameCache = unix.ByteSliceToString(ifr[:])
- })
- if err != nil && tun.nameErr == nil {
- tun.nameErr = err
- }
- }
- func (tun *NativeTun) MTU() (int, error) {
- name, err := tun.Name()
- if err != nil {
- return 0, err
- }
- // open datagram socket
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
- 0,
- )
- if err != nil {
- return 0, err
- }
- defer unix.Close(fd)
- var ifr [ifReqSize]byte
- copy(ifr[:], name)
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(unix.SIOCGIFMTU),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- return 0, errno
- }
- return int(*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
- }
- func (tun *NativeTun) Events() <-chan Event {
- return tun.events
- }
- func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
- tun.writeOpMu.Lock()
- defer func() {
- tun.tcp4GROTable.reset()
- tun.tcp6GROTable.reset()
- tun.writeOpMu.Unlock()
- }()
- var (
- errs error
- total int
- )
- tun.toWrite = tun.toWrite[:0]
- if tun.vnetHdr {
- err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
- if err != nil {
- return 0, err
- }
- offset -= virtioNetHdrLen
- } else {
- for i := range bufs {
- tun.toWrite = append(tun.toWrite, i)
- }
- }
- for _, bufsI := range tun.toWrite {
- n, err := tun.tunFile.Write(bufs[bufsI][offset:])
- if errors.Is(err, syscall.EBADFD) {
- return total, os.ErrClosed
- }
- if err != nil {
- errs = errors.Join(errs, err)
- } else {
- total += n
- }
- }
- return total, errs
- }
- // handleVirtioRead splits in into bufs, leaving offset bytes at the front of
- // each buffer. It mutates sizes to reflect the size of each element of bufs,
- // and returns the number of packets read.
- func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
- var hdr virtioNetHdr
- if err := hdr.decode(in); err != nil {
- return 0, err
- }
- in = in[virtioNetHdrLen:]
- if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
- if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
- if err := gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset); err != nil {
- return 0, err
- }
- }
- if len(in) > len(bufs[0][offset:]) {
- return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
- }
- n := copy(bufs[0][offset:], in)
- sizes[0] = n
- return 1, nil
- }
- if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
- return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
- }
- ipVersion := in[0] >> 4
- switch ipVersion {
- case 4:
- if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
- return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
- }
- case 6:
- if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
- return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
- }
- default:
- return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
- }
- if len(in) <= int(hdr.csumStart+12) {
- return 0, errors.New("packet is too short")
- }
- tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
- if tcpHLen < 20 || tcpHLen > 60 {
- return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
- }
- hdr.hdrLen = hdr.csumStart + tcpHLen
- if len(in) < int(hdr.hdrLen) {
- return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
- }
- if hdr.hdrLen < hdr.csumStart {
- return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
- }
- cSumAt := int(hdr.csumStart + hdr.csumOffset)
- if cSumAt+1 >= len(in) {
- return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
- }
- return tcpTSO(in, hdr, bufs, sizes, offset)
- }
- func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
- tun.readOpMu.Lock()
- defer tun.readOpMu.Unlock()
- select {
- case err := <-tun.errors:
- return 0, err
- default:
- readInto := bufs[0][offset:]
- if tun.vnetHdr {
- readInto = tun.readBuff[:]
- }
- n, err := tun.tunFile.Read(readInto)
- if errors.Is(err, syscall.EBADFD) {
- err = os.ErrClosed
- }
- if err != nil {
- return 0, err
- }
- if tun.vnetHdr {
- return handleVirtioRead(readInto[:n], bufs, sizes, offset)
- }
- sizes[0] = n
- return 1, nil
- }
- }
|