Browse Source

All that cgo crap compiles!

Adam Ierymenko 5 years ago
parent
commit
ccc9be2d4d
3 changed files with 107 additions and 39 deletions
  1. 12 0
      go/pkg/zerotier/multicastgroup.go
  2. 60 11
      go/pkg/zerotier/network.go
  3. 35 28
      go/pkg/zerotier/node.go

+ 12 - 0
go/pkg/zerotier/multicastgroup.go

@@ -18,3 +18,15 @@ type MulticastGroup struct {
 	MAC MAC
 	ADI uint32
 }
+
+// Less returns true if this MulticastGroup is less than another.
+func (mg *MulticastGroup) Less(mg2 *MulticastGroup) bool {
+	return (mg.MAC < mg2.MAC || (mg.MAC == mg2.MAC && mg.ADI < mg2.ADI))
+}
+
+// key returns an array usable as a key for a map[]
+func (mg *MulticastGroup) key() (k [2]uint64) {
+	k[0] = uint64(mg.MAC)
+	k[1] = uint64(mg.ADI)
+	return
+}

+ 60 - 11
go/pkg/zerotier/network.go

@@ -17,6 +17,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"net"
+	"sort"
 	"strconv"
 	"sync"
 )
@@ -111,18 +112,22 @@ type NetworkLocalSettings struct {
 
 // Network is a currently joined network
 type Network struct {
-	id         NetworkID
-	tap        Tap
-	config     NetworkConfig
-	settings   NetworkLocalSettings // locked by configLock
-	configLock sync.RWMutex
+	node                       *Node
+	id                         NetworkID
+	tap                        Tap
+	config                     NetworkConfig
+	settings                   NetworkLocalSettings // locked by configLock
+	multicastSubscriptions     map[[2]uint64]*MulticastGroup
+	configLock                 sync.RWMutex
+	multicastSubscriptionsLock sync.RWMutex
 }
 
-// NewNetwork creates a new network with default settings
-func NewNetwork(id NetworkID, t Tap) (*Network, error) {
-	return &Network{
-		id:  id,
-		tap: t,
+// newNetwork creates a new network with default settings
+func newNetwork(node *Node, id NetworkID, t Tap) (*Network, error) {
+	n := &Network{
+		node: node,
+		id:   id,
+		tap:  t,
 		config: NetworkConfig{
 			ID:     id,
 			Status: NetworkStatusRequestConfiguration,
@@ -134,7 +139,18 @@ func NewNetwork(id NetworkID, t Tap) (*Network, error) {
 			AllowGlobalRoutes:         false,
 			AllowDefaultRouteOverride: false,
 		},
-	}, nil
+		multicastSubscriptions: make(map[[2]uint64]*MulticastGroup),
+	}
+
+	t.AddMulticastGroupChangeHandler(func(added bool, mg *MulticastGroup) {
+		if added {
+			n.MulticastSubscribe(mg)
+		} else {
+			n.MulticastUnsubscribe(mg)
+		}
+	})
+
+	return n, nil
 }
 
 // ID gets this network's unique ID
@@ -153,6 +169,39 @@ func (n *Network) Tap() Tap { return n.tap }
 // SetLocalSettings modifies this network's local settings
 func (n *Network) SetLocalSettings(ls *NetworkLocalSettings) { n.updateConfig(nil, ls) }
 
+// MulticastSubscribe subscribes to a multicast group
+func (n *Network) MulticastSubscribe(mg *MulticastGroup) {
+	k := mg.key()
+	n.multicastSubscriptionsLock.Lock()
+	if _, have := n.multicastSubscriptions[k]; have {
+		n.multicastSubscriptionsLock.Unlock()
+		return
+	}
+	n.multicastSubscriptions[k] = mg
+	n.multicastSubscriptionsLock.Unlock()
+	n.node.multicastSubscribe(uint64(n.id), mg)
+}
+
+// MulticastUnsubscribe removes a subscription to a multicast group
+func (n *Network) MulticastUnsubscribe(mg *MulticastGroup) {
+	n.multicastSubscriptionsLock.Lock()
+	delete(n.multicastSubscriptions, mg.key())
+	n.multicastSubscriptionsLock.Unlock()
+	n.node.multicastUnsubscribe(uint64(n.id), mg)
+}
+
+// MulticastSubscriptions returns an array of all multicast subscriptions for this network
+func (n *Network) MulticastSubscriptions() []*MulticastGroup {
+	n.multicastSubscriptionsLock.RLock()
+	mgs := make([]*MulticastGroup, 0, len(n.multicastSubscriptions))
+	for _, mg := range n.multicastSubscriptions {
+		mgs = append(mgs, mg)
+	}
+	n.multicastSubscriptionsLock.RUnlock()
+	sort.Slice(mgs, func(a, b int) bool { return mgs[a].Less(mgs[b]) })
+	return mgs
+}
+
 func (n *Network) networkConfigRevision() uint64 {
 	n.configLock.RLock()
 	defer n.configLock.RUnlock()

+ 35 - 28
go/pkg/zerotier/node.go

@@ -40,8 +40,6 @@ const (
 	NetworkStatusOK                   int = C.ZT_NETWORK_STATUS_OK
 	NetworkStatusAccessDenied         int = C.ZT_NETWORK_STATUS_ACCESS_DENIED
 	NetworkStatusNotFound             int = C.ZT_NETWORK_STATUS_NOT_FOUND
-	NetworkStatusPortError            int = C.ZT_NETWORK_STATUS_PORT_ERROR
-	NetworkStatusClientTooOld         int = C.ZT_NETWORK_STATUS_CLIENT_TOO_OLD
 
 	NetworkTypePrivate int = C.ZT_NETWORK_TYPE_PRIVATE
 	NetworkTypePublic  int = C.ZT_NETWORK_TYPE_PUBLIC
@@ -200,7 +198,7 @@ func (n *Node) Join(nwid uint64, tap Tap) (*Network, error) {
 		return nil, ErrTapInitFailed
 	}
 
-	nw, err := NewNetwork(NetworkID(nwid), &nativeTap{tap: unsafe.Pointer(ntap), enabled: 1})
+	nw, err := newNetwork(n, NetworkID(nwid), &nativeTap{tap: unsafe.Pointer(ntap), enabled: 1})
 	if err != nil {
 		C.ZT_GoNode_leave(n.gn, C.uint64_t(nwid))
 		return nil, err
@@ -224,19 +222,19 @@ func (n *Node) Leave(nwid uint64) error {
 // AddStaticRoot adds a statically defined root server to this node.
 // If a static root with the given identity already exists this will update its IP and port information.
 func (n *Node) AddStaticRoot(id *Identity, addrs []net.Addr) {
-	var saddrs []*C.struct_sockaddr_storage
+	var saddrs []C.struct_sockaddr_storage
 	for _, a := range addrs {
 		aa, _ := a.(*net.UDPAddr)
 		if aa != nil {
 			ss := new(C.struct_sockaddr_storage)
 			if makeSockaddrStorage(aa.IP, aa.Port, ss) {
-				saddrs = append(saddrs, ss)
+				saddrs = append(saddrs, *ss)
 			}
 		}
 	}
 	if len(saddrs) > 0 {
 		ids := C.CString(id.String())
-		C.ZT_Node_setStaticRoot(n.zn, ids, &saddrs[0], C.uint(len(saddrs)))
+		C.ZT_Node_setStaticRoot(unsafe.Pointer(n.zn), ids, &saddrs[0], C.uint(len(saddrs)))
 		C.free(unsafe.Pointer(ids))
 	}
 }
@@ -244,7 +242,7 @@ func (n *Node) AddStaticRoot(id *Identity, addrs []net.Addr) {
 // RemoveStaticRoot removes a statically defined root server from this node.
 func (n *Node) RemoveStaticRoot(id *Identity) {
 	ids := C.CString(id.String())
-	C.ZT_Node_removeStaticRoot(n.zn, ids)
+	C.ZT_Node_removeStaticRoot(unsafe.Pointer(n.zn), ids)
 	C.free(unsafe.Pointer(ids))
 }
 
@@ -254,9 +252,9 @@ func (n *Node) RemoveStaticRoot(id *Identity) {
 func (n *Node) AddDynamicRoot(dnsName string, locator []byte) {
 	dn := C.CString(dnsName)
 	if len(locator) > 0 {
-		C.ZT_Node_setDynamicRoot(n.zn, dn, unsafe.Pointer(&locator[0]), C.uint(len(locator)))
+		C.ZT_Node_setDynamicRoot(unsafe.Pointer(n.zn), dn, unsafe.Pointer(&locator[0]), C.uint(len(locator)))
 	} else {
-		C.ZT_Node_setDynamicRoot(n.zn, dn, nil, 0)
+		C.ZT_Node_setDynamicRoot(unsafe.Pointer(n.zn), dn, nil, 0)
 	}
 	C.free(unsafe.Pointer(dn))
 }
@@ -264,41 +262,50 @@ func (n *Node) AddDynamicRoot(dnsName string, locator []byte) {
 // RemoveDynamicRoot removes a dynamic root from this node.
 func (n *Node) RemoveDynamicRoot(dnsName string) {
 	dn := C.CString(dnsName)
-	C.ZT_Node_removeDynamicRoot(n.zn, dn)
+	C.ZT_Node_removeDynamicRoot(unsafe.Pointer(n.zn), dn)
 	C.free(unsafe.Pointer(dn))
 }
 
 // ListRoots retrieves a list of root servers on this node and their preferred and online status.
 func (n *Node) ListRoots() []Root {
 	var roots []Root
-	rl := C.ZT_Node_listRoots(n.zn, C.int64_t(TimeMs()))
+	rl := C.ZT_Node_listRoots(unsafe.Pointer(n.zn), C.int64_t(TimeMs()))
 	if rl != nil {
 		for i := 0; i < int(rl.count); i++ {
-			id, err := NewIdentityFromString(C.GoString(rl.roots[i].identity))
+			root := (*C.ZT_Root)(unsafe.Pointer(uintptr(unsafe.Pointer(rl)) + C.sizeof_ZT_RootList))
+			id, err := NewIdentityFromString(C.GoString(root.identity))
 			if err == nil {
 				var addrs []net.Addr
-				for j := 0; j < int(rl.roots[i].addressCount); j++ {
-					a := sockaddrStorageToUDPAddr(&rl.roots[i].addresses[j])
+				for j := uintptr(0); j < uintptr(root.addressCount); j++ {
+					a := sockaddrStorageToUDPAddr((*C.struct_sockaddr_storage)(unsafe.Pointer(uintptr(unsafe.Pointer(root.addresses)) + (j * C.sizeof_struct_sockaddr_storage))))
 					if a != nil {
 						addrs = append(addrs, a)
 					}
 				}
 				roots = append(roots, Root{
-					DNSName:   C.GoString(rl.roots[i].dnsName),
+					DNSName:   C.GoString(root.dnsName),
 					Identity:  id,
 					Addresses: addrs,
-					Preferred: (rl.roots[i].preferred != 0),
-					Online:    (rl.roots[i].online != 0),
+					Preferred: (root.preferred != 0),
+					Online:    (root.online != 0),
 				})
 			}
 		}
-		defer C.ZT_Node_freeQueryResult(n.zn, unsafe.Pointer(rl))
+		defer C.ZT_Node_freeQueryResult(unsafe.Pointer(n.zn), unsafe.Pointer(rl))
 	}
 	return roots
 }
 
 //////////////////////////////////////////////////////////////////////////////
 
+func (n *Node) multicastSubscribe(nwid uint64, mg *MulticastGroup) {
+	C.ZT_Node_multicastSubscribe(unsafe.Pointer(n.zn), nil, C.uint64_t(nwid), C.uint64_t(mg.MAC), C.ulong(mg.ADI))
+}
+
+func (n *Node) multicastUnsubscribe(nwid uint64, mg *MulticastGroup) {
+	C.ZT_Node_multicastUnsubscribe(unsafe.Pointer(n.zn), C.uint64_t(nwid), C.uint64_t(mg.MAC), C.ulong(mg.ADI))
+}
+
 func (n *Node) pathCheck(ztAddress uint64, af int, ip net.IP, port int) bool {
 	return true
 }
@@ -500,9 +507,9 @@ func goVirtualNetworkConfigFunc(gn, tapP unsafe.Pointer, nwid C.uint64_t, op C.i
 				return
 			}
 			var nc NetworkConfig
-			nc.ID = uint64(ncc.nwid)
+			nc.ID = NetworkID(ncc.nwid)
 			nc.MAC = MAC(ncc.mac)
-			nc.Name = C.GoString(ncc.name)
+			nc.Name = C.GoString(&ncc.name[0])
 			nc.Status = int(ncc.status)
 			nc.Type = int(ncc._type)
 			nc.MTU = int(ncc.mtu)
@@ -698,16 +705,16 @@ func (t *nativeTap) AddRoute(r *Route) error {
 		if len(r.Target.IP) == 4 {
 			mask, _ := r.Target.Mask.Size()
 			if len(r.Via) == 4 {
-				rc = int(C.ZT_GoTap_addRoute(t.tap, afInet, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), afInet, unsafe.Pointer(&r.Via[0]), C.int(r.Metric)))
+				rc = int(C.ZT_GoTap_addRoute(t.tap, afInet, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), afInet, unsafe.Pointer(&r.Via[0]), C.uint(r.Metric)))
 			} else {
-				rc = int(C.ZT_GoTap_addRoute(t.tap, afInet, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), 0, nil, C.int(r.Metric)))
+				rc = int(C.ZT_GoTap_addRoute(t.tap, afInet, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), 0, nil, C.uint(r.Metric)))
 			}
 		} else if len(r.Target.IP) == 16 {
 			mask, _ := r.Target.Mask.Size()
 			if len(r.Via) == 4 {
-				rc = int(C.ZT_GoTap_addRoute(t.tap, afInet6, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), afInet6, unsafe.Pointer(&r.Via[0]), C.int(r.Metric)))
+				rc = int(C.ZT_GoTap_addRoute(t.tap, afInet6, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), afInet6, unsafe.Pointer(&r.Via[0]), C.uint(r.Metric)))
 			} else {
-				rc = int(C.ZT_GoTap_addRoute(t.tap, afInet6, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), 0, nil, C.int(r.Metric)))
+				rc = int(C.ZT_GoTap_addRoute(t.tap, afInet6, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), 0, nil, C.uint(r.Metric)))
 			}
 		}
 	}
@@ -724,16 +731,16 @@ func (t *nativeTap) RemoveRoute(r *Route) error {
 		if len(r.Target.IP) == 4 {
 			mask, _ := r.Target.Mask.Size()
 			if len(r.Via) == 4 {
-				rc = int(C.ZT_GoTap_removeRoute(t.tap, afInet, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), afInet, unsafe.Pointer(&r.Via[0]), C.int(r.Metric)))
+				rc = int(C.ZT_GoTap_removeRoute(t.tap, afInet, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), afInet, unsafe.Pointer(&r.Via[0]), C.uint(r.Metric)))
 			} else {
-				rc = int(C.ZT_GoTap_removeRoute(t.tap, afInet, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), 0, nil, C.int(r.Metric)))
+				rc = int(C.ZT_GoTap_removeRoute(t.tap, afInet, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), 0, nil, C.uint(r.Metric)))
 			}
 		} else if len(r.Target.IP) == 16 {
 			mask, _ := r.Target.Mask.Size()
 			if len(r.Via) == 4 {
-				rc = int(C.ZT_GoTap_removeRoute(t.tap, afInet6, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), afInet6, unsafe.Pointer(&r.Via[0]), C.int(r.Metric)))
+				rc = int(C.ZT_GoTap_removeRoute(t.tap, afInet6, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), afInet6, unsafe.Pointer(&r.Via[0]), C.uint(r.Metric)))
 			} else {
-				rc = int(C.ZT_GoTap_removeRoute(t.tap, afInet6, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), 0, nil, C.int(r.Metric)))
+				rc = int(C.ZT_GoTap_removeRoute(t.tap, afInet6, unsafe.Pointer(&r.Target.IP[0]), C.int(mask), 0, nil, C.uint(r.Metric)))
 			}
 		}
 	}