2
0
Эх сурвалжийг харах

Support UDP dialling with gvisor (#1181)

Jack Doan 11 сар өмнө
parent
commit
3dc56e1184

+ 18 - 9
examples/go_service/main.go

@@ -4,6 +4,7 @@ import (
 	"bufio"
 	"fmt"
 	"log"
+	"net"
 
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/service"
@@ -54,16 +55,16 @@ pki:
   cert: /home/rice/Developer/nebula-config/app.crt
   key: /home/rice/Developer/nebula-config/app.key
 `
-	var config config.C
-	if err := config.LoadString(configStr); err != nil {
+	var cfg config.C
+	if err := cfg.LoadString(configStr); err != nil {
 		return err
 	}
-	service, err := service.New(&config)
+	svc, err := service.New(&cfg)
 	if err != nil {
 		return err
 	}
 
-	ln, err := service.Listen("tcp", ":1234")
+	ln, err := svc.Listen("tcp", ":1234")
 	if err != nil {
 		return err
 	}
@@ -73,16 +74,24 @@ pki:
 			log.Printf("accept error: %s", err)
 			break
 		}
-		defer conn.Close()
+		defer func(conn net.Conn) {
+			_ = conn.Close()
+		}(conn)
 
 		log.Printf("got connection")
 
-		conn.Write([]byte("hello world\n"))
+		_, err = conn.Write([]byte("hello world\n"))
+		if err != nil {
+			log.Printf("write error: %s", err)
+		}
 
 		scanner := bufio.NewScanner(conn)
 		for scanner.Scan() {
 			message := scanner.Text()
-			fmt.Fprintf(conn, "echo: %q\n", message)
+			_, err = fmt.Fprintf(conn, "echo: %q\n", message)
+			if err != nil {
+				log.Printf("write error: %s", err)
+			}
 			log.Printf("got message %q", message)
 		}
 
@@ -92,8 +101,8 @@ pki:
 		}
 	}
 
-	service.Close()
-	if err := service.Wait(); err != nil {
+	_ = svc.Close()
+	if err := svc.Wait(); err != nil {
 		return err
 	}
 	return nil

+ 39 - 14
service/service.go

@@ -8,6 +8,7 @@ import (
 	"log"
 	"math"
 	"net"
+	"net/netip"
 	"os"
 	"strings"
 	"sync"
@@ -153,24 +154,48 @@ func New(config *config.C) (*Service, error) {
 	return &s, nil
 }
 
-// DialContext dials the provided address. Currently only TCP is supported.
-func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
-	if network != "tcp" && network != "tcp4" {
-		return nil, errors.New("only tcp is supported")
-	}
-
-	addr, err := net.ResolveTCPAddr(network, address)
-	if err != nil {
-		return nil, err
+func getProtocolNumber(addr netip.Addr) tcpip.NetworkProtocolNumber {
+	if addr.Is6() {
+		return ipv6.ProtocolNumber
 	}
+	return ipv4.ProtocolNumber
+}
 
-	fullAddr := tcpip.FullAddress{
-		NIC:  nicID,
-		Addr: tcpip.AddrFromSlice(addr.IP),
-		Port: uint16(addr.Port),
+// DialContext dials the provided address.
+func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+	switch network {
+	case "udp", "udp4", "udp6":
+		addr, err := net.ResolveUDPAddr(network, address)
+		if err != nil {
+			return nil, err
+		}
+		fullAddr := tcpip.FullAddress{
+			NIC:  nicID,
+			Addr: tcpip.AddrFromSlice(addr.IP),
+			Port: uint16(addr.Port),
+		}
+		num := getProtocolNumber(addr.AddrPort().Addr())
+		return gonet.DialUDP(s.ipstack, nil, &fullAddr, num)
+	case "tcp", "tcp4", "tcp6":
+		addr, err := net.ResolveTCPAddr(network, address)
+		if err != nil {
+			return nil, err
+		}
+		fullAddr := tcpip.FullAddress{
+			NIC:  nicID,
+			Addr: tcpip.AddrFromSlice(addr.IP),
+			Port: uint16(addr.Port),
+		}
+		num := getProtocolNumber(addr.AddrPort().Addr())
+		return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, num)
+	default:
+		return nil, fmt.Errorf("unknown network type: %s", network)
 	}
+}
 
-	return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber)
+// Dial dials the provided address
+func (s *Service) Dial(network, address string) (net.Conn, error) {
+	return s.DialContext(context.Background(), network, address)
 }
 
 // Listen listens on the provided address. Currently only TCP with wildcard