Parcourir la source

Fixed a number of memory leaks in the Lua scripting system where memory for objects or arrays passed from Lua to C was not being cleaned up properly.

Fixes #357.
Steve Grenier il y a 13 ans
Parent
commit
31ae00ce14

+ 31 - 40
gameplay-luagen/src/FunctionBinding.cpp

@@ -11,7 +11,7 @@ static inline void outputBindingInvocation(ostream& o, const FunctionBinding& b,
 static inline void outputGetParam(ostream& o, const FunctionBinding::Param& p, int i, int indentLevel, bool offsetIndex = false);
 static inline void outputMatchedBinding(ostream& o, const FunctionBinding& b, unsigned int paramCount, unsigned int indentLevel);
 static inline void outputReturnValue(ostream& o, const FunctionBinding& b, int indentLevel);
-
+static inline std::string getTypeName(const FunctionBinding::Param& param);
 
 FunctionBinding::Param::Param(FunctionBinding::Param::Type type, Kind kind, const string& info) : 
     type(type), kind(kind), info(info), hasDefaultValue(false)
@@ -450,66 +450,55 @@ bool FunctionBinding::signaturesMatch(const FunctionBinding& b1, const FunctionB
     return false;
 }
 
-ostream& operator<<(ostream& o, const FunctionBinding::Param& param)
+static inline std::string getTypeName(const FunctionBinding::Param& param)
 {
     switch (param.type)
     {
     case FunctionBinding::Param::TYPE_VOID:
-        o << "void";
-        break;
+        return "void";
     case FunctionBinding::Param::TYPE_BOOL:
-        o << "bool";
-        break;
+        return "bool";
     case FunctionBinding::Param::TYPE_CHAR:
-        o << "char";
-        break;
+        return "char";
     case FunctionBinding::Param::TYPE_SHORT:
-        o << "short";
-        break;
+        return "short";
     case FunctionBinding::Param::TYPE_INT:
-        o << "int";
-        break;
+        return "int";
     case FunctionBinding::Param::TYPE_LONG:
-        o << "long";
-        break;
+        return "long";
     case FunctionBinding::Param::TYPE_UCHAR:
-        o << "unsigned char";
-        break;
+        return "unsigned char";
     case FunctionBinding::Param::TYPE_USHORT:
-        o << "unsigned short";
-        break;
+        return "unsigned short";
     case FunctionBinding::Param::TYPE_UINT:
-        o << "unsigned int";
-        break;
+        return "unsigned int";
     case FunctionBinding::Param::TYPE_ULONG:
-        o << "unsigned long";
-        break;
+        return "unsigned long";
     case FunctionBinding::Param::TYPE_FLOAT:
-        o << "float";
-        break;
+        return "float";
     case FunctionBinding::Param::TYPE_DOUBLE:
-        o << "double";
-        break;
+        return "double";
     case FunctionBinding::Param::TYPE_ENUM:
-        o << Generator::getInstance()->getIdentifier(param.info);
-        break;
+        return Generator::getInstance()->getIdentifier(param.info).c_str();
     case FunctionBinding::Param::TYPE_STRING:
         if (param.info == "string")
-            o << "std::string";
+            return "std::string";
         else
-            o << "const char";
-        break;
+            return "const char";
     case FunctionBinding::Param::TYPE_OBJECT:
     case FunctionBinding::Param::TYPE_CONSTRUCTOR:
-        o << Generator::getInstance()->getIdentifier(param.info);
-        break;
+        return Generator::getInstance()->getIdentifier(param.info).c_str();
     case FunctionBinding::Param::TYPE_UNRECOGNIZED:
-        o << param.info;
-        break;
+        return param.info.c_str();
     case FunctionBinding::Param::TYPE_DESTRUCTOR:
     default:
-        break;
+        return "";
     }
+}
+
+ostream& operator<<(ostream& o, const FunctionBinding::Param& param)
+{
+    o << getTypeName(param);
 
     if (param.kind == FunctionBinding::Param::KIND_POINTER)
         o << "*";
@@ -735,7 +724,11 @@ static inline void outputGetParam(ostream& o, const FunctionBinding::Param& p, i
     case FunctionBinding::Param::TYPE_STRING:
     case FunctionBinding::Param::TYPE_ENUM:
         indent(o, indentLevel);
-        o << p << " param" << i + 1 << " = ";
+        if (p.kind == FunctionBinding::Param::KIND_POINTER)
+            o << "ScriptUtil::LuaArray<" << getTypeName(p) << ">";
+        else
+            o << p;
+        o << " param" << i + 1 << " = ";
         break;
     default:
         // Ignore these cases.
@@ -820,9 +813,7 @@ static inline void outputGetParam(ostream& o, const FunctionBinding::Param& p, i
         break;
     case FunctionBinding::Param::TYPE_OBJECT:
         indent(o, indentLevel);
-        o << p;
-        if (p.kind != FunctionBinding::Param::KIND_POINTER)
-            o << "*";
+        o << "ScriptUtil::LuaArray<" << getTypeName(p) << ">";
         o << " param" << i + 1 << " = ";
         o << "ScriptUtil::getObjectPointer<";
         o << Generator::getInstance()->getIdentifier(p.info) << ">(" << paramIndex;

+ 14 - 14
gameplay/src/ScriptController.cpp

@@ -22,21 +22,21 @@
     if (size <= 0) \
         return NULL; \
     \
-    /* Create an array to store the values. */ \
-    type* values = (type*)malloc(sizeof(type)*size); \
+    /* Declare a LuaArray to store the values. */ \
+	LuaArray<type> arr(size); \
     \
     /* Push the first key. */ \
     lua_pushnil(sc->_lua); \
     int i = 0; \
     for (; lua_next(sc->_lua, index) != 0 && i < size; i++) \
     { \
-        values[i] = (checkFunc(sc->_lua, -1)); \
+        arr[i] = (checkFunc(sc->_lua, -1)); \
         \
         /* Remove the value we just retrieved, but leave the key for the next iteration. */ \
         lua_pop(sc->_lua, 1); \
     } \
     \
-    return values
+    return arr
 
 namespace gameplay
 {
@@ -260,52 +260,52 @@ void ScriptUtil::addStringFromEnumConversionFunction(luaStringEnumConversionFunc
     Game::getInstance()->getScriptController()->_stringFromEnum.push_back(stringFromEnum);
 }
 
-bool* ScriptUtil::getBoolPointer(int index)
+ScriptUtil::LuaArray<bool> ScriptUtil::getBoolPointer(int index)
 {
     GENERATE_LUA_GET_POINTER(bool, luaCheckBool);
 }
 
-short* ScriptUtil::getShortPointer(int index)
+ScriptUtil::LuaArray<short> ScriptUtil::getShortPointer(int index)
 {
     GENERATE_LUA_GET_POINTER(short, (short)luaL_checkint);
 }
 
-int* ScriptUtil::getIntPointer(int index)
+ScriptUtil::LuaArray<int> ScriptUtil::getIntPointer(int index)
 {
     GENERATE_LUA_GET_POINTER(int, (int)luaL_checkint);
 }
 
-long* ScriptUtil::getLongPointer(int index)
+ScriptUtil::LuaArray<long> ScriptUtil::getLongPointer(int index)
 {
     GENERATE_LUA_GET_POINTER(long, (long)luaL_checkint);
 }
 
-unsigned char* ScriptUtil::getUnsignedCharPointer(int index)
+ScriptUtil::LuaArray<unsigned char> ScriptUtil::getUnsignedCharPointer(int index)
 {
     GENERATE_LUA_GET_POINTER(unsigned char, (unsigned char)luaL_checkunsigned);
 }
 
-unsigned short* ScriptUtil::getUnsignedShortPointer(int index)
+ScriptUtil::LuaArray<unsigned short> ScriptUtil::getUnsignedShortPointer(int index)
 {
     GENERATE_LUA_GET_POINTER(unsigned short, (unsigned short)luaL_checkunsigned);
 }
 
-unsigned int* ScriptUtil::getUnsignedIntPointer(int index)
+ScriptUtil::LuaArray<unsigned int> ScriptUtil::getUnsignedIntPointer(int index)
 {
     GENERATE_LUA_GET_POINTER(unsigned int, (unsigned int)luaL_checkunsigned);
 }
 
-unsigned long* ScriptUtil::getUnsignedLongPointer(int index)
+ScriptUtil::LuaArray<unsigned long> ScriptUtil::getUnsignedLongPointer(int index)
 {
     GENERATE_LUA_GET_POINTER(unsigned long, (unsigned long)luaL_checkunsigned);
 }
 
-float* ScriptUtil::getFloatPointer(int index)
+ScriptUtil::LuaArray<float> ScriptUtil::getFloatPointer(int index)
 {
     GENERATE_LUA_GET_POINTER(float, (float)luaL_checknumber);
 }
 
-double* ScriptUtil::getDoublePointer(int index)
+ScriptUtil::LuaArray<double> ScriptUtil::getDoublePointer(int index)
 {
     GENERATE_LUA_GET_POINTER(double, (double)luaL_checknumber);
 }

+ 93 - 23
gameplay/src/ScriptController.h

@@ -29,6 +29,75 @@ struct LuaObject
     bool owns;
 };
 
+/**
+ * Stores a Lua parameter of an array/pointer type that is passed from Lua to C.
+ * Handles automatic cleanup of any temporary memory associated with the array.
+ * @script{ignore}
+ */
+template <typename T>
+class LuaArray
+{
+public:
+
+	/**
+	 * Creates a LuaArray to store a single pointer value.
+	 */
+	LuaArray(T* param);
+
+	/**
+	 * Allocates a LuaArray to store an array of values.
+	 *
+	 * Individual items in the array can be set using the 
+	 * set(unsigned int, const T&) method.
+	 * 
+	 * @param object Parameter object.
+	 * @param count Number of elements to store in the parameter.
+	 */
+	LuaArray(int count);
+
+	/**
+	 * Copy construcotr.
+	 */
+	LuaArray(const LuaArray<T>& copy);
+
+	/**
+	 * Destructor.
+	 */
+	~LuaArray();
+
+	/**
+	 * Assignment operator.
+	 */
+	LuaArray<T>& operator = (const LuaArray<T>& p);
+
+	/**
+	 * Copies the value of the object pointed to by itemPtr into the specified
+     * index of this LuaArray's array.
+	 */
+	void set(unsigned int index, const T* itemPtr);
+
+	/**
+	 * Conversion operator from LuaArray to T*.
+	 */
+	operator T* () const;
+
+    /**
+     * Overloades [] operator to get/set item value at index.
+     */
+    T& operator[] (int index);
+
+private:
+
+	struct Data
+	{
+        Data() : value(NULL), refCount(0) { }
+		typename T* value;
+		int refCount;
+	};
+
+	Data* _data;
+};
+
 /**
  * Registers the given library with Lua.
  * 
@@ -115,7 +184,7 @@ void addStringFromEnumConversionFunction(luaStringEnumConversionFunction stringF
  * @return The pointer.
  * @script{ignore}
  */
-bool* getBoolPointer(int index);
+LuaArray<bool> getBoolPointer(int index);
 
 /**
  * Gets a pointer to a short (as an array-use SAFE_DELETE_ARRAY to clean up) for the given stack index.
@@ -124,7 +193,7 @@ bool* getBoolPointer(int index);
  * @return The pointer.
  * @script{ignore}
  */
-short* getShortPointer(int index);
+LuaArray<short> getShortPointer(int index);
 
 /**
  * Gets a pointer to an int (as an array-use SAFE_DELETE_ARRAY to clean up) for the given stack index.
@@ -133,7 +202,7 @@ short* getShortPointer(int index);
  * @return The pointer.
  * @script{ignore}
  */
-int* getIntPointer(int index);
+LuaArray<int> getIntPointer(int index);
 
 /**
  * Gets a pointer to a long (as an array-use SAFE_DELETE_ARRAY to clean up) for the given stack index.
@@ -142,7 +211,7 @@ int* getIntPointer(int index);
  * @return The pointer.
  * @script{ignore}
  */
-long* getLongPointer(int index);
+LuaArray<long> getLongPointer(int index);
 
 /**
  * Gets a pointer to an unsigned char (as an array-use SAFE_DELETE_ARRAY to clean up) for the given stack index.
@@ -151,7 +220,7 @@ long* getLongPointer(int index);
  * @return The pointer.
  * @script{ignore}
  */
-unsigned char* getUnsignedCharPointer(int index);
+LuaArray<unsigned char> getUnsignedCharPointer(int index);
 
 /**
  * Gets a pointer to an unsigned short (as an array-use SAFE_DELETE_ARRAY to clean up) for the given stack index.
@@ -160,7 +229,7 @@ unsigned char* getUnsignedCharPointer(int index);
  * @return The pointer.
  * @script{ignore}
  */
-unsigned short* getUnsignedShortPointer(int index);
+LuaArray<unsigned short> getUnsignedShortPointer(int index);
 
 /**
  * Gets a pointer to an unsigned int (as an array-use SAFE_DELETE_ARRAY to clean up) for the given stack index.
@@ -169,7 +238,7 @@ unsigned short* getUnsignedShortPointer(int index);
  * @return The pointer.
  * @script{ignore}
  */
-unsigned int* getUnsignedIntPointer(int index);
+LuaArray<unsigned int> getUnsignedIntPointer(int index);
 
 /**
  * Gets a pointer to an unsigned long (as an array-use SAFE_DELETE_ARRAY to clean up) for the given stack index.
@@ -178,7 +247,7 @@ unsigned int* getUnsignedIntPointer(int index);
  * @return The pointer.
  * @script{ignore}
  */
-unsigned long* getUnsignedLongPointer(int index);
+LuaArray<unsigned long> getUnsignedLongPointer(int index);
 
 /**
  * Gets a pointer to a float (as an array-use SAFE_DELETE_ARRAY to clean up) for the given stack index.
@@ -187,7 +256,7 @@ unsigned long* getUnsignedLongPointer(int index);
  * @return The pointer.
  * @script{ignore}
  */
-float* getFloatPointer(int index);
+LuaArray<float> getFloatPointer(int index);
 
 /**
  * Gets a pointer to a double (as an array-use SAFE_DELETE_ARRAY to clean up) for the given stack index.
@@ -196,7 +265,7 @@ float* getFloatPointer(int index);
  * @return The pointer.
  * @script{ignore}
  */
-double* getDoublePointer(int index);
+LuaArray<double> getDoublePointer(int index);
 
 /**
  * Gets an object pointer of the given type for the given stack index.
@@ -209,7 +278,8 @@ double* getDoublePointer(int index);
  *      is not an object or if the object is not derived from the given type.
  * @script{ignore}
  */
-template<typename T> T* getObjectPointer(int index, const char* type, bool nonNull);
+template <typename T>
+LuaArray<T> getObjectPointer(int index, const char* type, bool nonNull);
 
 /**
  * Gets a string for the given stack index.
@@ -733,17 +803,17 @@ private:
     friend void ScriptUtil::registerFunction(const char* luaFunction, lua_CFunction cppFunction);
     friend void ScriptUtil::setGlobalHierarchyPair(std::string base, std::string derived);
     friend void ScriptUtil::addStringFromEnumConversionFunction(luaStringEnumConversionFunction stringFromEnum);
-    friend bool* ScriptUtil::getBoolPointer(int index);
-    friend short* ScriptUtil::getShortPointer(int index);
-    friend int* ScriptUtil::getIntPointer(int index);
-    friend long* ScriptUtil::getLongPointer(int index);
-    friend unsigned char* ScriptUtil::getUnsignedCharPointer(int index);
-    friend unsigned short* ScriptUtil::getUnsignedShortPointer(int index);
-    friend unsigned int* ScriptUtil::getUnsignedIntPointer(int index);
-    friend unsigned long* ScriptUtil::getUnsignedLongPointer(int index);
-    friend float* ScriptUtil::getFloatPointer(int index);
-    friend double* ScriptUtil::getDoublePointer(int index);
-    template<typename T> friend T* ScriptUtil::getObjectPointer(int index, const char* type, bool nonNull);
+    friend ScriptUtil::LuaArray<bool> ScriptUtil::getBoolPointer(int index);
+    friend ScriptUtil::LuaArray<short> ScriptUtil::getShortPointer(int index);
+    friend ScriptUtil::LuaArray<int> ScriptUtil::getIntPointer(int index);
+    friend ScriptUtil::LuaArray<long> ScriptUtil::getLongPointer(int index);
+    friend ScriptUtil::LuaArray<unsigned char> ScriptUtil::getUnsignedCharPointer(int index);
+    friend ScriptUtil::LuaArray<unsigned short> ScriptUtil::getUnsignedShortPointer(int index);
+    friend ScriptUtil::LuaArray<unsigned int> ScriptUtil::getUnsignedIntPointer(int index);
+    friend ScriptUtil::LuaArray<unsigned long> ScriptUtil::getUnsignedLongPointer(int index);
+    friend ScriptUtil::LuaArray<float> ScriptUtil::getFloatPointer(int index);
+    friend ScriptUtil::LuaArray<double> ScriptUtil::getDoublePointer(int index);
+    template<typename T> friend ScriptUtil::LuaArray<T> ScriptUtil::getObjectPointer(int index, const char* type, bool nonNull);
     friend const char* ScriptUtil::getString(int index, bool isStdString);
 
     lua_State* _lua;
@@ -839,4 +909,4 @@ template<> std::string ScriptController::executeFunction<std::string>(const char
 
 #include "ScriptController.inl"
 
-#endif
+#endif

+ 100 - 23
gameplay/src/ScriptController.inl

@@ -3,7 +3,94 @@
 namespace gameplay
 {
 
-template<typename T>T* ScriptUtil::getObjectPointer(int index, const char* type, bool nonNull)
+template <typename T>
+ScriptUtil::LuaArray<T>::LuaArray(T* param)
+{
+	_data = new ScriptUtil::LuaArray<T>::Data();
+	_data->value = param;
+
+	// Initial ref count of zero means no memory management
+	_data->refCount = 0;
+}
+
+template <typename T>
+ScriptUtil::LuaArray<T>::LuaArray(int count)
+{
+	_data = new ScriptUtil::LuaArray<T>::Data();
+
+	// Allocate a chunk of memory to store 'count' number of T.
+	// Use new instead of malloc since we track memory allocations
+	// int DebugMem configurations.
+	_data->value = (T*)new unsigned char[sizeof(T) * count];
+
+	// Positive ref count means we automatically cleanup memory
+	_data->refCount = 1;
+}
+
+template <typename T>
+ScriptUtil::LuaArray<T>::LuaArray(const ScriptUtil::LuaArray<T>& copy)
+{
+	_data = copy._data;
+	++_data->refCount;
+}
+
+template <typename T>
+ScriptUtil::LuaArray<T>::~LuaArray()
+{
+	if ((--_data->refCount) <= 0)
+	{
+        // Non managed arrays/pointers start with ref count zero, so only delete data if
+        // the decremented ref count == 0 (otherwise it will be -1).
+        if (_data->refCount == 0)
+        {
+            unsigned char* value = (unsigned char*)_data->value;
+		    SAFE_DELETE_ARRAY(value);
+        }
+
+        SAFE_DELETE(_data);
+	}
+}
+
+template <typename T>
+ScriptUtil::LuaArray<T>& ScriptUtil::LuaArray<T>::operator = (const ScriptUtil::LuaArray<T>& p)
+{
+    _data = p._data;
+	++_data->refCount;
+}
+
+template <typename T>
+void ScriptUtil::LuaArray<T>::set(unsigned int index, const T* itemPtr)
+{
+	// WARNING: The following code will only work properly for arrays of pointers
+	// to objects (i.e. T**) or for simple structs that are being passed
+	// in as read-only. Since the memory is directly copied, any member data that
+	// is modified with the object that is copied, will not modify the original object.
+	// What is even scarier is that if an array of objects that contain virtual functions
+	// is copied here, then the vtables are copied directly, meaning the new object
+	// contains a copy of a vtable that points to functions in the old object. Calling
+	// virtual fucntions on the new object would then call the functions on the old object.
+	// If the old object is deleted, the vtable on the new object would point to addressess
+	// for functions that no longer exist.
+    if (itemPtr)
+        memcpy((void*)&_data->value[index], (void*)itemPtr, sizeof(T));
+    else
+        memset((void*)&_data->value[index], 0, sizeof(T));
+}
+
+template <typename T>
+ScriptUtil::LuaArray<T>::operator T* () const
+{
+	return _data->value;
+}
+
+template <typename T>
+T& ScriptUtil::LuaArray<T>::operator[] (int index)
+{
+    return _data->value[index];
+}
+
+template<typename T>
+ScriptUtil::LuaArray<T> ScriptUtil::getObjectPointer(int index, const char* type, bool nonNull)
 {
     ScriptController* sc = Game::getInstance()->getScriptController();
     if (lua_type(sc->_lua, index) == LUA_TNIL)
@@ -13,7 +100,7 @@ template<typename T>T* ScriptUtil::getObjectPointer(int index, const char* type,
             GP_ERROR("Attempting to pass NULL for required non-NULL parameter at index %d (likely a reference or by-value parameter).", index);
         }
 
-        return NULL;
+        return LuaArray<T>((T*)NULL);
     }
     else if (lua_type(sc->_lua, index) == LUA_TTABLE)
     {
@@ -22,11 +109,10 @@ template<typename T>T* ScriptUtil::getObjectPointer(int index, const char* type,
         int size = luaL_checkint(sc->_lua, -1);
 
         if (size <= 0)
-            return NULL;
+            return LuaArray<T>((T*)NULL);
+
+		LuaArray<T> arr(size);
 
-        // Create an array to store the values.
-        T* values = (T*)malloc(sizeof(T)*size);
-        
         // Push the first key.
         lua_pushnil(sc->_lua);
         int i = 0;
@@ -42,12 +128,7 @@ template<typename T>T* ScriptUtil::getObjectPointer(int index, const char* type,
                     if (lua_rawequal(sc->_lua, -1, -2))
                     {
                         lua_pop(sc->_lua, 2);
-                        T* ptr = (T*)((ScriptUtil::LuaObject*)p)->instance;
-                        if (ptr)
-                            memcpy((void*)&values[i], (void*)ptr, sizeof(T));
-                        else
-                            memset((void*)&values[i], 0, sizeof(T));
-
+						arr.set(i, (T*)((ScriptUtil::LuaObject*)p)->instance);
                         lua_pop(sc->_lua, 1);
                         continue;
                     }
@@ -61,13 +142,8 @@ template<typename T>T* ScriptUtil::getObjectPointer(int index, const char* type,
                         if (lua_rawequal(sc->_lua, -1, -2))
                         {
                             lua_pop(sc->_lua, 2);
-                            T* ptr = (T*)((ScriptUtil::LuaObject*)p)->instance;
-                            if (ptr)
-                                memcpy((void*)&values[i], (void*)ptr, sizeof(T));
-                            else
-                                memset((void*)&values[i], 0, sizeof(T));
-
-                            lua_pop(sc->_lua, 1);
+							arr.set(i, (T*)((ScriptUtil::LuaObject*)p)->instance);
+                            lua_pop(sc->_lua, 1);
                             continue;
                         }
                         lua_pop(sc->_lua, 1);
@@ -77,7 +153,7 @@ template<typename T>T* ScriptUtil::getObjectPointer(int index, const char* type,
             }
         }
         
-        return values;
+        return arr;
     }
     else
     {
@@ -96,7 +172,7 @@ template<typename T>T* ScriptUtil::getObjectPointer(int index, const char* type,
                     {
                         GP_ERROR("Attempting to pass NULL for required non-NULL parameter at index %d (likely a reference or by-value parameter).", index);
                     }
-                    return ptr;
+                    return LuaArray<T>(ptr);
                 }
                 lua_pop(sc->_lua, 1);
 
@@ -113,7 +189,7 @@ template<typename T>T* ScriptUtil::getObjectPointer(int index, const char* type,
                         {
                             GP_ERROR("Attempting to pass NULL for required non-NULL parameter at index %d (likely a reference or by-value parameter).", index);
                         }
-                        return ptr;
+                        return LuaArray<T>(ptr);
                     }
                     lua_pop(sc->_lua, 1);
                 }
@@ -126,7 +202,8 @@ template<typename T>T* ScriptUtil::getObjectPointer(int index, const char* type,
         {
             GP_ERROR("Failed to retrieve a valid object pointer of type '%s' for parameter %d.", type, index);
         }
-        return NULL;
+
+        return LuaArray<T>((T*)NULL);
     }
 }