Browse Source

patterns now accept '\0' as a regular character

Roberto Ierusalimschy 15 years ago
parent
commit
4541243355
1 changed files with 38 additions and 28 deletions
  1. 38 28
      lstrlib.c

+ 38 - 28
lstrlib.c

@@ -1,5 +1,5 @@
 /*
-** $Id: lstrlib.c,v 1.148 2010/01/04 16:37:19 roberto Exp roberto $
+** $Id: lstrlib.c,v 1.149 2010/04/09 16:14:46 roberto Exp roberto $
 ** Standard library for string operations and pattern-matching
 ** See Copyright Notice in lua.h
 */
@@ -180,7 +180,8 @@ static int str_dump (lua_State *L) {
 
 typedef struct MatchState {
   const char *src_init;  /* init of source string */
-  const char *src_end;  /* end (`\0') of source string */
+  const char *src_end;  /* end ('\0') of source string */
+  const char *p_end;  /* end ('\0') of pattern */
   lua_State *L;
   int level;  /* total number of captures (finished or unfinished) */
   struct {
@@ -213,16 +214,16 @@ static int capture_to_close (MatchState *ms) {
 static const char *classend (MatchState *ms, const char *p) {
   switch (*p++) {
     case L_ESC: {
-      if (*p == '\0')
+      if (p == ms->p_end)
         luaL_error(ms->L, "malformed pattern (ends with " LUA_QL("%%") ")");
       return p+1;
     }
     case '[': {
       if (*p == '^') p++;
       do {  /* look for a `]' */
-        if (*p == '\0')
+        if (p == ms->p_end)
           luaL_error(ms->L, "malformed pattern (missing " LUA_QL("]") ")");
-        if (*(p++) == L_ESC && *p != '\0')
+        if (*(p++) == L_ESC && p < ms->p_end)
           p++;  /* skip escapes (e.g. `%]') */
       } while (*p != ']');
       return p+1;
@@ -246,7 +247,7 @@ static int match_class (int c, int cl) {
     case 'u' : res = isupper(c); break;
     case 'w' : res = isalnum(c); break;
     case 'x' : res = isxdigit(c); break;
-    case 'z' : res = (c == 0); break;
+    case 'z' : res = (c == 0); break;  /* deprecated option */
     default: return (cl == c);
   }
   return (islower(cl) ? res : !res);
@@ -291,8 +292,9 @@ static const char *match (MatchState *ms, const char *s, const char *p);
 
 static const char *matchbalance (MatchState *ms, const char *s,
                                    const char *p) {
-  if (*p == 0 || *(p+1) == 0)
-    luaL_error(ms->L, "unbalanced pattern");
+  if (p >= ms->p_end - 1)
+    luaL_error(ms->L, "malformed pattern "
+                      "(missing arguments to " LUA_QL("%%b") ")");
   if (*s != *p) return NULL;
   else {
     int b = *p;
@@ -375,6 +377,8 @@ static const char *match_capture (MatchState *ms, const char *s, int l) {
 
 static const char *match (MatchState *ms, const char *s, const char *p) {
   init: /* using goto's to optimize tail recursion */
+  if (p == ms->p_end)  /* end of pattern? */
+    return s;  /* match succeeded */
   switch (*p) {
     case '(': {  /* start capture */
       if (*(p+1) == ')')  /* position capture? */
@@ -385,11 +389,8 @@ static const char *match (MatchState *ms, const char *s, const char *p) {
     case ')': {  /* end capture */
       return end_capture(ms, s, p+1);
     }
-    case '\0': {  /* end of pattern */
-      return s;  /* match succeeded */
-    }
     case '$': {
-      if (*(p+1) == '\0')  /* is the `$' the last char in pattern? */
+      if ((p+1) == ms->p_end)  /* is the `$' the last char in pattern? */
         return (s == ms->src_end) ? s : NULL;  /* check end of string */
       else goto dflt;
     }
@@ -419,12 +420,12 @@ static const char *match (MatchState *ms, const char *s, const char *p) {
           if (s == NULL) return NULL;
           p+=2; goto init;  /* else return match(ms, s, p+2) */
         }
-        default: break;  /* go through to 'dflt' */
+        default: goto dflt;
       }
     }
     default: dflt: {  /* pattern class plus optional sufix */
       const char *ep = classend(ms, p);  /* points to what is next */
-      int m = s<ms->src_end && singlematch(uchar(*s), p, ep);
+      int m = s < ms->src_end && singlematch(uchar(*s), p, ep);
       switch (*ep) {
         case '?': {  /* optional */
           const char *res;
@@ -504,32 +505,36 @@ static int push_captures (MatchState *ms, const char *s, const char *e) {
 
 
 static int str_find_aux (lua_State *L, int find) {
-  size_t l1, l2;
-  const char *s = luaL_checklstring(L, 1, &l1);
-  const char *p = luaL_checklstring(L, 2, &l2);
-  size_t init = posrelat(luaL_optinteger(L, 3, 1), l1);
+  size_t ls, lp;
+  const char *s = luaL_checklstring(L, 1, &ls);
+  const char *p = luaL_checklstring(L, 2, &lp);
+  size_t init = posrelat(luaL_optinteger(L, 3, 1), ls);
   if (init < 1) init = 1;
-  else if (init > l1 + 1) {  /* start after string's end? */
+  else if (init > ls + 1) {  /* start after string's end? */
     lua_pushnil(L);  /* cannot find anything */
     return 1;
   }
   if (find && (lua_toboolean(L, 4) ||  /* explicit request? */
       strpbrk(p, SPECIALS) == NULL)) {  /* or no special characters? */
     /* do a plain search */
-    const char *s2 = lmemfind(s + init - 1, l1 - init + 1, p, l2);
+    const char *s2 = lmemfind(s + init - 1, ls - init + 1, p, lp);
     if (s2) {
       lua_pushinteger(L, s2 - s + 1);
-      lua_pushinteger(L, s2 - s + l2);
+      lua_pushinteger(L, s2 - s + lp);
       return 2;
     }
   }
   else {
     MatchState ms;
-    int anchor = (*p == '^') ? (p++, 1) : 0;
     const char *s1 = s + init - 1;
+    int anchor = (*p == '^');
+    if (anchor) {
+      p++; lp--;  /* skip anchor character */
+    }
     ms.L = L;
     ms.src_init = s;
-    ms.src_end = s + l1;
+    ms.src_end = s + ls;
+    ms.p_end = p + lp;
     do {
       const char *res;
       ms.level = 0;
@@ -561,13 +566,14 @@ static int str_match (lua_State *L) {
 
 static int gmatch_aux (lua_State *L) {
   MatchState ms;
-  size_t ls;
+  size_t ls, lp;
   const char *s = lua_tolstring(L, lua_upvalueindex(1), &ls);
-  const char *p = lua_tostring(L, lua_upvalueindex(2));
+  const char *p = lua_tolstring(L, lua_upvalueindex(2), &lp);
   const char *src;
   ms.L = L;
   ms.src_init = s;
   ms.src_end = s+ls;
+  ms.p_end = p + lp;
   for (src = s + (size_t)lua_tointeger(L, lua_upvalueindex(3));
        src <= ms.src_end;
        src++) {
@@ -659,12 +665,12 @@ static void add_value (MatchState *ms, luaL_Buffer *b, const char *s,
 
 
 static int str_gsub (lua_State *L) {
-  size_t srcl;
+  size_t srcl, lp;
   const char *src = luaL_checklstring(L, 1, &srcl);
-  const char *p = luaL_checkstring(L, 2);
+  const char *p = luaL_checklstring(L, 2, &lp);
   int tr = lua_type(L, 3);
   size_t max_s = luaL_optinteger(L, 4, srcl+1);
-  int anchor = (*p == '^') ? (p++, 1) : 0;
+  int anchor = (*p == '^');
   size_t n = 0;
   MatchState ms;
   luaL_Buffer b;
@@ -672,9 +678,13 @@ static int str_gsub (lua_State *L) {
                    tr == LUA_TFUNCTION || tr == LUA_TTABLE, 3,
                       "string/function/table expected");
   luaL_buffinit(L, &b);
+  if (anchor) {
+    p++; lp--;  /* skip anchor character */
+  }
   ms.L = L;
   ms.src_init = src;
   ms.src_end = src+srcl;
+  ms.p_end = p + lp;
   while (n < max_s) {
     const char *e;
     ms.level = 0;