Browse Source

:gear: Add context to services

Ettore Di Giacinto 3 years ago
parent
commit
6d81d85b3a
5 changed files with 110 additions and 103 deletions
  1. 1 1
      README.md
  2. 2 2
      cmd/file.go
  3. 2 2
      cmd/service.go
  4. 53 52
      pkg/services/files.go
  5. 52 46
      pkg/services/services.go

+ 1 - 1
README.md

@@ -273,7 +273,7 @@ EdgeVPN can be used as a library. It is very portable and offers a functional in
 ```golang
 
 import (
-    edgevpn "github.com/mudler/edgevpn/pkg/edgevpn"
+    edgevpn "github.com/mudler/edgevpn/pkg/node"
 )
 
 e := edgevpn.New(edgevpn.Logger(l),

+ 2 - 2
cmd/file.go

@@ -80,7 +80,7 @@ This is also the ID used to refer when receiving it.`,
 				return err
 			}
 
-			services.SendFile(ledger, e, e.Logger(), time.Duration(c.Int("ledger-announce-interval"))*time.Second, name, path)
+			services.SendFile(context.Background(), ledger, e, e.Logger(), time.Duration(c.Int("ledger-announce-interval"))*time.Second, name, path)
 
 			// Start the node to the network, using our ledger
 			if err := e.Start(context.Background()); err != nil {
@@ -127,7 +127,7 @@ func FileReceive() cli.Command {
 
 			ledger, _ := e.Ledger()
 
-			return services.ReceiveFile(ledger, e, e.Logger(), time.Duration(c.Int("ledger-announce-interval"))*time.Second, name, path)
+			return services.ReceiveFile(context.Background(), ledger, e, e.Logger(), time.Duration(c.Int("ledger-announce-interval"))*time.Second, name, path)
 		},
 	}
 }

+ 2 - 2
cmd/service.go

@@ -79,7 +79,7 @@ For example, '192.168.1.1:80', or '127.0.0.1:22'.`,
 				return err
 			}
 
-			services.ExposeService(ledger, e, e.Logger(), time.Duration(c.Int("ledger-announce-interval"))*time.Second, name, address)
+			services.ExposeService(context.Background(), ledger, e, e.Logger(), time.Duration(c.Int("ledger-announce-interval"))*time.Second, name, address)
 
 			// Join the node to the network, using our ledger
 			if err := e.Start(context.Background()); err != nil {
@@ -128,7 +128,7 @@ to the service over the network`,
 			}
 
 			ledger, _ := e.Ledger()
-			return services.ConnectToService(ledger, e, e.Logger(), time.Duration(c.Int("ledger-announce-interval"))*time.Second, name, address)
+			return services.ConnectToService(context.Background(), ledger, e, e.Logger(), time.Duration(c.Int("ledger-announce-interval"))*time.Second, name, address)
 		},
 	}
 }

+ 53 - 52
pkg/services/files.go

@@ -31,13 +31,13 @@ import (
 	"github.com/pkg/errors"
 )
 
