Browse Source

Fix SQLite3.create_aggretate function, it never worked before, my bad

mingodad 5 years ago
parent
commit
5629c23fac
2 changed files with 60 additions and 19 deletions
  1. 20 18
      SquiLu-ext/sq_sqlite3.cpp
  2. 40 1
      SquiLu/samples/test-sqlite3.nut

+ 20 - 18
SquiLu-ext/sq_sqlite3.cpp

@@ -3093,11 +3093,11 @@ static SQRESULT sq_sqlite3_context__tostring(HSQUIRRELVM v)
 static SQRESULT sq_sqlite3_context_check_aggregate(HSQUIRRELVM v, sq_sqlite3_context_st *ctx)
 {
     sq_sqlite3_sdb_func *func = (sq_sqlite3_sdb_func*)sqlite3_user_data(ctx->ctx);
-    if (sq_isclosure(func->fn_finalize))
+    if (!sq_isclosure(func->fn_finalize))
     {
         return sq_throwerror(v, "attempt to call aggregate method from scalar function");
     }
-    return 1;
+    return SQ_OK;
 }
 
 static SQRESULT sq_sqlite3_context_user_data(HSQUIRRELVM v)
@@ -3113,10 +3113,12 @@ static SQRESULT sq_sqlite3_context_aggregate_data(HSQUIRRELVM v)
 {
     SQ_FUNC_VARS(v);
     GET_sqlite3_context_INSTANCE();
-    sq_sqlite3_context_check_aggregate(v, self);
+    SQRESULT rc = sq_sqlite3_context_check_aggregate(v, self);
+    if(rc != SQ_OK) return rc;
     if(_top_ < 2)
     {
         sq_pushobject(v, self->udata);
+        return 1;
     }
     else
     {
@@ -3349,7 +3351,7 @@ static void db_sql_normal_function(sqlite3_context *context, int argc, sqlite3_v
     HSQUIRRELVM v = func->sdb->v;
 
     int n;
-    sq_sqlite3_context_st *ctx;
+    sq_sqlite3_context_st *ctx = NULL;
     SQInteger top = sq_gettop(v);
 
     /* ensure there is enough space in the stack */
@@ -3375,7 +3377,6 @@ static void db_sql_normal_function(sqlite3_context *context, int argc, sqlite3_v
         if(sq_rawget(v, -2) != SQ_OK)
         {
             /* not yet created? */
-            sq_poptop(v); //remove null
             new_context_instance(v, &ctx);
             sq_pushuserpointer(v, p);
             sq_push(v, -2);
@@ -3391,7 +3392,7 @@ static void db_sql_normal_function(sqlite3_context *context, int argc, sqlite3_v
     }
 
     /* set context */
-    ctx->ctx = context;
+    if(ctx) ctx->ctx = context;
 
     if (sq_call(v, argc + 2, SQFalse, SQFalse) != SQ_OK)   //2 = roottable + ctx
     {
@@ -3405,12 +3406,9 @@ static void db_sql_finalize_function(sqlite3_context *context)
     sq_sqlite3_sdb_func *func = (sq_sqlite3_sdb_func*)sqlite3_user_data(context);
     HSQUIRRELVM v = func->sdb->v;
     void *p = sqlite3_aggregate_context(context, 1); /* minimal mem usage */
-    sq_sqlite3_context_st *ctx;
     SQInteger top = sq_gettop(v);
-
     sq_pushobject(v, func->fn_finalize);
     sq_pushroottable(v);
-
     /* i think it is OK to use assume that using a light user data
     ** as an entry on SquiLu REGISTRY table will be unique */
     sq_pushregistrytable(v);
@@ -3419,19 +3417,23 @@ static void db_sql_finalize_function(sqlite3_context *context)
     /* context table */
     if(sq_deleteslot(v, -2, SQTrue) != SQ_OK)
     {
-        /* not yet created? - shouldn't happen in finalize function */
-        sq_pop(v, 1);
-        new_context_instance(v, &ctx);
-        sq_pushuserpointer(v, p);
-        sq_push(v, -2);
-        sq_rawset(v, -4);
+        sqlite3_result_error(context, "Unexpected missing sqlite3_aggregate_context", -1);
+        sq_settop(v, top);
+        return;
     }
     sq_remove(v, -2); //registrytable
 
-    /* set context */
-    ctx->ctx = context;
+    sq_sqlite3_context_st *self;
+    SQRESULT rc = sq_getinstanceup(v, -1, (void**)&self, (void*)sq_sqlite3_context_TAG);
+    if(rc != SQ_OK)
+    {
+        sqlite3_result_error(context, "Unexpected missing sqlite3_aggregate_context on stack", -1);
+        sq_settop(v, top);
+        return;
+    }
+    self->ctx = context;
 
-    if (sq_call(v, 1, SQFalse, SQFalse) != SQ_OK)
+    if (sq_call(v, 2, SQFalse, SQFalse) != SQ_OK) //2 = roottable + ctx
     {
         sqlite3_result_error(context, sq_getlasterror_str(v), -1);
     }

+ 40 - 1
SquiLu/samples/test-sqlite3.nut

@@ -152,6 +152,14 @@ for(local i=0; i<count; ++i){
 }
 print("SQL prepared function took:", os.clock() -now);
 
+now = os.clock();
+for(local i=0; i<count; ++i){
+	stmt.reset();
+	stmt.step()
+	local val = stmt[0];
+}
+print("SQL prepared2 function took:", os.clock() -now);
+
 stmt.finalize();
 stmt_squilu.finalize();
 
@@ -161,5 +169,36 @@ stmt = db.prepare("select * from test_slice");
 if(stmt.next_row()) print("col_slice", stmt.col_slice(0, 2, 5));
 stmt.finalize();
 
+local function sq_concat_xStep(ctx, value)
+{
+	local buf = ctx.aggregate_data();
+	if(!buf)
+	{
+		buf = blob();
+		ctx.aggregate_data(buf);
+		buf.write(value);//first value
+	}
+	else
+	{
+		//print("sq_concat_xStep", ctx, buf, buf.len(), value);
+		buf.write( "::", value);
+	}
+}
+
+local function sq_concat_xFinal(ctx)
+{
+	local buf = ctx.aggregate_data();
+	//print("sq_concat_xFinal", ctx, buf, buf.len());
+	//print(buf.tostring());
+	ctx.result_text(buf.tostring());
+	buf.clear();
+}
+
+db.create_aggregate("sq_concat", 1, sq_concat_xStep, sq_concat_xFinal);
+db.exec_dml("insert into test_slice values('another text')");
+print(db.exec_get_one("select group_concat(value, '::') gc from test_slice"));
+print(db.exec_get_one("select sq_concat(value) gc from test_slice"));
+print(db.exec_get_one("select sq_concat(value) gc from test_slice"));
+
 
-db.close();
+db.close();