Sfoglia il codice sorgente

Optimize regex match iterator.

Reuse virtual machine and capture groups between matches.
Jeroen van Rijn 5 mesi fa
parent
commit
a5e513567b

+ 40 - 16
core/text/regex/regex.odin

@@ -70,9 +70,9 @@ An iterator to repeatedly match a pattern against a string, to be used with `*_i
 */
 Match_Iterator :: struct {
 	haystack: string,
-	offset:   int,
 	regex:    Regular_Expression,
 	capture:  Capture,
+	vm:       virtual_machine.Machine,
 	idx:      int,
 	temp:     runtime.Allocator,
 }
@@ -283,10 +283,11 @@ create_iterator :: proc(
 	flags := flags
 	flags += {.Global} // We're iterating over a string, so the next match could start anywhere
 
-	result.haystack = str
-	result.regex    = create(pattern, flags, permanent_allocator, temporary_allocator) or_return
-	result.capture  = preallocate_capture()
-	result.temp     = temporary_allocator
+	result.regex         = create(pattern, flags, permanent_allocator, temporary_allocator) or_return
+	result.capture       = preallocate_capture()
+	result.temp          = temporary_allocator
+	result.vm            = virtual_machine.create(result.regex.program, str)
+	result.vm.class_data = result.regex.class_data
 
 	return
 }
@@ -444,24 +445,47 @@ Returns:
 - ok:     A bool indicating if there was a match, stopping the iteration on `false`.
 */
 match_iterator :: proc(it: ^Match_Iterator) -> (result: Capture, index: int, ok: bool) {
+	assert(len(it.capture.groups) >= common.MAX_CAPTURE_GROUPS,
+		"Pre-allocated RegEx capture `groups` must be at least 10 elements long.")
+	assert(len(it.capture.pos) >= common.MAX_CAPTURE_GROUPS,
+		"Pre-allocated RegEx capture `pos` must be at least 10 elements long.")
+
 	runtime.DEFAULT_TEMP_ALLOCATOR_TEMP_GUARD()
+
+	saved: ^[2 * common.MAX_CAPTURE_GROUPS]int
+	{
+		context.allocator = it.temp
+		if .Unicode in it.regex.flags {
+			saved, ok = virtual_machine.run(&it.vm, true)
+		} else {
+			saved, ok = virtual_machine.run(&it.vm, false)
+		}
+	}
+
+	str := string(it.vm.memory)
 	num_groups: int
-	num_groups, ok = match_with_preallocated_capture(
-		it.regex,
-		it.haystack[it.offset:],
-		&it.capture,
-		it.temp,
-	)
+
+	if saved != nil {
+		n := 0
+
+		#no_bounds_check for i := 0; i < len(saved); i += 2 {
+			a, b := saved[i], saved[i + 1]
+			if a == -1 || b == -1 {
+				continue
+			}
+
+			it.capture.groups[n] = str[a:b]
+			it.capture.pos[n]    = {a, b}
+			n += 1
+		}
+		num_groups = n
+	}
 
 	defer if ok {
 		it.idx += 1
 	}
 
 	if num_groups > 0 {
-		for i in 0..<num_groups {
-			it.capture.pos[i] += it.offset
-		}
-		it.offset = it.capture.pos[0][1]
 		result = {it.capture.pos[:num_groups], it.capture.groups[:num_groups]}
 	}
 	return result, it.idx, ok
@@ -474,7 +498,6 @@ match :: proc {
 }
 
 reset :: proc(it: ^Match_Iterator) {
-	it.offset = 0
 	it.idx    = 0
 }
 
@@ -544,6 +567,7 @@ destroy_iterator :: proc(it: Match_Iterator, allocator := context.allocator) {
 	context.allocator = allocator
 	destroy(it.regex)
 	destroy(it.capture)
+	virtual_machine.destroy(it.vm)
 }
 
 destroy :: proc {

+ 10 - 1
core/text/regex/virtual_machine/virtual_machine.odin

@@ -627,8 +627,9 @@ opcode_count :: proc(code: Program) -> (opcodes: int) {
 	return
 }
 
-create :: proc(code: Program, str: string) -> (vm: Machine) {
+create :: proc(code: Program, str: string, allocator := context.allocator) -> (vm: Machine) {
 	assert(len(code) > 0, "RegEx VM has no instructions.")
+	context.allocator = allocator
 
 	vm.memory = str
 	vm.code = code
@@ -644,3 +645,11 @@ create :: proc(code: Program, str: string) -> (vm: Machine) {
 
 	return
 }
+
+destroy :: proc(vm: Machine, allocator := context.allocator) {
+	context.allocator = allocator
+
+	delete(vm.busy_map)
+	free(vm.threads)
+	free(vm.next_threads)
+}

+ 1 - 2
tests/core/text/regex/test_core_text_regex.odin

@@ -1126,9 +1126,8 @@ test_match_iterator :: proc(t: ^testing.T) {
 		testing.expect_value(t, err, nil)
 		(err == nil) or_continue
 
-		count: int
 		for capture, idx in regex.match(&it) {
-			if count > len(test.expected) {
+			if idx >= len(test.expected) {
 				break
 			}
 			check_capture(t, capture, test.expected[idx])