123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248 |
- package service
- import (
- "bytes"
- "context"
- "errors"
- "fmt"
- "log"
- "math"
- "net"
- "os"
- "strings"
- "sync"
- "github.com/sirupsen/logrus"
- "github.com/slackhq/nebula"
- "github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/overlay"
- "golang.org/x/sync/errgroup"
- "gvisor.dev/gvisor/pkg/buffer"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
- )
- const nicID = 1
- type Service struct {
- eg *errgroup.Group
- control *nebula.Control
- ipstack *stack.Stack
- mu struct {
- sync.Mutex
- listeners map[uint16]*tcpListener
- }
- }
- func New(config *config.C) (*Service, error) {
- logger := logrus.New()
- logger.Out = os.Stdout
- control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
- if err != nil {
- return nil, err
- }
- control.Start()
- ctx := control.Context()
- eg, ctx := errgroup.WithContext(ctx)
- s := Service{
- eg: eg,
- control: control,
- }
- s.mu.listeners = map[uint16]*tcpListener{}
- device, ok := control.Device().(*overlay.UserDevice)
- if !ok {
- return nil, errors.New("must be using user device")
- }
- s.ipstack = stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
- })
- sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
- tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
- if tcpipErr != nil {
- return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
- }
- linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "")
- if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
- return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
- }
- ipv4Subnet, _ := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0x00, 0x00, 0x00, 0x00}), tcpip.MaskFrom(strings.Repeat("\x00", 4)))
- s.ipstack.SetRouteTable([]tcpip.Route{
- {
- Destination: ipv4Subnet,
- NIC: nicID,
- },
- })
- ipNet := device.Cidr()
- pa := tcpip.ProtocolAddress{
- AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(),
- Protocol: ipv4.ProtocolNumber,
- }
- if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
- PEB: stack.CanBePrimaryEndpoint, // zero value default
- ConfigType: stack.AddressConfigStatic, // zero value default
- }); err != nil {
- return nil, fmt.Errorf("error creating IP: %s", err)
- }
- const tcpReceiveBufferSize = 0
- const maxInFlightConnectionAttempts = 1024
- tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler)
- s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
- reader, writer := device.Pipe()
- go func() {
- <-ctx.Done()
- reader.Close()
- writer.Close()
- }()
- // create Goroutines to forward packets between Nebula and Gvisor
- eg.Go(func() error {
- buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize)
- for {
- // this will read exactly one packet
- n, err := reader.Read(buf)
- if err != nil {
- return err
- }
- packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Payload: buffer.MakeWithData(bytes.Clone(buf[:n])),
- })
- linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
- if err := ctx.Err(); err != nil {
- return err
- }
- }
- })
- eg.Go(func() error {
- for {
- packet := linkEP.ReadContext(ctx)
- if packet == nil {
- if err := ctx.Err(); err != nil {
- return err
- }
- continue
- }
- bufView := packet.ToView()
- if _, err := bufView.WriteTo(writer); err != nil {
- return err
- }
- bufView.Release()
- }
- })
- return &s, nil
- }
- // DialContext dials the provided address. Currently only TCP is supported.
- func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
- if network != "tcp" && network != "tcp4" {
- return nil, errors.New("only tcp is supported")
- }
- addr, err := net.ResolveTCPAddr(network, address)
- if err != nil {
- return nil, err
- }
- fullAddr := tcpip.FullAddress{
- NIC: nicID,
- Addr: tcpip.AddrFromSlice(addr.IP),
- Port: uint16(addr.Port),
- }
- return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber)
- }
- // Listen listens on the provided address. Currently only TCP with wildcard
- // addresses are supported.
- func (s *Service) Listen(network, address string) (net.Listener, error) {
- if network != "tcp" && network != "tcp4" {
- return nil, errors.New("only tcp is supported")
- }
- addr, err := net.ResolveTCPAddr(network, address)
- if err != nil {
- return nil, err
- }
- if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) {
- return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP)
- }
- if addr.Port == 0 {
- return nil, errors.New("specific port required, got 0")
- }
- if addr.Port < 0 || addr.Port >= math.MaxUint16 {
- return nil, fmt.Errorf("invalid port %d", addr.Port)
- }
- port := uint16(addr.Port)
- l := &tcpListener{
- port: port,
- s: s,
- addr: addr,
- accept: make(chan net.Conn),
- }
- s.mu.Lock()
- defer s.mu.Unlock()
- if _, ok := s.mu.listeners[port]; ok {
- return nil, fmt.Errorf("already listening on port %d", port)
- }
- s.mu.listeners[port] = l
- return l, nil
- }
- func (s *Service) Wait() error {
- return s.eg.Wait()
- }
- func (s *Service) Close() error {
- s.control.Stop()
- return nil
- }
- func (s *Service) tcpHandler(r *tcp.ForwarderRequest) {
- endpointID := r.ID()
- s.mu.Lock()
- defer s.mu.Unlock()
- l, ok := s.mu.listeners[endpointID.LocalPort]
- if !ok {
- r.Complete(true)
- return
- }
- var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- log.Printf("got error creating endpoint %q", err)
- r.Complete(true)
- return
- }
- r.Complete(false)
- ep.SocketOptions().SetKeepAlive(true)
- conn := gonet.NewTCPConn(&wq, ep)
- l.accept <- conn
- }
|