tun_wintun_windows.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. package overlay
  2. import (
  3. "crypto"
  4. "fmt"
  5. "io"
  6. "net/netip"
  7. "sync/atomic"
  8. "unsafe"
  9. "github.com/gaissmai/bart"
  10. "github.com/sirupsen/logrus"
  11. "github.com/slackhq/nebula/config"
  12. "github.com/slackhq/nebula/util"
  13. "github.com/slackhq/nebula/wintun"
  14. "golang.org/x/sys/windows"
  15. "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
  16. )
  17. const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
  18. type winTun struct {
  19. Device string
  20. cidr netip.Prefix
  21. MTU int
  22. Routes atomic.Pointer[[]Route]
  23. routeTree atomic.Pointer[bart.Table[netip.Addr]]
  24. l *logrus.Logger
  25. tun *wintun.NativeTun
  26. }
  27. func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
  28. // GUID is 128 bit
  29. hash := crypto.MD5.New()
  30. _, err := hash.Write([]byte(tunGUIDLabel))
  31. if err != nil {
  32. return nil, err
  33. }
  34. _, err = hash.Write([]byte(name))
  35. if err != nil {
  36. return nil, err
  37. }
  38. sum := hash.Sum(nil)
  39. return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
  40. }
  41. func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) {
  42. deviceName := c.GetString("tun.dev", "")
  43. guid, err := generateGUIDByDeviceName(deviceName)
  44. if err != nil {
  45. return nil, fmt.Errorf("generate GUID failed: %w", err)
  46. }
  47. t := &winTun{
  48. Device: deviceName,
  49. cidr: cidr,
  50. MTU: c.GetInt("tun.mtu", DefaultMTU),
  51. l: l,
  52. }
  53. err = t.reload(c, true)
  54. if err != nil {
  55. return nil, err
  56. }
  57. var tunDevice wintun.Device
  58. tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
  59. if err != nil {
  60. // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
  61. // Trying a second time resolves the issue.
  62. l.WithError(err).Debug("Failed to create wintun device, retrying")
  63. tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
  64. if err != nil {
  65. return nil, fmt.Errorf("create TUN device failed: %w", err)
  66. }
  67. }
  68. t.tun = tunDevice.(*wintun.NativeTun)
  69. c.RegisterReloadCallback(func(c *config.C) {
  70. err := t.reload(c, false)
  71. if err != nil {
  72. util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
  73. }
  74. })
  75. return t, nil
  76. }
  77. func (t *winTun) reload(c *config.C, initial bool) error {
  78. change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
  79. if err != nil {
  80. return err
  81. }
  82. if !initial && !change {
  83. return nil
  84. }
  85. routeTree, err := makeRouteTree(t.l, routes, false)
  86. if err != nil {
  87. return err
  88. }
  89. // Teach nebula how to handle the routes before establishing them in the system table
  90. oldRoutes := t.Routes.Swap(&routes)
  91. t.routeTree.Store(routeTree)
  92. if !initial {
  93. // Remove first, if the system removes a wanted route hopefully it will be re-added next
  94. err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
  95. if err != nil {
  96. util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
  97. }
  98. // Ensure any routes we actually want are installed
  99. err = t.addRoutes(true)
  100. if err != nil {
  101. // Catch any stray logs
  102. util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
  103. }
  104. }
  105. return nil
  106. }
  107. func (t *winTun) Activate() error {
  108. luid := winipcfg.LUID(t.tun.LUID())
  109. err := luid.SetIPAddresses([]netip.Prefix{t.cidr})
  110. if err != nil {
  111. return fmt.Errorf("failed to set address: %w", err)
  112. }
  113. err = t.addRoutes(false)
  114. if err != nil {
  115. return err
  116. }
  117. return nil
  118. }
  119. func (t *winTun) addRoutes(logErrors bool) error {
  120. luid := winipcfg.LUID(t.tun.LUID())
  121. routes := *t.Routes.Load()
  122. foundDefault4 := false
  123. for _, r := range routes {
  124. if !r.Via.IsValid() || !r.Install {
  125. // We don't allow route MTUs so only install routes with a via
  126. continue
  127. }
  128. // Add our unsafe route
  129. err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
  130. if err != nil {
  131. retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
  132. if logErrors {
  133. retErr.Log(t.l)
  134. continue
  135. } else {
  136. return retErr
  137. }
  138. } else {
  139. t.l.WithField("route", r).Info("Added route")
  140. }
  141. if !foundDefault4 {
  142. if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
  143. foundDefault4 = true
  144. }
  145. }
  146. }
  147. ipif, err := luid.IPInterface(windows.AF_INET)
  148. if err != nil {
  149. return fmt.Errorf("failed to get ip interface: %w", err)
  150. }
  151. ipif.NLMTU = uint32(t.MTU)
  152. if foundDefault4 {
  153. ipif.UseAutomaticMetric = false
  154. ipif.Metric = 0
  155. }
  156. if err := ipif.Set(); err != nil {
  157. return fmt.Errorf("failed to set ip interface: %w", err)
  158. }
  159. return nil
  160. }
  161. func (t *winTun) removeRoutes(routes []Route) error {
  162. luid := winipcfg.LUID(t.tun.LUID())
  163. for _, r := range routes {
  164. if !r.Install {
  165. continue
  166. }
  167. err := luid.DeleteRoute(r.Cidr, r.Via)
  168. if err != nil {
  169. t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
  170. } else {
  171. t.l.WithField("route", r).Info("Removed route")
  172. }
  173. }
  174. return nil
  175. }
  176. func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
  177. r, _ := t.routeTree.Load().Lookup(ip)
  178. return r
  179. }
  180. func (t *winTun) Cidr() netip.Prefix {
  181. return t.cidr
  182. }
  183. func (t *winTun) Name() string {
  184. return t.Device
  185. }
  186. func (t *winTun) Read(b []byte) (int, error) {
  187. return t.tun.Read(b, 0)
  188. }
  189. func (t *winTun) Write(b []byte) (int, error) {
  190. return t.tun.Write(b, 0)
  191. }
  192. func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
  193. return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
  194. }
  195. func (t *winTun) Close() error {
  196. // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
  197. // so to be certain, just remove everything before destroying.
  198. luid := winipcfg.LUID(t.tun.LUID())
  199. _ = luid.FlushRoutes(windows.AF_INET)
  200. _ = luid.FlushIPAddresses(windows.AF_INET)
  201. /* We don't support IPV6 yet
  202. _ = luid.FlushRoutes(windows.AF_INET6)
  203. _ = luid.FlushIPAddresses(windows.AF_INET6)
  204. */
  205. _ = luid.FlushDNS(windows.AF_INET)
  206. return t.tun.Close()
  207. }