浏览代码

Pattern-Matching!
plus several other changes...

Roberto Ierusalimschy 29 年之前
父节点
当前提交
1630c2533a
共有 1 个文件被更改,包括 359 次插入135 次删除
  1. 359 135
      strlib.c

+ 359 - 135
strlib.c

@@ -3,42 +3,68 @@
 ** String library to LUA
 ** String library to LUA
 */
 */
 
 
-char *rcs_strlib="$Id: strlib.c,v 1.23 1996/04/30 21:13:55 roberto Exp roberto $";
+char *rcs_strlib="$Id: strlib.c,v 1.24 1996/05/22 21:59:07 roberto Exp roberto $";
 
 
 #include <string.h>
 #include <string.h>
 #include <stdio.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <stdlib.h>
 #include <ctype.h>
 #include <ctype.h>
-#include <limits.h>
 
 
 #include "lua.h"
 #include "lua.h"
 #include "lualib.h"
 #include "lualib.h"
 
 
 
 
-void lua_arg_error(char *funcname)
+static char *buffer = NULL;
+static size_t maxbuff = 0;
+static size_t buff_size = 0;
+
+
+static char *lua_strbuffer (unsigned long size)
+{
+  if (size > maxbuff) {
+    buffer = (buffer) ? realloc(buffer, maxbuff=size) : malloc(maxbuff=size);
+    if (buffer == NULL)
+      lua_error("memory overflow");
+  }
+  return buffer;
+}
+
+static char *openspace (unsigned long size)
 {
 {
-  char buff[100];
-  sprintf(buff, "incorrect arguments to function `%s'", funcname);
-  lua_error(buff);
+  char *buff = lua_strbuffer(buff_size+size);
+  return buff+buff_size;
+}
+
+void lua_arg_check(int cond, char *funcname)
+{
+  if (!cond) {
+    char buff[100];
+    sprintf(buff, "incorrect argument to function `%s'", funcname);
+    lua_error(buff);
+  }
 }
 }
 
 
 char *lua_check_string (int numArg, char *funcname)
 char *lua_check_string (int numArg, char *funcname)
 {
 {
   lua_Object o = lua_getparam(numArg);
   lua_Object o = lua_getparam(numArg);
-  if (!lua_isstring(o))
-    lua_arg_error(funcname);
+  lua_arg_check(lua_isstring(o), funcname);
   return lua_getstring(o);
   return lua_getstring(o);
 }
 }
 
 
+char *lua_opt_string (int numArg, char *def, char *funcname)
+{
+  return (lua_getparam(numArg) == LUA_NOOBJECT) ? def :
+                              lua_check_string(numArg, funcname);
+}
+
 double lua_check_number (int numArg, char *funcname)
 double lua_check_number (int numArg, char *funcname)
 {
 {
   lua_Object o = lua_getparam(numArg);
   lua_Object o = lua_getparam(numArg);
-  if (!lua_isnumber(o))
-    lua_arg_error(funcname);
+  lua_arg_check(lua_isnumber(o), funcname);
   return lua_getnumber(o);
   return lua_getnumber(o);
 }
 }
 
 
