浏览代码

Fixes and add prepared statements.

mingodad 10 年之前
父节点
当前提交
58c7f48ab0
共有 1 个文件被更改,包括 378 次插入21 次删除
  1. 378 21
      SquiLu-ext/sq_postgresql.cpp

+ 378 - 21
SquiLu-ext/sq_postgresql.cpp

@@ -2,10 +2,24 @@
 
 #include "squirrel.h"
 #include "libpq-fe.h"
+//#include "pg_type.h"
 #include <string.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include "sqstdblobimpl.h"
+SQ_OPT_STRING_STRLEN();
+
+#define BOOLOID 16
+#define BYTEAOID 17
+#define CHAROID 18
+#define INT8OID 20
+#define INT2OID 21
+#define INT4OID 23
+#define FLOAT4OID 700
+#define FLOAT8OID 701
+#define TIMESTAMPOID 1114
+#define TIMESTAMPTZOID 1184
+#define VARCHAROID 1043
 
 #include "dynamic_library.h"
 
@@ -18,6 +32,11 @@ local pgsql_functions = [
                         const char *query,
                         int nParams,
                         const Oid *paramTypes"],
+    ["void", "PQfreemem", "void *ptr"],
+    ["char *", "PQescapeLiteral", "PGconn *conn, const char *str, size_t length"],
+    ["char *", "PQescapeIdentifier", "PGconn *conn, const char *str, size_t length"],
+    ["unsigned char *", "PQescapeByteaConn", "PGconn *conn, const unsigned char *str, size_t length"],
+    ["unsigned char *", "PQunescapeBytea", "const unsigned char *str, size_t length"],
     ["PGresult *", "PQdescribePrepared", "PGconn *conn, const char *stmtName"],
     ["int", "PQnparams", "const PGresult *res"],
     ["Oid", "PQparamtype", "const PGresult *res, int param_number"],
@@ -98,6 +117,16 @@ typedef PGresult * (*PQprepare_t)(PGconn *conn,
                         int nParams,
                         const Oid *paramTypes);
 static PQprepare_t dlPQprepare = 0;
+typedef void (*PQfreemem_t)(void *ptr);
+static PQfreemem_t dlPQfreemem = 0;
+typedef char * (*PQescapeLiteral_t)(PGconn *conn, const char *str, size_t length);
+static PQescapeLiteral_t dlPQescapeLiteral = 0;
+typedef char * (*PQescapeIdentifier_t)(PGconn *conn, const char *str, size_t length);
+static PQescapeIdentifier_t dlPQescapeIdentifier = 0;
+typedef unsigned char * (*PQescapeByteaConn_t)(PGconn *conn, const unsigned char *str, size_t length);
+static PQescapeByteaConn_t dlPQescapeByteaConn = 0;
+typedef unsigned char * (*PQunescapeBytea_t)(const unsigned char *str, size_t length);
+static PQunescapeBytea_t dlPQunescapeBytea = 0;
 typedef PGresult * (*PQdescribePrepared_t)(PGconn *conn, const char *stmtName);
 static PQdescribePrepared_t dlPQdescribePrepared = 0;
 typedef int (*PQnparams_t)(const PGresult *res);
@@ -196,6 +225,16 @@ dlPQfinish = (PQfinish_t) libpq.dlsym("PQfinish");
 if(!dlPQfinish) return false;
 dlPQprepare = (PQprepare_t) libpq.dlsym("PQprepare");
 if(!dlPQprepare) return false;
+dlPQfreemem = (PQfreemem_t) libpq.dlsym("PQfreemem");
+if(!dlPQfreemem) return false;
+dlPQescapeLiteral = (PQescapeLiteral_t) libpq.dlsym("PQescapeLiteral");
+if(!dlPQescapeLiteral) return false;
+dlPQescapeIdentifier = (PQescapeIdentifier_t) libpq.dlsym("PQescapeIdentifier");
+if(!dlPQescapeIdentifier) return false;
+dlPQescapeByteaConn = (PQescapeByteaConn_t) libpq.dlsym("PQescapeByteaConn");
+if(!dlPQescapeByteaConn) return false;
+dlPQunescapeBytea = (PQunescapeBytea_t) libpq.dlsym("PQunescapeBytea");
+if(!dlPQunescapeBytea) return false;
 dlPQdescribePrepared = (PQdescribePrepared_t) libpq.dlsym("PQdescribePrepared");
 if(!dlPQdescribePrepared) return false;
 dlPQnparams = (PQnparams_t) libpq.dlsym("PQnparams");
@@ -282,11 +321,11 @@ static SQRESULT get_pgsql_instance(HSQUIRRELVM v, SQInteger idx, PGconn **self){
 	return _rc_;
 }
 
-#define GET_pgsql_INSTANCE_AT(idx) \
-	PGconn *self=NULL; \
+#define GET_pgsql_INSTANCE_AT(idx) \
+	PGconn *self=NULL; \
 	if((_rc_ = get_pgsql_instance(v,idx,&self)) < 0) return _rc_;
 
-#define GET_pgsql_INSTANCE() GET_pgsql_INSTANCE_AT(1)
+#define GET_pgsql_INSTANCE() GET_pgsql_INSTANCE_AT(1)
 
 static const SQChar *PostgreSQL_Result_TAG = _SC("PostgreSQL_Result");
 static const SQChar *_curr_row_key = _SC("_curr_row");
@@ -298,11 +337,11 @@ static SQRESULT get_pgsql_result_instance(HSQUIRRELVM v, SQInteger idx, PGresult
 	return _rc_;
 }
 
-#define GET_pgsql_result_INSTANCE_AT(idx) \
-	PGresult *self=NULL; \
+#define GET_pgsql_result_INSTANCE_AT(idx) \
+	PGresult *self=NULL; \
 	if((_rc_ = get_pgsql_result_instance(v,idx,&self)) < 0) return _rc_;
 
-#define GET_pgsql_result_INSTANCE() GET_pgsql_result_INSTANCE_AT(1)
+#define GET_pgsql_result_INSTANCE() GET_pgsql_result_INSTANCE_AT(1)
 
 static SQRESULT sq_pgsql_result_releasehook(SQUserPointer p, SQInteger size, HSQUIRRELVM v)
 {
@@ -341,6 +380,14 @@ static SQRESULT sq_pgsql_result_col_name(HSQUIRRELVM v){
 	return 1;
 }
 
+static SQRESULT sq_pgsql_result_col_type(HSQUIRRELVM v){
+	SQ_FUNC_VARS_NO_TOP(v);
+	GET_pgsql_result_INSTANCE();
+	SQ_GET_INTEGER(v, 2, col);
+	sq_pushinteger(v, dlPQftype(self, col));
+	return 1;
+}
+
 static SQRESULT sq_pgsql_result_col_index(HSQUIRRELVM v){
 	SQ_FUNC_VARS_NO_TOP(v);
 	GET_pgsql_result_INSTANCE();
@@ -441,6 +488,7 @@ static SQRegFunction sq_pgsql_result_methods[] =
 	_DECL_FUNC(col_name,  2, _SC("xi")),
 	_DECL_FUNC(col_index,  2, _SC("xs")),
 	_DECL_FUNC(col_value,  2, _SC("x i|s")),
+	_DECL_FUNC(col_type,  2, _SC("x i|s")),
 	_DECL_FUNC(row_as_array,  -1, _SC("xi")),
 	{0,0}
 };
@@ -448,8 +496,12 @@ static SQRegFunction sq_pgsql_result_methods[] =
 
 struct PgSqlStatement {
     PGconn *db;
-    PGresult *result;
     char name[64];
+    int param_count;
+    int isGetPrepared;
+	void **param_values;
+	int *param_sizes;
+	int *param_types;
 };
 
 static const SQChar *PostgreSQL_Statement_TAG = _SC("PostgreSQL_Statement");
@@ -461,22 +513,42 @@ static SQRESULT get_pgsql_statement_instance(HSQUIRRELVM v, SQInteger idx, PgSql
 	return _rc_;
 }
 
-#define GET_pgsql_statement_INSTANCE_AT(idx) \
-	PgSqlStatement *self=NULL; \
+#define GET_pgsql_statement_INSTANCE_AT(idx) \
+	PgSqlStatement *self=NULL; \
 	if((_rc_ = get_pgsql_statement_instance(v,idx,&self)) < 0) return _rc_;
 
-#define GET_pgsql_statement_INSTANCE() GET_pgsql_statement_INSTANCE_AT(1)
+#define GET_pgsql_statement_INSTANCE() GET_pgsql_statement_INSTANCE_AT(1)
 
 static SQRESULT sq_pgsql_statement_releasehook(SQUserPointer p, SQInteger size, HSQUIRRELVM v)
 {
 	PgSqlStatement *self = ((PgSqlStatement *)p);
-	if (self && self->result){
-        char sql[128];
-        snprintf(sql, sizeof(sql), "DEALLOCATE '%s'", self->name);
-        PGresult *qres = dlPQexec(self->db, sql);
-        bool is_ok = dlPQresultStatus(qres) != PGRES_BAD_RESPONSE;
-        dlPQclear(qres);
-        if(is_ok) dlPQclear(self->result);
+	if (self){
+        if(!self->isGetPrepared)
+        {
+            char sql[128];
+            snprintf(sql, sizeof(sql), "DEALLOCATE %s", self->name);
+            PGresult *qres = dlPQexec(self->db, sql);
+            dlPQclear(qres);
+        }
+        if(self->param_count)
+        {
+            SQUnsignedInteger the_size;
+            if(self->param_values)
+            {
+                    the_size = self->param_count * sizeof(self->param_values);
+                    sq_free(self->param_values, the_size);
+            }
+            if(self->param_sizes)
+            {
+                    the_size = self->param_count * sizeof(self->param_sizes);
+                    sq_free(self->param_sizes, the_size);
+            }
+            if(self->param_types)
+            {
+                    the_size = self->param_count * sizeof(self->param_types);
+                    sq_free(self->param_types, the_size);
+            }
+        }
         sq_free(self, sizeof(PgSqlStatement));
 	}
 	return 0;
@@ -490,10 +562,221 @@ static SQRESULT sq_pgsql_statement_close(HSQUIRRELVM v){
 	return 0;
 }
 
+/*
+static SQRESULT sq_pgsql_statement_exec(HSQUIRRELVM v){
+	SQ_FUNC_VARS(v);
+	GET_pgsql_statement_INSTANCE();
+	SQ_OPT_INTEGER(v, 3, result_type, 0);
+	SQInteger psize = sq_getsize(v, 2);
+	void **param_values;
+	int *param_sizes;
+	int *param_types;
+	bool bresult = false;
+	SQBool bval;
+	PGresult *qres;
+    SQUnsignedInteger param_values_size = psize * sizeof(param_values);
+    SQUnsignedInteger param_sizes_size = psize * sizeof(param_sizes);
+
+	if(self->result)
+    {
+        dlPQclear(self->result);
+        self->result = NULL;
+    }
+
+	param_values = (void **)sq_malloc(param_values_size);
+    memset(param_values, 0, param_values_size);
+	param_sizes = (int *)sq_malloc(param_sizes_size);
+    memset(param_sizes, 0, param_sizes_size);
+	param_types = (int *)sq_malloc(param_sizes_size);
+    memset(param_types, 0, param_sizes_size);
+
+    for(SQInteger i=0; i < psize; ++i)
+    {
+        sq_pushinteger(v, i);
+        if(sq_get(v, 2) == SQ_OK)
+        {
+            switch(sq_gettype(v, -1))
+            {
+            case OT_NULL:
+                param_values[i] = NULL;
+                param_sizes[i] = 0;
+                param_types[i] = 0;
+                sq_poptop(v);
+                break;
+            case OT_BOOL:
+                sq_getbool(v, -1, &bval);
+                param_values[i] = (void*)(bval == SQTrue ? "1" : "0");
+                param_sizes[i] = 1;
+                param_types[i] = BOOLOID;
+                sq_poptop(v);
+                break;
+            case OT_INTEGER:
+                param_values[i] = (void*)sq_tostring(v, -1);
+                param_sizes[i] = (int)sq_getsize(v, -1);
+                param_types[i] = INT4OID;
+                break;
+            case OT_FLOAT:
+                param_values[i] = (void*)sq_tostring(v, -1);
+                param_sizes[i] = (int)sq_getsize(v, -1);
+                param_types[i] = FLOAT8OID;
+                break;
+            case OT_STRING:
+                param_values[i] = (void*)sq_tostring(v, -1);
+                param_sizes[i] = (int)sq_getsize(v, -1);
+                param_types[i] = VARCHAROID;
+                break;
+            default:
+
+                goto cleanup;
+            }
+        }
+    }
+
+	qres = dlPQexecPrepared(self->db, self->name, psize, (const char**)param_values, param_sizes, NULL, result_type);
+	bresult = dlPQresultStatus(qres) == PGRES_COMMAND_OK;
+	if(bresult) self->result = qres;
+	else dlPQclear(qres);
+cleanup:
+	sq_free(param_values, param_values_size);
+	sq_free(param_sizes, param_sizes_size);
+	sq_settop(v, _top_);
+	sq_pushbool(v, bresult);
+	return 1;
+}
+*/
+
+#define SQ_EXEC_DML 1
+#define SQ_EXEC_SCALAR 2
+#define SQ_EXEC_QUERY 3
+
+
+static SQRESULT sq_pgsql_statement_exec(HSQUIRRELVM v, int exec_type){
+	SQ_FUNC_VARS(v);
+	GET_pgsql_statement_INSTANCE();
+	SQ_OPT_INTEGER(v, 3, result_type, 0);
+	SQInteger psize = sq_getsize(v, 2);
+	if(psize != self->param_count)
+    {
+        return sq_throwerror(v, _SC("Wrong number of paramters, exptexted %d"), self->param_count);
+    }
+	int result = SQ_ERROR;
+	SQBool bval;
+	const SQChar *str_val;
+	PGresult *qres;
+    SQUnsignedInteger param_values_size = psize * sizeof(self->param_values);
+    SQUnsignedInteger param_sizes_size = psize * sizeof(self->param_sizes);
+	if(!self->param_values)
+    {
+        //only allocate once
+        self->param_values = (void **)sq_malloc(param_values_size);
+        self->param_sizes = (int *)sq_malloc(param_sizes_size);
+    }
+
+    memset(self->param_values, 0, param_values_size);
+    memset(self->param_sizes, 0, param_sizes_size);
+
+    sq_reservestack(v, psize*2);
+
+    for(SQInteger i=0; i < psize; ++i)
+    {
+        sq_pushinteger(v, i);
+        if(sq_get(v, 2) == SQ_OK)
+        {
+            switch(sq_gettype(v, -1))
+            {
+            case OT_NULL:
+                sq_poptop(v);
+                break;
+            case OT_BOOL:
+                sq_getbool(v, -1, &bval);
+                self->param_values[i] = (void*)(bval == SQTrue ? "1" : "0");
+                self->param_sizes[i] = 1;
+                sq_poptop(v);
+                break;
+            case OT_INTEGER:
+            case OT_FLOAT:
+            case OT_STRING:
+                sq_tostring(v, -1);
+                sq_getstring(v, -1, &str_val);
+                self->param_values[i] = (void*)str_val;
+                self->param_sizes[i] = (int)sq_getsize(v, -1);
+                break;
+            default:
+                result = sq_throwerror(v, _SC("Unknow parameter type at pos %d"), i);
+                goto cleanup;
+            }
+        }
+    }
+
+	qres = dlPQexecPrepared(self->db, self->name, psize,
+                         (const char**)self->param_values, self->param_sizes, NULL, result_type);
+	result = dlPQresultStatus(qres);
+	if(result == PGRES_COMMAND_OK || result == PGRES_TUPLES_OK)
+    {
+        if(exec_type == SQ_EXEC_DML)
+        {
+            sq_pushinteger(v, atoi(dlPQcmdTuples(qres)));
+            dlPQclear(qres);
+        }
+        else if(exec_type == SQ_EXEC_SCALAR)
+        {
+            int ntuples = dlPQntuples(qres);
+            int nfields = dlPQnfields(qres);
+            if(exec_type == SQ_EXEC_SCALAR && (ntuples == 1) && (nfields > 0))
+            {
+                result = atoi(dlPQgetvalue(qres, 0, 0));
+                sq_pushinteger(v, result);
+                dlPQclear(qres);
+            }
+            else
+            {
+                sq_pushnull(v);
+            }
+        }
+        else if(exec_type == SQ_EXEC_QUERY)
+        {
+            sq_pushroottable(v);
+            sq_pushstring(v, PostgreSQL_Result_TAG, -1);
+            if(sq_get(v, -2) == SQ_OK){
+                if(sq_createinstance(v, -1) == SQ_OK){
+                    sq_setinstanceup(v, -1, qres);
+                    sq_setreleasehook(v, -1, sq_pgsql_result_releasehook);
+                    sq_pushstring(v, _curr_row_key, -1);
+                    sq_pushinteger(v, -1);
+                    sq_set(v, -3);
+                }
+            }
+        }
+        result = 1;
+    }
+	else
+    {
+        dlPQclear(qres);
+        //sq_settop(v, _top_);
+        result = sq_throwerror(v, dlPQerrorMessage(self->db));
+    }
+cleanup:
+	return result;
+}
+
+static SQRESULT sq_pgsql_statement_exec_query(HSQUIRRELVM v){
+    return sq_pgsql_statement_exec(v, SQ_EXEC_QUERY);
+}
+
+static SQRESULT sq_pgsql_statement_exec_scalar(HSQUIRRELVM v){
+    return sq_pgsql_statement_exec(v, SQ_EXEC_SCALAR);
+}
+
+static SQRESULT sq_pgsql_statement_exec_dml(HSQUIRRELVM v){
+    return sq_pgsql_statement_exec(v, SQ_EXEC_DML);
+}
 
 #define _DECL_FUNC(name,nparams,tycheck) {_SC(#name),  sq_pgsql_statement_##name,nparams,tycheck}
 static SQRegFunction sq_pgsql_statement_methods[] =
 {
+	_DECL_FUNC(exec_query,  -2, _SC("xai")),
+	_DECL_FUNC(exec_scalar,  2, _SC("xa")),
+	_DECL_FUNC(exec_dml,  2, _SC("xa")),
 	_DECL_FUNC(close,  1, _SC("x")),
 	{0,0}
 };
@@ -601,18 +884,41 @@ static SQRESULT sq_pgsql_exec_query(HSQUIRRELVM v){
     return sq_throwerror(v, dlPQerrorMessage(self));
 }
 
-static SQRESULT sq_pgsql_prepare(HSQUIRRELVM v){
+static SQRESULT sq_pgsql_do_prepare(HSQUIRRELVM v, int isGetPrepared){
 	SQ_FUNC_VARS_NO_TOP(v);
 	GET_pgsql_INSTANCE();
     SQ_GET_STRING(v, 2, szSQL);
+    PGresult *qres;
+    bool bresult;
 
     PgSqlStatement *stmt = (PgSqlStatement*)sq_malloc(sizeof(PgSqlStatement));
+    memset(stmt, 0, sizeof(PgSqlStatement));
     stmt->db = self;
-    snprintf(stmt->name, sizeof(stmt->name), "sq_pg_preared_stmt_%p", stmt);
+    stmt->isGetPrepared = isGetPrepared;
 
-    stmt->result = dlPQprepare(self, stmt->name, szSQL, 0, NULL);
+    if(isGetPrepared)
+    {
+        snprintf(stmt->name, sizeof(stmt->name), "%s", szSQL);
+        qres = dlPQdescribePrepared(self, stmt->name);
+        bresult = dlPQresultStatus(qres) == PGRES_COMMAND_OK;
+        if(bresult) stmt->param_count = dlPQnparams(qres);
+        dlPQclear(qres);
+    }
+    else
+    {
+        snprintf(stmt->name, sizeof(stmt->name), "sq_stmt_%p_%p", self, stmt);
+        qres = dlPQprepare(self, stmt->name, szSQL, 0, NULL);
+        bresult = dlPQresultStatus(qres) == PGRES_COMMAND_OK;
+        dlPQclear(qres);
+    }
 
-    if(dlPQresultStatus(stmt->result) == PGRES_COMMAND_OK){
+    if(bresult){
+        if(!isGetPrepared)
+        {
+            qres = dlPQdescribePrepared(self, stmt->name);
+            stmt->param_count = dlPQnparams(qres);
+            dlPQclear(qres);
+        }
         sq_pushroottable(v);
         sq_pushstring(v, PostgreSQL_Statement_TAG, -1);
         if(sq_get(v, -2) == SQ_OK){
@@ -627,6 +933,14 @@ static SQRESULT sq_pgsql_prepare(HSQUIRRELVM v){
     return sq_throwerror(v, dlPQerrorMessage(self));
 }
 
+static SQRESULT sq_pgsql_prepare(HSQUIRRELVM v){
+    return sq_pgsql_do_prepare(v, 0);
+}
+
+static SQRESULT sq_pgsql_get_prepared(HSQUIRRELVM v){
+    return sq_pgsql_do_prepare(v, 1);
+}
+
 static SQRESULT sq_pgsql_error_message(HSQUIRRELVM v){
 	SQ_FUNC_VARS_NO_TOP(v);
 	GET_pgsql_INSTANCE();
@@ -767,6 +1081,45 @@ static SQRESULT sq_pgsql_delete_blob_field(HSQUIRRELVM v){
 	return 1;
 }
 
+static SQRESULT sq_pgsql_escape_string(HSQUIRRELVM v){
+	SQ_FUNC_VARS_NO_TOP(v);
+	GET_pgsql_INSTANCE();
+	SQ_GET_STRING(v, 2, str);
+	char *escaped_str = dlPQescapeLiteral(self, str, str_size);
+	if(escaped_str) {
+		sq_pushstring(v, escaped_str, -1);
+		dlPQfreemem(escaped_str);
+		return 1;
+	}
+	return sq_throwerror(v, _SC("could not allocate escaped string"));
+}
+
+static SQRESULT sq_pgsql_escape_bytea(HSQUIRRELVM v){
+	SQ_FUNC_VARS_NO_TOP(v);
+	GET_pgsql_INSTANCE();
+	SQ_GET_STRING(v, 2, str);
+	char *escaped_str = (char*)dlPQescapeByteaConn(self, (const unsigned char*)str, str_size);
+	if(escaped_str) {
+		sq_pushstring(v, escaped_str, -1);
+		dlPQfreemem(escaped_str);
+		return 1;
+	}
+	return sq_throwerror(v, _SC("could not allocate escaped bytea"));
+}
+
+static SQRESULT sq_pgsql_unescape_bytea(HSQUIRRELVM v){
+	SQ_FUNC_VARS_NO_TOP(v);
+	GET_pgsql_INSTANCE();
+	SQ_GET_STRING(v, 2, str);
+	char *escaped_str = (char*)dlPQunescapeBytea((const unsigned char*)str, str_size);
+	if(escaped_str) {
+		sq_pushstring(v, escaped_str, -1);
+		dlPQfreemem(escaped_str);
+		return 1;
+	}
+	return sq_throwerror(v, _SC("could not allocate unescaped bytea"));
+}
+
 #define _DECL_FUNC(name,nparams,tycheck) {_SC(#name),  sq_pgsql_##name,nparams,tycheck}
 static SQRegFunction sq_pgsql_methods[] =
 {
@@ -776,12 +1129,16 @@ static SQRegFunction sq_pgsql_methods[] =
 	_DECL_FUNC(exec_scalar,  2, _SC("xs")),
 	_DECL_FUNC(exec_query,  2, _SC("xs")),
 	_DECL_FUNC(prepare,  2, _SC("xs")),
+	_DECL_FUNC(get_prepared,  2, _SC("xs")),
 	_DECL_FUNC(error_message,  1, _SC("x")),
 	_DECL_FUNC(version,  1, _SC("x")),
 	_DECL_FUNC(get_blob_field,  2, _SC("xi")),
 	_DECL_FUNC(insert_blob_field,  3, _SC("xsb")),
 	_DECL_FUNC(update_blob_field,  3, _SC("xisb")),
 	_DECL_FUNC(delete_blob_field,  2, _SC("xi")),
+	_DECL_FUNC(escape_string,  2, _SC("xs")),
+	_DECL_FUNC(escape_bytea,  2, _SC("xs")),
+	_DECL_FUNC(unescape_bytea,  2, _SC("xs")),
 	{0,0}
 };
 #undef _DECL_FUNC