Переглянути джерело

Fix another not-fun bug. Also exterminate a memory leak and do a few optimizations.

Adam Ierymenko 5 роки тому
батько
коміт
e5f2314055

+ 46 - 41
cmd/zt_service_tests/certificate.go

@@ -92,36 +92,39 @@ func TestCertificate() bool {
 	c.MaxPathLength = 9999
 	c.MaxPathLength = 9999
 	c.Signature = []byte("qwerty")
 	c.Signature = []byte("qwerty")
 
 
-	cc := c.CCertificate()
-	if cc == nil {
-		fmt.Println("  Error converting Certificate to ZT_Certificate")
-		return false
-	}
-
-	c2 := zerotier.NewCertificateFromCCertificate(cc.C)
-	if c2 == nil {
-		fmt.Println("  Error converting ZT_Certificate to Certificate")
-		return false
-	}
-
-	j, _ := json.Marshal(c)
-	j2, _ := json.Marshal(c2)
-	if !bytes.Equal(j, j2) {
-		j, _ = json.MarshalIndent(c, "", "  ")
-		fmt.Print("  Deep equality test failed: certificates do not match! (see dumps below)\n\n")
-		fmt.Println(string(j))
-		fmt.Println(string(j2))
-		return false
+	for k := 0; k < 1; k++ {
+		cc := c.CCertificate()
+		if cc == nil {
+			fmt.Println("  Error converting Certificate to ZT_Certificate")
+			return false
+		}
+		c2 := zerotier.NewCertificateFromCCertificate(cc)
+		if c2 == nil {
+			fmt.Println("  Error converting ZT_Certificate to Certificate")
+			return false
+		}
+		zerotier.DeleteCCertificate(cc)
+
+		j, _ := json.Marshal(c)
+		j2, _ := json.Marshal(c2)
+		if !bytes.Equal(j, j2) {
+			j, _ = json.MarshalIndent(c, "", "  ")
+			j2, _ = json.MarshalIndent(c2, "", "  ")
+			fmt.Print("  Deep equality test failed: certificates do not match! (see dumps below)\n\n")
+			fmt.Println(string(j))
+			fmt.Println(string(j2))
+			return false
+		}
 	}
 	}
 
 
 	fmt.Printf("Checking certificate marshal/unmarshal... ")
 	fmt.Printf("Checking certificate marshal/unmarshal... ")
-	for k := 0; k < 1024; k++ {
+	for k := 0; k < 1; k++ {
 		cb, err := c.Marshal()
 		cb, err := c.Marshal()
 		if err != nil {
 		if err != nil {
 			fmt.Printf("marshal FAILED (%s)\n", err.Error())
 			fmt.Printf("marshal FAILED (%s)\n", err.Error())
 			return false
 			return false
 		}
 		}
-		c2, err = zerotier.NewCertificateFromBytes(cb, false)
+		c2, err := zerotier.NewCertificateFromBytes(cb, false)
 		if err != nil {
 		if err != nil {
 			fmt.Printf("unmarshal FAILED (%s)\n", err.Error())
 			fmt.Printf("unmarshal FAILED (%s)\n", err.Error())
 			return false
 			return false
@@ -139,25 +142,27 @@ func TestCertificate() bool {
 	fmt.Println("OK")
 	fmt.Println("OK")
 
 
 	fmt.Printf("Checking certificate CSR sign/verify... ")
 	fmt.Printf("Checking certificate CSR sign/verify... ")
-	csr, err := zerotier.NewCertificateCSR(&c.Subject, uniqueId, uniqueIdPrivate)
-	if err != nil {
-		fmt.Printf("CSR generate FAILED (%s)\n", err.Error())
-		return false
-	}
-	fmt.Printf("CSR size: %d ", len(csr))
-	csr2, err := zerotier.NewCertificateFromBytes(csr, false)
-	if err != nil {
-		fmt.Printf("CSR decode FAILED (%s)\n", err.Error())
-		return false
-	}
-	signedCert, err := csr2.Sign(id)
-	if err != nil {
-		fmt.Printf("CSR sign FAILED (%s)\n", err.Error())
-		return false
-	}
-	if len(signedCert.Signature) == 0 {
-		fmt.Println("CSR sign FAILED (no signature found)", err.Error())
-		return false
+	for k := 0; k < 1; k++ {
+		csr, err := zerotier.NewCertificateCSR(&c.Subject, uniqueId, uniqueIdPrivate)
+		if err != nil {
+			fmt.Printf("CSR generate FAILED (%s)\n", err.Error())
+			return false
+		}
+		fmt.Printf("CSR size: %d ", len(csr))
+		csr2, err := zerotier.NewCertificateFromBytes(csr, false)
+		if err != nil {
+			fmt.Printf("CSR decode FAILED (%s)\n", err.Error())
+			return false
+		}
+		signedCert, err := csr2.Sign(id)
+		if err != nil {
+			fmt.Printf("CSR sign FAILED (%s)\n", err.Error())
+			return false
+		}
+		if len(signedCert.Signature) == 0 {
+			fmt.Println("CSR sign FAILED (no signature found)", err.Error())
+			return false
+		}
 	}
 	}
 	fmt.Println("OK")
 	fmt.Println("OK")
 
 

+ 3 - 2
cmd/zt_service_tests/zt_service_tests.go

@@ -1,10 +1,11 @@
 package main
 package main
 
 
-import "os"
+import (
+	"os"
+)
 
 
 func main() {
 func main() {
 	if !TestCertificate() {
 	if !TestCertificate() {
 		os.Exit(1)
 		os.Exit(1)
 	}
 	}
 }
 }
-

+ 19 - 27
core/Certificate.cpp

@@ -39,28 +39,13 @@ Certificate &Certificate::operator=(const ZT_Certificate &cert)
 {
 {
 	m_clear();
 	m_clear();
 
 
-	ZT_Certificate *const sup = this;
-	Utils::copy< sizeof(ZT_Certificate) >(sup, &cert);
-
-	// Zero these since we must explicitly attach all the objects from
-	// the other certificate to copy them into our containers.
-	this->subject.identities = nullptr;
-	this->subject.identityCount = 0;
-	this->subject.networks = nullptr;
-	this->subject.networkCount = 0;
-	this->subject.certificates = nullptr;
-	this->subject.certificateCount = 0;
-	this->subject.updateURLs = nullptr;
-	this->subject.updateURLCount = 0;
-	this->subject.uniqueId = nullptr;
-	this->subject.uniqueIdProofSignature = nullptr;
-	this->subject.uniqueIdSize = 0;
-	this->subject.uniqueIdProofSignatureSize = 0;
-	this->extendedAttributes = nullptr;
-	this->extendedAttributesSize = 0;
-	this->issuer = nullptr;
-	this->signature = nullptr;
-	this->signatureSize = 0;
+	Utils::copy< 48 >(this->serialNo, cert.serialNo);
+	this->flags = cert.flags;
+	this->timestamp = cert.timestamp;
+	this->validity[0] = cert.validity[0];
+	this->validity[1] = cert.validity[1];
+
+	this->subject.timestamp = cert.subject.timestamp;
 
 
 	if (cert.subject.identities) {
 	if (cert.subject.identities) {
 		for (unsigned int i = 0; i < cert.subject.identityCount; ++i) {
 		for (unsigned int i = 0; i < cert.subject.identityCount; ++i) {
@@ -95,6 +80,8 @@ Certificate &Certificate::operator=(const ZT_Certificate &cert)
 		}
 		}
 	}
 	}
 
 
+	Utils::copy< sizeof(ZT_Certificate_Name) >(&(this->subject.name), &(cert.subject.name));
+
 	if ((cert.subject.uniqueId) && (cert.subject.uniqueIdSize > 0)) {
 	if ((cert.subject.uniqueId) && (cert.subject.uniqueIdSize > 0)) {
 		m_subjectUniqueId.assign(cert.subject.uniqueId, cert.subject.uniqueId + cert.subject.uniqueIdSize);
 		m_subjectUniqueId.assign(cert.subject.uniqueId, cert.subject.uniqueId + cert.subject.uniqueIdSize);
 		this->subject.uniqueId = m_subjectUniqueId.data();
 		this->subject.uniqueId = m_subjectUniqueId.data();
@@ -111,12 +98,16 @@ Certificate &Certificate::operator=(const ZT_Certificate &cert)
 		this->issuer = &(m_identities.front());
 		this->issuer = &(m_identities.front());
 	}
 	}
 
 
+	Utils::copy< sizeof(ZT_Certificate_Name) >(&(this->issuerName), &(cert.issuerName));
+
 	if ((cert.extendedAttributes) && (cert.extendedAttributesSize > 0)) {
 	if ((cert.extendedAttributes) && (cert.extendedAttributesSize > 0)) {
 		m_extendedAttributes.assign(cert.extendedAttributes, cert.extendedAttributes + cert.extendedAttributesSize);
 		m_extendedAttributes.assign(cert.extendedAttributes, cert.extendedAttributes + cert.extendedAttributesSize);
 		this->extendedAttributes = m_extendedAttributes.data();
 		this->extendedAttributes = m_extendedAttributes.data();
 		this->extendedAttributesSize = (unsigned int)m_extendedAttributes.size();
 		this->extendedAttributesSize = (unsigned int)m_extendedAttributes.size();
 	}
 	}
 
 
+	this->maxPathLength = cert.maxPathLength;
+
 	if ((cert.signature) && (cert.signatureSize > 0)) {
 	if ((cert.signature) && (cert.signatureSize > 0)) {
 		m_signature.assign(cert.signature, cert.signature + cert.signatureSize);
 		m_signature.assign(cert.signature, cert.signature + cert.signatureSize);
 		this->signature = m_signature.data();
 		this->signature = m_signature.data();
@@ -512,7 +503,7 @@ ZT_CertificateError Certificate::verify() const
 Vector< uint8_t > Certificate::createCSR(const ZT_Certificate_Subject &s, const void *uniqueId, unsigned int uniqueIdSize, const void *uniqueIdPrivate, unsigned int uniqueIdPrivateSize)
 Vector< uint8_t > Certificate::createCSR(const ZT_Certificate_Subject &s, const void *uniqueId, unsigned int uniqueIdSize, const void *uniqueIdPrivate, unsigned int uniqueIdPrivateSize)
 {
 {
 	ZT_Certificate_Subject sc;
 	ZT_Certificate_Subject sc;
-	Utils::copy< sizeof(ZT_Certificate_Subject) >(&sc,&s);
+	Utils::copy< sizeof(ZT_Certificate_Subject) >(&sc, &s);
 
 
 	if ((uniqueId) && (uniqueIdSize > 0) && (uniqueIdPrivate) && (uniqueIdPrivateSize > 0)) {
 	if ((uniqueId) && (uniqueIdSize > 0) && (uniqueIdPrivate) && (uniqueIdPrivateSize > 0)) {
 		sc.uniqueId = reinterpret_cast<const uint8_t *>(uniqueId);
 		sc.uniqueId = reinterpret_cast<const uint8_t *>(uniqueId);
@@ -531,14 +522,15 @@ Vector< uint8_t > Certificate::createCSR(const ZT_Certificate_Subject &s, const
 		uint8_t h[ZT_SHA384_DIGEST_SIZE];
 		uint8_t h[ZT_SHA384_DIGEST_SIZE];
 		SHA384(h, enc.data(), (unsigned int)enc.size());
 		SHA384(h, enc.data(), (unsigned int)enc.size());
 		enc.clear();
 		enc.clear();
-		if (
-			(reinterpret_cast<const uint8_t *>(uniqueId)[0] == ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384) &&
-			(uniqueIdSize == ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384_SIZE) &&
-			(uniqueIdPrivateSize == ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384_PRIVATE_SIZE)) {
+		if ((reinterpret_cast<const uint8_t *>(uniqueId)[0] == ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384) &&
+		    (uniqueIdSize == ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384_SIZE) &&
+		    (uniqueIdPrivateSize == ZT_CERTIFICATE_UNIQUE_ID_TYPE_NIST_P_384_PRIVATE_SIZE)) {
 			uint8_t sig[ZT_ECC384_SIGNATURE_SIZE];
 			uint8_t sig[ZT_ECC384_SIGNATURE_SIZE];
 			ECC384ECDSASign(reinterpret_cast<const uint8_t *>(uniqueIdPrivate), h, sig);
 			ECC384ECDSASign(reinterpret_cast<const uint8_t *>(uniqueIdPrivate), h, sig);
+
 			sc.uniqueIdProofSignature = sig;
 			sc.uniqueIdProofSignature = sig;
 			sc.uniqueIdProofSignatureSize = ZT_ECC384_SIGNATURE_SIZE;
 			sc.uniqueIdProofSignatureSize = ZT_ECC384_SIGNATURE_SIZE;
+
 			d.clear();
 			d.clear();
 			m_encodeSubject(sc, d, false);
 			m_encodeSubject(sc, d, false);
 			d.encode(enc);
 			d.encode(enc);

+ 2 - 2
core/Dictionary.hpp

@@ -48,7 +48,7 @@ public:
 
 
 	~Dictionary();
 	~Dictionary();
 
 
-	///*
+	/*
 	ZT_INLINE void dump() const
 	ZT_INLINE void dump() const
 	{
 	{
 		printf("\n--\n");
 		printf("\n--\n");
@@ -73,7 +73,7 @@ public:
 		}
 		}
 		printf("--\n");
 		printf("--\n");
 	}
 	}
-	//*/
+	*/
 
 
 	/**
 	/**
 	 * Get a reference to a value
 	 * Get a reference to a value

+ 0 - 28
core/Utils.hpp

@@ -710,9 +710,6 @@ static ZT_INLINE void storeLittleEndian(void *const p, const I i) noexcept
  * and requires no branching or function calls. Specialized memcpy() can still
  * and requires no branching or function calls. Specialized memcpy() can still
  * be faster for large memory regions, but ZeroTier doesn't copy anything
  * be faster for large memory regions, but ZeroTier doesn't copy anything
  * much larger than 16KiB.
  * much larger than 16KiB.
- *
- * A templated version for statically known sizes is provided since this can
- * allow some nice optimizations in some cases.
  */
  */
 
 
 /**
 /**
@@ -733,31 +730,6 @@ static ZT_INLINE void copy(void *dest, const void *src) noexcept
 #endif
 #endif
 }
 }
 
 
-#ifndef ZT_NO_UNALIGNED_ACCESS
-template<>
-ZT_INLINE void copy<4>(void *dest, const void *src) noexcept
-{
-	*reinterpret_cast<uint32_t *>(dest) = *reinterpret_cast<const uint32_t *>(src);
-}
-template<>
-ZT_INLINE void copy<8>(void *dest, const void *src) noexcept
-{
-	*reinterpret_cast<uint64_t *>(dest) = *reinterpret_cast<const uint64_t *>(src);
-}
-template<>
-ZT_INLINE void copy<12>(void *dest, const void *src) noexcept
-{
-	*reinterpret_cast<uint64_t *>(dest) = *reinterpret_cast<const uint64_t *>(src);
-	*reinterpret_cast<uint32_t *>(reinterpret_cast<uint8_t *>(dest) + 8) = *reinterpret_cast<const uint32_t *>(reinterpret_cast<const uint8_t *>(src) + 8);
-}
-template<>
-ZT_INLINE void copy<16>(void *dest, const void *src) noexcept
-{
-	*reinterpret_cast<uint64_t *>(dest) = *reinterpret_cast<const uint64_t *>(src);
-	*reinterpret_cast<uint64_t *>(reinterpret_cast<uint8_t *>(dest) + 8) = *reinterpret_cast<const uint64_t *>(reinterpret_cast<const uint8_t *>(src) + 8);
-}
-#endif
-
 /**
 /**
  * Copy memory block whose size is known at run time
  * Copy memory block whose size is known at run time
  *
  *

+ 23 - 26
pkg/zerotier/certificate.go

@@ -20,7 +20,6 @@ import "C"
 import (
 import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
-	"runtime"
 	"unsafe"
 	"unsafe"
 )
 )
 
 
@@ -87,14 +86,6 @@ type Certificate struct {
 	Signature          []byte             `json:"signature,omitempty"`
 	Signature          []byte             `json:"signature,omitempty"`
 }
 }
 
 
-// CCertificate just wraps a C pointer so a Go finalizer can be attached to it.
-// This allows CCertificate() to be used without requiring the caller to
-// explicitly free memory. Ensure that a pointer to this structure is held until
-// the underlying C memory is no longer needed.
-type CCertificate struct {
-	C unsafe.Pointer
-}
-
 func certificateErrorToError(cerr int) error {
 func certificateErrorToError(cerr int) error {
 	switch cerr {
 	switch cerr {
 	case C.ZT_CERTIFICATE_ERROR_NONE:
 	case C.ZT_CERTIFICATE_ERROR_NONE:
@@ -274,11 +265,14 @@ func NewCertificateFromCCertificate(ccptr unsafe.Pointer) *Certificate {
 	return c
 	return c
 }
 }
 
 
+// DeleteCCertificate deletes a ZT_Certificate object returned by Certificate.CCertificate()
+func DeleteCCertificate(cc unsafe.Pointer) {
+	C.ZT_Certificate_delete((*C.ZT_Certificate)(cc))
+}
+
 // CCertificate creates a C ZT_Certificate structure from the content of a Certificate.
 // CCertificate creates a C ZT_Certificate structure from the content of a Certificate.
-//
-// This will return nil if an error occurs, which would indicate an invalid C
-// structure or one with invalid values.
-func (c *Certificate) CCertificate() *CCertificate {
+// It must be deleted with DeleteCCertificate.
+func (c *Certificate) CCertificate() unsafe.Pointer {
 	var cc C.ZT_Certificate
 	var cc C.ZT_Certificate
 	var subjectIdentities []C.ZT_Certificate_Identity
 	var subjectIdentities []C.ZT_Certificate_Identity
 	var subjectNetworks []C.ZT_Certificate_Network
 	var subjectNetworks []C.ZT_Certificate_Network
@@ -401,13 +395,7 @@ func (c *Certificate) CCertificate() *CCertificate {
 	// HACK: pass pointer to cc as uintptr to disable Go's protection against "Go pointers to
 	// HACK: pass pointer to cc as uintptr to disable Go's protection against "Go pointers to
 	// Go pointers," as the C function called here will make a deep clone and then we are going
 	// Go pointers," as the C function called here will make a deep clone and then we are going
 	// to throw away 'cc' and its components.
 	// to throw away 'cc' and its components.
-	cc2 := &CCertificate{C: unsafe.Pointer(C._ZT_Certificate_clone2(C.uintptr_t(uintptr(unsafe.Pointer(&cc)))))}
-	runtime.SetFinalizer(cc2, func(obj interface{}) {
-		if obj != nil {
-			C.ZT_Certificate_delete((*C.ZT_Certificate)(obj.(*CCertificate).C))
-		}
-	})
-	return cc2
+	return unsafe.Pointer(C._ZT_Certificate_clone2(C.uintptr_t(uintptr(unsafe.Pointer(&cc)))))
 }
 }
 
 
 // Marshal encodes this certificate as a byte array.
 // Marshal encodes this certificate as a byte array.
@@ -416,9 +404,10 @@ func (c *Certificate) Marshal() ([]byte, error) {
 	if cc == nil {
 	if cc == nil {
 		return nil, ErrInternal
 		return nil, ErrInternal
 	}
 	}
+	defer DeleteCCertificate(cc)
 	var encoded [16384]byte
 	var encoded [16384]byte
 	encodedSize := C.int(16384)
 	encodedSize := C.int(16384)
-	rv := int(C.ZT_Certificate_encode((*C.ZT_Certificate)(cc.C), unsafe.Pointer(&encoded[0]), &encodedSize))
+	rv := int(C.ZT_Certificate_encode((*C.ZT_Certificate)(cc), unsafe.Pointer(&encoded[0]), &encodedSize))
 	if rv != 0 {
 	if rv != 0 {
 		return nil, fmt.Errorf("Certificate encode error %d", rv)
 		return nil, fmt.Errorf("Certificate encode error %d", rv)
 	}
 	}
@@ -437,9 +426,10 @@ func (c *Certificate) Sign(id *Identity) (*Certificate, error) {
 	if ctmp == nil {
 	if ctmp == nil {
 		return nil, ErrInternal
 		return nil, ErrInternal
 	}
 	}
+	defer DeleteCCertificate(ctmp)
 	var signedCert [16384]byte
 	var signedCert [16384]byte
 	signedCertSize := C.int(16384)
 	signedCertSize := C.int(16384)
-	rv := int(C.ZT_Certificate_sign((*C.ZT_Certificate)(ctmp.C), id.cIdentity(), unsafe.Pointer(&signedCert[0]), &signedCertSize))
+	rv := int(C.ZT_Certificate_sign((*C.ZT_Certificate)(ctmp), id.cIdentity(), unsafe.Pointer(&signedCert[0]), &signedCertSize))
 	if rv != 0 {
 	if rv != 0 {
 		return nil, fmt.Errorf("signing failed: error %d", rv)
 		return nil, fmt.Errorf("signing failed: error %d", rv)
 	}
 	}
@@ -452,7 +442,14 @@ func (c *Certificate) Verify() error {
 	if cc == nil {
 	if cc == nil {
 		return ErrInternal
 		return ErrInternal
 	}
 	}
-	return certificateErrorToError(int(C.ZT_Certificate_verify((*C.ZT_Certificate)(cc.C))))
+	defer DeleteCCertificate(cc)
+	return certificateErrorToError(int(C.ZT_Certificate_verify((*C.ZT_Certificate)(cc))))
+}
+
+// String returns a compact JSON representation of this certificate.
+func (c *Certificate) String() string {
+	j, _ := json.Marshal(c)
+	return string(j)
 }
 }
 
 
 // JSON returns this certificate as a human-readable indented JSON string.
 // JSON returns this certificate as a human-readable indented JSON string.
@@ -503,15 +500,15 @@ func NewCertificateCSR(subject *CertificateSubject, uniqueId []byte, uniqueIdPri
 	if ctmp == nil {
 	if ctmp == nil {
 		return nil, ErrInternal
 		return nil, ErrInternal
 	}
 	}
+	defer DeleteCCertificate(ctmp)
 
 
 	var csr [16384]byte
 	var csr [16384]byte
 	csrSize := C.int(16384)
 	csrSize := C.int(16384)
-	cc := (*C.ZT_Certificate)(ctmp.C)
+	cc := (*C.ZT_Certificate)(ctmp)
 	rv := int(C.ZT_Certificate_newCSR(&(cc.subject), uid, C.int(len(uniqueId)), uidp, C.int(len(uniqueIdPrivate)), unsafe.Pointer(&csr[0]), &csrSize))
 	rv := int(C.ZT_Certificate_newCSR(&(cc.subject), uid, C.int(len(uniqueId)), uidp, C.int(len(uniqueIdPrivate)), unsafe.Pointer(&csr[0]), &csrSize))
 	if rv != 0 {
 	if rv != 0 {
-		return nil, fmt.Errorf("newCSR error %d", rv)
+		return nil, fmt.Errorf("ZT_Certificate_newCSR() failed: %d", rv)
 	}
 	}
-	ctmp = nil
 
 
 	return append(make([]byte, 0, int(csrSize)), csr[0:int(csrSize)]...), nil
 	return append(make([]byte, 0, int(csrSize)), csr[0:int(csrSize)]...), nil
 }
 }

+ 0 - 1
pkg/zerotier/identity.go

@@ -26,7 +26,6 @@ import (
 	"unsafe"
 	"unsafe"
 )
 )
 
 
-// Constants from node/Identity.hpp (must be the same)
 const (
 const (
 	IdentityTypeC25519 = 0
 	IdentityTypeC25519 = 0
 	IdentityTypeP384   = 1
 	IdentityTypeP384   = 1