Browse Source

Limit recursion depth in string.match() et al.

Mike Pall 13 years ago
parent
commit
ff00a78f3a
2 changed files with 41 additions and 26 deletions
  1. 40 26
      src/lib_string.c
  2. 1 0
      src/lj_errmsg.h

+ 40 - 26
src/lib_string.c

@@ -148,6 +148,7 @@ typedef struct MatchState {
   const char *src_end;  /* end (`\0') of source string */
   const char *src_end;  /* end (`\0') of source string */
   lua_State *L;
   lua_State *L;
   int level;  /* total number of captures (finished or unfinished) */
   int level;  /* total number of captures (finished or unfinished) */
+  int depth;
   struct {
   struct {
     const char *init;
     const char *init;
     ptrdiff_t len;
     ptrdiff_t len;
@@ -339,22 +340,26 @@ static const char *match_capture(MatchState *ms, const char *s, int l)
 
 
 static const char *match(MatchState *ms, const char *s, const char *p)
 static const char *match(MatchState *ms, const char *s, const char *p)
 {
 {
+  if (++ms->depth > LJ_MAX_XLEVEL)
+    lj_err_caller(ms->L, LJ_ERR_STRPATX);
   init: /* using goto's to optimize tail recursion */
   init: /* using goto's to optimize tail recursion */
   switch (*p) {
   switch (*p) {
   case '(':  /* start capture */
   case '(':  /* start capture */
     if (*(p+1) == ')')  /* position capture? */
     if (*(p+1) == ')')  /* position capture? */
-      return start_capture(ms, s, p+2, CAP_POSITION);
+      s = start_capture(ms, s, p+2, CAP_POSITION);
     else
     else
-      return start_capture(ms, s, p+1, CAP_UNFINISHED);
+      s = start_capture(ms, s, p+1, CAP_UNFINISHED);
+    break;
   case ')':  /* end capture */
   case ')':  /* end capture */
-    return end_capture(ms, s, p+1);
+    s = end_capture(ms, s, p+1);
+    break;
   case L_ESC:
   case L_ESC:
     switch (*(p+1)) {
     switch (*(p+1)) {
     case 'b':  /* balanced string? */
     case 'b':  /* balanced string? */
       s = matchbalance(ms, s, p+2);
       s = matchbalance(ms, s, p+2);
-      if (s == NULL) return NULL;
+      if (s == NULL) break;
       p+=4;
       p+=4;
-      goto init;  /* else return match(ms, s, p+4); */
+      goto init;  /* else s = match(ms, s, p+4); */
     case 'f': {  /* frontier? */
     case 'f': {  /* frontier? */
       const char *ep; char previous;
       const char *ep; char previous;
       p += 2;
       p += 2;
@@ -363,50 +368,59 @@ static const char *match(MatchState *ms, const char *s, const char *p)
       ep = classend(ms, p);  /* points to what is next */
       ep = classend(ms, p);  /* points to what is next */
       previous = (s == ms->src_init) ? '\0' : *(s-1);
       previous = (s == ms->src_init) ? '\0' : *(s-1);
       if (matchbracketclass(uchar(previous), p, ep-1) ||
       if (matchbracketclass(uchar(previous), p, ep-1) ||
-	 !matchbracketclass(uchar(*s), p, ep-1)) return NULL;
+	 !matchbracketclass(uchar(*s), p, ep-1)) { s = NULL; break; }
       p=ep;
       p=ep;
-      goto init;  /* else return match(ms, s, ep); */
+      goto init;  /* else s = match(ms, s, ep); */
       }
       }
     default:
     default:
       if (lj_char_isdigit(uchar(*(p+1)))) {  /* capture results (%0-%9)? */
       if (lj_char_isdigit(uchar(*(p+1)))) {  /* capture results (%0-%9)? */
 	s = match_capture(ms, s, uchar(*(p+1)));
 	s = match_capture(ms, s, uchar(*(p+1)));
-	if (s == NULL) return NULL;
+	if (s == NULL) break;
 	p+=2;
 	p+=2;
-	goto init;  /* else return match(ms, s, p+2) */
+	goto init;  /* else s = match(ms, s, p+2) */
       }
       }
       goto dflt;  /* case default */
       goto dflt;  /* case default */
     }
     }
+    break;
   case '\0':  /* end of pattern */
   case '\0':  /* end of pattern */
-    return s;  /* match succeeded */
+    break;  /* match succeeded */
   case '$':
   case '$':
-    if (*(p+1) == '\0')  /* is the `$' the last char in pattern? */
-      return (s == ms->src_end) ? s : NULL;  /* check end of string */
-    else
-      goto dflt;
+    /* is the `$' the last char in pattern? */
+    if (*(p+1) != '\0') goto dflt;
+    if (s != ms->src_end) s = NULL;  /* check end of string */
+    break;
   default: dflt: {  /* it is a pattern item */
   default: dflt: {  /* it is a pattern item */
     const char *ep = classend(ms, p);  /* points to what is next */
     const char *ep = classend(ms, p);  /* points to what is next */
     int m = s<ms->src_end && singlematch(uchar(*s), p, ep);
     int m = s<ms->src_end && singlematch(uchar(*s), p, ep);
     switch (*ep) {
     switch (*ep) {
     case '?': {  /* optional */
     case '?': {  /* optional */
       const char *res;
       const char *res;
-      if (m && ((res=match(ms, s+1, ep+1)) != NULL))
-	return res;
+      if (m && ((res=match(ms, s+1, ep+1)) != NULL)) {
+	s = res;
+	break;
+      }
       p=ep+1;
       p=ep+1;
-      goto init;  /* else return match(ms, s, ep+1); */
+      goto init;  /* else s = match(ms, s, ep+1); */
       }
       }
     case '*':  /* 0 or more repetitions */
     case '*':  /* 0 or more repetitions */
-      return max_expand(ms, s, p, ep);
+      s = max_expand(ms, s, p, ep);
+      break;
     case '+':  /* 1 or more repetitions */
     case '+':  /* 1 or more repetitions */
-      return (m ? max_expand(ms, s+1, p, ep) : NULL);
+      s = (m ? max_expand(ms, s+1, p, ep) : NULL);
+      break;
     case '-':  /* 0 or more repetitions (minimum) */
     case '-':  /* 0 or more repetitions (minimum) */
-      return min_expand(ms, s, p, ep);
+      s = min_expand(ms, s, p, ep);
+      break;
     default:
     default:
-      if (!m) return NULL;
-      s++; p=ep;
-      goto init;  /* else return match(ms, s+1, ep); */
+      if (m) { s++; p=ep; goto init; }  /* else s = match(ms, s+1, ep); */
+      s = NULL;
+      break;
     }
     }
+    break;
     }
     }
   }
   }
+  ms->depth--;
+  return s;
 }
 }
 
 
 static const char *lmemfind(const char *s1, size_t l1,
 static const char *lmemfind(const char *s1, size_t l1,
@@ -495,7 +509,7 @@ static int str_find_aux(lua_State *L, int find)
     ms.src_end = s+l1;
     ms.src_end = s+l1;
     do {
     do {
       const char *res;
       const char *res;
-      ms.level = 0;
+      ms.level = ms.depth = 0;
       if ((res=match(&ms, s1, p)) != NULL) {
       if ((res=match(&ms, s1, p)) != NULL) {
 	if (find) {
 	if (find) {
 	  lua_pushinteger(L, s1-s+1);  /* start */
 	  lua_pushinteger(L, s1-s+1);  /* start */
@@ -534,7 +548,7 @@ LJLIB_NOREG LJLIB_CF(string_gmatch_aux)
   ms.src_end = s + str->len;
   ms.src_end = s + str->len;
   for (; src <= ms.src_end; src++) {
   for (; src <= ms.src_end; src++) {
     const char *e;
     const char *e;
-    ms.level = 0;
+    ms.level = ms.depth = 0;
     if ((e = match(&ms, src, p)) != NULL) {
     if ((e = match(&ms, src, p)) != NULL) {
       int32_t pos = (int32_t)(e - s);
       int32_t pos = (int32_t)(e - s);
       if (e == src) pos++;  /* Ensure progress for empty match. */
       if (e == src) pos++;  /* Ensure progress for empty match. */
@@ -628,7 +642,7 @@ LJLIB_CF(string_gsub)
   ms.src_end = src+srcl;
   ms.src_end = src+srcl;
   while (n < max_s) {
   while (n < max_s) {
     const char *e;
     const char *e;
-    ms.level = 0;
+    ms.level = ms.depth = 0;
     e = match(&ms, src, p);
     e = match(&ms, src, p);
     if (e) {
     if (e) {
       n++;
       n++;

+ 1 - 0
src/lj_errmsg.h

@@ -91,6 +91,7 @@ ERRDEF(STRPATC,	"invalid pattern capture")
 ERRDEF(STRPATE,	"malformed pattern (ends with " LUA_QL("%") ")")
 ERRDEF(STRPATE,	"malformed pattern (ends with " LUA_QL("%") ")")
 ERRDEF(STRPATM,	"malformed pattern (missing " LUA_QL("]") ")")
 ERRDEF(STRPATM,	"malformed pattern (missing " LUA_QL("]") ")")
 ERRDEF(STRPATU,	"unbalanced pattern")
 ERRDEF(STRPATU,	"unbalanced pattern")
+ERRDEF(STRPATX,	"pattern too complex")
 ERRDEF(STRCAPI,	"invalid capture index")
 ERRDEF(STRCAPI,	"invalid capture index")
 ERRDEF(STRCAPN,	"too many captures")
 ERRDEF(STRCAPN,	"too many captures")
 ERRDEF(STRCAPU,	"unfinished capture")
 ERRDEF(STRCAPU,	"unfinished capture")