service.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. package service
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "math"
  9. "net"
  10. "net/netip"
  11. "os"
  12. "strings"
  13. "sync"
  14. "github.com/sirupsen/logrus"
  15. "github.com/slackhq/nebula"
  16. "github.com/slackhq/nebula/config"
  17. "github.com/slackhq/nebula/overlay"
  18. "golang.org/x/sync/errgroup"
  19. "gvisor.dev/gvisor/pkg/buffer"
  20. "gvisor.dev/gvisor/pkg/tcpip"
  21. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  22. "gvisor.dev/gvisor/pkg/tcpip/header"
  23. "gvisor.dev/gvisor/pkg/tcpip/link/channel"
  24. "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  25. "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  26. "gvisor.dev/gvisor/pkg/tcpip/stack"
  27. "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
  28. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  29. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  30. "gvisor.dev/gvisor/pkg/waiter"
  31. )
  32. const nicID = 1
  33. type Service struct {
  34. eg *errgroup.Group
  35. control *nebula.Control
  36. ipstack *stack.Stack
  37. mu struct {
  38. sync.Mutex
  39. listeners map[uint16]*tcpListener
  40. }
  41. }
  42. func New(config *config.C) (*Service, error) {
  43. logger := logrus.New()
  44. logger.Out = os.Stdout
  45. control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
  46. if err != nil {
  47. return nil, err
  48. }
  49. control.Start()
  50. ctx := control.Context()
  51. eg, ctx := errgroup.WithContext(ctx)
  52. s := Service{
  53. eg: eg,
  54. control: control,
  55. }
  56. s.mu.listeners = map[uint16]*tcpListener{}
  57. device, ok := control.Device().(*overlay.UserDevice)
  58. if !ok {
  59. return nil, errors.New("must be using user device")
  60. }
  61. s.ipstack = stack.New(stack.Options{
  62. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  63. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
  64. })
  65. sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
  66. tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
  67. if tcpipErr != nil {
  68. return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
  69. }
  70. linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "")
  71. if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
  72. return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
  73. }
  74. ipv4Subnet, _ := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0x00, 0x00, 0x00, 0x00}), tcpip.MaskFrom(strings.Repeat("\x00", 4)))
  75. s.ipstack.SetRouteTable([]tcpip.Route{
  76. {
  77. Destination: ipv4Subnet,
  78. NIC: nicID,
  79. },
  80. })
  81. ipNet := device.Networks()
  82. pa := tcpip.ProtocolAddress{
  83. AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(),
  84. Protocol: ipv4.ProtocolNumber,
  85. }
  86. if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
  87. PEB: stack.CanBePrimaryEndpoint, // zero value default
  88. ConfigType: stack.AddressConfigStatic, // zero value default
  89. }); err != nil {
  90. return nil, fmt.Errorf("error creating IP: %s", err)
  91. }
  92. const tcpReceiveBufferSize = 0
  93. const maxInFlightConnectionAttempts = 1024
  94. tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler)
  95. s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
  96. reader, writer := device.Pipe()
  97. go func() {
  98. <-ctx.Done()
  99. reader.Close()
  100. writer.Close()
  101. }()
  102. // create Goroutines to forward packets between Nebula and Gvisor
  103. eg.Go(func() error {
  104. buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize)
  105. for {
  106. // this will read exactly one packet
  107. n, err := reader.Read(buf)
  108. if err != nil {
  109. return err
  110. }
  111. packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
  112. Payload: buffer.MakeWithData(bytes.Clone(buf[:n])),
  113. })
  114. linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
  115. if err := ctx.Err(); err != nil {
  116. return err
  117. }
  118. }
  119. })
  120. eg.Go(func() error {
  121. for {
  122. packet := linkEP.ReadContext(ctx)
  123. if packet == nil {
  124. if err := ctx.Err(); err != nil {
  125. return err
  126. }
  127. continue
  128. }
  129. bufView := packet.ToView()
  130. if _, err := bufView.WriteTo(writer); err != nil {
  131. return err
  132. }
  133. bufView.Release()
  134. }
  135. })
  136. return &s, nil
  137. }
  138. func getProtocolNumber(addr netip.Addr) tcpip.NetworkProtocolNumber {
  139. if addr.Is6() {
  140. return ipv6.ProtocolNumber
  141. }
  142. return ipv4.ProtocolNumber
  143. }
  144. // DialContext dials the provided address.
  145. func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  146. switch network {
  147. case "udp", "udp4", "udp6":
  148. addr, err := net.ResolveUDPAddr(network, address)
  149. if err != nil {
  150. return nil, err
  151. }
  152. fullAddr := tcpip.FullAddress{
  153. NIC: nicID,
  154. Addr: tcpip.AddrFromSlice(addr.IP),
  155. Port: uint16(addr.Port),
  156. }
  157. num := getProtocolNumber(addr.AddrPort().Addr())
  158. return gonet.DialUDP(s.ipstack, nil, &fullAddr, num)
  159. case "tcp", "tcp4", "tcp6":
  160. addr, err := net.ResolveTCPAddr(network, address)
  161. if err != nil {
  162. return nil, err
  163. }
  164. fullAddr := tcpip.FullAddress{
  165. NIC: nicID,
  166. Addr: tcpip.AddrFromSlice(addr.IP),
  167. Port: uint16(addr.Port),
  168. }
  169. num := getProtocolNumber(addr.AddrPort().Addr())
  170. return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, num)
  171. default:
  172. return nil, fmt.Errorf("unknown network type: %s", network)
  173. }
  174. }
  175. // Dial dials the provided address
  176. func (s *Service) Dial(network, address string) (net.Conn, error) {
  177. return s.DialContext(context.Background(), network, address)
  178. }
  179. // Listen listens on the provided address. Currently only TCP with wildcard
  180. // addresses are supported.
  181. func (s *Service) Listen(network, address string) (net.Listener, error) {
  182. if network != "tcp" && network != "tcp4" {
  183. return nil, errors.New("only tcp is supported")
  184. }
  185. addr, err := net.ResolveTCPAddr(network, address)
  186. if err != nil {
  187. return nil, err
  188. }
  189. if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) {
  190. return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP)
  191. }
  192. if addr.Port == 0 {
  193. return nil, errors.New("specific port required, got 0")
  194. }
  195. if addr.Port < 0 || addr.Port >= math.MaxUint16 {
  196. return nil, fmt.Errorf("invalid port %d", addr.Port)
  197. }
  198. port := uint16(addr.Port)
  199. l := &tcpListener{
  200. port: port,
  201. s: s,
  202. addr: addr,
  203. accept: make(chan net.Conn),
  204. }
  205. s.mu.Lock()
  206. defer s.mu.Unlock()
  207. if _, ok := s.mu.listeners[port]; ok {
  208. return nil, fmt.Errorf("already listening on port %d", port)
  209. }
  210. s.mu.listeners[port] = l
  211. return l, nil
  212. }
  213. func (s *Service) Wait() error {
  214. return s.eg.Wait()
  215. }
  216. func (s *Service) Close() error {
  217. s.control.Stop()
  218. return nil
  219. }
  220. func (s *Service) tcpHandler(r *tcp.ForwarderRequest) {
  221. endpointID := r.ID()
  222. s.mu.Lock()
  223. defer s.mu.Unlock()
  224. l, ok := s.mu.listeners[endpointID.LocalPort]
  225. if !ok {
  226. r.Complete(true)
  227. return
  228. }
  229. var wq waiter.Queue
  230. ep, err := r.CreateEndpoint(&wq)
  231. if err != nil {
  232. log.Printf("got error creating endpoint %q", err)
  233. r.Complete(true)
  234. return
  235. }
  236. r.Complete(false)
  237. ep.SocketOptions().SetKeepAlive(true)
  238. conn := gonet.NewTCPConn(&wq, ep)
  239. l.accept <- conn
  240. }