浏览代码

:gear: Construct from client the socket dialer

Ettore Di Giacinto 3 年之前
父节点
当前提交
c27cda6965
共有 2 个文件被更改,包括 15 次插入9 次删除
  1. 1 9
      api/api_test.go
  2. 14 0
      api/client/client.go

+ 1 - 9
api/api_test.go

@@ -19,8 +19,6 @@ import (
 	"context"
 	"context"
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
-	"net"
-	"net/http"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"time"
 	"time"
@@ -44,13 +42,7 @@ var _ = Describe("API", func() {
 			os.MkdirAll(d, os.ModePerm)
 			os.MkdirAll(d, os.ModePerm)
 			socket := filepath.Join(d, "socket")
 			socket := filepath.Join(d, "socket")
 
 
-			c := client.NewClient(client.WithHost("http://unix"), client.WithHTTPClient(&http.Client{
-				Transport: &http.Transport{
-					DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
-						return net.Dial("unix", socket)
-					},
-				},
-			}))
+			c := client.NewClient(client.WithHost("unix://" + socket))
 
 
 			token := node.GenerateNewConnectionData().Base64()
 			token := node.GenerateNewConnectionData().Base64()
 			ctx, cancel := context.WithCancel(context.Background())
 			ctx, cancel := context.WithCancel(context.Background())

+ 14 - 0
api/client/client.go

@@ -16,11 +16,14 @@
 package client
 package client
 
 
 import (
 import (
+	"context"
 	"encoding/base64"
 	"encoding/base64"
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
+	"net"
 	"net/http"
 	"net/http"
+	"strings"
 	"time"
 	"time"
 
 
 	"github.com/mudler/edgevpn/pkg/blockchain"
 	"github.com/mudler/edgevpn/pkg/blockchain"
@@ -47,6 +50,17 @@ const (
 func WithHost(host string) func(c *Client) error {
 func WithHost(host string) func(c *Client) error {
 	return func(c *Client) error {
 	return func(c *Client) error {
 		c.host = host
 		c.host = host
+		if strings.HasPrefix(host, "unix://") {
+			socket := strings.ReplaceAll(host, "unix://", "")
+			c.host = "http://unix"
+			c.httpClient = &http.Client{
+				Transport: &http.Transport{
+					DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
+						return net.Dial("unix", socket)
+					},
+				},
+			}
+		}
 		return nil
 		return nil
 	}
 	}
 }
 }