Browse Source

Merge pull request #64795 from RandomShaper/fix_saferefcount

Prevent misuse of SafeRefCount
Rémi Verschelde 2 years ago
parent
commit
d46568205d
2 changed files with 27 additions and 7 deletions
  1. 7 7
      core/error/error_macros.h
  2. 20 0
      core/templates/safe_refcount.h

+ 7 - 7
core/error/error_macros.h

@@ -33,7 +33,7 @@
 
 #include "core/typedefs.h"
 
-#include "core/templates/safe_refcount.h"
+#include <atomic> // We'd normally use safe_refcount.h, but that would cause circular includes.
 
 class String;
 
@@ -737,10 +737,10 @@ void _err_flush_stdout();
  */
 #define WARN_DEPRECATED                                                                                                                                           \
 	if (true) {                                                                                                                                                   \
-		static SafeFlag warning_shown;                                                                                                                            \
-		if (!warning_shown.is_set()) {                                                                                                                            \
+		static std::atomic<bool> warning_shown;                                                                                                                   \
+		if (!warning_shown.load()) {                                                                                                                              \
 			_err_print_error(FUNCTION_STR, __FILE__, __LINE__, "This method has been deprecated and will be removed in the future.", false, ERR_HANDLER_WARNING); \
-			warning_shown.set();                                                                                                                                  \
+			warning_shown.store(true);                                                                                                                            \
 		}                                                                                                                                                         \
 	} else                                                                                                                                                        \
 		((void)0)
@@ -750,10 +750,10 @@ void _err_flush_stdout();
  */
 #define WARN_DEPRECATED_MSG(m_msg)                                                                                                                                       \
 	if (true) {                                                                                                                                                          \
-		static SafeFlag warning_shown;                                                                                                                                   \
-		if (!warning_shown.is_set()) {                                                                                                                                   \
+		static std::atomic<bool> warning_shown;                                                                                                                          \
+		if (!warning_shown.load()) {                                                                                                                                     \
 			_err_print_error(FUNCTION_STR, __FILE__, __LINE__, "This method has been deprecated and will be removed in the future.", m_msg, false, ERR_HANDLER_WARNING); \
-			warning_shown.set();                                                                                                                                         \
+			warning_shown.store(true);                                                                                                                                   \
 		}                                                                                                                                                                \
 	} else                                                                                                                                                               \
 		((void)0)

+ 20 - 0
core/templates/safe_refcount.h

@@ -33,6 +33,10 @@
 
 #include "core/typedefs.h"
 
+#ifdef DEV_ENABLED
+#include "core/error/error_macros.h"
+#endif
+
 #include <atomic>
 #include <type_traits>
 
@@ -163,6 +167,16 @@ public:
 class SafeRefCount {
 	SafeNumeric<uint32_t> count;
 
+#ifdef DEV_ENABLED
+	_ALWAYS_INLINE_ void _check_unref_sanity() {
+		// This won't catch every misuse, but it's better than nothing.
+		CRASH_COND_MSG(count.get() == 0,
+				"Trying to unreference a SafeRefCount which is already zero is wrong and a symptom of it being misused.\n"
+				"Upon a SafeRefCount reaching zero any object whose lifetime is tied to it, as well as the ref count itself, must be destroyed.\n"
+				"Moreover, to guarantee that, no multiple threads should be racing to do the final unreferencing to zero.");
+	}
+#endif
+
 public:
 	_ALWAYS_INLINE_ bool ref() { // true on success
 		return count.conditional_increment() != 0;
@@ -173,10 +187,16 @@ public:
 	}
 
 	_ALWAYS_INLINE_ bool unref() { // true if must be disposed of
+#ifdef DEV_ENABLED
+		_check_unref_sanity();
+#endif
 		return count.decrement() == 0;
 	}
 
 	_ALWAYS_INLINE_ uint32_t unrefval() { // 0 if must be disposed of
+#ifdef DEV_ENABLED
+		_check_unref_sanity();
+#endif
 		return count.decrement();
 	}