service.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. package service
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "math"
  9. "net"
  10. "os"
  11. "strings"
  12. "sync"
  13. "github.com/sirupsen/logrus"
  14. "github.com/slackhq/nebula"
  15. "github.com/slackhq/nebula/config"
  16. "github.com/slackhq/nebula/overlay"
  17. "golang.org/x/sync/errgroup"
  18. "gvisor.dev/gvisor/pkg/buffer"
  19. "gvisor.dev/gvisor/pkg/tcpip"
  20. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  21. "gvisor.dev/gvisor/pkg/tcpip/header"
  22. "gvisor.dev/gvisor/pkg/tcpip/link/channel"
  23. "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  24. "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  25. "gvisor.dev/gvisor/pkg/tcpip/stack"
  26. "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
  27. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  28. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  29. "gvisor.dev/gvisor/pkg/waiter"
  30. )
  31. const nicID = 1
  32. type Service struct {
  33. eg *errgroup.Group
  34. control *nebula.Control
  35. ipstack *stack.Stack
  36. mu struct {
  37. sync.Mutex
  38. listeners map[uint16]*tcpListener
  39. }
  40. }
  41. func New(config *config.C) (*Service, error) {
  42. logger := logrus.New()
  43. logger.Out = os.Stdout
  44. control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
  45. if err != nil {
  46. return nil, err
  47. }
  48. control.Start()
  49. ctx := control.Context()
  50. eg, ctx := errgroup.WithContext(ctx)
  51. s := Service{
  52. eg: eg,
  53. control: control,
  54. }
  55. s.mu.listeners = map[uint16]*tcpListener{}
  56. device, ok := control.Device().(*overlay.UserDevice)
  57. if !ok {
  58. return nil, errors.New("must be using user device")
  59. }
  60. s.ipstack = stack.New(stack.Options{
  61. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  62. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
  63. })
  64. sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
  65. tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
  66. if tcpipErr != nil {
  67. return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
  68. }
  69. linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "")
  70. if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
  71. return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
  72. }
  73. ipv4Subnet, _ := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0x00, 0x00, 0x00, 0x00}), tcpip.MaskFrom(strings.Repeat("\x00", 4)))
  74. s.ipstack.SetRouteTable([]tcpip.Route{
  75. {
  76. Destination: ipv4Subnet,
  77. NIC: nicID,
  78. },
  79. })
  80. ipNet := device.Cidr()
  81. pa := tcpip.ProtocolAddress{
  82. AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(),
  83. Protocol: ipv4.ProtocolNumber,
  84. }
  85. if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
  86. PEB: stack.CanBePrimaryEndpoint, // zero value default
  87. ConfigType: stack.AddressConfigStatic, // zero value default
  88. }); err != nil {
  89. return nil, fmt.Errorf("error creating IP: %s", err)
  90. }
  91. const tcpReceiveBufferSize = 0
  92. const maxInFlightConnectionAttempts = 1024
  93. tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler)
  94. s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
  95. reader, writer := device.Pipe()
  96. go func() {
  97. <-ctx.Done()
  98. reader.Close()
  99. writer.Close()
  100. }()
  101. // create Goroutines to forward packets between Nebula and Gvisor
  102. eg.Go(func() error {
  103. buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize)
  104. for {
  105. // this will read exactly one packet
  106. n, err := reader.Read(buf)
  107. if err != nil {
  108. return err
  109. }
  110. packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
  111. Payload: buffer.MakeWithData(bytes.Clone(buf[:n])),
  112. })
  113. linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
  114. if err := ctx.Err(); err != nil {
  115. return err
  116. }
  117. }
  118. })
  119. eg.Go(func() error {
  120. for {
  121. packet := linkEP.ReadContext(ctx)
  122. if packet == nil {
  123. if err := ctx.Err(); err != nil {
  124. return err
  125. }
  126. continue
  127. }
  128. bufView := packet.ToView()
  129. if _, err := bufView.WriteTo(writer); err != nil {
  130. return err
  131. }
  132. bufView.Release()
  133. }
  134. })
  135. return &s, nil
  136. }
  137. // DialContext dials the provided address. Currently only TCP is supported.
  138. func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  139. if network != "tcp" && network != "tcp4" {
  140. return nil, errors.New("only tcp is supported")
  141. }
  142. addr, err := net.ResolveTCPAddr(network, address)
  143. if err != nil {
  144. return nil, err
  145. }
  146. fullAddr := tcpip.FullAddress{
  147. NIC: nicID,
  148. Addr: tcpip.AddrFromSlice(addr.IP),
  149. Port: uint16(addr.Port),
  150. }
  151. return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber)
  152. }
  153. // Listen listens on the provided address. Currently only TCP with wildcard
  154. // addresses are supported.
  155. func (s *Service) Listen(network, address string) (net.Listener, error) {
  156. if network != "tcp" && network != "tcp4" {
  157. return nil, errors.New("only tcp is supported")
  158. }
  159. addr, err := net.ResolveTCPAddr(network, address)
  160. if err != nil {
  161. return nil, err
  162. }
  163. if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) {
  164. return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP)
  165. }
  166. if addr.Port == 0 {
  167. return nil, errors.New("specific port required, got 0")
  168. }
  169. if addr.Port < 0 || addr.Port >= math.MaxUint16 {
  170. return nil, fmt.Errorf("invalid port %d", addr.Port)
  171. }
  172. port := uint16(addr.Port)
  173. l := &tcpListener{
  174. port: port,
  175. s: s,
  176. addr: addr,
  177. accept: make(chan net.Conn),
  178. }
  179. s.mu.Lock()
  180. defer s.mu.Unlock()
  181. if _, ok := s.mu.listeners[port]; ok {
  182. return nil, fmt.Errorf("already listening on port %d", port)
  183. }
  184. s.mu.listeners[port] = l
  185. return l, nil
  186. }
  187. func (s *Service) Wait() error {
  188. return s.eg.Wait()
  189. }
  190. func (s *Service) Close() error {
  191. s.control.Stop()
  192. return nil
  193. }
  194. func (s *Service) tcpHandler(r *tcp.ForwarderRequest) {
  195. endpointID := r.ID()
  196. s.mu.Lock()
  197. defer s.mu.Unlock()
  198. l, ok := s.mu.listeners[endpointID.LocalPort]
  199. if !ok {
  200. r.Complete(true)
  201. return
  202. }
  203. var wq waiter.Queue
  204. ep, err := r.CreateEndpoint(&wq)
  205. if err != nil {
  206. log.Printf("got error creating endpoint %q", err)
  207. r.Complete(true)
  208. return
  209. }
  210. r.Complete(false)
  211. ep.SocketOptions().SetKeepAlive(true)
  212. conn := gonet.NewTCPConn(&wq, ep)
  213. l.accept <- conn
  214. }