service.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. package service
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "math"
  9. "net"
  10. "net/netip"
  11. "strings"
  12. "sync"
  13. "github.com/slackhq/nebula"
  14. "github.com/slackhq/nebula/overlay"
  15. "golang.org/x/sync/errgroup"
  16. "gvisor.dev/gvisor/pkg/buffer"
  17. "gvisor.dev/gvisor/pkg/tcpip"
  18. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  19. "gvisor.dev/gvisor/pkg/tcpip/header"
  20. "gvisor.dev/gvisor/pkg/tcpip/link/channel"
  21. "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  22. "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  23. "gvisor.dev/gvisor/pkg/tcpip/stack"
  24. "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
  25. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  26. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  27. "gvisor.dev/gvisor/pkg/waiter"
  28. )
  29. const nicID = 1
  30. type Service struct {
  31. eg *errgroup.Group
  32. control *nebula.Control
  33. ipstack *stack.Stack
  34. mu struct {
  35. sync.Mutex
  36. listeners map[uint16]*tcpListener
  37. }
  38. }
  39. func New(control *nebula.Control) (*Service, error) {
  40. control.Start()
  41. ctx := control.Context()
  42. eg, ctx := errgroup.WithContext(ctx)
  43. s := Service{
  44. eg: eg,
  45. control: control,
  46. }
  47. s.mu.listeners = map[uint16]*tcpListener{}
  48. device, ok := control.Device().(*overlay.UserDevice)
  49. if !ok {
  50. return nil, errors.New("must be using user device")
  51. }
  52. s.ipstack = stack.New(stack.Options{
  53. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  54. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
  55. })
  56. sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
  57. tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
  58. if tcpipErr != nil {
  59. return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
  60. }
  61. linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "")
  62. if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
  63. return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
  64. }
  65. ipv4Subnet, _ := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0x00, 0x00, 0x00, 0x00}), tcpip.MaskFrom(strings.Repeat("\x00", 4)))
  66. s.ipstack.SetRouteTable([]tcpip.Route{
  67. {
  68. Destination: ipv4Subnet,
  69. NIC: nicID,
  70. },
  71. })
  72. ipNet := device.Networks()
  73. pa := tcpip.ProtocolAddress{
  74. AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(),
  75. Protocol: ipv4.ProtocolNumber,
  76. }
  77. if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
  78. PEB: stack.CanBePrimaryEndpoint, // zero value default
  79. ConfigType: stack.AddressConfigStatic, // zero value default
  80. }); err != nil {
  81. return nil, fmt.Errorf("error creating IP: %s", err)
  82. }
  83. const tcpReceiveBufferSize = 0
  84. const maxInFlightConnectionAttempts = 1024
  85. tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler)
  86. s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
  87. reader, writer := device.Pipe()
  88. go func() {
  89. <-ctx.Done()
  90. reader.Close()
  91. writer.Close()
  92. }()
  93. // create Goroutines to forward packets between Nebula and Gvisor
  94. eg.Go(func() error {
  95. buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize)
  96. for {
  97. // this will read exactly one packet
  98. n, err := reader.Read(buf)
  99. if err != nil {
  100. return err
  101. }
  102. packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
  103. Payload: buffer.MakeWithData(bytes.Clone(buf[:n])),
  104. })
  105. linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
  106. if err := ctx.Err(); err != nil {
  107. return err
  108. }
  109. }
  110. })
  111. eg.Go(func() error {
  112. for {
  113. packet := linkEP.ReadContext(ctx)
  114. if packet == nil {
  115. if err := ctx.Err(); err != nil {
  116. return err
  117. }
  118. continue
  119. }
  120. bufView := packet.ToView()
  121. if _, err := bufView.WriteTo(writer); err != nil {
  122. return err
  123. }
  124. bufView.Release()
  125. }
  126. })
  127. return &s, nil
  128. }
  129. func getProtocolNumber(addr netip.Addr) tcpip.NetworkProtocolNumber {
  130. if addr.Is6() {
  131. return ipv6.ProtocolNumber
  132. }
  133. return ipv4.ProtocolNumber
  134. }
  135. // DialContext dials the provided address.
  136. func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  137. switch network {
  138. case "udp", "udp4", "udp6":
  139. addr, err := net.ResolveUDPAddr(network, address)
  140. if err != nil {
  141. return nil, err
  142. }
  143. fullAddr := tcpip.FullAddress{
  144. NIC: nicID,
  145. Addr: tcpip.AddrFromSlice(addr.IP),
  146. Port: uint16(addr.Port),
  147. }
  148. num := getProtocolNumber(addr.AddrPort().Addr())
  149. return gonet.DialUDP(s.ipstack, nil, &fullAddr, num)
  150. case "tcp", "tcp4", "tcp6":
  151. addr, err := net.ResolveTCPAddr(network, address)
  152. if err != nil {
  153. return nil, err
  154. }
  155. fullAddr := tcpip.FullAddress{
  156. NIC: nicID,
  157. Addr: tcpip.AddrFromSlice(addr.IP),
  158. Port: uint16(addr.Port),
  159. }
  160. num := getProtocolNumber(addr.AddrPort().Addr())
  161. return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, num)
  162. default:
  163. return nil, fmt.Errorf("unknown network type: %s", network)
  164. }
  165. }
  166. // Dial dials the provided address
  167. func (s *Service) Dial(network, address string) (net.Conn, error) {
  168. return s.DialContext(context.Background(), network, address)
  169. }
  170. // Listen listens on the provided address. Currently only TCP with wildcard
  171. // addresses are supported.
  172. func (s *Service) Listen(network, address string) (net.Listener, error) {
  173. if network != "tcp" && network != "tcp4" {
  174. return nil, errors.New("only tcp is supported")
  175. }
  176. addr, err := net.ResolveTCPAddr(network, address)
  177. if err != nil {
  178. return nil, err
  179. }
  180. if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) {
  181. return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP)
  182. }
  183. if addr.Port == 0 {
  184. return nil, errors.New("specific port required, got 0")
  185. }
  186. if addr.Port < 0 || addr.Port >= math.MaxUint16 {
  187. return nil, fmt.Errorf("invalid port %d", addr.Port)
  188. }
  189. port := uint16(addr.Port)
  190. l := &tcpListener{
  191. port: port,
  192. s: s,
  193. addr: addr,
  194. accept: make(chan net.Conn),
  195. }
  196. s.mu.Lock()
  197. defer s.mu.Unlock()
  198. if _, ok := s.mu.listeners[port]; ok {
  199. return nil, fmt.Errorf("already listening on port %d", port)
  200. }
  201. s.mu.listeners[port] = l
  202. return l, nil
  203. }
  204. func (s *Service) Wait() error {
  205. return s.eg.Wait()
  206. }
  207. func (s *Service) Close() error {
  208. s.control.Stop()
  209. return nil
  210. }
  211. func (s *Service) tcpHandler(r *tcp.ForwarderRequest) {
  212. endpointID := r.ID()
  213. s.mu.Lock()
  214. defer s.mu.Unlock()
  215. l, ok := s.mu.listeners[endpointID.LocalPort]
  216. if !ok {
  217. r.Complete(true)
  218. return
  219. }
  220. var wq waiter.Queue
  221. ep, err := r.CreateEndpoint(&wq)
  222. if err != nil {
  223. log.Printf("got error creating endpoint %q", err)
  224. r.Complete(true)
  225. return
  226. }
  227. r.Complete(false)
  228. ep.SocketOptions().SetKeepAlive(true)
  229. conn := gonet.NewTCPConn(&wq, ep)
  230. l.accept <- conn
  231. }