tun_wintun_windows.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package overlay
  2. import (
  3. "crypto"
  4. "fmt"
  5. "io"
  6. "net"
  7. "unsafe"
  8. "github.com/slackhq/nebula/wintun"
  9. "golang.org/x/sys/windows"
  10. "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
  11. )
  12. const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
  13. type winTun struct {
  14. Device string
  15. Cidr *net.IPNet
  16. MTU int
  17. UnsafeRoutes []Route
  18. tun *wintun.NativeTun
  19. }
  20. func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
  21. // GUID is 128 bit
  22. hash := crypto.MD5.New()
  23. _, err := hash.Write([]byte(tunGUIDLabel))
  24. if err != nil {
  25. return nil, err
  26. }
  27. _, err = hash.Write([]byte(name))
  28. if err != nil {
  29. return nil, err
  30. }
  31. sum := hash.Sum(nil)
  32. return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
  33. }
  34. func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, unsafeRoutes []Route) (*winTun, error) {
  35. guid, err := generateGUIDByDeviceName(deviceName)
  36. if err != nil {
  37. return nil, fmt.Errorf("generate GUID failed: %w", err)
  38. }
  39. tunDevice, err := wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
  40. if err != nil {
  41. return nil, fmt.Errorf("create TUN device failed: %w", err)
  42. }
  43. return &winTun{
  44. Device: deviceName,
  45. Cidr: cidr,
  46. MTU: defaultMTU,
  47. UnsafeRoutes: unsafeRoutes,
  48. tun: tunDevice.(*wintun.NativeTun),
  49. }, nil
  50. }
  51. func (t *winTun) Activate() error {
  52. luid := winipcfg.LUID(t.tun.LUID())
  53. if err := luid.SetIPAddresses([]net.IPNet{*t.Cidr}); err != nil {
  54. return fmt.Errorf("failed to set address: %w", err)
  55. }
  56. foundDefault4 := false
  57. routes := make([]*winipcfg.RouteData, 0, len(t.UnsafeRoutes)+1)
  58. for _, r := range t.UnsafeRoutes {
  59. if !foundDefault4 {
  60. if cidr, bits := r.Cidr.Mask.Size(); cidr == 0 && bits != 0 {
  61. foundDefault4 = true
  62. }
  63. }
  64. // Add our unsafe route
  65. routes = append(routes, &winipcfg.RouteData{
  66. Destination: *r.Cidr,
  67. NextHop: *r.Via,
  68. Metric: uint32(r.Metric),
  69. })
  70. }
  71. if err := luid.AddRoutes(routes); err != nil {
  72. return fmt.Errorf("failed to add routes: %w", err)
  73. }
  74. ipif, err := luid.IPInterface(windows.AF_INET)
  75. if err != nil {
  76. return fmt.Errorf("failed to get ip interface: %w", err)
  77. }
  78. ipif.NLMTU = uint32(t.MTU)
  79. if foundDefault4 {
  80. ipif.UseAutomaticMetric = false
  81. ipif.Metric = 0
  82. }
  83. if err := ipif.Set(); err != nil {
  84. return fmt.Errorf("failed to set ip interface: %w", err)
  85. }
  86. return nil
  87. }
  88. func (t *winTun) CidrNet() *net.IPNet {
  89. return t.Cidr
  90. }
  91. func (t *winTun) DeviceName() string {
  92. return t.Device
  93. }
  94. func (t *winTun) Read(b []byte) (int, error) {
  95. return t.tun.Read(b, 0)
  96. }
  97. func (t *winTun) Write(b []byte) (int, error) {
  98. return t.tun.Write(b, 0)
  99. }
  100. func (t *winTun) WriteRaw(b []byte) error {
  101. _, err := t.Write(b)
  102. return err
  103. }
  104. func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
  105. return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
  106. }
  107. func (t *winTun) Close() error {
  108. // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
  109. // so to be certain, just remove everything before destroying.
  110. luid := winipcfg.LUID(t.tun.LUID())
  111. _ = luid.FlushRoutes(windows.AF_INET)
  112. _ = luid.FlushIPAddresses(windows.AF_INET)
  113. /* We don't support IPV6 yet
  114. _ = luid.FlushRoutes(windows.AF_INET6)
  115. _ = luid.FlushIPAddresses(windows.AF_INET6)
  116. */
  117. _ = luid.FlushDNS(windows.AF_INET)
  118. return t.tun.Close()
  119. }