Parcourir la source

:rocket: Allow to blacklist cidrs/peers

Ettore Di Giacinto il y a 3 ans
Parent
commit
f63c1da979
6 fichiers modifiés avec 76 ajouts et 11 suppressions
  1. 6 0
      cmd/util.go
  2. 2 0
      pkg/node/config.go
  3. 26 6
      pkg/node/connection.go
  4. 3 2
      pkg/node/node.go
  5. 32 3
      pkg/node/node_test.go
  6. 7 0
      pkg/node/options.go

+ 6 - 0
cmd/util.go

@@ -160,6 +160,11 @@ var CommonFlags []cli.Flag = []cli.Flag{
 		Usage:  "List of discovery peers to use",
 		Usage:  "List of discovery peers to use",
 		EnvVar: "EDGEVPNBOOTSTRAPPEERS",
 		EnvVar: "EDGEVPNBOOTSTRAPPEERS",
 	},
 	},
+	&cli.StringSliceFlag{
+		Name:   "blacklist",
+		Usage:  "List of peers/cidr to gate",
+		EnvVar: "EDGEVPNBLACKLIST",
+	},
 	&cli.StringFlag{
 	&cli.StringFlag{
 		Name:   "token",
 		Name:   "token",
 		Usage:  "Specify an edgevpn token in place of a config file",
 		Usage:  "Specify an edgevpn token in place of a config file",
@@ -216,6 +221,7 @@ func cliToOpts(c *cli.Context) ([]node.Option, []vpn.Option, *logger.Logger) {
 		node.WithLedgerInterval(time.Duration(c.Int("ledger-syncronization-interval")) * time.Second),
 		node.WithLedgerInterval(time.Duration(c.Int("ledger-syncronization-interval")) * time.Second),
 		node.Logger(llger),
 		node.Logger(llger),
 		node.WithDiscoveryBootstrapPeers(addrsList),
 		node.WithDiscoveryBootstrapPeers(addrsList),
+		node.WithBlacklist(c.StringSlice("blacklist")...),
 		node.LibP2PLogLevel(libp2plvl),
 		node.LibP2PLogLevel(libp2plvl),
 		node.WithInterfaceAddress(address),
 		node.WithInterfaceAddress(address),
 		node.FromBase64(mDNS, dht, token),
 		node.FromBase64(mDNS, dht, token),

+ 2 - 0
pkg/node/config.go

@@ -66,6 +66,8 @@ type Config struct {
 
 
 	DiscoveryInterval, LedgerSyncronizationTime, LedgerAnnounceTime time.Duration
 	DiscoveryInterval, LedgerSyncronizationTime, LedgerAnnounceTime time.Duration
 	DiscoveryBootstrapPeers                                         discovery.AddrList
 	DiscoveryBootstrapPeers                                         discovery.AddrList
+
+	Whitelist, Blacklist []string
 }
 }
 
 
 // NetworkService is a service running over the network. It takes a context, a node and a ledger
 // NetworkService is a service running over the network. It takes a context, a node and a ledger

+ 26 - 6
pkg/node/connection.go

@@ -25,6 +25,7 @@ import (
 	"github.com/libp2p/go-libp2p"
 	"github.com/libp2p/go-libp2p"
 	"github.com/libp2p/go-libp2p-core/crypto"
 	"github.com/libp2p/go-libp2p-core/crypto"
 	"github.com/libp2p/go-libp2p-core/host"
 	"github.com/libp2p/go-libp2p-core/host"
+	"github.com/libp2p/go-libp2p-core/peer"
 	conngater "github.com/libp2p/go-libp2p/p2p/net/conngater"
 	conngater "github.com/libp2p/go-libp2p/p2p/net/conngater"
 	hub "github.com/mudler/edgevpn/pkg/hub"
 	hub "github.com/mudler/edgevpn/pkg/hub"
 	multiaddr "github.com/multiformats/go-multiaddr"
 	multiaddr "github.com/multiformats/go-multiaddr"
@@ -36,6 +37,11 @@ func (e *Node) Host() host.Host {
 	return e.host
 	return e.host
 }
 }
 
 
+// ConnectionGater returns the underlying libp2p conngater
+func (e *Node) ConnectionGater() *conngater.BasicConnectionGater {
+	return e.cg
+}
+
 func (e *Node) genHost(ctx context.Context) (host.Host, error) {
 func (e *Node) genHost(ctx context.Context) (host.Host, error) {
 	var r io.Reader
 	var r io.Reader
 	if e.seed == 0 {
 	if e.seed == 0 {
@@ -51,23 +57,37 @@ func (e *Node) genHost(ctx context.Context) (host.Host, error) {
 
 
 	opts := e.config.Options
 	opts := e.config.Options
 
 
+	cg, err := conngater.NewBasicConnectionGater(nil)
+	if err != nil {
+		return nil, err
+	}
+
+	e.cg = cg
+
 	if e.config.InterfaceAddress != "" {
 	if e.config.InterfaceAddress != "" {
 		// Avoid to loopback traffic by trying to connect to nodes in via VPN
 		// Avoid to loopback traffic by trying to connect to nodes in via VPN
 		_, vpnNetwork, err := net.ParseCIDR(e.config.InterfaceAddress)
 		_, vpnNetwork, err := net.ParseCIDR(e.config.InterfaceAddress)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
-		cg, err := conngater.NewBasicConnectionGater(nil)
-		if err != nil {
-			return nil, err
-		}
+
 		if err := cg.BlockSubnet(vpnNetwork); err != nil {
 		if err := cg.BlockSubnet(vpnNetwork); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
-		opts = append(opts, libp2p.ConnectionGater(cg))
 	}
 	}
 
 
-	opts = append(opts, libp2p.Identity(prvKey))
+	for _, b := range e.config.Blacklist {
+		_, net, err := net.ParseCIDR(b)
+		if err != nil {
+			// Assume it's a peerID
+			cg.BlockPeer(peer.ID(b))
+		}
+		if net != nil {
+			cg.BlockSubnet(net)
+		}
+	}
+
+	opts = append(opts, libp2p.ConnectionGater(cg), libp2p.Identity(prvKey))
 
 
 	addrs := []multiaddr.Multiaddr{}
 	addrs := []multiaddr.Multiaddr{}
 	for _, l := range e.config.ListenAddresses {
 	for _, l := range e.config.ListenAddresses {

+ 3 - 2
pkg/node/node.go

@@ -23,6 +23,7 @@ import (
 	"github.com/libp2p/go-libp2p"
 	"github.com/libp2p/go-libp2p"
 	"github.com/libp2p/go-libp2p-core/host"
 	"github.com/libp2p/go-libp2p-core/host"
 	"github.com/libp2p/go-libp2p-core/network"
 	"github.com/libp2p/go-libp2p-core/network"
+	"github.com/libp2p/go-libp2p/p2p/net/conngater"
 
 
 	protocol "github.com/mudler/edgevpn/pkg/protocol"
 	protocol "github.com/mudler/edgevpn/pkg/protocol"
 
 
@@ -38,8 +39,8 @@ type Node struct {
 	inputCh chan *hub.Message
 	inputCh chan *hub.Message
 	seed    int64
 	seed    int64
 	host    host.Host
 	host    host.Host
-
-	ledger *blockchain.Ledger
+	cg      *conngater.BasicConnectionGater
+	ledger  *blockchain.Ledger
 }
 }
 
 
 var defaultLibp2pOptions = []libp2p.Option{
 var defaultLibp2pOptions = []libp2p.Option{

+ 32 - 3
pkg/node/node_test.go

@@ -34,13 +34,14 @@ var _ = Describe("Node", func() {
 
 
 	l := Logger(logger.New(log.LevelFatal))
 	l := Logger(logger.New(log.LevelFatal))
 
 
-	e := New(FromBase64(true, true, token), WithStore(&blockchain.MemoryStore{}), l)
-	e2 := New(FromBase64(true, true, token), WithStore(&blockchain.MemoryStore{}), l)
-
 	Context("Connection", func() {
 	Context("Connection", func() {
 		It("see each other node ID", func() {
 		It("see each other node ID", func() {
 			ctx, cancel := context.WithCancel(context.Background())
 			ctx, cancel := context.WithCancel(context.Background())
 			defer cancel()
 			defer cancel()
+
+			e := New(FromBase64(true, true, token), WithStore(&blockchain.MemoryStore{}), l)
+			e2 := New(FromBase64(true, true, token), WithStore(&blockchain.MemoryStore{}), l)
+
 			e.Start(ctx)
 			e.Start(ctx)
 			e2.Start(ctx)
 			e2.Start(ctx)
 
 
@@ -48,5 +49,33 @@ var _ = Describe("Node", func() {
 				return e.Host().Network().Peers()
 				return e.Host().Network().Peers()
 			}, 100*time.Second, 1*time.Second).Should(ContainElement(e2.Host().ID()))
 			}, 100*time.Second, 1*time.Second).Should(ContainElement(e2.Host().ID()))
 		})
 		})
+
+	})
+
+	Context("connection gater", func() {
+		It("blacklists", func() {
+			ctx, cancel := context.WithCancel(context.Background())
+			defer cancel()
+			e := New(
+				WithBlacklist("1.1.1.1/32", "1.1.1.0/24"),
+				FromBase64(true, true, token),
+				WithStore(&blockchain.MemoryStore{}),
+				l,
+			)
+
+			e.Start(ctx)
+			addrs := e.ConnectionGater().ListBlockedAddrs()
+			peers := e.ConnectionGater().ListBlockedPeers()
+			subs := e.ConnectionGater().ListBlockedSubnets()
+			Expect(len(addrs)).To(Equal(0))
+			Expect(len(peers)).To(Equal(0))
+			Expect(len(subs)).To(Equal(2))
+
+			ips := []string{}
+			for _, s := range subs {
+				ips = append(ips, s.String())
+			}
+			Expect(ips).To(ContainElements("1.1.1.1/32", "1.1.1.0/24"))
+		})
 	})
 	})
 })
 })

+ 7 - 0
pkg/node/options.go

@@ -60,6 +60,13 @@ func WithInterfaceAddress(i string) func(cfg *Config) error {
 	}
 	}
 }
 }
 
 
+func WithBlacklist(i ...string) func(cfg *Config) error {
+	return func(cfg *Config) error {
+		cfg.Blacklist = i
+		return nil
+	}
+}
+
 func Logger(l log.StandardLogger) func(cfg *Config) error {
 func Logger(l log.StandardLogger) func(cfg *Config) error {
 	return func(cfg *Config) error {
 	return func(cfg *Config) error {
 		cfg.Logger = l
 		cfg.Logger = l