tun_wintun_windows.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. package overlay
  2. import (
  3. "crypto"
  4. "fmt"
  5. "io"
  6. "net"
  7. "net/netip"
  8. "sync/atomic"
  9. "unsafe"
  10. "github.com/sirupsen/logrus"
  11. "github.com/slackhq/nebula/cidr"
  12. "github.com/slackhq/nebula/config"
  13. "github.com/slackhq/nebula/iputil"
  14. "github.com/slackhq/nebula/util"
  15. "github.com/slackhq/nebula/wintun"
  16. "golang.org/x/sys/windows"
  17. "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
  18. )
  19. const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
  20. type winTun struct {
  21. Device string
  22. cidr *net.IPNet
  23. prefix netip.Prefix
  24. MTU int
  25. Routes atomic.Pointer[[]Route]
  26. routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
  27. l *logrus.Logger
  28. tun *wintun.NativeTun
  29. }
  30. func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
  31. // GUID is 128 bit
  32. hash := crypto.MD5.New()
  33. _, err := hash.Write([]byte(tunGUIDLabel))
  34. if err != nil {
  35. return nil, err
  36. }
  37. _, err = hash.Write([]byte(name))
  38. if err != nil {
  39. return nil, err
  40. }
  41. sum := hash.Sum(nil)
  42. return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
  43. }
  44. func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) {
  45. deviceName := c.GetString("tun.dev", "")
  46. guid, err := generateGUIDByDeviceName(deviceName)
  47. if err != nil {
  48. return nil, fmt.Errorf("generate GUID failed: %w", err)
  49. }
  50. prefix, err := iputil.ToNetIpPrefix(*cidr)
  51. if err != nil {
  52. return nil, err
  53. }
  54. t := &winTun{
  55. Device: deviceName,
  56. cidr: cidr,
  57. prefix: prefix,
  58. MTU: c.GetInt("tun.mtu", DefaultMTU),
  59. l: l,
  60. }
  61. err = t.reload(c, true)
  62. if err != nil {
  63. return nil, err
  64. }
  65. var tunDevice wintun.Device
  66. tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
  67. if err != nil {
  68. // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
  69. // Trying a second time resolves the issue.
  70. l.WithError(err).Debug("Failed to create wintun device, retrying")
  71. tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
  72. if err != nil {
  73. return nil, fmt.Errorf("create TUN device failed: %w", err)
  74. }
  75. }
  76. t.tun = tunDevice.(*wintun.NativeTun)
  77. c.RegisterReloadCallback(func(c *config.C) {
  78. err := t.reload(c, false)
  79. if err != nil {
  80. util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
  81. }
  82. })
  83. return t, nil
  84. }
  85. func (t *winTun) reload(c *config.C, initial bool) error {
  86. change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
  87. if err != nil {
  88. return err
  89. }
  90. if !initial && !change {
  91. return nil
  92. }
  93. routeTree, err := makeRouteTree(t.l, routes, false)
  94. if err != nil {
  95. return err
  96. }
  97. // Teach nebula how to handle the routes before establishing them in the system table
  98. oldRoutes := t.Routes.Swap(&routes)
  99. t.routeTree.Store(routeTree)
  100. if !initial {
  101. // Remove first, if the system removes a wanted route hopefully it will be re-added next
  102. err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
  103. if err != nil {
  104. util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
  105. }
  106. // Ensure any routes we actually want are installed
  107. err = t.addRoutes(true)
  108. if err != nil {
  109. // Catch any stray logs
  110. util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
  111. }
  112. }
  113. return nil
  114. }
  115. func (t *winTun) Activate() error {
  116. luid := winipcfg.LUID(t.tun.LUID())
  117. err := luid.SetIPAddresses([]netip.Prefix{t.prefix})
  118. if err != nil {
  119. return fmt.Errorf("failed to set address: %w", err)
  120. }
  121. err = t.addRoutes(false)
  122. if err != nil {
  123. return err
  124. }
  125. return nil
  126. }
  127. func (t *winTun) addRoutes(logErrors bool) error {
  128. luid := winipcfg.LUID(t.tun.LUID())
  129. routes := *t.Routes.Load()
  130. foundDefault4 := false
  131. for _, r := range routes {
  132. if r.Via == nil || !r.Install {
  133. // We don't allow route MTUs so only install routes with a via
  134. continue
  135. }
  136. prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
  137. if err != nil {
  138. retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err)
  139. if logErrors {
  140. retErr.Log(t.l)
  141. continue
  142. } else {
  143. return retErr
  144. }
  145. }
  146. // Add our unsafe route
  147. err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric))
  148. if err != nil {
  149. retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
  150. if logErrors {
  151. retErr.Log(t.l)
  152. continue
  153. } else {
  154. return retErr
  155. }
  156. } else {
  157. t.l.WithField("route", r).Info("Added route")
  158. }
  159. if !foundDefault4 {
  160. if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 {
  161. foundDefault4 = true
  162. }
  163. }
  164. }
  165. ipif, err := luid.IPInterface(windows.AF_INET)
  166. if err != nil {
  167. return fmt.Errorf("failed to get ip interface: %w", err)
  168. }
  169. ipif.NLMTU = uint32(t.MTU)
  170. if foundDefault4 {
  171. ipif.UseAutomaticMetric = false
  172. ipif.Metric = 0
  173. }
  174. if err := ipif.Set(); err != nil {
  175. return fmt.Errorf("failed to set ip interface: %w", err)
  176. }
  177. return nil
  178. }
  179. func (t *winTun) removeRoutes(routes []Route) error {
  180. luid := winipcfg.LUID(t.tun.LUID())
  181. for _, r := range routes {
  182. if !r.Install {
  183. continue
  184. }
  185. prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
  186. if err != nil {
  187. t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix")
  188. continue
  189. }
  190. err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr())
  191. if err != nil {
  192. t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
  193. } else {
  194. t.l.WithField("route", r).Info("Removed route")
  195. }
  196. }
  197. return nil
  198. }
  199. func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
  200. _, r := t.routeTree.Load().MostSpecificContains(ip)
  201. return r
  202. }
  203. func (t *winTun) Cidr() *net.IPNet {
  204. return t.cidr
  205. }
  206. func (t *winTun) Name() string {
  207. return t.Device
  208. }
  209. func (t *winTun) Read(b []byte) (int, error) {
  210. return t.tun.Read(b, 0)
  211. }
  212. func (t *winTun) Write(b []byte) (int, error) {
  213. return t.tun.Write(b, 0)
  214. }
  215. func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
  216. return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
  217. }
  218. func (t *winTun) Close() error {
  219. // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
  220. // so to be certain, just remove everything before destroying.
  221. luid := winipcfg.LUID(t.tun.LUID())
  222. _ = luid.FlushRoutes(windows.AF_INET)
  223. _ = luid.FlushIPAddresses(windows.AF_INET)
  224. /* We don't support IPV6 yet
  225. _ = luid.FlushRoutes(windows.AF_INET6)
  226. _ = luid.FlushIPAddresses(windows.AF_INET6)
  227. */
  228. _ = luid.FlushDNS(windows.AF_INET)
  229. return t.tun.Close()
  230. }