-func SendFile(ledger *blockchain.Ledger, node types.Node, l log.StandardLogger, announcetime time.Duration, fileID, filepath string) error {
+func SendFile(ctx context.Context, ledger *blockchain.Ledger, node types.Node, l log.StandardLogger, announcetime time.Duration, fileID, filepath string) error {
 
 	l.Infof("Serving '%s' as '%s'", filepath, fileID)
 
 	// By announcing periodically our service to the blockchain
 	ledger.Announce(
-		context.Background(),
+		ctx,
 		announcetime,
 		func() {
 			// Retrieve current ID for ip in the blockchain
@@ -88,11 +88,10 @@ func SendFile(ledger *blockchain.Ledger, node types.Node, l log.StandardLogger,
 	return nil
 }
 
-func ReceiveFile(ledger *blockchain.Ledger, node types.Node, l log.StandardLogger, announcetime time.Duration, fileID string, path string) error {
-
+func ReceiveFile(ctx context.Context, ledger *blockchain.Ledger, node types.Node, l log.StandardLogger, announcetime time.Duration, fileID string, path string) error {
 	// Announce ourselves so nodes accepts our connection
 	ledger.Announce(
-		context.Background(),
+		ctx,
 		announcetime,
 		func() {
 			// Retrieve current ID for ip in the blockchain
@@ -108,60 +107,62 @@ func ReceiveFile(ledger *blockchain.Ledger, node types.Node, l log.StandardLogge
 			}
 		},
 	)
-	for {
-		time.Sleep(5 * time.Second)
 
-		l.Debug("Attempting to find file in the blockchain")
-
-		_, found := ledger.GetKey(protocol.UsersLedgerKey, node.Host().ID().String())
-		if !found {
-			continue
-		}
-		existingValue, found := ledger.GetKey(protocol.FilesLedgerKey, fileID)
-		fi := &types.File{}
-		existingValue.Unmarshal(fi)
-		// If mismatch, update the blockchain
-		if !found {
-			l.Debug("file not found on blockchain, retrying in 5 seconds")
-			continue
-		} else {
-			break
-		}
-	}
-	// Listen for an incoming connection.
+	for {
+		select {
+		case <-ctx.Done():
+			return errors.New("context canceled")
+		default:
+			time.Sleep(5 * time.Second)
 
-	// Retrieve current ID for ip in the blockchain
-	existingValue, found := ledger.GetKey(protocol.FilesLedgerKey, fileID)
-	fi := &types.File{}
-	existingValue.Unmarshal(fi)
+			l.Debug("Attempting to find file in the blockchain")
 
-	// If mismatch, update the blockchain
-	if !found {
-		return errors.New("file not found")
-	}
+			_, found := ledger.GetKey(protocol.UsersLedgerKey, node.Host().ID().String())
+			if !found {
+				continue
+			}
+			existingValue, found := ledger.GetKey(protocol.FilesLedgerKey, fileID)
+			fi := &types.File{}
+			existingValue.Unmarshal(fi)
+			// If mismatch, update the blockchain
+			if !found {
+				l.Debug("file not found on blockchain, retrying in 5 seconds")
+				continue
+			} else {
+				// Retrieve current ID for ip in the blockchain
+				existingValue, found := ledger.GetKey(protocol.FilesLedgerKey, fileID)
+				fi := &types.File{}
+				existingValue.Unmarshal(fi)
+
+				// If mismatch, update the blockchain
+				if !found {
+					return errors.New("file not found")
+				}
 
-	// Decode the Peer
-	d, err := peer.Decode(fi.PeerID)
-	if err != nil {
-		return err
-	}
+				// Decode the Peer
+				d, err := peer.Decode(fi.PeerID)
+				if err != nil {
+					return err
+				}
 
-	// Open a stream
-	stream, err := node.Host().NewStream(context.Background(), d, protocol.FileProtocol.ID())
-	if err != nil {
-		return err
-	}
-	l.Infof("Saving file %s to %s", fileID, path)
+				// Open a stream
+				stream, err := node.Host().NewStream(context.Background(), d, protocol.FileProtocol.ID())
+				if err != nil {
+					return err
+				}
+				l.Infof("Saving file %s to %s", fileID, path)
 
-	f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0755)
-	if err != nil {
-		return err
-	}
+				f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0755)
+				if err != nil {
+					return err
+				}
 
-	io.Copy(f, stream)
+				io.Copy(f, stream)
 
-	f.Close()
+				f.Close()
 
-	l.Infof("Received file %s to %s", fileID, path)
-	return nil
+				l.Infof("Received file %s to %s", fileID, path)
+			}
+		}
+	}
 }

+ 52 - 46
pkg/services/services.go

@@ -26,18 +26,19 @@ import (
 	"github.com/libp2p/go-libp2p-core/peer"
 	"github.com/mudler/edgevpn/pkg/blockchain"
 	protocol "github.com/mudler/edgevpn/pkg/protocol"
+	"github.com/pkg/errors"
 
 	"github.com/mudler/edgevpn/pkg/types"
 )
 
-func ExposeService(ledger *blockchain.Ledger, node types.Node, l log.StandardLogger, announcetime time.Duration, serviceID, dstaddress string) {
+func ExposeService(ctx context.Context, ledger *blockchain.Ledger, node types.Node, l log.StandardLogger, announcetime time.Duration, serviceID, dstaddress string) {
 
 	l.Infof("Exposing service '%s' (%s)", serviceID, dstaddress)
 
 	// 1) Register the ServiceID <-> PeerID Association
 	// By announcing periodically our service to the blockchain
 	ledger.Announce(
-		context.Background(),
+		ctx,
 		announcetime,
 		func() {
 			// Retrieve current ID for ip in the blockchain
@@ -88,7 +89,7 @@ func ExposeService(ledger *blockchain.Ledger, node types.Node, l log.StandardLog
 	})
 }
 
-func ConnectToService(ledger *blockchain.Ledger, node types.Node, ll log.StandardLogger, announcetime time.Duration, serviceID string, srcaddr string) error {
+func ConnectToService(ctx context.Context, ledger *blockchain.Ledger, node types.Node, ll log.StandardLogger, announcetime time.Duration, serviceID string, srcaddr string) error {
 
 	// Open local port for listening
 	l, err := net.Listen("tcp", srcaddr)
@@ -99,7 +100,7 @@ func ConnectToService(ledger *blockchain.Ledger, node types.Node, ll log.Standar
 
 	// Announce ourselves so nodes accepts our connection
 	ledger.Announce(
-		context.Background(),
+		ctx,
 		announcetime,
 		func() {
 			// Retrieve current ID for ip in the blockchain
@@ -117,53 +118,58 @@ func ConnectToService(ledger *blockchain.Ledger, node types.Node, ll log.Standar
 	)
 	defer l.Close()
 	for {
-		// Listen for an incoming connection.
-		conn, err := l.Accept()
-		if err != nil {
-			ll.Error("Error accepting: ", err.Error())
-			continue
-		}
-
-		ll.Info("New connection from", l.Addr().String())
-		// Handle connections in a new goroutine, forwarding to the p2p service
-		go func() {
-			// Retrieve current ID for ip in the blockchain
-			existingValue, found := ledger.GetKey(protocol.ServicesLedgerKey, serviceID)
-			service := &types.Service{}
-			existingValue.Unmarshal(service)
-			// If mismatch, update the blockchain
-			if !found {
-				conn.Close()
-				ll.Debugf("service '%s' not found on blockchain", serviceID)
-				return
-			}
-
-			// Decode the Peer
-			d, err := peer.Decode(service.PeerID)
+		select {
+		case <-ctx.Done():
+			return errors.New("context canceled")
+		default:
+			// Listen for an incoming connection.
+			conn, err := l.Accept()
 			if err != nil {
-				conn.Close()
-				ll.Debugf("could not decode peer '%s'", service.PeerID)
-				return
+				ll.Error("Error accepting: ", err.Error())
+				continue
 			}
 
-			// Open a stream
-			stream, err := node.Host().NewStream(context.Background(), d, protocol.ServiceProtocol.ID())
-			if err != nil {
-				conn.Close()
-				ll.Debugf("could not open stream '%s'", err.Error())
-				return
-			}
-			ll.Debugf("(service %s) Redirecting", serviceID, l.Addr().String())
+			ll.Info("New connection from", l.Addr().String())
+			// Handle connections in a new goroutine, forwarding to the p2p service
+			go func() {
+				// Retrieve current ID for ip in the blockchain
+				existingValue, found := ledger.GetKey(protocol.ServicesLedgerKey, serviceID)
+				service := &types.Service{}
+				existingValue.Unmarshal(service)
+				// If mismatch, update the blockchain
+				if !found {
+					conn.Close()
+					ll.Debugf("service '%s' not found on blockchain", serviceID)
+					return
+				}
 
-			closer := make(chan struct{}, 2)
-			go copyStream(closer, stream, conn)
-			go copyStream(closer, conn, stream)
-			<-closer
+				// Decode the Peer
+				d, err := peer.Decode(service.PeerID)
+				if err != nil {
+					conn.Close()
+					ll.Debugf("could not decode peer '%s'", service.PeerID)
+					return
+				}
 
-			stream.Close()
-			conn.Close()
-			ll.Infof("(service %s) Done handling %s", serviceID, l.Addr().String())
-		}()
+				// Open a stream
+				stream, err := node.Host().NewStream(context.Background(), d, protocol.ServiceProtocol.ID())
+				if err != nil {
+					conn.Close()
+					ll.Debugf("could not open stream '%s'", err.Error())
+					return
+				}
+				ll.Debugf("(service %s) Redirecting", serviceID, l.Addr().String())
+
+				closer := make(chan struct{}, 2)
+				go copyStream(closer, stream, conn)
+				go copyStream(closer, conn, stream)
+				<-closer
+
+				stream.Close()
+				conn.Close()
+				ll.Infof("(service %s) Done handling %s", serviceID, l.Addr().String())
+			}()
+		}
 	}
 }