Преглед на файлове

change hdecode() to newdecoder()

Xavier Wang преди 4 години
родител
ревизия
8ef845677b
променени са 2 файла, в които са добавени 189 реда и са изтрити 90 реда
  1. 151 85
      mp.c
  2. 38 5
      test.lua

+ 151 - 85
mp.c

@@ -280,14 +280,29 @@ static void lmp_writeext(lmp_Buffer *B, int type, const char *s, size_t len) {
 static int lmp_addfloat(lmp_Buffer *B, int idx, int len)
 { lmp_writefloat(B, lua_tonumber(B->L, idx), len); return 1; }
 
-static int lmp_error(lmp_Buffer *B, int idx, int prev, const char *fmt, ...) {
+static int lmp_relindex(int idx, int onstack)
+{ return idx > 0 || idx <= LUA_REGISTRYINDEX ? idx : idx - onstack; }
+
+static int lmp_error(lmp_Buffer *B, int idx, const char *fmt, ...) {
+    va_list l;
+    va_start(l, fmt);
+    lua_pushvfstring(B->L, fmt, l);
+    va_end(l);
+    lua_replace(B->L, lmp_relindex(idx, 1));
+    lua_settop(B->L, idx);
+    return 0;
+}
+
+static int lmp_chain(lmp_Buffer *B, int idx, int prev, const char *fmt, ...) {
     va_list l;
     va_start(l, fmt);
     lua_pushvfstring(B->L, fmt, l);
     va_end(l);
-    if (prev) lua_pushfstring(B->L, "%s;\n\t%s",
-                 lua_tostring(B->L, -1), lua_tostring(B->L, prev));
-    lua_replace(B->L, idx), lua_settop(B->L, idx);
+    lua_pushliteral(B->L, ";\n\t");
+    lua_pushvalue(B->L, lmp_relindex(prev, 2));
+    lua_concat(B->L, 3);
+    lua_replace(B->L, lmp_relindex(idx, 1));
+    lua_settop(B->L, idx);
     return 0;
 }
 
@@ -295,7 +310,7 @@ static int lmp_addinteger(lmp_Buffer *B, int idx, int uint) {
     int isint;
     lua_Integer i = lua53_tointegerx(B->L, idx, &isint);
     if (!isint)
-        return lmp_error(B, idx,0, "integer expected, got %s",
+        return lmp_error(B, idx, "integer expected, got %s",
                 luaL_typename(B->L, idx));
     if (uint) lmp_writeint(B, (lmp_U64)i, 1);
     else      lmp_writeint(B, (lmp_U64)i, 0);
@@ -310,11 +325,53 @@ static int lmp_addstring(lmp_Buffer *B, int idx, int type) {
     return 1;
 }
 
+static int lmp_addext(lmp_Buffer *B, int idx, int fetch) {
+    int type, tt, vt;
+    size_t len;
+    const char *s;
+    tt = fetch ? lua53_getfield(B->L, idx, "type") : lua_type(B->L, idx);
+    if (tt != LUA_TNUMBER)
+        return lmp_error(B, idx,
+                "integer expected for extension type, got %s",
+                lua_typename(B->L, tt));
+    type = (int)lua_tointeger(B->L, fetch ? -1 : idx);
+    if (type < -128 || type > 127)
+        return lmp_error(B, idx, "invalid extension type: %d", type);
+    vt = fetch ? lua53_getfield(B->L, idx, "value") : lua_type(B->L, idx+1);
+    if (vt != LUA_TSTRING)
+        return lmp_error(B, idx,
+                "string expected for extension value, got %s",
+                lua_typename(B->L, vt));
+    s = lua_tolstring(B->L, fetch ? -1 : idx+1, &len);
+    if (len > LMP_MAX_SIZE) return lmp_toobig(B,idx,"extension",(int)len);
+    lmp_writeext(B, type, s, len);
+    if (fetch) lua_pop(B->L, 2);
+    return 1;
+}
+
+static int lmp_handlerresult(lmp_Buffer *B, int idx, int top) {
+    int r;
+    const char *type;
+    if ((type = lua_tostring(B->L, top)) == NULL)
+        return lmp_error(B, idx, "type expected from handler, got %s",
+                luaL_typename(B->L, top));
+    r = *type == 'e' ? lmp_addext(B, top+1, 0) : lmp_pack(B, top+1, type, 0);
+    if (r ==  0) return lmp_chain(B, idx,top+1, "error from handler");
+    if (r == -1) return lmp_error(B, idx, "invalid msgpack.type '%s'", type);
+    return (lua_pop(B->L, 3), 1);
+}
+
+static int lmp_encoderesult(lmp_Buffer *B, int idx, int r, int hidx) {
+    if (r < 0 && hidx == 0)
+        return lmp_error(B, idx, "invalid type '%s'", lua_typename(B->L, -r));
+    return r;
+}
+
 static int lmp_addarray(lmp_Buffer *B, int idx, int hidx) {
     int i, len = (int)luaL_len(B->L, idx);
     int top = lua_gettop(B->L);
     if (top > LMP_MAX_STACK || !lua_checkstack(B->L, 5))
-        return lmp_error(B, idx,0, "array level too deep");
+        return lmp_error(B, idx, "array level too deep");
     if (len < 16)
         lmp_addchar(B, 0x90 + len);
     else if (len < 0x10000) {
@@ -325,9 +382,17 @@ static int lmp_addarray(lmp_Buffer *B, int idx, int hidx) {
         lmp_writeuint(B, len, 4);
     }
     for (i = 1; i <= len; ++i) {
-        if (!lmp_encode(B, top+1, lua53_geti(B->L, idx, i), hidx))
-            return lmp_error(B, idx,top+1, "invalid element '%d' in array", i);
-        lua_pop(B->L, 1);
+        int r = lmp_encode(B, top+1, lua53_geti(B->L, idx, i), hidx);
+        if ((r = lmp_encoderesult(B, top+1, r, hidx)) < 0) {
+            lua_pushvalue(B->L, hidx);
+            lua_insert(B->L, -2);
+            lua_pushinteger(B->L, i);
+            lua_pushvalue(B->L, idx);
+            lua_call(B->L, 3, 3);
+            r = lmp_handlerresult(B, top+1, top+1);
+        }
+        if (!r) return lmp_chain(B, idx,top+1, "invalid element '%d' in array", i);
+        lua_settop(B->L, top);
     }
     return 1;
 }
@@ -354,63 +419,44 @@ static int lmp_addmap(lmp_Buffer *B, int idx, int hidx) {
     unsigned off = B->len + 1, count = 0;
     int top = lua_gettop(B->L);
     if (top > LMP_MAX_STACK || !lua_checkstack(B->L, 10))
-        return lmp_error(B, idx,0, "map level too deep");
+        return lmp_error(B, idx, "map level too deep");
     lmp_addchar(B, 0x80);
     lua_pushnil(B->L);
     for (; lua_next(B->L, idx); ++count) {
-        if (!lmp_encode(B, top+1, 0, hidx))
-            return lmp_error(B, idx,top+1, "invalid key in map");
-        if (!lmp_encode(B, top+2, 0, hidx))
-            return lmp_error(B, idx,top+2, "invalid value for key '%s' in map",
+        int r = lmp_encode(B, top+1, 0, hidx);
+        if ((r = lmp_encoderesult(B, top+1, r, hidx)) < 0) {
+            lua_pushvalue(B->L, hidx);
+            lua_pushvalue(B->L, top+1);
+            lua_pushnil(B->L); /* for key, the key is nil */
+            lua_pushvalue(B->L, idx);
+            lua_call(B->L, 3, 3);
+            r = lmp_handlerresult(B, top+1, top+3);
+        }
+        if (!r) return lmp_chain(B, idx, top+1, "invalid key in map");
+        r = lmp_encode(B, top+2, 0, hidx);
+        if ((r = lmp_encoderesult(B, top+2, r, hidx)) < 0) {
+            lua_pushvalue(B->L, hidx);
+            lua_insert(B->L, -2);
+            lua_pushvalue(B->L, top+1);
+            lua_pushvalue(B->L, idx);
+            lua_call(B->L, 3, 3);
+            r = lmp_handlerresult(B, top+2, top+2);
+        }
+        if (!r)
+            return lmp_chain(B, idx,top+2, "invalid value for key '%s' in map",
                     luaL_tolstring(B->L, top+1, NULL));
-        lua_pop(B->L, 1);
+        lua_settop(B->L, top+1);
     }
     lmp_fixmapszie(B, off, count);
     return 1;
 }
 
-static int lmp_addext(lmp_Buffer *B, int idx, int fetch) {
-    int type, tt, vt;
-    size_t len;
-    const char *s;
-    tt = fetch ? lua53_getfield(B->L, idx, "type") : lua_type(B->L, idx);
-    if (tt != LUA_TNUMBER)
-        return lmp_error(B, idx,0,
-                "integer expected for extension type, got %s",
-                lua_typename(B->L, tt));
-    type = (int)lua_tointeger(B->L, fetch ? -1 : idx);
-    if (type < -128 || type > 127)
-        return lmp_error(B, idx,0, "invalid extension type: %d", type);
-    vt = fetch ? lua53_getfield(B->L, idx, "value") : lua_type(B->L, idx+1);
-    if (vt != LUA_TSTRING)
-        return lmp_error(B, idx,0,
-                "string expected for extension value, got %s",
-                lua_typename(B->L, vt));
-    s = lua_tolstring(B->L, fetch ? -1 : idx+1, &len);
-    if (len > LMP_MAX_SIZE) return lmp_toobig(B,idx,"extension",(int)len);
-    lmp_writeext(B, type, s, len);
-    if (fetch) lua_pop(B->L, 2);
-    return 1;
-}
-
-static int lmp_handlerresult(lmp_Buffer *B, int idx, int top) {
-    int r;
-    const char *type;
-    if ((type = lua_tostring(B->L, top)) == NULL)
-        return lmp_error(B, idx,0, "type expected from handler, got %s",
-                luaL_typename(B->L, top));
-    r = *type == 'e' ? lmp_addext(B, top+1, 0) : lmp_pack(B, top+1, type, 0);
-    if (r ==  0) return lmp_error(B, idx,top+1, "error from handler");
-    if (r == -1) return lmp_error(B, idx,0, "invalid msgpack.type '%s'", type);
-    return (lua_pop(B->L, 3), 1);
-}
-
 static int lmp_addhandler(lmp_Buffer *B, int idx, int fetch) {
     int top = lua_gettop(B->L)+1;
     if (!fetch)
         lua_pushvalue(B->L, idx);
     else if (lua53_getfield(B->L, idx, "pack") == LUA_TNIL)
-        return lmp_error(B, idx,0, "'pack' field expected in handler object");
+        return lmp_error(B, idx, "'pack' field expected in handler object");
     if (fetch) lua_pushvalue(B->L, idx);
     lua_call(B->L, fetch, 3);
     return lmp_handlerresult(B, idx, top);
@@ -422,11 +468,11 @@ static int lmp_check(lmp_Buffer *B, int idx, int fetch, int type) {
         rt = lua_type(B->L, idx);
     else {
         if ((rt = lua53_getfield(B->L, idx, "value")) == LUA_TNIL)
-            return lmp_error(B, idx,0, "'value' field expected in wrapper object");
+            return lmp_error(B, idx, "'value' field expected in wrapper object");
         lua_replace(B->L, idx);
     }
     return !type || rt == type ? 1 :
-        lmp_error(B, idx,0, "%s expected, got %s",
+        lmp_error(B, idx, "%s expected, got %s",
                 lua_typename(B->L, type), luaL_typename(B->L, idx));
 }
 
@@ -483,22 +529,28 @@ static int lmp_encode(lmp_Buffer *B, int idx, int type, int hidx) {
     case LUA_TFUNCTION: return lmp_addhandler(B, idx, 0);
     case LUA_TTABLE:    return lmp_addtable(B, idx, hidx);
     }
-    if (hidx) {
-        int top = lua_gettop(B->L)+1;
-        lua_pushvalue(B->L, idx);
-        lua_pushvalue(B->L, hidx);
-        lua_insert(B->L, -2);
-        lua_call(B->L, 1, 3);
-        return lmp_handlerresult(B, idx, top);
+    return -type;
+}
+
+static int Lencode_aux(lua_State *L) {
+    lmp_Buffer *B = (lmp_Buffer*)lua_touserdata(L, 1);
+    int i, top = lua_gettop(L);
+    B->L = L;
+    for (i = 2; i <= top; ++i) {
+        int r = lmp_encode(B, i, 0, 0);
+        if (!(r = lmp_encoderesult(B, i, r, 0)))
+            return luaL_error(L, "bad argument to #%d: %s",
+                    i-1, lua_tostring(L, i));
     }
-    return lmp_error(B, idx,0, "invalid type '%s'", lua_typename(B->L, type));
+    lua_pushlstring(L, (const char*)lmp_data(B), B->len);
+    return 1;
 }
 
-static int lmp_encode_helper(lua_State *L, lua_CFunction encode) {
+static int Lencode(lua_State *L) {
     lmp_Buffer B;
     int r;
     memset(&B, 0, sizeof(B));
-    lua_pushcfunction(L, encode);
+    lua_pushcfunction(L, Lencode_aux);
     lua_insert(L, 1);
     lua_pushlightuserdata(L, &B);
     lua_insert(L, 2);
@@ -507,35 +559,49 @@ static int lmp_encode_helper(lua_State *L, lua_CFunction encode) {
     return r ? 1 : luaL_error(L, "%s", lua_tostring(L, -1));
 }
 
-static int Lencode_aux(lua_State *L) {
+static int Lencoder_aux(lua_State *L) {
     lmp_Buffer *B = (lmp_Buffer*)lua_touserdata(L, 1);
     int i, top = lua_gettop(L);
     B->L = L;
-    for (i = 2; i <= top; ++i)
-        if (!lmp_encode(B, i, 0, 0))
-            return luaL_error(L, "bad argument to #%d: %s",
-                    i-1, lua_tostring(L, i));
+    for (i = 3; i <= top; ++i) {
+        int r = lmp_encode(B, i, 0, 2);
+        if ((r = lmp_encoderesult(B, i, r, 2)) < 0) {
+            lua_pushvalue(B->L, 2);
+            lua_pushvalue(B->L, i);
+            lua_call(B->L, 1, 3);
+            r = lmp_handlerresult(B, i, top+1);
+        }
+        if (r == 0) return luaL_error(L, "bad argument to #%d: %s",
+                i-2, lua_tostring(L, i));
+    }
     lua_pushlstring(L, (const char*)lmp_data(B), B->len);
     return 1;
 }
 
-static int Lhencode_aux(lua_State *L) {
-    lmp_Buffer *B = (lmp_Buffer*)lua_touserdata(L, 1);
-    int i, top = lua_gettop(L);
-    B->L = L;
-    for (i = 3; i <= top; ++i)
-        if (!lmp_encode(B, i, 0, 2))
-            return luaL_error(L, "bad argument to #%d: %s",
-                    i-1, lua_tostring(L, i));
-    lua_pushlstring(L, (const char*)lmp_data(B), B->len);
-    return 1;
+static int Lencoder(lua_State *L) {
+    lmp_Buffer B;
+    int r;
+    memset(&B, 0, sizeof(B));
+    lua_pushcfunction(L, Lencoder_aux);
+    lua_insert(L, 1);
+    lua_pushlightuserdata(L, &B);
+    lua_insert(L, 2);
+    lua_pushvalue(L, lua_upvalueindex(1));
+    lua_insert(L, 3);
+    r = lua_pcall(L, lua_gettop(L)-1, 1, 0) == LUA_OK;
+    lmp_resetbuffer(&B);
+    return r ? 1 : luaL_error(L, "%s", lua_tostring(L, -1));
 }
 
-static int Lencode(lua_State *L)
-{ return lmp_encode_helper(L, Lencode_aux); }
-
-static int Lhencode(lua_State *L)
-{ return lmp_encode_helper(L, Lhencode_aux); }
+static int Lnewencoder(lua_State *L) {
+    if (lua_isnoneornil(L, 1))
+        lua_pushcfunction(L, Lencode);
+    else {
+        lua_pushvalue(L, 1);
+        lua_pushcclosure(L, Lencoder, 1);
+    }
+    return 1;
+}
 
 
 /* decode */
@@ -792,7 +858,7 @@ LUALIB_API int luaopen_mp(lua_State *L) {
         ENTRY(map),
         ENTRY(meta),
         ENTRY(encode),
-        ENTRY(hencode),
+        ENTRY(newencoder),
         ENTRY(decode),
         ENTRY(fromhex),
         ENTRY(tohex),

+ 38 - 5
test.lua

@@ -127,6 +127,7 @@ end
 function _G.test_nil()
    check_eq(nil)
    eq(mp.decode(mp.encode(mp.null)), nil)
+   eq(mp.decode(mp.newencoder()(mp.null)), nil)
    eq(tostring(mp.null), "null")
 end
 
@@ -169,15 +170,46 @@ function _G.test_handler()
 
    local co = coroutine.create(function()end)
    local c = 0
-   local v = mp.hencode(function(v)
+   local v = mp.newencoder(function(v, k, t)
+      assert(k == nil)
+      assert(t == nil)
       if v == co then
          c = c + 1
          return "int", 1
       end
       return "nil"
-   end, co, co, co)
-   eq(c, 3)
-   eq(v, "\1\1\1")
+   end)(co, co, co)
+   eq(c, 3); eq(v, "\1\1\1")
+   c = 0; v = mp.newencoder(function(v, k, t)
+      assert(k >= 1 and k <= 3)
+      assert(type(t) == "table")
+      if v == co then
+         c = c + 1
+         return "int", 1
+      end
+      return "nil"
+   end)({co, co, co})
+   eq(c, 3); eq(v, "\x93\1\1\1")
+   c = 0; v = mp.newencoder(function(v, k, t)
+      assert(k == "foo")
+      assert(type(t) == "table")
+      if v == co then
+         c = c + 1
+         return "int", 1
+      end
+      return "nil"
+   end)({foo = co})
+   eq(c, 1); eq(v, "\x81\xA3foo\1")
+   c = 0; v = mp.newencoder(function(v, k, t)
+      assert(k == nil)
+      assert(type(t) == "table")
+      if v == co then
+         c = c + 1
+         return "int", 1
+      end
+      return "nil"
+   end)({[co] = "bar"})
+   eq(c, 1); eq(v, "\x81\1\xA3bar")
 end
 
 function _G.test_error()
@@ -234,4 +266,5 @@ else
    os.exit(u.LuaUnit.run(), true)
 end
 
--- cc: run='rm -f *.gcda; time lua test.lua; gcov mp.c'
+-- unixcc: run='rm -f *.gcda; time lua test.lua; gcov mp.c'
+-- win32cc: run='del /s/q *.gcda & lua test.lua & gcov mp.c'