Browse Source

Fix callbacks

Adam Ierymenko 5 years ago
parent
commit
e2f3996843
3 changed files with 41 additions and 26 deletions
  1. 25 18
      go/native/GoGlue.cpp
  2. 1 1
      go/native/GoGlue.h
  3. 15 7
      go/pkg/zerotier/node.go

+ 25 - 18
go/native/GoGlue.cpp

@@ -88,6 +88,7 @@ struct ZT_GoNodeThread
 
 
 struct ZT_GoNode_Impl
 struct ZT_GoNode_Impl
 {
 {
+	void *goUserPtr;
 	Node *node;
 	Node *node;
 	volatile int64_t nextBackgroundTaskDeadline;
 	volatile int64_t nextBackgroundTaskDeadline;
 
 
@@ -109,15 +110,15 @@ const char *ZT_PLATFORM_DEFAULT_HOMEPATH = defaultHomePath.c_str();
 /****************************************************************************/
 /****************************************************************************/
 
 
 /* These functions are implemented in Go in pkg/ztnode/node-callbacks.go */
 /* These functions are implemented in Go in pkg/ztnode/node-callbacks.go */
-extern "C" int goPathCheckFunc(ZT_GoNode *,uint64_t,int,const void *,int);
-extern "C" int goPathLookupFunc(ZT_GoNode *,uint64_t,int,int *,uint8_t [16],int *);
-extern "C" void goStateObjectPutFunc(ZT_GoNode *,int,const uint64_t [2],const void *,int);
-extern "C" int goStateObjectGetFunc(ZT_GoNode *,int,const uint64_t [2],void *,unsigned int);
-extern "C" void goDNSResolverFunc(ZT_GoNode *,const uint8_t *,int,const char *,uintptr_t);
-extern "C" void goVirtualNetworkConfigFunc(ZT_GoNode *,ZT_GoTap *,uint64_t,int,const ZT_VirtualNetworkConfig *);
-extern "C" void goZtEvent(ZT_GoNode *,int,const void *);
-extern "C" void goHandleTapAddedMulticastGroup(ZT_GoNode *,ZT_GoTap *,uint64_t,uint64_t,uint32_t);
-extern "C" void goHandleTapRemovedMulticastGroup(ZT_GoNode *,ZT_GoTap *,uint64_t,uint64_t,uint32_t);
+extern "C" int goPathCheckFunc(void *,uint64_t,int,const void *,int);
+extern "C" int goPathLookupFunc(void *,uint64_t,int,int *,uint8_t [16],int *);
+extern "C" void goStateObjectPutFunc(void *,int,const uint64_t [2],const void *,int);
+extern "C" int goStateObjectGetFunc(void *,int,const uint64_t [2],void *,unsigned int);
+extern "C" void goDNSResolverFunc(void *,const uint8_t *,int,const char *,uintptr_t);
+extern "C" void goVirtualNetworkConfigFunc(void *,ZT_GoTap *,uint64_t,int,const ZT_VirtualNetworkConfig *);
+extern "C" void goZtEvent(void *,int,const void *);
+extern "C" void goHandleTapAddedMulticastGroup(void *,ZT_GoTap *,uint64_t,uint64_t,uint32_t);
+extern "C" void goHandleTapRemovedMulticastGroup(void *,ZT_GoTap *,uint64_t,uint64_t,uint32_t);
 
 
 static void ZT_GoNode_VirtualNetworkConfigFunction(
 static void ZT_GoNode_VirtualNetworkConfigFunction(
 	ZT_Node *node,
 	ZT_Node *node,
@@ -128,7 +129,7 @@ static void ZT_GoNode_VirtualNetworkConfigFunction(
 	enum ZT_VirtualNetworkConfigOperation op,
 	enum ZT_VirtualNetworkConfigOperation op,
 	const ZT_VirtualNetworkConfig *cfg)
 	const ZT_VirtualNetworkConfig *cfg)
 {
 {
-	goVirtualNetworkConfigFunc(reinterpret_cast<ZT_GoNode *>(uptr),reinterpret_cast<ZT_GoTap *>(*nptr),nwid,op,cfg);
+	goVirtualNetworkConfigFunc(reinterpret_cast<ZT_GoNode *>(uptr)->goUserPtr,reinterpret_cast<ZT_GoTap *>(*nptr),nwid,op,cfg);
 }
 }
 
 
 static void ZT_GoNode_VirtualNetworkFrameFunction(
 static void ZT_GoNode_VirtualNetworkFrameFunction(
@@ -155,7 +156,7 @@ static void ZT_GoNode_EventCallback(
 	enum ZT_Event et,
 	enum ZT_Event et,
 	const void *data)
 	const void *data)
 {
 {
-	goZtEvent(reinterpret_cast<ZT_GoNode *>(uptr),et,data);
+	goZtEvent(reinterpret_cast<ZT_GoNode *>(uptr)->goUserPtr,et,data);
 }
 }
 
 
 static void ZT_GoNode_StatePutFunction(
 static void ZT_GoNode_StatePutFunction(
@@ -167,7 +168,12 @@ static void ZT_GoNode_StatePutFunction(
 	const void *data,
 	const void *data,
 	int len)
 	int len)
 {
 {
-	goStateObjectPutFunc(reinterpret_cast<ZT_GoNode *>(uptr),objType,id,data,len);
+	goStateObjectPutFunc(
+		reinterpret_cast<ZT_GoNode *>(uptr)->goUserPtr,
+		objType,
+		id,
+		data,
+		len);
 }
 }
 
 
 static int ZT_GoNode_StateGetFunction(
 static int ZT_GoNode_StateGetFunction(
@@ -180,7 +186,7 @@ static int ZT_GoNode_StateGetFunction(
 	unsigned int buflen)
 	unsigned int buflen)
 {
 {
 	return goStateObjectGetFunc(
 	return goStateObjectGetFunc(
-		reinterpret_cast<ZT_GoNode *>(uptr),
+		reinterpret_cast<ZT_GoNode *>(uptr)->goUserPtr,
 		(int)objType,
 		(int)objType,
 		id,
 		id,
 		buf,
 		buf,
@@ -252,14 +258,14 @@ static int ZT_GoNode_PathCheckFunction(
 	switch(sa->ss_family) {
 	switch(sa->ss_family) {
 		case AF_INET:
 		case AF_INET:
 			return goPathCheckFunc(
 			return goPathCheckFunc(
-				reinterpret_cast<ZT_GoNode *>(uptr),
+				reinterpret_cast<ZT_GoNode *>(uptr)->goUserPtr,
 				ztAddress,
 				ztAddress,
 				AF_INET,
 				AF_INET,
 				&(reinterpret_cast<const struct sockaddr_in *>(sa)->sin_addr.s_addr),
 				&(reinterpret_cast<const struct sockaddr_in *>(sa)->sin_addr.s_addr),
 				Utils::ntoh((uint16_t)reinterpret_cast<const struct sockaddr_in *>(sa)->sin_port));
 				Utils::ntoh((uint16_t)reinterpret_cast<const struct sockaddr_in *>(sa)->sin_port));
 		case AF_INET6:
 		case AF_INET6:
 			return goPathCheckFunc(
 			return goPathCheckFunc(
-				reinterpret_cast<ZT_GoNode *>(uptr),
+				reinterpret_cast<ZT_GoNode *>(uptr)->goUserPtr,
 				ztAddress,
 				ztAddress,
 				AF_INET6,
 				AF_INET6,
 				reinterpret_cast<const struct sockaddr_in6 *>(sa)->sin6_addr.s6_addr,
 				reinterpret_cast<const struct sockaddr_in6 *>(sa)->sin6_addr.s6_addr,
@@ -280,7 +286,7 @@ static int ZT_GoNode_PathLookupFunction(
 	uint8_t ip[16];
 	uint8_t ip[16];
 	int port = 0;
 	int port = 0;
 	const int result = goPathLookupFunc(
 	const int result = goPathLookupFunc(
-		reinterpret_cast<ZT_GoNode *>(uptr),
+		reinterpret_cast<ZT_GoNode *>(uptr)->goUserPtr,
 		ztAddress,
 		ztAddress,
 		desiredAddressFamily,
 		desiredAddressFamily,
 		&family,
 		&family,
@@ -315,12 +321,12 @@ static void ZT_GoNode_DNSResolver(
 {
 {
 	uint8_t t[256];
 	uint8_t t[256];
 	for(unsigned int i=0;(i<numTypes)&&(i<256);++i) t[i] = (uint8_t)types[i];
 	for(unsigned int i=0;(i<numTypes)&&(i<256);++i) t[i] = (uint8_t)types[i];
-	goDNSResolverFunc(reinterpret_cast<ZT_GoNode *>(uptr),t,(int)numTypes,name,requestId);
+	goDNSResolverFunc(reinterpret_cast<ZT_GoNode *>(uptr)->goUserPtr,t,(int)numTypes,name,requestId);
 }
 }
 
 
 /****************************************************************************/
 /****************************************************************************/
 
 
-extern "C" ZT_GoNode *ZT_GoNode_new(const char *workingPath)
+extern "C" ZT_GoNode *ZT_GoNode_new(const char *workingPath,uintptr_t userPtr)
 {
 {
 	try {
 	try {
 		struct ZT_Node_Callbacks cb;
 		struct ZT_Node_Callbacks cb;
@@ -336,6 +342,7 @@ extern "C" ZT_GoNode *ZT_GoNode_new(const char *workingPath)
 
 
 		ZT_GoNode_Impl *gn = new ZT_GoNode_Impl;
 		ZT_GoNode_Impl *gn = new ZT_GoNode_Impl;
 		const int64_t now = OSUtils::now();
 		const int64_t now = OSUtils::now();
+		gn->goUserPtr = reinterpret_cast<void *>(userPtr);
 		gn->node = new Node(reinterpret_cast<void *>(gn),nullptr,&cb,now);
 		gn->node = new Node(reinterpret_cast<void *>(gn),nullptr,&cb,now);
 		gn->nextBackgroundTaskDeadline = now;
 		gn->nextBackgroundTaskDeadline = now;
 		gn->path = workingPath;
 		gn->path = workingPath;

+ 1 - 1
go/native/GoGlue.h

@@ -44,7 +44,7 @@ extern const char *ZT_PLATFORM_DEFAULT_HOMEPATH;
 
 
 /****************************************************************************/
 /****************************************************************************/
 
 
-ZT_GoNode *ZT_GoNode_new(const char *workingPath);
+ZT_GoNode *ZT_GoNode_new(const char *workingPath,uintptr_t userPtr);
 
 
 void ZT_GoNode_delete(ZT_GoNode *gn);
 void ZT_GoNode_delete(ZT_GoNode *gn);
 
 

+ 15 - 7
go/pkg/zerotier/node.go

@@ -267,11 +267,18 @@ func NewNode(basePath string) (*Node, error) {
 		return nil, errors.New("unable to bind to primary port")
 		return nil, errors.New("unable to bind to primary port")
 	}
 	}
 
 
+	nodesByUserPtrLock.Lock()
+	nodesByUserPtr[uintptr(unsafe.Pointer(n))] = n
+	nodesByUserPtrLock.Unlock()
+
 	cPath := C.CString(basePath)
 	cPath := C.CString(basePath)
-	n.gn = C.ZT_GoNode_new(cPath)
+	n.gn = C.ZT_GoNode_new(cPath, C.uintptr_t(uintptr(unsafe.Pointer(n))))
 	C.free(unsafe.Pointer(cPath))
 	C.free(unsafe.Pointer(cPath))
 	if n.gn == nil {
 	if n.gn == nil {
 		n.log.Println("FATAL: node initialization failed")
 		n.log.Println("FATAL: node initialization failed")
+		nodesByUserPtrLock.Lock()
+		delete(nodesByUserPtr, uintptr(unsafe.Pointer(n)))
+		nodesByUserPtrLock.Unlock()
 		return nil, ErrNodeInitFailed
 		return nil, ErrNodeInitFailed
 	}
 	}
 	n.zn = (*C.ZT_Node)(C.ZT_GoNode_getNode(n.gn))
 	n.zn = (*C.ZT_Node)(C.ZT_GoNode_getNode(n.gn))
@@ -282,6 +289,9 @@ func NewNode(basePath string) (*Node, error) {
 	n.id, err = NewIdentityFromString(idString)
 	n.id, err = NewIdentityFromString(idString)
 	if err != nil {
 	if err != nil {
 		n.log.Printf("FATAL: node's identity does not seem valid (%s)", string(idString))
 		n.log.Printf("FATAL: node's identity does not seem valid (%s)", string(idString))
+		nodesByUserPtrLock.Lock()
+		delete(nodesByUserPtr, uintptr(unsafe.Pointer(n)))
+		nodesByUserPtrLock.Unlock()
 		C.ZT_GoNode_delete(n.gn)
 		C.ZT_GoNode_delete(n.gn)
 		return nil, err
 		return nil, err
 	}
 	}
@@ -289,15 +299,13 @@ func NewNode(basePath string) (*Node, error) {
 	n.apiServer, n.tcpApiServer, err = createAPIServer(basePath, n)
 	n.apiServer, n.tcpApiServer, err = createAPIServer(basePath, n)
 	if err != nil {
 	if err != nil {
 		n.log.Printf("FATAL: unable to start API server: %s", err.Error())
 		n.log.Printf("FATAL: unable to start API server: %s", err.Error())
+		nodesByUserPtrLock.Lock()
+		delete(nodesByUserPtr, uintptr(unsafe.Pointer(n)))
+		nodesByUserPtrLock.Unlock()
 		C.ZT_GoNode_delete(n.gn)
 		C.ZT_GoNode_delete(n.gn)
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	gnRawAddr := uintptr(unsafe.Pointer(n.gn))
-	nodesByUserPtrLock.Lock()
-	nodesByUserPtr[gnRawAddr] = n
-	nodesByUserPtrLock.Unlock()
-
 	n.online = 0
 	n.online = 0
 	n.running = 1
 	n.running = 1
 
 
@@ -411,7 +419,7 @@ func (n *Node) Close() {
 		n.runLock.Unlock()
 		n.runLock.Unlock()
 
 
 		nodesByUserPtrLock.Lock()
 		nodesByUserPtrLock.Lock()
-		delete(nodesByUserPtr, uintptr(unsafe.Pointer(n.gn)))
+		delete(nodesByUserPtr, uintptr(unsafe.Pointer(n)))
 		nodesByUserPtrLock.Unlock()
 		nodesByUserPtrLock.Unlock()
 	}
 	}
 }
 }