tun_windows.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. //go:build !e2e_testing
  2. // +build !e2e_testing
  3. package overlay
  4. import (
  5. "crypto"
  6. "fmt"
  7. "io"
  8. "net/netip"
  9. "os"
  10. "path/filepath"
  11. "runtime"
  12. "sync/atomic"
  13. "syscall"
  14. "unsafe"
  15. "github.com/gaissmai/bart"
  16. "github.com/sirupsen/logrus"
  17. "github.com/slackhq/nebula/config"
  18. "github.com/slackhq/nebula/util"
  19. "github.com/slackhq/nebula/wintun"
  20. "golang.org/x/sys/windows"
  21. "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
  22. )
  23. const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
  24. type winTun struct {
  25. Device string
  26. vpnNetworks []netip.Prefix
  27. MTU int
  28. Routes atomic.Pointer[[]Route]
  29. routeTree atomic.Pointer[bart.Table[netip.Addr]]
  30. l *logrus.Logger
  31. tun *wintun.NativeTun
  32. }
  33. func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
  34. return nil, fmt.Errorf("newTunFromFd not supported in Windows")
  35. }
  36. func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
  37. err := checkWinTunExists()
  38. if err != nil {
  39. return nil, fmt.Errorf("can not load the wintun driver: %w", err)
  40. }
  41. deviceName := c.GetString("tun.dev", "")
  42. guid, err := generateGUIDByDeviceName(deviceName)
  43. if err != nil {
  44. return nil, fmt.Errorf("generate GUID failed: %w", err)
  45. }
  46. t := &winTun{
  47. Device: deviceName,
  48. vpnNetworks: vpnNetworks,
  49. MTU: c.GetInt("tun.mtu", DefaultMTU),
  50. l: l,
  51. }
  52. err = t.reload(c, true)
  53. if err != nil {
  54. return nil, err
  55. }
  56. var tunDevice wintun.Device
  57. tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
  58. if err != nil {
  59. // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
  60. // Trying a second time resolves the issue.
  61. l.WithError(err).Debug("Failed to create wintun device, retrying")
  62. tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
  63. if err != nil {
  64. return nil, fmt.Errorf("create TUN device failed: %w", err)
  65. }
  66. }
  67. t.tun = tunDevice.(*wintun.NativeTun)
  68. c.RegisterReloadCallback(func(c *config.C) {
  69. err := t.reload(c, false)
  70. if err != nil {
  71. util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
  72. }
  73. })
  74. return t, nil
  75. }
  76. func (t *winTun) reload(c *config.C, initial bool) error {
  77. change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
  78. if err != nil {
  79. return err
  80. }
  81. if !initial && !change {
  82. return nil
  83. }
  84. routeTree, err := makeRouteTree(t.l, routes, false)
  85. if err != nil {
  86. return err
  87. }
  88. // Teach nebula how to handle the routes before establishing them in the system table
  89. oldRoutes := t.Routes.Swap(&routes)
  90. t.routeTree.Store(routeTree)
  91. if !initial {
  92. // Remove first, if the system removes a wanted route hopefully it will be re-added next
  93. err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
  94. if err != nil {
  95. util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
  96. }
  97. // Ensure any routes we actually want are installed
  98. err = t.addRoutes(true)
  99. if err != nil {
  100. // Catch any stray logs
  101. util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
  102. }
  103. }
  104. return nil
  105. }
  106. func (t *winTun) Activate() error {
  107. luid := winipcfg.LUID(t.tun.LUID())
  108. err := luid.SetIPAddresses(t.vpnNetworks)
  109. if err != nil {
  110. return fmt.Errorf("failed to set address: %w", err)
  111. }
  112. err = t.addRoutes(false)
  113. if err != nil {
  114. return err
  115. }
  116. return nil
  117. }
  118. func (t *winTun) addRoutes(logErrors bool) error {
  119. luid := winipcfg.LUID(t.tun.LUID())
  120. routes := *t.Routes.Load()
  121. foundDefault4 := false
  122. for _, r := range routes {
  123. if !r.Via.IsValid() || !r.Install {
  124. // We don't allow route MTUs so only install routes with a via
  125. continue
  126. }
  127. // Add our unsafe route
  128. err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
  129. if err != nil {
  130. retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
  131. if logErrors {
  132. retErr.Log(t.l)
  133. continue
  134. } else {
  135. return retErr
  136. }
  137. } else {
  138. t.l.WithField("route", r).Info("Added route")
  139. }
  140. if !foundDefault4 {
  141. if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
  142. foundDefault4 = true
  143. }
  144. }
  145. }
  146. ipif, err := luid.IPInterface(windows.AF_INET)
  147. if err != nil {
  148. return fmt.Errorf("failed to get ip interface: %w", err)
  149. }
  150. ipif.NLMTU = uint32(t.MTU)
  151. if foundDefault4 {
  152. ipif.UseAutomaticMetric = false
  153. ipif.Metric = 0
  154. }
  155. if err := ipif.Set(); err != nil {
  156. return fmt.Errorf("failed to set ip interface: %w", err)
  157. }
  158. return nil
  159. }
  160. func (t *winTun) removeRoutes(routes []Route) error {
  161. luid := winipcfg.LUID(t.tun.LUID())
  162. for _, r := range routes {
  163. if !r.Install {
  164. continue
  165. }
  166. err := luid.DeleteRoute(r.Cidr, r.Via)
  167. if err != nil {
  168. t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
  169. } else {
  170. t.l.WithField("route", r).Info("Removed route")
  171. }
  172. }
  173. return nil
  174. }
  175. func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
  176. r, _ := t.routeTree.Load().Lookup(ip)
  177. return r
  178. }
  179. func (t *winTun) Networks() []netip.Prefix {
  180. return t.vpnNetworks
  181. }
  182. func (t *winTun) Name() string {
  183. return t.Device
  184. }
  185. func (t *winTun) Read(b []byte) (int, error) {
  186. return t.tun.Read(b, 0)
  187. }
  188. func (t *winTun) Write(b []byte) (int, error) {
  189. return t.tun.Write(b, 0)
  190. }
  191. func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
  192. return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
  193. }
  194. func (t *winTun) Close() error {
  195. // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
  196. // so to be certain, just remove everything before destroying.
  197. luid := winipcfg.LUID(t.tun.LUID())
  198. _ = luid.FlushRoutes(windows.AF_INET)
  199. _ = luid.FlushIPAddresses(windows.AF_INET)
  200. _ = luid.FlushRoutes(windows.AF_INET6)
  201. _ = luid.FlushIPAddresses(windows.AF_INET6)
  202. _ = luid.FlushDNS(windows.AF_INET)
  203. _ = luid.FlushDNS(windows.AF_INET6)
  204. return t.tun.Close()
  205. }
  206. func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
  207. // GUID is 128 bit
  208. hash := crypto.MD5.New()
  209. _, err := hash.Write([]byte(tunGUIDLabel))
  210. if err != nil {
  211. return nil, err
  212. }
  213. _, err = hash.Write([]byte(name))
  214. if err != nil {
  215. return nil, err
  216. }
  217. sum := hash.Sum(nil)
  218. return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
  219. }
  220. func checkWinTunExists() error {
  221. myPath, err := os.Executable()
  222. if err != nil {
  223. return err
  224. }
  225. arch := runtime.GOARCH
  226. switch arch {
  227. case "386":
  228. //NOTE: wintun bundles 386 as x86
  229. arch = "x86"
  230. }
  231. _, err = syscall.LoadDLL(filepath.Join(filepath.Dir(myPath), "dist", "windows", "wintun", "bin", arch, "wintun.dll"))
  232. return err
  233. }