tun_linux.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. //go:build !android && !e2e_testing
  2. // +build !android,!e2e_testing
  3. package overlay
  4. import (
  5. "fmt"
  6. "io"
  7. "net"
  8. "net/netip"
  9. "os"
  10. "strings"
  11. "sync/atomic"
  12. "time"
  13. "unsafe"
  14. "github.com/gaissmai/bart"
  15. "github.com/sirupsen/logrus"
  16. "github.com/slackhq/nebula/config"
  17. "github.com/slackhq/nebula/util"
  18. "github.com/vishvananda/netlink"
  19. "golang.org/x/sys/unix"
  20. )
  21. type tun struct {
  22. io.ReadWriteCloser
  23. fd int
  24. Device string
  25. vpnNetworks []netip.Prefix
  26. MaxMTU int
  27. DefaultMTU int
  28. TXQueueLen int
  29. deviceIndex int
  30. ioctlFd uintptr
  31. Routes atomic.Pointer[[]Route]
  32. routeTree atomic.Pointer[bart.Table[netip.Addr]]
  33. routeChan chan struct{}
  34. useSystemRoutes bool
  35. l *logrus.Logger
  36. }
  37. func (t *tun) Networks() []netip.Prefix {
  38. return t.vpnNetworks
  39. }
  40. type ifReq struct {
  41. Name [16]byte
  42. Flags uint16
  43. pad [8]byte
  44. }
  45. type ifreqMTU struct {
  46. Name [16]byte
  47. MTU int32
  48. pad [8]byte
  49. }
  50. type ifreqQLEN struct {
  51. Name [16]byte
  52. Value int32
  53. pad [8]byte
  54. }
  55. func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
  56. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
  57. t, err := newTunGeneric(c, l, file, vpnNetworks)
  58. if err != nil {
  59. return nil, err
  60. }
  61. t.Device = "tun0"
  62. return t, nil
  63. }
  64. func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
  65. fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
  66. if err != nil {
  67. // If /dev/net/tun doesn't exist, try to create it (will happen in docker)
  68. if os.IsNotExist(err) {
  69. err = os.MkdirAll("/dev/net", 0755)
  70. if err != nil {
  71. return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
  72. }
  73. err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
  74. if err != nil {
  75. return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
  76. }
  77. fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
  78. if err != nil {
  79. return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
  80. }
  81. } else {
  82. return nil, err
  83. }
  84. }
  85. var req ifReq
  86. req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
  87. if multiqueue {
  88. req.Flags |= unix.IFF_MULTI_QUEUE
  89. }
  90. copy(req.Name[:], c.GetString("tun.dev", ""))
  91. if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
  92. return nil, err
  93. }
  94. name := strings.Trim(string(req.Name[:]), "\x00")
  95. file := os.NewFile(uintptr(fd), "/dev/net/tun")
  96. t, err := newTunGeneric(c, l, file, vpnNetworks)
  97. if err != nil {
  98. return nil, err
  99. }
  100. t.Device = name
  101. return t, nil
  102. }
  103. func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
  104. t := &tun{
  105. ReadWriteCloser: file,
  106. fd: int(file.Fd()),
  107. vpnNetworks: vpnNetworks,
  108. TXQueueLen: c.GetInt("tun.tx_queue", 500),
  109. useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
  110. l: l,
  111. }
  112. err := t.reload(c, true)
  113. if err != nil {
  114. return nil, err
  115. }
  116. c.RegisterReloadCallback(func(c *config.C) {
  117. err := t.reload(c, false)
  118. if err != nil {
  119. util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
  120. }
  121. })
  122. return t, nil
  123. }
  124. func (t *tun) reload(c *config.C, initial bool) error {
  125. routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
  126. if err != nil {
  127. return err
  128. }
  129. if !initial && !routeChange && !c.HasChanged("tun.mtu") {
  130. return nil
  131. }
  132. routeTree, err := makeRouteTree(t.l, routes, true)
  133. if err != nil {
  134. return err
  135. }
  136. oldDefaultMTU := t.DefaultMTU
  137. oldMaxMTU := t.MaxMTU
  138. newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
  139. newMaxMTU := newDefaultMTU
  140. for i, r := range routes {
  141. if r.MTU == 0 {
  142. routes[i].MTU = newDefaultMTU
  143. }
  144. if r.MTU > t.MaxMTU {
  145. newMaxMTU = r.MTU
  146. }
  147. }
  148. t.MaxMTU = newMaxMTU
  149. t.DefaultMTU = newDefaultMTU
  150. // Teach nebula how to handle the routes before establishing them in the system table
  151. oldRoutes := t.Routes.Swap(&routes)
  152. t.routeTree.Store(routeTree)
  153. if !initial {
  154. if oldMaxMTU != newMaxMTU {
  155. t.setMTU()
  156. t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
  157. }
  158. if oldDefaultMTU != newDefaultMTU {
  159. for i := range t.vpnNetworks {
  160. err := t.setDefaultRoute(t.vpnNetworks[i])
  161. if err != nil {
  162. t.l.Warn(err)
  163. } else {
  164. t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
  165. }
  166. }
  167. }
  168. // Remove first, if the system removes a wanted route hopefully it will be re-added next
  169. t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
  170. // Ensure any routes we actually want are installed
  171. err = t.addRoutes(true)
  172. if err != nil {
  173. // This should never be called since addRoutes should log its own errors in a reload condition
  174. util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
  175. }
  176. }
  177. return nil
  178. }
  179. func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
  180. fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
  181. if err != nil {
  182. return nil, err
  183. }
  184. var req ifReq
  185. req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
  186. copy(req.Name[:], t.Device)
  187. if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
  188. return nil, err
  189. }
  190. file := os.NewFile(uintptr(fd), "/dev/net/tun")
  191. return file, nil
  192. }
  193. func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
  194. r, _ := t.routeTree.Load().Lookup(ip)
  195. return r
  196. }
  197. func (t *tun) Write(b []byte) (int, error) {
  198. var nn int
  199. maximum := len(b)
  200. for {
  201. n, err := unix.Write(t.fd, b[nn:maximum])
  202. if n > 0 {
  203. nn += n
  204. }
  205. if nn == len(b) {
  206. return nn, err
  207. }
  208. if err != nil {
  209. return nn, err
  210. }
  211. if n == 0 {
  212. return nn, io.ErrUnexpectedEOF
  213. }
  214. }
  215. }
  216. func (t *tun) deviceBytes() (o [16]byte) {
  217. for i, c := range t.Device {
  218. o[i] = byte(c)
  219. }
  220. return
  221. }
  222. func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
  223. for i := range al {
  224. if al[i].Equal(x) {
  225. return true
  226. }
  227. }
  228. return false
  229. }
  230. // addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
  231. func (t *tun) addIPs(link netlink.Link) error {
  232. newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
  233. for i := range t.vpnNetworks {
  234. newAddrs[i] = &netlink.Addr{
  235. IPNet: &net.IPNet{
  236. IP: t.vpnNetworks[i].Addr().AsSlice(),
  237. Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
  238. },
  239. Label: t.vpnNetworks[i].Addr().Zone(),
  240. }
  241. }
  242. //add all new addresses
  243. for i := range newAddrs {
  244. //TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
  245. //AddrReplace still adds new IPs, but if their properties change it will change them as well
  246. if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
  247. return err
  248. }
  249. }
  250. //iterate over remainder, remove whoever shouldn't be there
  251. al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
  252. if err != nil {
  253. return fmt.Errorf("failed to get tun address list: %s", err)
  254. }
  255. for i := range al {
  256. if hasNetlinkAddr(newAddrs, al[i]) {
  257. continue
  258. }
  259. err = netlink.AddrDel(link, &al[i])
  260. if err != nil {
  261. t.l.WithError(err).Error("failed to remove address from tun address list")
  262. } else {
  263. t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
  264. }
  265. }
  266. return nil
  267. }
  268. func (t *tun) Activate() error {
  269. devName := t.deviceBytes()
  270. if t.useSystemRoutes {
  271. t.watchRoutes()
  272. }
  273. s, err := unix.Socket(
  274. unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
  275. unix.SOCK_DGRAM,
  276. unix.IPPROTO_IP,
  277. )
  278. if err != nil {
  279. return err
  280. }
  281. t.ioctlFd = uintptr(s)
  282. // Set the device name
  283. ifrf := ifReq{Name: devName}
  284. if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
  285. return fmt.Errorf("failed to set tun device name: %s", err)
  286. }
  287. link, err := netlink.LinkByName(t.Device)
  288. if err != nil {
  289. return fmt.Errorf("failed to get tun device link: %s", err)
  290. }
  291. t.deviceIndex = link.Attrs().Index
  292. // Setup our default MTU
  293. t.setMTU()
  294. // Set the transmit queue length
  295. ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
  296. if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
  297. // If we can't set the queue length nebula will still work but it may lead to packet loss
  298. t.l.WithError(err).Error("Failed to set tun tx queue length")
  299. }
  300. if err = t.addIPs(link); err != nil {
  301. return err
  302. }
  303. // Bring up the interface
  304. ifrf.Flags = ifrf.Flags | unix.IFF_UP
  305. if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
  306. return fmt.Errorf("failed to bring the tun device up: %s", err)
  307. }
  308. //set route MTU
  309. for i := range t.vpnNetworks {
  310. if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
  311. return fmt.Errorf("failed to set default route MTU: %w", err)
  312. }
  313. }
  314. // Set the routes
  315. if err = t.addRoutes(false); err != nil {
  316. return err
  317. }
  318. // Run the interface
  319. ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
  320. if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
  321. return fmt.Errorf("failed to run tun device: %s", err)
  322. }
  323. return nil
  324. }
  325. func (t *tun) setMTU() {
  326. // Set the MTU on the device
  327. ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
  328. if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
  329. // This is currently a non fatal condition because the route table must have the MTU set appropriately as well
  330. t.l.WithError(err).Error("Failed to set tun mtu")
  331. }
  332. }
  333. func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
  334. dr := &net.IPNet{
  335. IP: cidr.Masked().Addr().AsSlice(),
  336. Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
  337. }
  338. nr := netlink.Route{
  339. LinkIndex: t.deviceIndex,
  340. Dst: dr,
  341. MTU: t.DefaultMTU,
  342. AdvMSS: t.advMSS(Route{}),
  343. Scope: unix.RT_SCOPE_LINK,
  344. Src: net.IP(cidr.Addr().AsSlice()),
  345. Protocol: unix.RTPROT_KERNEL,
  346. Table: unix.RT_TABLE_MAIN,
  347. Type: unix.RTN_UNICAST,
  348. }
  349. err := netlink.RouteReplace(&nr)
  350. if err != nil {
  351. t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
  352. //retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
  353. for i := 0; i < 2; i++ {
  354. time.Sleep(100 * time.Millisecond)
  355. err = netlink.RouteReplace(&nr)
  356. if err == nil {
  357. break
  358. } else {
  359. t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying")
  360. }
  361. }
  362. if err != nil {
  363. return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
  364. }
  365. }
  366. return nil
  367. }
  368. func (t *tun) addRoutes(logErrors bool) error {
  369. // Path routes
  370. routes := *t.Routes.Load()
  371. for _, r := range routes {
  372. if !r.Install {
  373. continue
  374. }
  375. dr := &net.IPNet{
  376. IP: r.Cidr.Masked().Addr().AsSlice(),
  377. Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
  378. }
  379. nr := netlink.Route{
  380. LinkIndex: t.deviceIndex,
  381. Dst: dr,
  382. MTU: r.MTU,
  383. AdvMSS: t.advMSS(r),
  384. Scope: unix.RT_SCOPE_LINK,
  385. }
  386. if r.Metric > 0 {
  387. nr.Priority = r.Metric
  388. }
  389. err := netlink.RouteReplace(&nr)
  390. if err != nil {
  391. retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
  392. if logErrors {
  393. retErr.Log(t.l)
  394. } else {
  395. return retErr
  396. }
  397. } else {
  398. t.l.WithField("route", r).Info("Added route")
  399. }
  400. }
  401. return nil
  402. }
  403. func (t *tun) removeRoutes(routes []Route) {
  404. for _, r := range routes {
  405. if !r.Install {
  406. continue
  407. }
  408. dr := &net.IPNet{
  409. IP: r.Cidr.Masked().Addr().AsSlice(),
  410. Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
  411. }
  412. nr := netlink.Route{
  413. LinkIndex: t.deviceIndex,
  414. Dst: dr,
  415. MTU: r.MTU,
  416. AdvMSS: t.advMSS(r),
  417. Scope: unix.RT_SCOPE_LINK,
  418. }
  419. if r.Metric > 0 {
  420. nr.Priority = r.Metric
  421. }
  422. err := netlink.RouteDel(&nr)
  423. if err != nil {
  424. t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
  425. } else {
  426. t.l.WithField("route", r).Info("Removed route")
  427. }
  428. }
  429. }
  430. func (t *tun) Name() string {
  431. return t.Device
  432. }
  433. func (t *tun) advMSS(r Route) int {
  434. mtu := r.MTU
  435. if r.MTU == 0 {
  436. mtu = t.DefaultMTU
  437. }
  438. // We only need to set advmss if the route MTU does not match the device MTU
  439. if mtu != t.MaxMTU {
  440. return mtu - 40
  441. }
  442. return 0
  443. }
  444. func (t *tun) watchRoutes() {
  445. rch := make(chan netlink.RouteUpdate)
  446. doneChan := make(chan struct{})
  447. if err := netlink.RouteSubscribe(rch, doneChan); err != nil {
  448. t.l.WithError(err).Errorf("failed to subscribe to system route changes")
  449. return
  450. }
  451. t.routeChan = doneChan
  452. go func() {
  453. for {
  454. select {
  455. case r := <-rch:
  456. t.updateRoutes(r)
  457. case <-doneChan:
  458. // netlink.RouteSubscriber will close the rch for us
  459. return
  460. }
  461. }
  462. }()
  463. }
  464. func (t *tun) updateRoutes(r netlink.RouteUpdate) {
  465. if r.Gw == nil {
  466. // Not a gateway route, ignore
  467. t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
  468. return
  469. }
  470. gwAddr, ok := netip.AddrFromSlice(r.Gw)
  471. if !ok {
  472. t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
  473. return
  474. }
  475. gwAddr = gwAddr.Unmap()
  476. withinNetworks := false
  477. for i := range t.vpnNetworks {
  478. if t.vpnNetworks[i].Contains(gwAddr) {
  479. withinNetworks = true
  480. break
  481. }
  482. }
  483. if !withinNetworks {
  484. // Gateway isn't in our overlay network, ignore
  485. t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
  486. return
  487. }
  488. dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
  489. if !ok {
  490. t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
  491. return
  492. }
  493. ones, _ := r.Dst.Mask.Size()
  494. dst := netip.PrefixFrom(dstAddr, ones)
  495. newTree := t.routeTree.Load().Clone()
  496. if r.Type == unix.RTM_NEWROUTE {
  497. t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
  498. newTree.Insert(dst, gwAddr)
  499. } else {
  500. newTree.Delete(dst)
  501. t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
  502. }
  503. t.routeTree.Store(newTree)
  504. }
  505. func (t *tun) Close() error {
  506. if t.routeChan != nil {
  507. close(t.routeChan)
  508. }
  509. if t.ReadWriteCloser != nil {
  510. _ = t.ReadWriteCloser.Close()
  511. }
  512. if t.ioctlFd > 0 {
  513. _ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
  514. }
  515. return nil
  516. }