Browse Source

Merge pull request #46197 from RandomShaper/volatile_robustness

Improve robustness of atomics
Rémi Verschelde 4 years ago
parent
commit
b4aba47969
2 changed files with 30 additions and 12 deletions
  1. 20 12
      core/templates/cowdata.h
  2. 10 0
      core/templates/safe_refcount.h

+ 20 - 12
core/templates/cowdata.h

@@ -45,8 +45,9 @@ class CharString;
 template <class T, class V>
 class VMap;
 
-// CowData is relying on this to be true
-static_assert(sizeof(SafeNumeric<uint32_t>) == sizeof(uint32_t));
+#if !defined(NO_THREADS)
+SAFE_NUMERIC_TYPE_PUN_GUARANTEES(uint32_t)
+#endif
 
 template <class T>
 class CowData {
@@ -114,7 +115,7 @@ private:
 	void _unref(void *p_data);
 	void _ref(const CowData *p_from);
 	void _ref(const CowData &p_from);
-	void _copy_on_write();
+	uint32_t _copy_on_write();
 
 public:
 	void operator=(const CowData<T> &p_from) { _ref(p_from); }
@@ -217,20 +218,21 @@ void CowData<T>::_unref(void *p_data) {
 }
 
 template <class T>
-void CowData<T>::_copy_on_write() {
+uint32_t CowData<T>::_copy_on_write() {
 	if (!_ptr) {
-		return;
+		return 0;
 	}
 
 	SafeNumeric<uint32_t> *refc = _get_refcount();
 
-	if (unlikely(refc->get() > 1)) {
+	uint32_t rc = refc->get();
+	if (unlikely(rc > 1)) {
 		/* in use by more than me */
 		uint32_t current_size = *_get_size();
 
 		uint32_t *mem_new = (uint32_t *)Memory::alloc_static(_get_alloc_size(current_size), true);
 
-		reinterpret_cast<SafeNumeric<uint32_t> *>(mem_new - 2)->set(1); //refcount
+		new (mem_new - 2, sizeof(uint32_t), "") SafeNumeric<uint32_t>(1); //refcount
 		*(mem_new - 1) = current_size; //size
 
 		T *_data = (T *)(mem_new);
@@ -247,7 +249,10 @@ void CowData<T>::_copy_on_write() {
 
 		_unref(_ptr);
 		_ptr = _data;
+
+		rc = 1;
 	}
+	return rc;
 }
 
 template <class T>
@@ -268,7 +273,7 @@ Error CowData<T>::resize(int p_size) {
 	}
 
 	// possibly changing size, copy on write
-	_copy_on_write();
+	uint32_t rc = _copy_on_write();
 
 	size_t current_alloc_size = _get_alloc_size(current_size);
 	size_t alloc_size;
@@ -281,13 +286,15 @@ Error CowData<T>::resize(int p_size) {
 				uint32_t *ptr = (uint32_t *)Memory::alloc_static(alloc_size, true);
 				ERR_FAIL_COND_V(!ptr, ERR_OUT_OF_MEMORY);
 				*(ptr - 1) = 0; //size, currently none
-				reinterpret_cast<SafeNumeric<uint32_t> *>(ptr - 2)->set(1); //refcount
+				new (ptr - 2, sizeof(uint32_t), "") SafeNumeric<uint32_t>(1); //refcount
 
 				_ptr = (T *)ptr;
 
 			} else {
-				void *_ptrnew = (T *)Memory::realloc_static(_ptr, alloc_size, true);
+				uint32_t *_ptrnew = (uint32_t *)Memory::realloc_static(_ptr, alloc_size, true);
 				ERR_FAIL_COND_V(!_ptrnew, ERR_OUT_OF_MEMORY);
+				new (_ptrnew - 2, sizeof(uint32_t), "") SafeNumeric<uint32_t>(rc); //refcount
+
 				_ptr = (T *)(_ptrnew);
 			}
 		}
@@ -314,8 +321,9 @@ Error CowData<T>::resize(int p_size) {
 		}
 
 		if (alloc_size != current_alloc_size) {
-			void *_ptrnew = (T *)Memory::realloc_static(_ptr, alloc_size, true);
+			uint32_t *_ptrnew = (uint32_t *)Memory::realloc_static(_ptr, alloc_size, true);
 			ERR_FAIL_COND_V(!_ptrnew, ERR_OUT_OF_MEMORY);
+			new (_ptrnew - 2, sizeof(uint32_t), "") SafeNumeric<uint32_t>(rc); //refcount
 
 			_ptr = (T *)(_ptrnew);
 		}
@@ -362,7 +370,7 @@ void CowData<T>::_ref(const CowData &p_from) {
 		return; //nothing to do
 	}
 
-	if (p_from._get_refcount()->increment() > 0) { // could reference
+	if (p_from._get_refcount()->conditional_increment() > 0) { // could reference
 		_ptr = p_from._ptr;
 	}
 }

+ 10 - 0
core/templates/safe_refcount.h

@@ -47,10 +47,18 @@
 //   value and, as an important benefit, you can be sure the value is properly synchronized
 //   even with threads that are already running.
 
+// This is used in very specific areas of the engine where it's critical that these guarantees are held
+#define SAFE_NUMERIC_TYPE_PUN_GUARANTEES(m_type)                    \
+	static_assert(sizeof(SafeNumeric<m_type>) == sizeof(m_type));   \
+	static_assert(alignof(SafeNumeric<m_type>) == alignof(m_type)); \
+	static_assert(std::is_trivially_destructible<std::atomic<m_type>>::value);
+
 template <class T>
 class SafeNumeric {
 	std::atomic<T> value;
 
+	static_assert(std::atomic<T>::is_always_lock_free);
+
 public:
 	_ALWAYS_INLINE_ void set(T p_value) {
 		value.store(p_value, std::memory_order_release);
@@ -128,6 +136,8 @@ public:
 class SafeFlag {
 	std::atomic_bool flag;
 
+	static_assert(std::atomic_bool::is_always_lock_free);
+
 public:
 	_ALWAYS_INLINE_ bool is_set() const {
 		return flag.load(std::memory_order_acquire);