Преглед на файлове

Bug: Some patterns can overflow the C stack, due to recursion
(Took the opportunity to refactor function 'match')

Roberto Ierusalimschy преди 13 години
родител
ревизия
6625cbecd1
променени са 1 файла, в които са добавени 124 реда и са изтрити 77 реда
  1. 124 77
      lstrlib.c

+ 124 - 77
lstrlib.c

@@ -1,5 +1,5 @@
 /*
-** $Id: lstrlib.c,v 1.175 2012/04/20 13:16:48 roberto Exp roberto $
+** $Id: lstrlib.c,v 1.176 2012/05/23 15:37:09 roberto Exp roberto $
 ** Standard library for string operations and pattern-matching
 ** See Copyright Notice in lua.h
 */
@@ -194,7 +194,9 @@ static int str_dump (lua_State *L) {
 #define CAP_UNFINISHED	(-1)
 #define CAP_POSITION	(-2)
 
+
 typedef struct MatchState {
+  int matchdepth;  /* control for recursive depth (to avoid C stack overflow) */
   const char *src_init;  /* init of source string */
   const char *src_end;  /* end ('\0') of source string */
   const char *p_end;  /* end ('\0') of pattern */
@@ -207,6 +209,16 @@ typedef struct MatchState {
 } MatchState;
 
 
+/* recursive function */
+static const char *match (MatchState *ms, const char *s, const char *p);
+
+
+/* maximum recursion depth for 'match' */
+#if !defined(MAXCCALLS)
+#define MAXCCALLS	200
+#endif
+
+
 #define L_ESC		'%'
 #define SPECIALS	"^$*+?.([%-"
 
@@ -294,19 +306,22 @@ static int matchbracketclass (int c, const char *p, const char *ec) {
 }
 
 
-static int singlematch (int c, const char *p, const char *ep) {
-  switch (*p) {
-    case '.': return 1;  /* matches any char */
-    case L_ESC: return match_class(c, uchar(*(p+1)));
-    case '[': return matchbracketclass(c, p, ep-1);
-    default:  return (uchar(*p) == c);
+static int singlematch (MatchState *ms, const char *s, const char *p,
+                        const char *ep) {
+  if (s >= ms->src_end)
+    return 0;
+  else {
+    int c = uchar(*s);
+    switch (*p) {
+      case '.': return 1;  /* matches any char */
+      case L_ESC: return match_class(c, uchar(*(p+1)));
+      case '[': return matchbracketclass(c, p, ep-1);
+      default:  return (uchar(*p) == c);
+    }
   }
 }
 
 
-static const char *match (MatchState *ms, const char *s, const char *p);
-
-
 static const char *matchbalance (MatchState *ms, const char *s,
                                    const char *p) {
   if (p >= ms->p_end - 1)
@@ -331,7 +346,7 @@ static const char *matchbalance (MatchState *ms, const char *s,
 static const char *max_expand (MatchState *ms, const char *s,
                                  const char *p, const char *ep) {
   ptrdiff_t i = 0;  /* counts maximum expand for item */
-  while ((s+i)<ms->src_end && singlematch(uchar(*(s+i)), p, ep))
+  while (singlematch(ms, s + i, p, ep))
     i++;
   /* keeps trying to match with the maximum repetitions */
   while (i>=0) {
@@ -349,7 +364,7 @@ static const char *min_expand (MatchState *ms, const char *s,
     const char *res = match(ms, s, ep+1);
     if (res != NULL)
       return res;
-    else if (s<ms->src_end && singlematch(uchar(*s), p, ep))
+    else if (singlematch(ms, s, p, ep))
       s++;  /* try with one more repetition */
     else return NULL;
   }
@@ -393,79 +408,105 @@ static const char *match_capture (MatchState *ms, const char *s, int l) {
 
 
 static const char *match (MatchState *ms, const char *s, const char *p) {
+  if (ms->matchdepth-- == 0)
+    luaL_error(ms->L, "pattern too complex");
   init: /* using goto's to optimize tail recursion */
-  if (p == ms->p_end)  /* end of pattern? */
-    return s;  /* match succeeded */
-  switch (*p) {
-    case '(': {  /* start capture */
-      if (*(p+1) == ')')  /* position capture? */
-        return start_capture(ms, s, p+2, CAP_POSITION);
-      else
-        return start_capture(ms, s, p+1, CAP_UNFINISHED);
-    }
-    case ')': {  /* end capture */
-      return end_capture(ms, s, p+1);
-    }
-    case '$': {
-      if ((p+1) == ms->p_end)  /* is the `$' the last char in pattern? */
-        return (s == ms->src_end) ? s : NULL;  /* check end of string */
-      else goto dflt;
-    }
-    case L_ESC: {  /* escaped sequences not in the format class[*+?-]? */
-      switch (*(p+1)) {
-        case 'b': {  /* balanced string? */
-          s = matchbalance(ms, s, p+2);
-          if (s == NULL) return NULL;
-          p+=4; goto init;  /* else return match(ms, s, p+4); */
-        }
-        case 'f': {  /* frontier? */
-          const char *ep; char previous;
-          p += 2;
-          if (*p != '[')
-            luaL_error(ms->L, "missing " LUA_QL("[") " after "
-                               LUA_QL("%%f") " in pattern");
-          ep = classend(ms, p);  /* points to what is next */
-          previous = (s == ms->src_init) ? '\0' : *(s-1);
-          if (matchbracketclass(uchar(previous), p, ep-1) ||
-             !matchbracketclass(uchar(*s), p, ep-1)) return NULL;
-          p=ep; goto init;  /* else return match(ms, s, ep); */
-        }
-        case '0': case '1': case '2': case '3':
-        case '4': case '5': case '6': case '7':
-        case '8': case '9': {  /* capture results (%0-%9)? */
-          s = match_capture(ms, s, uchar(*(p+1)));
-          if (s == NULL) return NULL;
-          p+=2; goto init;  /* else return match(ms, s, p+2) */
-        }
-        default: goto dflt;
+  if (p != ms->p_end) {  /* end of pattern? */
+    switch (*p) {
+      case '(': {  /* start capture */
+        if (*(p + 1) == ')')  /* position capture? */
+          s = start_capture(ms, s, p + 2, CAP_POSITION);
+        else
+          s = start_capture(ms, s, p + 1, CAP_UNFINISHED);
+        break;
       }
-    }
-    default: dflt: {  /* pattern class plus optional suffix */
-      const char *ep = classend(ms, p);  /* points to what is next */
-      int m = s < ms->src_end && singlematch(uchar(*s), p, ep);
-      switch (*ep) {
-        case '?': {  /* optional */
-          const char *res;
-          if (m && ((res=match(ms, s+1, ep+1)) != NULL))
-            return res;
-          p=ep+1; goto init;  /* else return match(ms, s, ep+1); */
-        }
-        case '*': {  /* 0 or more repetitions */
-          return max_expand(ms, s, p, ep);
-        }
-        case '+': {  /* 1 or more repetitions */
-          return (m ? max_expand(ms, s+1, p, ep) : NULL);
+      case ')': {  /* end capture */
+        s = end_capture(ms, s, p + 1);
+        break;
+      }
+      case '$': {
+        if ((p + 1) != ms->p_end)  /* is the `$' the last char in pattern? */
+          goto dflt;  /* no; go to default */
+        s = (s == ms->src_end) ? s : NULL;  /* check end of string */
+        break;
+      }
+      case L_ESC: {  /* escaped sequences not in the format class[*+?-]? */
+        switch (*(p + 1)) {
+          case 'b': {  /* balanced string? */
+            s = matchbalance(ms, s, p + 2);
+            if (s != NULL) {
+              p += 4; goto init;  /* return match(ms, s, p + 4); */
+            }  /* else fail (s == NULL) */
+            break;
+          }
+          case 'f': {  /* frontier? */
+            const char *ep; char previous;
+            p += 2;
+            if (*p != '[')
+              luaL_error(ms->L, "missing " LUA_QL("[") " after "
+                                 LUA_QL("%%f") " in pattern");
+            ep = classend(ms, p);  /* points to what is next */
+            previous = (s == ms->src_init) ? '\0' : *(s - 1);
+            if (!matchbracketclass(uchar(previous), p, ep - 1) &&
+               matchbracketclass(uchar(*s), p, ep - 1)) {
+              p = ep; goto init;  /* return match(ms, s, ep); */
+            }
+            s = NULL;  /* match failed */
+            break;
+          }
+          case '0': case '1': case '2': case '3':
+          case '4': case '5': case '6': case '7':
+          case '8': case '9': {  /* capture results (%0-%9)? */
+            s = match_capture(ms, s, uchar(*(p + 1)));
+            if (s != NULL) {
+              p += 2; goto init;  /* return match(ms, s, p + 2) */
+            }
+            break;
+          }
+          default: goto dflt;
         }
-        case '-': {  /* 0 or more repetitions (minimum) */
-          return min_expand(ms, s, p, ep);
+        break;
+      }
+      default: dflt: {  /* pattern class plus optional suffix */
+        const char *ep = classend(ms, p);  /* points to optional suffix */
+        /* does not match at least once? */
+        if (!singlematch(ms, s, p, ep)) {
+          if (*ep == '*' || *ep == '?' || *ep == '-') {  /* accept empty? */
+            p = ep + 1; goto init;  /* return match(ms, s, ep + 1); */
+          }
+          else  /* '+' or no suffix */
+            s = NULL;  /* fail */
         }
-        default: {
-          if (!m) return NULL;
-          s++; p=ep; goto init;  /* else return match(ms, s+1, ep); */
+        else {  /* matched once */
+          switch (*ep) {  /* handle optional suffix */
+            case '?': {  /* optional */
+              const char *res;
+              if ((res = match(ms, s + 1, ep + 1)) != NULL)
+                s = res;
+              else {
+                p = ep + 1; goto init;  /* else return match(ms, s, ep + 1); */
+              }
+              break;
+            }
+            case '+':  /* 1 or more repetitions */
+              s++;  /* 1 match already done */
+              /* go through */
+            case '*':  /* 0 or more repetitions */
+              s = max_expand(ms, s, p, ep);
+              break;
+            case '-':  /* 0 or more repetitions (minimum) */
+              s = min_expand(ms, s, p, ep);
+              break;
+            default:  /* no suffix */
+              s++; p = ep; goto init;  /* return match(ms, s + 1, ep); */
+          }
         }
+        break;
       }
     }
   }
+  ms->matchdepth++;
+  return s;
 }
 
 
@@ -561,12 +602,14 @@ static int str_find_aux (lua_State *L, int find) {
       p++; lp--;  /* skip anchor character */
     }
     ms.L = L;
+    ms.matchdepth = MAXCCALLS;
     ms.src_init = s;
     ms.src_end = s + ls;
     ms.p_end = p + lp;
     do {
       const char *res;
       ms.level = 0;
+      lua_assert(ms.matchdepth == MAXCCALLS);
       if ((res=match(&ms, s1, p)) != NULL) {
         if (find) {
           lua_pushinteger(L, s1 - s + 1);  /* start */
@@ -600,6 +643,7 @@ static int gmatch_aux (lua_State *L) {
   const char *p = lua_tolstring(L, lua_upvalueindex(2), &lp);
   const char *src;
   ms.L = L;
+  ms.matchdepth = MAXCCALLS;
   ms.src_init = s;
   ms.src_end = s+ls;
   ms.p_end = p + lp;
@@ -608,6 +652,7 @@ static int gmatch_aux (lua_State *L) {
        src++) {
     const char *e;
     ms.level = 0;
+    lua_assert(ms.matchdepth == MAXCCALLS);
     if ((e = match(&ms, src, p)) != NULL) {
       lua_Integer newstart = e-s;
       if (e == src) newstart++;  /* empty match? go at least one position */
@@ -705,12 +750,14 @@ static int str_gsub (lua_State *L) {
     p++; lp--;  /* skip anchor character */
   }
   ms.L = L;
+  ms.matchdepth = MAXCCALLS;
   ms.src_init = src;
   ms.src_end = src+srcl;
   ms.p_end = p + lp;
   while (n < max_s) {
     const char *e;
     ms.level = 0;
+    lua_assert(ms.matchdepth == MAXCCALLS);
     e = match(&ms, src, p);
     if (e) {
       n++;