-static long lua_opt_number (int numArg, long def, char *funcname)
+long lua_opt_number (int numArg, long def, char *funcname)
 {
 {
   return (lua_getparam(numArg) == LUA_NOOBJECT) ? def :
   return (lua_getparam(numArg) == LUA_NOOBJECT) ? def :
                               (long)lua_check_number(numArg, funcname);
                               (long)lua_check_number(numArg, funcname);
@@ -46,54 +72,45 @@ static long lua_opt_number (int numArg, long def, char *funcname)
 
 
 char *luaI_addchar (int c)
 char *luaI_addchar (int c)
 {
 {
-  static char *buff = NULL;
-  static size_t max = 0;
-  static size_t n = 0;
-  if (n >= max)
-  {
-    if (max == 0)
-    {
-      max = 100;
-      buff = (char *)malloc(max);
-    }
-    else
-    {
-      max *= 2;
-      buff = (char *)realloc(buff, max);
-    }
-    if (buff == NULL)
-      lua_error("memory overflow");
-  }
-  buff[n++] = c;
+  if (buff_size >= maxbuff)
+    lua_strbuffer(maxbuff == 0 ? 100 : maxbuff*2);
+  buffer[buff_size++] = c;
   if (c == 0)
   if (c == 0)
-    n = 0;  /* prepare for next string */
-  return buff;
+    buff_size = 0;  /* prepare for next string */
+  return buffer;
+}
+
+static void addnchar (char *s, int n)
+{
+  char *b = openspace(n);
+  strncpy(b, s, n);
+  buff_size += n;
 }
 }
 
 
 
 
 /*
 /*
-** Return the position of the first caracter of a substring into a string
-** LUA interface:
-**			n = strfind (string, substring, init, end)
+** Interface to strtok
 */
 */
-static void str_find (void)
+static void str_tok (void)
 {
 {
- char *s1 = lua_check_string(1, "strfind");
- char *s2 = lua_check_string(2, "strfind");
- long init = lua_opt_number(3, 1, "strfind") - 1;
- char *f = (init>=0 && init<=strlen(s1)) ? strstr(s1+init,s2) : NULL;
- if (f != NULL)
- {
-  size_t pos = f-s1+1;
-  if (lua_opt_number(4, LONG_MAX, "strfind") >= pos+strlen(s2)-1)
-   lua_pushnumber (pos);
-  else
-   lua_pushnil();
- }
- else
-  lua_pushnil();
+  char *s1 = lua_check_string(1, "strtok");
+  char *del = lua_check_string(2, "strtok");
+  lua_Object t = lua_createtable();
+  int i = 1;
+  /* As strtok changes s1, and s1 is "constant",  make a copy of it */
+  s1 = strcpy(lua_strbuffer(strlen(s1+1)), s1);
+  while ((s1 = strtok(s1, del)) != NULL) {
+    lua_pushobject(t);
+    lua_pushnumber(i++);
+    lua_pushstring(s1);
+    lua_storesubscript();
+    s1 = NULL;  /* prepare for next strtok */
+  }
+  lua_pushobject(t);
+  lua_pushnumber(i-1);  /* total number of tokens */
 }
 }
 
 
+
 /*
 /*
 ** Return the string length
 ** Return the string length
 ** LUA interface:
 ** LUA interface:
@@ -101,11 +118,9 @@ static void str_find (void)
 */
 */
 static void str_len (void)
 static void str_len (void)
 {
 {
- char *s = lua_check_string(1, "strlen");
- lua_pushnumber(strlen(s));
+ lua_pushnumber(strlen(lua_check_string(1, "strlen")));
 }
 }
 
 
-
 /*
 /*
 ** Return the substring of a string, from start to end
 ** Return the substring of a string, from start to end
 ** LUA interface:
 ** LUA interface:
@@ -113,136 +128,342 @@ static void str_len (void)
 */
 */
 static void str_sub (void)
 static void str_sub (void)
 {
 {
- char *s = lua_check_string(1, "strsub");
- long start = (long)lua_check_number(2, "strsub");
- long end = lua_opt_number(3, strlen(s), "strsub");
- if (end < start || start < 1 || end > strlen(s))
-  lua_pushliteral("");
- else
- {
-   luaI_addchar(0);
-   while (start <= end) 
-     luaI_addchar(s[start++ - 1]);
-   lua_pushstring (luaI_addchar(0));
- }
+  char *s = lua_check_string(1, "strsub");
+  long start = (long)lua_check_number(2, "strsub");
+  long end = lua_opt_number(3, strlen(s), "strsub");
+  if (1 <= start && start <= end && end <= strlen(s)) {
+    luaI_addchar(0);
+    addnchar(s+start-1, end-start+1);
+    lua_pushstring(luaI_addchar(0));
+  }
+  else lua_pushliteral("");
 }
 }
 
 
 /*
 /*
-** Convert a string to lower case.
-** LUA interface:
-**			lowercase = strlower (string)
+** Transliterate a string
 */
 */
-static void str_lower (void)
+static void str_map (void)
 {
 {
-  char *s = lua_check_string(1, "strlower");
-  luaI_addchar(0);
-  while (*s)
-    luaI_addchar(tolower(*s++));
+  char *s = lua_check_string(1, "strmap");
+  char *from = lua_check_string(2, "strmap");
+  char *to = lua_opt_string(3, "", "strmap");
+  long len = strlen(to);
+  for (luaI_addchar(0); *s; s++) {
+    char *f = strrchr(from, *s);
+    if (f == NULL)
+      luaI_addchar(*s);
+    else {
+      long pos = f-from;
+      if (pos < len)
+        luaI_addchar(to[pos]);
+    }
+  }
   lua_pushstring(luaI_addchar(0));
   lua_pushstring(luaI_addchar(0));
 }
 }
 
 
-
 /*
 /*
-** Convert a string to upper case.
-** LUA interface:
-**			uppercase = strupper (string)
+** get ascii value of a character in a string
 */
 */
-static void str_upper (void)
+static void str_ascii (void)
+{
+  char *s = lua_check_string(1, "ascii");
+  long pos = lua_opt_number(2, 1, "ascii") - 1;
+  lua_arg_check(0<=pos && pos<strlen(s), "ascii");
+  lua_pushnumber((unsigned char)s[pos]);
+}
+
+
+/* pattern matching */
+
+#define ESC	'%'
+#define SPECIALS  "^$*?.([%"
+
+static char *item_end (char *p)
+{
+  switch (*p) {
+    case '\0': return p;
+    case ESC:
+      if (*(p+1) == 0) lua_error("incorrect pattern");
+      return p+2;
+    case '[': {
+      char *end = (*(p+1) == 0) ? NULL : strchr(p+2, ']');
+      if (end == NULL) lua_error("incorrect pattern");
+      return end+1;
+    }
+    default:
+      return p+1;
+  }
+}
+
+static int matchclass (int c, int cl)
+{
+  int res;
+  switch (tolower(cl)) {
+    case 'a' : res = isalpha(c); break;
+    case 'c' : res = iscntrl(c); break;
+    case 'd' : res = isdigit(c); break;
+    case 'l' : res = islower(c); break;
+    case 'p' : res = ispunct(c); break;
+    case 's' : res = isspace(c); break;
+    case 'u' : res = isupper(c); break;
+    case 'w' : res = isalnum(c); break;
+    default: return (cl == c);
+  }
+  return (islower(cl) ? res : !res);
+}
+
+static int singlematch (int c, char *p)
+{
+  if (c == 0) return 0;
+  switch (*p) {
+    case '.': return 1;
+    case ESC: return matchclass(c, *(p+1));
+    case '[': {
+      char *end = strchr(p+2, ']');
+      int sig = *(p+1) == '^' ? (p++, 0) : 1;
+      while (++p < end) {
+        if (*p == ESC) {
+          if (((p+1) < end) && matchclass(c, *++p)) return sig;
+        }
+        else if ((*(p+1) == '-') && (p+2 < end)) {
+          p+=2;
+          if (*(p-2) <= c && c <= *p) return sig;
+        }
+        else if (*p == c) return sig;
+      }
+      return !sig;
+    }
+    default: return (*p == c);
+  }
+}
+
+#define MAX_CAPT 9
+
+static struct {
+  char *init;
+  int len;  /* -1 signals unfinished capture */
+} capture[MAX_CAPT];
+
+static int num_captures;  /* only valid after a sucessful call to match */
+
+
+static void push_captures (void)
 {
 {
-  char *s = lua_check_string(1, "strupper");
+  int i;
   luaI_addchar(0);
   luaI_addchar(0);
-  while (*s)
-    luaI_addchar(toupper(*s++));
+  for (i=0; i<num_captures; i++) {
+    if (capture[i].len == -1) lua_error("unfinished capture");
+    addnchar(capture[i].init, capture[i].len);
+    lua_pushstring(luaI_addchar(0));
+  }
+}
+
+static int check_cap (int l, int level)
+{
+  l -= '1';
+  if (!(0 <= l && l < level && capture[l].len != -1))
+    lua_error("invalid capture index");
+  return l;
+}
+
+static void add_s (char *newp)
+{
+  while (*newp) {
+    if (*newp != ESC || !isdigit(*++newp))
+      luaI_addchar(*newp++);
+    else {
+      int l = check_cap(*newp++, num_captures);
+      addnchar(capture[l].init, capture[l].len);
+    }
+  }
+}
+
+static int capture_to_close (int level)
+{
+  for (level--; level>=0; level--)
+    if (capture[level].len == -1) return level;
+  lua_error("invalid pattern capture");
+  return 0;  /* to avoid warnings */
+}
+
+static char *match (char *s, char *p, int level)
+{
+  init: /* using goto's to optimize tail recursion */
+  switch (*p) {
+    case '(':  /* start capture */
+      if (level >= MAX_CAPT) lua_error("too many captures");
+      capture[level].init = s;
+      capture[level].len = -1;
+      level++; p++; goto init;  /* return match(s, p+1, level); */
+    case ')': {  /* end capture */
+      int l = capture_to_close(level);
+      char *res;
+      capture[l].len = s - capture[l].init;  /* close capture */
+      if ((res = match(s, p+1, level)) == NULL) /* match failed? */
+        capture[l].len = -1;  /* undo capture */
+      return res;
+    }
+    case ESC:  /* possibly a capture (if followed by a digit) */
+      if (!isdigit(*(p+1))) goto dflt;
+      else {
+        int l = check_cap(*(p+1), level);
+        if (strncmp(capture[l].init, s, capture[l].len) == 0) {
+          /* return match(p+2, s+capture[l].len, level); */
+          p+=2; s+=capture[l].len; goto init;
+        }
+        else return NULL;
+     }
+    case '\0': case '$':  /* (possibly) end of pattern */
+      if (*p == 0 || (*(p+1) == 0 && *s == 0)) {
+        num_captures = level;
+        return s;
+      }
+      else goto dflt;
+    default: dflt: {  /* it is a pattern item */
+      int m = singlematch(*s, p);
+      char *ep = item_end(p);  /* get what is next */
+      switch (*ep) {
+        case '*': {  /* repetition? */
+          char *res;
+          if (m && (res = match(s+1, p, level)))
+            return res;
+          p=ep+1; goto init;  /* else return match(s, ep+1, level); */
+        }
+        case '?': {  /* optional? */
+          char *res;
+          if (m && (res = match(s+1, ep+1, level)))
+            return res;
+          p=ep+1; goto init;  /* else return match(s, ep+1, level); */
+        }
+        default:
+          if (m) { s++; p=ep; goto init; }  /* return match(s+1, ep, level); */
+          else return NULL;
+      }
+    }
+  }
+}
+
+static void str_find (void)
+{
+  char *s = lua_check_string(1, "find");
+  char *p = lua_check_string(2, "find");
+  long init = lua_opt_number(3, 1, "strfind") - 1;
+  lua_arg_check(0 <= init && init <= strlen(s), "find");
+  if (strpbrk(p, SPECIALS) == NULL) {  /* no special caracters? */
+    char *s2 = strstr(s+init, p);
+    if (s2) {
+      lua_pushnumber(s2-s+1);
+      lua_pushnumber(s2-s+strlen(p));
+    }
+  }
+  else {
+    int anchor = (*p == '^') ? (p++, 1) : 0;
+    char *s1=s+init;
+    do {
+      char *res;
+      if ((res=match(s1, p, 0)) != NULL) {
+        lua_pushnumber(s1-s+1);  /* start */
+        lua_pushnumber(res-s);   /* end */
+        push_captures();
+        return;
+      }
+    } while (*s1++ && !anchor);
+  }
+}
+
+static void str_s (void)
+{
+  char *src = lua_check_string(1, "s");
+  char *p = lua_check_string(2, "s");
+  char *newp = lua_check_string(3, "s");
+  int max_s = lua_opt_number(4, strlen(src), "s");
+  int anchor = (*p == '^') ? (p++, 1) : 0;
+  int n = 0;
+  luaI_addchar(0);
+  while (*src && n < max_s) {
+    char *e;
+    if ((e=match(src, p, 0)) == NULL)
+      luaI_addchar(*src++);
+    else {
+      if (e == src) lua_error("empty pattern in substitution");  /* ??? */
+      add_s(newp);
+      src = e;
+      n++;
+    }
+    if (anchor) break;
+  }
+  addnchar(src, strlen(src));
   lua_pushstring(luaI_addchar(0));
   lua_pushstring(luaI_addchar(0));
+  lua_pushnumber(n);  /* number of substitutions */
 }
 }
 
 
-/*
-** get ascii value of a character in a string
-*/
-static void str_ascii (void)
+static void str_set (void)
 {
 {
-  char *s = lua_check_string(1, "ascii");
-  long pos = lua_opt_number(2, 1, "ascii") - 1;
-  if (pos<0 || pos>=strlen(s))
-    lua_arg_error("ascii");
-  lua_pushnumber(s[pos]);
+  char *item = lua_check_string(1, "strset");
+  int i;
+  lua_arg_check(*item_end(item) == 0, "strset");
+  luaI_addchar(0);
+  for (i=1; i<256; i++)  /* 0 cannot be part of a set */
+    if (singlematch(i, item))
+      luaI_addchar(i);
+  lua_pushstring(luaI_addchar(0));
 }
 }
 
 
+
 void luaI_addquoted (char *s)
 void luaI_addquoted (char *s)
 {
 {
   luaI_addchar('"');
   luaI_addchar('"');
-  for (; *s; s++)
-  {
-    if (*s == '"' || *s == '\\' || *s == '\n')
+  for (; *s; s++) {
+    if (strchr("\"\\\n", *s))
       luaI_addchar('\\');
       luaI_addchar('\\');
     luaI_addchar(*s);
     luaI_addchar(*s);
   }
   }
   luaI_addchar('"');
   luaI_addchar('"');
 }
 }
 
 
-#define MAX_CONVERTION 2000
-#define MAX_FORMAT 50
+#define MAX_FORMAT 200
 
 
 static void str_format (void)
 static void str_format (void)
 {
 {
   int arg = 1;
   int arg = 1;
   char *strfrmt = lua_check_string(arg++, "format");
   char *strfrmt = lua_check_string(arg++, "format");
   luaI_addchar(0);  /* initialize */
   luaI_addchar(0);  /* initialize */
-  while (*strfrmt)
-  {
+  while (*strfrmt) {
     if (*strfrmt != '%')
     if (*strfrmt != '%')
       luaI_addchar(*strfrmt++);
       luaI_addchar(*strfrmt++);
     else if (*++strfrmt == '%')
     else if (*++strfrmt == '%')
       luaI_addchar(*strfrmt++);  /* %% */
       luaI_addchar(*strfrmt++);  /* %% */
-    else
-    { /* format item */
+    else { /* format item */
       char form[MAX_FORMAT];      /* store the format ('%...') */
       char form[MAX_FORMAT];      /* store the format ('%...') */
-      char buff[MAX_CONVERTION];  /* store the formated value */
-      int size = 0;
-      int i = 0;
-      form[i++] = '%';
-      form[i] = *strfrmt++;
-      while (!isalpha(form[i]))
-      {
-        if (isdigit(form[i]))
-        {
-          size = size*10 + form[i]-'0';
-          if (size >= MAX_CONVERTION)
-            lua_error("format size/precision too long in function `format'");
-        }
-        else if (form[i] == '.')
-          size = 0;  /* re-start */
-        if (++i >= MAX_FORMAT)
-            lua_error("bad format in function `format'");
-        form[i] = *strfrmt++;
-      }
-      form[i+1] = 0;  /* ends string */
-      switch (form[i])
-      {
+      char *buff;
+      char *initf = strfrmt-1;  /* -1 to include % */
+      strfrmt = match(strfrmt, "[-+ #]*(%d*)%.?(%d*)", 0);
+      if (capture[0].len > 3 || capture[1].len > 3)  /* < 1000? */
+        lua_error("invalid format (width/precision too long)");
+      strncpy(form, initf, strfrmt-initf+1); /* +1 to include convertion */
+      form[strfrmt-initf+1] = 0;
+      buff = openspace(1000);  /* to store the formated value */
+      switch (*strfrmt++) {
         case 'q':
         case 'q':
           luaI_addquoted(lua_check_string(arg++, "format"));
           luaI_addquoted(lua_check_string(arg++, "format"));
-          buff[0] = '\0';  /* addchar already done */
           break;
           break;
-        case 's':
-        {
+        case 's': {
           char *s = lua_check_string(arg++, "format");
           char *s = lua_check_string(arg++, "format");
-          if (strlen(s) >= MAX_CONVERTION)
-            lua_error("string argument too long in function `format'");
-          sprintf(buff, form, s);
+          buff = openspace(strlen(s));
+          buff_size += sprintf(buff, form, s);
           break;
           break;
         }
         }
         case 'c':  case 'd':  case 'i': case 'o':
         case 'c':  case 'd':  case 'i': case 'o':
         case 'u':  case 'x':  case 'X':
         case 'u':  case 'x':  case 'X':
-          sprintf(buff, form, (int)lua_check_number(arg++, "format"));
+          buff_size += sprintf(buff, form,
+                               (int)lua_check_number(arg++, "format"));
           break;
           break;
         case 'e':  case 'E': case 'f': case 'g':
         case 'e':  case 'E': case 'f': case 'g':
-          sprintf(buff, form, lua_check_number(arg++, "format"));
+          buff_size += sprintf(buff, form, lua_check_number(arg++, "format"));
           break;
           break;
         default:  /* also treat cases 'pnLlh' */
         default:  /* also treat cases 'pnLlh' */
           lua_error("invalid format option in function `format'");
           lua_error("invalid format option in function `format'");
       }
       }
-      for (i=0; buff[i]; i++)  /* move formated value to result */
-        luaI_addchar(buff[i]);
     }
     }
   }
   }
   lua_pushstring(luaI_addchar(0));  /* push the result */
   lua_pushstring(luaI_addchar(0));  /* push the result */
@@ -256,14 +477,17 @@ void luaI_openlib (struct lua_reg *l, int n)
     lua_register(l[i].name, l[i].func);
     lua_register(l[i].name, l[i].func);
 }
 }
 
 
+
 static struct lua_reg strlib[] = {
 static struct lua_reg strlib[] = {
-{"strfind", str_find},
+{"strtok", str_tok},
 {"strlen", str_len},
 {"strlen", str_len},
 {"strsub", str_sub},
 {"strsub", str_sub},
-{"strlower", str_lower},
-{"strupper", str_upper},
+{"strset", str_set},
+{"strmap", str_map},
 {"ascii", str_ascii},
 {"ascii", str_ascii},
-{"format",    str_format}
+{"format", str_format},
+{"strfind", str_find},
+{"s",    str_s}
 };
 };