Procházet zdrojové kódy

Refactored TransformRegExp: fail early, avoid allocation and copy if not necessary.

Dmitry Panov před 4 roky
rodič
revize
ab0d68a9e9
5 změnil soubory, kde provedl 248 přidání a 156 odebrání
  1. 1 1
      builtin_regexp.go
  2. 0 19
      parser/lexer.go
  3. 161 121
      parser/regexp.go
  4. 67 13
      parser/regexp_test.go
  5. 19 2
      regexp_test.go

+ 1 - 1
builtin_regexp.go

@@ -252,7 +252,7 @@ func compileRegexp(patternStr, flags string) (p *regexpPattern, err error) {
 		}
 		wrapper = (*regexpWrapper)(pattern)
 	} else {
-		if re2Str == "" {
+		if _, incompat := err1.(parser.RegexpErrorIncompatible); !incompat {
 			err = err1
 			return
 		}

+ 0 - 19
parser/lexer.go

@@ -464,25 +464,6 @@ func (self *_parser) read() {
 	}
 }
 
-// This is here since the functions are so similar
-func (self *_RegExp_parser) read() {
-	if self.offset < self.length {
-		self.chrOffset = self.offset
-		chr, width := rune(self.str[self.offset]), 1
-		if chr >= utf8.RuneSelf { // !ASCII
-			chr, width = utf8.DecodeRuneInString(self.str[self.offset:])
-			if chr == utf8.RuneError && width == 1 {
-				self.error(self.chrOffset, "Invalid UTF-8 character")
-			}
-		}
-		self.offset += width
-		self.chr = chr
-	} else {
-		self.chrOffset = self.length
-		self.chr = -1 // EOF
-	}
-}
-
 func (self *_parser) skipSingleLineComment() {
 	for self.chr != -1 {
 		self.read()

+ 161 - 121
parser/regexp.go

@@ -1,7 +1,6 @@
 package parser
 
 import (
-	"bytes"
 	"fmt"
 	"strconv"
 	"strings"
@@ -12,6 +11,22 @@ const (
 	WhitespaceChars = " \f\n\r\t\v\u00a0\u1680\u2000\u2001\u2002\u2003\u2004\u2005\u2006\u2007\u2008\u2009\u200a\u2028\u2029\u202f\u205f\u3000\ufeff"
 )
 
+type regexpParseError struct {
+	offset int
+	err    string
+}
+
+type RegexpErrorIncompatible struct {
+	regexpParseError
+}
+type RegexpSyntaxError struct {
+	regexpParseError
+}
+
+func (s regexpParseError) Error() string {
+	return s.err
+}
+
 type _RegExp_parser struct {
 	str    string
 	length int
@@ -20,10 +35,10 @@ type _RegExp_parser struct {
 	chrOffset int  // The offset of current character
 	offset    int  // The offset after current character (may be greater than 1)
 
-	errors  []error
-	invalid bool // The input is an invalid JavaScript RegExp
+	err error
 
-	goRegexp *bytes.Buffer
+	goRegexp   strings.Builder
+	passOffset int
 }
 
 // TransformRegExp transforms a JavaScript pattern into  a Go "regexp" pattern.
@@ -34,36 +49,86 @@ type _RegExp_parser struct {
 // re2 (Go) has a different definition for \s: [\t\n\f\r ].
 // The JavaScript definition, on the other hand, also includes \v, Unicode "Separator, Space", etc.
 //
-// If the pattern is invalid (not valid even in JavaScript), then this function
-// returns the empty string and an error.
-//
 // If the pattern is valid, but incompatible (contains a lookahead or backreference),
-// then this function returns the transformation (a non-empty string) AND an error.
-func TransformRegExp(pattern string) (string, error) {
+// then this function returns an empty string an error of type RegexpErrorIncompatible.
+//
+// If the pattern is invalid (not valid even in JavaScript), then this function
+// returns an empty string and a generic error.
+func TransformRegExp(pattern string) (transformed string, err error) {
 
 	if pattern == "" {
 		return "", nil
 	}
 
-	// TODO If without \, if without (?=, (?!, then another shortcut
-
 	parser := _RegExp_parser{
-		str:      pattern,
-		length:   len(pattern),
-		goRegexp: bytes.NewBuffer(make([]byte, 0, 3*len(pattern)/2)),
-	}
-	parser.read() // Pull in the first character
-	parser.scan()
-	var err error
-	if len(parser.errors) > 0 {
-		err = parser.errors[0]
+		str:    pattern,
+		length: len(pattern),
 	}
-	if parser.invalid {
+	err = parser.parse()
+	if err != nil {
 		return "", err
 	}
 
-	// Might not be re2 compatible, but is still a valid JavaScript RegExp
-	return parser.goRegexp.String(), err
+	return parser.ResultString(), nil
+}
+
+func (self *_RegExp_parser) ResultString() string {
+	if self.passOffset != -1 {
+		return self.str[:self.passOffset]
+	}
+	return self.goRegexp.String()
+}
+
+func (self *_RegExp_parser) parse() (err error) {
+	self.read() // Pull in the first character
+	self.scan()
+	return self.err
+}
+
+func (self *_RegExp_parser) read() {
+	if self.offset < self.length {
+		self.chrOffset = self.offset
+		chr, width := rune(self.str[self.offset]), 1
+		if chr >= utf8.RuneSelf { // !ASCII
+			chr, width = utf8.DecodeRuneInString(self.str[self.offset:])
+			if chr == utf8.RuneError && width == 1 {
+				self.error(true, "Invalid UTF-8 character")
+				return
+			}
+		}
+		self.offset += width
+		self.chr = chr
+	} else {
+		self.chrOffset = self.length
+		self.chr = -1 // EOF
+	}
+}
+
+func (self *_RegExp_parser) stopPassing() {
+	self.goRegexp.Grow(3 * len(self.str) / 2)
+	self.goRegexp.WriteString(self.str[:self.passOffset])
+	self.passOffset = -1
+}
+
+func (self *_RegExp_parser) write(p []byte) {
+	if self.passOffset != -1 {
+		self.stopPassing()
+	}
+	self.goRegexp.Write(p)
+}
+
+func (self *_RegExp_parser) writeByte(b byte) {
+	if self.passOffset != -1 {
+		self.stopPassing()
+	}
+	self.goRegexp.WriteByte(b)
+}
+
+func (self *_RegExp_parser) writeString(s string) {
+	if self.passOffset != -1 {
+		self.stopPassing()
+	}
+	self.goRegexp.WriteString(s)
 }
 
 func (self *_RegExp_parser) scan() {
@@ -78,11 +143,10 @@ func (self *_RegExp_parser) scan() {
 		case '[':
 			self.scanBracket()
 		case ')':
-			self.error(-1, "Unmatched ')'")
-			self.invalid = true
-			self.pass()
+			self.error(true, "Unmatched ')'")
+			return
 		case '.':
-			self.goRegexp.WriteString("[^\\r\\n]")
+			self.writeString("[^\\r\\n]")
 			self.read()
 		default:
 			self.pass()
@@ -98,12 +162,14 @@ func (self *_RegExp_parser) scanGroup() {
 			ch := str[1]
 			switch {
 			case ch == '=' || ch == '!':
-				self.error(-1, "re2: Invalid (%s) <lookahead>", self.str[self.chrOffset:self.chrOffset+2])
+				self.error(false, "re2: Invalid (%s) <lookahead>", self.str[self.chrOffset:self.chrOffset+2])
+				return
 			case ch == '<':
-				self.error(-1, "re2: Invalid (%s) <lookbehind>", self.str[self.chrOffset:self.chrOffset+2])
+				self.error(false, "re2: Invalid (%s) <lookbehind>", self.str[self.chrOffset:self.chrOffset+2])
+				return
 			case ch != ':':
-				self.error(-1, "Invalid group")
-				self.invalid = true
+				self.error(true, "Invalid group")
+				return
 			}
 		}
 	}
@@ -118,7 +184,7 @@ func (self *_RegExp_parser) scanGroup() {
 		case '[':
 			self.scanBracket()
 		case '.':
-			self.goRegexp.WriteString("[^\\r\\n]")
+			self.writeString("[^\\r\\n]")
 			self.read()
 		default:
 			self.pass()
@@ -126,8 +192,7 @@ func (self *_RegExp_parser) scanGroup() {
 		}
 	}
 	if self.chr != ')' {
-		self.error(-1, "Unterminated group")
-		self.invalid = true
+		self.error(true, "Unterminated group")
 		return
 	}
 	self.pass()
@@ -138,14 +203,14 @@ func (self *_RegExp_parser) scanBracket() {
 	str := self.str[self.chrOffset:]
 	if strings.HasPrefix(str, "[]") {
 		// [] -- Empty character class
-		self.goRegexp.WriteString("[^\u0000-uffff]")
+		self.writeString("[^\u0000-\U0001FFFF]")
 		self.offset += 1
 		self.read()
 		return
 	}
 
 	if strings.HasPrefix(str, "[^]") {
-		self.goRegexp.WriteString("[\u0000-\uffff]")
+		self.writeString("[\u0000-\U0001FFFF]")
 		self.offset += 2
 		self.read()
 		return
@@ -163,8 +228,7 @@ func (self *_RegExp_parser) scanBracket() {
 		self.pass()
 	}
 	if self.chr != ']' {
-		self.error(-1, "Unterminated character class")
-		self.invalid = true
+		self.error(true, "Unterminated character class")
 		return
 	}
 	self.pass()
@@ -191,14 +255,12 @@ func (self *_RegExp_parser) scanEscape(inClass bool) {
 			size += 1
 		}
 		if size == 1 { // The number of characters read
-			_, err := self.goRegexp.Write([]byte{'\\', byte(value) + '0'})
-			if err != nil {
-				self.errors = append(self.errors, err)
-			}
 			if value != 0 {
 				// An invalid backreference
-				self.error(-1, "re2: Invalid \\%d <backreference>", value)
+				self.error(false, "re2: Invalid \\%d <backreference>", value)
+				return
 			}
+			self.passString(offset-1, self.chrOffset)
 			return
 		}
 		tmp := []byte{'\\', 'x', '0', 0}
@@ -208,32 +270,12 @@ func (self *_RegExp_parser) scanEscape(inClass bool) {
 			tmp = tmp[0:3]
 		}
 		tmp = strconv.AppendInt(tmp, value, 16)
-		_, err := self.goRegexp.Write(tmp)
-		if err != nil {
-			self.errors = append(self.errors, err)
-		}
+		self.write(tmp)
 		return
 
 	case '8', '9':
-		size := 0
-		for {
-			digit := digitValue(self.chr)
-			if digit >= 10 {
-				// Not a valid digit
-				break
-			}
-			self.read()
-			size += 1
-		}
-		err := self.goRegexp.WriteByte('\\')
-		if err != nil {
-			self.errors = append(self.errors, err)
-		}
-		_, err = self.goRegexp.WriteString(self.str[offset:self.chrOffset])
-		if err != nil {
-			self.errors = append(self.errors, err)
-		}
-		self.error(-1, "re2: Invalid \\%s <backreference>", self.str[offset:self.chrOffset])
+		self.read()
+		self.error(false, "re2: Invalid \\%s <backreference>", self.str[offset:self.chrOffset])
 		return
 
 	case 'x':
@@ -246,10 +288,7 @@ func (self *_RegExp_parser) scanEscape(inClass bool) {
 
 	case 'b':
 		if inClass {
-			_, err := self.goRegexp.Write([]byte{'\\', 'x', '0', '8'})
-			if err != nil {
-				self.errors = append(self.errors, err)
-			}
+			self.write([]byte{'\\', 'x', '0', '8'})
 			self.read()
 			return
 		}
@@ -267,25 +306,19 @@ func (self *_RegExp_parser) scanEscape(inClass bool) {
 		fallthrough
 
 	case 'f', 'n', 'r', 't', 'v':
-		err := self.goRegexp.WriteByte('\\')
-		if err != nil {
-			self.errors = append(self.errors, err)
-		}
-		self.pass()
+		self.passString(offset-1, self.offset)
+		self.read()
 		return
 
 	case 'c':
 		self.read()
 		var value int64
 		if 'a' <= self.chr && self.chr <= 'z' {
-			value = int64(self.chr) - 'a' + 1
+			value = int64(self.chr - 'a' + 1)
 		} else if 'A' <= self.chr && self.chr <= 'Z' {
-			value = int64(self.chr) - 'A' + 1
+			value = int64(self.chr - 'A' + 1)
 		} else {
-			err := self.goRegexp.WriteByte('c')
-			if err != nil {
-				self.errors = append(self.errors, err)
-			}
+			self.writeByte('c')
 			return
 		}
 		tmp := []byte{'\\', 'x', '0', 0}
@@ -295,26 +328,23 @@ func (self *_RegExp_parser) scanEscape(inClass bool) {
 			tmp = tmp[0:3]
 		}
 		tmp = strconv.AppendInt(tmp, value, 16)
-		_, err := self.goRegexp.Write(tmp)
-		if err != nil {
-			self.errors = append(self.errors, err)
-		}
+		self.write(tmp)
 		self.read()
 		return
 	case 's':
 		if inClass {
-			self.goRegexp.WriteString(WhitespaceChars)
+			self.writeString(WhitespaceChars)
 		} else {
-			self.goRegexp.WriteString("[" + WhitespaceChars + "]")
+			self.writeString("[" + WhitespaceChars + "]")
 		}
 		self.read()
 		return
 	case 'S':
 		if inClass {
-			self.error(self.chrOffset, "S in class")
+			self.error(false, "S in class")
 			return
 		} else {
-			self.goRegexp.WriteString("[^" + WhitespaceChars + "]")
+			self.writeString("[^" + WhitespaceChars + "]")
 		}
 		self.read()
 		return
@@ -323,10 +353,9 @@ func (self *_RegExp_parser) scanEscape(inClass bool) {
 		// a special case for it here
 		if self.chr == '$' || self.chr < utf8.RuneSelf && !isIdentifierPart(self.chr) {
 			// A non-identifier character needs escaping
-			err := self.goRegexp.WriteByte('\\')
-			if err != nil {
-				self.errors = append(self.errors, err)
-			}
+			self.passString(offset-1, self.offset)
+			self.read()
+			return
 		}
 		// Unescape the character for re2
 		self.pass()
@@ -351,7 +380,7 @@ func (self *_RegExp_parser) scanEscape(inClass bool) {
 	}
 
 	if length == 4 {
-		_, err := self.goRegexp.Write([]byte{
+		self.write([]byte{
 			'\\',
 			'x',
 			'{',
@@ -361,47 +390,58 @@ func (self *_RegExp_parser) scanEscape(inClass bool) {
 			self.str[valueOffset+3],
 			'}',
 		})
-		if err != nil {
-			self.errors = append(self.errors, err)
-		}
 	} else if length == 2 {
-		_, err := self.goRegexp.Write([]byte{
-			'\\',
-			'x',
-			self.str[valueOffset+0],
-			self.str[valueOffset+1],
-		})
-		if err != nil {
-			self.errors = append(self.errors, err)
-		}
+		self.passString(offset-1, valueOffset+2)
 	} else {
 		// Should never, ever get here...
-		self.error(-1, "re2: Illegal branch in scanEscape")
-		goto skip
+		self.error(true, "re2: Illegal branch in scanEscape")
+		return
 	}
 
 	return
 
 skip:
-	_, err := self.goRegexp.WriteString(self.str[offset:self.chrOffset])
-	if err != nil {
-		self.errors = append(self.errors, err)
-	}
+	self.passString(offset, self.chrOffset)
 }
 
 func (self *_RegExp_parser) pass() {
-	if self.chr != -1 {
-		_, err := self.goRegexp.WriteRune(self.chr)
-		if err != nil {
-			self.errors = append(self.errors, err)
+	if self.passOffset == self.chrOffset {
+		self.passOffset = self.offset
+	} else {
+		if self.passOffset != -1 {
+			self.stopPassing()
+		}
+		if self.chr != -1 {
+			self.goRegexp.WriteRune(self.chr)
 		}
 	}
 	self.read()
 }
 
-// TODO Better error reporting, use the offset, etc.
-func (self *_RegExp_parser) error(offset int, msg string, msgValues ...interface{}) error {
-	err := fmt.Errorf(msg, msgValues...)
-	self.errors = append(self.errors, err)
-	return err
+func (self *_RegExp_parser) passString(start, end int) {
+	if self.passOffset == start {
+		self.passOffset = end
+		return
+	}
+	if self.passOffset != -1 {
+		self.stopPassing()
+	}
+	self.goRegexp.WriteString(self.str[start:end])
+}
+
+func (self *_RegExp_parser) error(fatal bool, msg string, msgValues ...interface{}) {
+	if self.err != nil {
+		return
+	}
+	e := regexpParseError{
+		offset: self.offset,
+		err:    fmt.Sprintf(msg, msgValues...),
+	}
+	if fatal {
+		self.err = RegexpSyntaxError{e}
+	} else {
+		self.err = RegexpErrorIncompatible{e}
+	}
+	self.offset = self.length
+	self.chr = -1
 }

+ 67 - 13
parser/regexp_test.go

@@ -11,6 +11,8 @@ func TestRegExp(t *testing.T) {
 			// err
 			test := func(input string, expect interface{}) {
 				_, err := TransformRegExp(input)
+				_, incompat := err.(RegexpErrorIncompatible)
+				is(incompat, false)
 				is(err, expect)
 			}
 
@@ -21,29 +23,32 @@ func TestRegExp(t *testing.T) {
 			test("\\(?=)", "Unmatched ')'")
 
 			test(")", "Unmatched ')'")
+			test("0:(?)", "Invalid group")
+			test("(?)", "Invalid group")
+			test("(?U)", "Invalid group")
+			test("(?)|(?i)", "Invalid group")
+			test("(?P<w>)(?P<w>)(?P<D>)", "Invalid group")
 		}
 
 		{
-			// err
-			test := func(input, expect string, expectErr interface{}) {
-				output, err := TransformRegExp(input)
-				is(output, expect)
+			// incompatible
+			test := func(input string, expectErr interface{}) {
+				_, err := TransformRegExp(input)
+				_, incompat := err.(RegexpErrorIncompatible)
+				is(incompat, true)
 				is(err, expectErr)
 			}
 
-			test(")", "", "Unmatched ')'")
+			test(`<%([\s\S]+?)%>`, "S in class")
+
+			test("(?<=y)x", "re2: Invalid (?<) <lookbehind>")
 
-			test("\\0", "\\0", nil)
+			test(`(?!test)`, "re2: Invalid (?!) <lookahead>")
 
-			test("0:(?)", "", "Invalid group")
-			test("(?)", "", "Invalid group")
-			test("(?U)", "", "Invalid group")
-			test("(?)|(?i)", "", "Invalid group")
-			test("(?P<w>)(?P<w>)(?P<D>)", "", "Invalid group")
+			test(`\1`, "re2: Invalid \\1 <backreference>")
 
-			test(`<%([\s\S]+?)%>`, `<%([`+WhitespaceChars+`S]+?)%>`, "S in class")
+			test(`\8`, "re2: Invalid \\8 <backreference>")
 
-			test("(?<=y)x", "(?<=y)x", "re2: Invalid (?<) <lookbehind>")
 		}
 
 		{
@@ -51,6 +56,8 @@ func TestRegExp(t *testing.T) {
 			test := func(input string, expect string) {
 				result, err := TransformRegExp(input)
 				is(err, nil)
+				_, incompat := err.(RegexpErrorIncompatible)
+				is(incompat, false)
 				is(result, expect)
 				_, err = regexp.Compile(result)
 				is(err, nil)
@@ -106,6 +113,8 @@ func TestRegExp(t *testing.T) {
 
 			test("\\175", "\\x7d")
 
+			test("\\0", "\\0")
+
 			test("\\04", "\\x04")
 
 			test(`(.)^`, "([^\\r\\n])^")
@@ -115,6 +124,27 @@ func TestRegExp(t *testing.T) {
 			test(`[G-b]`, `[G-b]`)
 
 			test(`[G-b\0]`, `[G-b\0]`)
+
+			test(`\k`, `k`)
+
+			test(`\x20`, `\x20`)
+
+			test(`😊`, `😊`)
+
+			test(`^.*`, `^[^\r\n]*`)
+
+			test(`(\n)`, `(\n)`)
+
+			test(`(a(bc))`, `(a(bc))`)
+
+			test(`[]`, "[^\u0000-\U0001FFFF]")
+
+			test(`[^]`, "[\u0000-\U0001FFFF]")
+
+			test(`\s+`, "["+WhitespaceChars+"]+")
+
+			test(`\S+`, "[^"+WhitespaceChars+"]+")
+
 		}
 	})
 }
@@ -123,7 +153,31 @@ func TestTransformRegExp(t *testing.T) {
 	tt(t, func() {
 		pattern, err := TransformRegExp(`\s+abc\s+`)
 		is(err, nil)
+		_, incompat := err.(RegexpErrorIncompatible)
+		is(incompat, false)
 		is(pattern, `[`+WhitespaceChars+`]+abc[`+WhitespaceChars+`]+`)
 		is(regexp.MustCompile(pattern).MatchString("\t abc def"), true)
 	})
 }
+
+func BenchmarkTransformRegExp(b *testing.B) {
+	f := func(reStr string, b *testing.B) {
+		b.ResetTimer()
+		b.ReportAllocs()
+		for i := 0; i < b.N; i++ {
+			_, _ = TransformRegExp(reStr)
+		}
+	}
+
+	b.Run("Re", func(b *testing.B) {
+		f(`^(([^<>()\[\]\\.,;:\s@"]+(\.[^<>()\[\]\\.,;:\s@"]+)*)|(".+"))@((\[[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}])|(([a-zA-Z\-0-9]+\.)+[a-zA-Z]{2,}))$`, b)
+	})
+
+	b.Run("Re2-1", func(b *testing.B) {
+		f(`(?=)^(([^<>()\[\]\\.,;:\s@"]+(\.[^<>()\[\]\\.,;:\s@"]+)*)|(".+"))@((\[[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}])|(([a-zA-Z\-0-9]+\.)+[a-zA-Z]{2,}))$`, b)
+	})
+
+	b.Run("Re2-1", func(b *testing.B) {
+		f(`^(([^<>()\[\]\\.,;:\s@"]+(\.[^<>()\[\]\\.,;:\s@"]+)*)|(".+"))@((\[[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}])|(([a-zA-Z\-0-9]+\.)+[a-zA-Z]{2,}))$(?=)`, b)
+	})
+}

+ 19 - 2
regexp_test.go

@@ -158,7 +158,7 @@ func TestRegexpSInClass(t *testing.T) {
 	testScript1(SCRIPT, valueFalse, t)
 }
 
-func TestRegexpDotMatchSlashR(t *testing.T) {
+func TestRegexpDotMatchCR(t *testing.T) {
 	const SCRIPT = `
 	/./.test("\r");
 	`
@@ -166,7 +166,7 @@ func TestRegexpDotMatchSlashR(t *testing.T) {
 	testScript1(SCRIPT, valueFalse, t)
 }
 
-func TestRegexpDotMatchSlashRInGroup(t *testing.T) {
+func TestRegexpDotMatchCRInGroup(t *testing.T) {
 	const SCRIPT = `
 	/(.)/.test("\r");
 	`
@@ -174,6 +174,14 @@ func TestRegexpDotMatchSlashRInGroup(t *testing.T) {
 	testScript1(SCRIPT, valueFalse, t)
 }
 
+func TestRegexpDotMatchLF(t *testing.T) {
+	const SCRIPT = `
+	/./.test("\n");
+	`
+
+	testScript1(SCRIPT, valueFalse, t)
+}
+
 func TestRegexpSplitWithBackRef(t *testing.T) {
 	const SCRIPT = `
 	"a++b+-c".split(/([+-])\1/).join(" $$ ")
@@ -511,6 +519,15 @@ func TestRegexpLookbehindAssertion(t *testing.T) {
 	testScript1(TESTLIB+SCRIPT, _undefined, t)
 }
 
+func TestRegexpInvalidUTF8(t *testing.T) {
+	vm := New()
+	// Note that normally vm.ToValue() would replace invalid UTF-8 sequences with RuneError
+	_, err := vm.New(vm.Get("RegExp"), asciiString([]byte{0xAD}))
+	if err == nil {
+		t.Fatal("Expected error")
+	}
+}
+
 func BenchmarkRegexpSplitWithBackRef(b *testing.B) {
 	const SCRIPT = `
 	"aaaaaaaaaaaaaaaaaaaaaaaaa++bbbbbbbbbbbbbbbbbbbbbb+-ccccccccccccccccccccccc".split(/([+-])\1/)