123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646 |
- package regex_vm
- /*
- (c) Copyright 2024 Feoramund <[email protected]>.
- Made available under Odin's BSD-3 license.
- List of contributors:
- Feoramund: Initial implementation.
- */
- import "base:intrinsics"
- @require import "core:io"
- import "core:slice"
- import "core:text/regex/common"
- import "core:text/regex/parser"
- import "core:unicode/utf8"
- Rune_Class_Range :: parser.Rune_Class_Range
- // NOTE: This structure differs intentionally from the one in `regex/parser`,
- // as this data doesn't need to be a dynamic array once it hits the VM.
- Rune_Class_Data :: struct {
- runes: []rune,
- ranges: []Rune_Class_Range,
- }
- Opcode :: enum u8 {
- // | [ operands ]
- Match = 0x00, // |
- Match_And_Exit = 0x01, // |
- Byte = 0x02, // | u8
- Rune = 0x03, // | i32
- Rune_Class = 0x04, // | u8
- Rune_Class_Negated = 0x05, // | u8
- Wildcard = 0x06, // |
- Jump = 0x07, // | u16
- Split = 0x08, // | u16, u16
- Save = 0x09, // | u8
- Assert_Start = 0x0A, // |
- Assert_End = 0x0B, // |
- Assert_Word_Boundary = 0x0C, // |
- Assert_Non_Word_Boundary = 0x0D, // |
- Multiline_Open = 0x0E, // |
- Multiline_Close = 0x0F, // |
- Wait_For_Byte = 0x10, // | u8
- Wait_For_Rune = 0x11, // | i32
- Wait_For_Rune_Class = 0x12, // | u8
- Wait_For_Rune_Class_Negated = 0x13, // | u8
- Match_All_And_Escape = 0x14, // |
- }
- Thread :: struct {
- pc: int,
- saved: ^[2 * common.MAX_CAPTURE_GROUPS]int,
- }
- Program :: []Opcode
- Machine :: struct {
- // Program state
- memory: string,
- class_data: []Rune_Class_Data,
- code: Program,
- // Thread state
- top_thread: int,
- threads: [^]Thread,
- next_threads: [^]Thread,
- // The busy map is used to merge threads based on their program counters.
- busy_map: []u64,
- // Global state
- string_pointer: int,
- current_rune: rune,
- current_rune_size: int,
- next_rune: rune,
- next_rune_size: int,
- }
- // @MetaCharacter
- // NOTE: This must be kept in sync with the compiler & tokenizer.
- is_word_class :: #force_inline proc "contextless" (r: rune) -> bool {
- switch r {
- case '0'..='9', 'A'..='Z', '_', 'a'..='z':
- return true
- case:
- return false
- }
- }
- set_busy_map :: #force_inline proc "contextless" (vm: ^Machine, pc: int) -> bool #no_bounds_check {
- slot := cast(u64)pc >> 6
- bit: u64 = 1 << (cast(u64)pc & 0x3F)
- if vm.busy_map[slot] & bit > 0 {
- return false
- }
- vm.busy_map[slot] |= bit
- return true
- }
- check_busy_map :: #force_inline proc "contextless" (vm: ^Machine, pc: int) -> bool #no_bounds_check {
- slot := cast(u64)pc >> 6
- bit: u64 = 1 << (cast(u64)pc & 0x3F)
- return vm.busy_map[slot] & bit > 0
- }
- add_thread :: proc(vm: ^Machine, saved: ^[2 * common.MAX_CAPTURE_GROUPS]int, pc: int) #no_bounds_check {
- if check_busy_map(vm, pc) {
- return
- }
- saved := saved
- pc := pc
- resolution_loop: for {
- if !set_busy_map(vm, pc) {
- return
- }
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "Thread [PC:")
- common.write_padded_hex(common.debug_stream, pc, 4)
- io.write_string(common.debug_stream, "] thinking about ")
- io.write_string(common.debug_stream, opcode_to_name(vm.code[pc]))
- io.write_rune(common.debug_stream, '\n')
- }
- #partial switch vm.code[pc] {
- case .Jump:
- pc = cast(int)intrinsics.unaligned_load(cast(^u16)&vm.code[pc + size_of(Opcode)])
- continue
- case .Split:
- jmp_x := cast(int)intrinsics.unaligned_load(cast(^u16)&vm.code[pc + size_of(Opcode)])
- jmp_y := cast(int)intrinsics.unaligned_load(cast(^u16)&vm.code[pc + size_of(Opcode) + size_of(u16)])
- add_thread(vm, saved, jmp_x)
- pc = jmp_y
- continue
- case .Save:
- new_saved := new([2 * common.MAX_CAPTURE_GROUPS]int)
- new_saved ^= saved^
- saved = new_saved
- index := vm.code[pc + size_of(Opcode)]
- sp := vm.string_pointer+vm.current_rune_size
- saved[index] = sp
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "Thread [PC:")
- common.write_padded_hex(common.debug_stream, pc, 4)
- io.write_string(common.debug_stream, "] saving state: (slot ")
- io.write_int(common.debug_stream, cast(int)index)
- io.write_string(common.debug_stream, " = ")
- io.write_int(common.debug_stream, sp)
- io.write_string(common.debug_stream, ")\n")
- }
- pc += size_of(Opcode) + size_of(u8)
- continue
- case .Assert_Start:
- sp := vm.string_pointer+vm.current_rune_size
- if sp == 0 {
- pc += size_of(Opcode)
- continue
- }
- case .Assert_End:
- sp := vm.string_pointer+vm.current_rune_size
- if sp == len(vm.memory) {
- pc += size_of(Opcode)
- continue
- }
- case .Multiline_Open:
- sp := vm.string_pointer+vm.current_rune_size
- if sp == 0 || sp == len(vm.memory) {
- if vm.next_rune == '\r' || vm.next_rune == '\n' {
- // The VM is currently on a newline at the string boundary,
- // so consume the newline next frame.
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "*** New thread added [PC:")
- common.write_padded_hex(common.debug_stream, pc, 4)
- io.write_string(common.debug_stream, "]\n")
- }
- vm.next_threads[vm.top_thread] = Thread{ pc = pc, saved = saved }
- vm.top_thread += 1
- } else {
- // Skip the `Multiline_Close` opcode.
- pc += 2 * size_of(Opcode)
- continue
- }
- } else {
- // Not on a string boundary.
- // Try to consume a newline next frame in the other opcode loop.
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "*** New thread added [PC:")
- common.write_padded_hex(common.debug_stream, pc, 4)
- io.write_string(common.debug_stream, "]\n")
- }
- vm.next_threads[vm.top_thread] = Thread{ pc = pc, saved = saved }
- vm.top_thread += 1
- }
- case .Assert_Word_Boundary:
- sp := vm.string_pointer+vm.current_rune_size
- if sp == 0 || sp == len(vm.memory) {
- pc += size_of(Opcode)
- continue
- } else {
- last_rune_is_wc := is_word_class(vm.current_rune)
- this_rune_is_wc := is_word_class(vm.next_rune)
- if last_rune_is_wc && !this_rune_is_wc || !last_rune_is_wc && this_rune_is_wc {
- pc += size_of(Opcode)
- continue
- }
- }
- case .Assert_Non_Word_Boundary:
- sp := vm.string_pointer+vm.current_rune_size
- if sp != 0 && sp != len(vm.memory) {
- last_rune_is_wc := is_word_class(vm.current_rune)
- this_rune_is_wc := is_word_class(vm.next_rune)
- if last_rune_is_wc && this_rune_is_wc || !last_rune_is_wc && !this_rune_is_wc {
- pc += size_of(Opcode)
- continue
- }
- }
- case .Wait_For_Byte:
- operand := cast(rune)vm.code[pc + size_of(Opcode)]
- if vm.next_rune == operand {
- add_thread(vm, saved, pc + size_of(Opcode) + size_of(u8))
- }
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "*** New thread added [PC:")
- common.write_padded_hex(common.debug_stream, pc, 4)
- io.write_string(common.debug_stream, "]\n")
- }
- vm.next_threads[vm.top_thread] = Thread{ pc = pc, saved = saved }
- vm.top_thread += 1
- case .Wait_For_Rune:
- operand := intrinsics.unaligned_load(cast(^rune)&vm.code[pc + size_of(Opcode)])
- if vm.next_rune == operand {
- add_thread(vm, saved, pc + size_of(Opcode) + size_of(rune))
- }
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "*** New thread added [PC:")
- common.write_padded_hex(common.debug_stream, pc, 4)
- io.write_string(common.debug_stream, "]\n")
- }
- vm.next_threads[vm.top_thread] = Thread{ pc = pc, saved = saved }
- vm.top_thread += 1
- case .Wait_For_Rune_Class:
- operand := cast(u8)vm.code[pc + size_of(Opcode)]
- class_data := vm.class_data[operand]
- next_rune := vm.next_rune
- check: {
- for r in class_data.runes {
- if next_rune == r {
- add_thread(vm, saved, pc + size_of(Opcode) + size_of(u8))
- break check
- }
- }
- for range in class_data.ranges {
- if range.lower <= next_rune && next_rune <= range.upper {
- add_thread(vm, saved, pc + size_of(Opcode) + size_of(u8))
- break check
- }
- }
- }
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "*** New thread added [PC:")
- common.write_padded_hex(common.debug_stream, pc, 4)
- io.write_string(common.debug_stream, "]\n")
- }
- vm.next_threads[vm.top_thread] = Thread{ pc = pc, saved = saved }
- vm.top_thread += 1
- case .Wait_For_Rune_Class_Negated:
- operand := cast(u8)vm.code[pc + size_of(Opcode)]
- class_data := vm.class_data[operand]
- next_rune := vm.next_rune
- check_negated: {
- for r in class_data.runes {
- if next_rune == r {
- break check_negated
- }
- }
- for range in class_data.ranges {
- if range.lower <= next_rune && next_rune <= range.upper {
- break check_negated
- }
- }
- add_thread(vm, saved, pc + size_of(Opcode) + size_of(u8))
- }
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "*** New thread added [PC:")
- common.write_padded_hex(common.debug_stream, pc, 4)
- io.write_string(common.debug_stream, "]\n")
- }
- vm.next_threads[vm.top_thread] = Thread{ pc = pc, saved = saved }
- vm.top_thread += 1
- case:
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "*** New thread added [PC:")
- common.write_padded_hex(common.debug_stream, pc, 4)
- io.write_string(common.debug_stream, "]\n")
- }
- vm.next_threads[vm.top_thread] = Thread{ pc = pc, saved = saved }
- vm.top_thread += 1
- }
- break resolution_loop
- }
- return
- }
- run :: proc(vm: ^Machine, $UNICODE_MODE: bool) -> (saved: ^[2 * common.MAX_CAPTURE_GROUPS]int, ok: bool) #no_bounds_check {
- when UNICODE_MODE {
- vm.next_rune, vm.next_rune_size = utf8.decode_rune_in_string(vm.memory)
- } else {
- if len(vm.memory) > 0 {
- vm.next_rune = cast(rune)vm.memory[0]
- vm.next_rune_size = 1
- }
- }
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "### Adding initial thread.\n")
- }
- {
- starter_saved := new([2 * common.MAX_CAPTURE_GROUPS]int)
- starter_saved ^= -1
- add_thread(vm, starter_saved, 0)
- }
- // `add_thread` adds to `next_threads` by default, but we need to put this
- // thread in the current thread buffer.
- vm.threads, vm.next_threads = vm.next_threads, vm.threads
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "### VM starting.\n")
- defer io.write_string(common.debug_stream, "### VM finished.\n")
- }
- for {
- slice.zero(vm.busy_map[:])
- assert(vm.string_pointer <= len(vm.memory), "VM string pointer went out of bounds.")
- current_rune := vm.next_rune
- vm.current_rune = current_rune
- vm.current_rune_size = vm.next_rune_size
- when UNICODE_MODE {
- vm.next_rune, vm.next_rune_size = utf8.decode_rune_in_string(vm.memory[vm.string_pointer+vm.current_rune_size:])
- } else {
- if vm.string_pointer+size_of(u8) < len(vm.memory) {
- vm.next_rune = cast(rune)vm.memory[vm.string_pointer+size_of(u8)]
- vm.next_rune_size = size_of(u8)
- } else {
- vm.next_rune = 0
- vm.next_rune_size = 0
- }
- }
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, ">>> Dispatching rune: ")
- io.write_encoded_rune(common.debug_stream, current_rune)
- io.write_byte(common.debug_stream, '\n')
- }
- thread_count := vm.top_thread
- vm.top_thread = 0
- thread_loop: for i := 0; i < thread_count; i += 1 {
- t := vm.threads[i]
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "Thread [PC:")
- common.write_padded_hex(common.debug_stream, t.pc, 4)
- io.write_string(common.debug_stream, "] stepping on ")
- io.write_string(common.debug_stream, opcode_to_name(vm.code[t.pc]))
- io.write_byte(common.debug_stream, '\n')
- }
- #partial opcode: switch vm.code[t.pc] {
- case .Match:
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "Thread matched!\n")
- }
- saved = t.saved
- ok = true
- break thread_loop
- case .Match_And_Exit:
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "Thread matched! (Exiting)\n")
- }
- return nil, true
- case .Byte:
- operand := cast(rune)vm.code[t.pc + size_of(Opcode)]
- if current_rune == operand {
- add_thread(vm, t.saved, t.pc + size_of(Opcode) + size_of(u8))
- }
- case .Rune:
- operand := intrinsics.unaligned_load(cast(^rune)&vm.code[t.pc + size_of(Opcode)])
- if current_rune == operand {
- add_thread(vm, t.saved, t.pc + size_of(Opcode) + size_of(rune))
- }
- case .Rune_Class:
- operand := cast(u8)vm.code[t.pc + size_of(Opcode)]
- class_data := vm.class_data[operand]
- for r in class_data.runes {
- if current_rune == r {
- add_thread(vm, t.saved, t.pc + size_of(Opcode) + size_of(u8))
- break opcode
- }
- }
- for range in class_data.ranges {
- if range.lower <= current_rune && current_rune <= range.upper {
- add_thread(vm, t.saved, t.pc + size_of(Opcode) + size_of(u8))
- break opcode
- }
- }
- case .Rune_Class_Negated:
- operand := cast(u8)vm.code[t.pc + size_of(Opcode)]
- class_data := vm.class_data[operand]
- for r in class_data.runes {
- if current_rune == r {
- break opcode
- }
- }
- for range in class_data.ranges {
- if range.lower <= current_rune && current_rune <= range.upper {
- break opcode
- }
- }
- add_thread(vm, t.saved, t.pc + size_of(Opcode) + size_of(u8))
- case .Wildcard:
- add_thread(vm, t.saved, t.pc + size_of(Opcode))
- case .Multiline_Open:
- if current_rune == '\n' {
- // UNIX newline.
- add_thread(vm, t.saved, t.pc + 2 * size_of(Opcode))
- } else if current_rune == '\r' {
- if vm.next_rune == '\n' {
- // Windows newline. (1/2)
- add_thread(vm, t.saved, t.pc + size_of(Opcode))
- } else {
- // Mac newline.
- add_thread(vm, t.saved, t.pc + 2 * size_of(Opcode))
- }
- }
- case .Multiline_Close:
- if current_rune == '\n' {
- // Windows newline. (2/2)
- add_thread(vm, t.saved, t.pc + size_of(Opcode))
- }
- case .Wait_For_Byte:
- operand := cast(rune)vm.code[t.pc + size_of(Opcode)]
- if vm.next_rune == operand {
- add_thread(vm, t.saved, t.pc + size_of(Opcode) + size_of(u8))
- }
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "*** New thread added [PC:")
- common.write_padded_hex(common.debug_stream, t.pc, 4)
- io.write_string(common.debug_stream, "]\n")
- }
- vm.next_threads[vm.top_thread] = Thread{ pc = t.pc, saved = t.saved }
- vm.top_thread += 1
- case .Wait_For_Rune:
- operand := intrinsics.unaligned_load(cast(^rune)&vm.code[t.pc + size_of(Opcode)])
- if vm.next_rune == operand {
- add_thread(vm, t.saved, t.pc + size_of(Opcode) + size_of(rune))
- }
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "*** New thread added [PC:")
- common.write_padded_hex(common.debug_stream, t.pc, 4)
- io.write_string(common.debug_stream, "]\n")
- }
- vm.next_threads[vm.top_thread] = Thread{ pc = t.pc, saved = t.saved }
- vm.top_thread += 1
- case .Wait_For_Rune_Class:
- operand := cast(u8)vm.code[t.pc + size_of(Opcode)]
- class_data := vm.class_data[operand]
- next_rune := vm.next_rune
- check: {
- for r in class_data.runes {
- if next_rune == r {
- add_thread(vm, t.saved, t.pc + size_of(Opcode) + size_of(u8))
- break check
- }
- }
- for range in class_data.ranges {
- if range.lower <= next_rune && next_rune <= range.upper {
- add_thread(vm, t.saved, t.pc + size_of(Opcode) + size_of(u8))
- break check
- }
- }
- }
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "*** New thread added [PC:")
- common.write_padded_hex(common.debug_stream, t.pc, 4)
- io.write_string(common.debug_stream, "]\n")
- }
- vm.next_threads[vm.top_thread] = Thread{ pc = t.pc, saved = t.saved }
- vm.top_thread += 1
- case .Wait_For_Rune_Class_Negated:
- operand := cast(u8)vm.code[t.pc + size_of(Opcode)]
- class_data := vm.class_data[operand]
- next_rune := vm.next_rune
- check_negated: {
- for r in class_data.runes {
- if next_rune == r {
- break check_negated
- }
- }
- for range in class_data.ranges {
- if range.lower <= next_rune && next_rune <= range.upper {
- break check_negated
- }
- }
- add_thread(vm, t.saved, t.pc + size_of(Opcode) + size_of(u8))
- }
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "*** New thread added [PC:")
- common.write_padded_hex(common.debug_stream, t.pc, 4)
- io.write_string(common.debug_stream, "]\n")
- }
- vm.next_threads[vm.top_thread] = Thread{ pc = t.pc, saved = t.saved }
- vm.top_thread += 1
- case .Match_All_And_Escape:
- t.pc += size_of(Opcode)
- // The point of this loop is to walk out of wherever this
- // opcode lives to the end of the program, while saving the
- // index to the length of the string at each pass on the way.
- escape_loop: for {
- #partial switch vm.code[t.pc] {
- case .Match, .Match_And_Exit:
- break escape_loop
- case .Jump:
- t.pc = cast(int)intrinsics.unaligned_load(cast(^u16)&vm.code[t.pc + size_of(Opcode)])
- case .Save:
- index := vm.code[t.pc + size_of(Opcode)]
- t.saved[index] = len(vm.memory)
- t.pc += size_of(Opcode) + size_of(u8)
- case .Match_All_And_Escape:
- // Layering these is fine.
- t.pc += size_of(Opcode)
- // If the loop has to process any opcode not listed above,
- // it means someone did something odd like `a(.*$)b`, in
- // which case, just fail. Technically, the expression makes
- // no sense.
- case:
- break opcode
- }
- }
- saved = t.saved
- ok = true
- return
- case:
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "Opcode: ")
- io.write_int(common.debug_stream, cast(int)vm.code[t.pc])
- io.write_string(common.debug_stream, "\n")
- }
- panic("Invalid opcode in RegEx thread loop.")
- }
- }
- vm.threads, vm.next_threads = vm.next_threads, vm.threads
- when common.ODIN_DEBUG_REGEX {
- io.write_string(common.debug_stream, "<<< Frame ended. (Threads: ")
- io.write_int(common.debug_stream, vm.top_thread)
- io.write_string(common.debug_stream, ")\n")
- }
- if vm.string_pointer == len(vm.memory) || vm.top_thread == 0 {
- break
- }
- vm.string_pointer += vm.current_rune_size
- }
- return
- }
- opcode_count :: proc(code: Program) -> (opcodes: int) {
- iter := Opcode_Iterator{ code, 0 }
- for _ in iterate_opcodes(&iter) {
- opcodes += 1
- }
- return
- }
- create :: proc(code: Program, str: string) -> (vm: Machine) {
- assert(len(code) > 0, "RegEx VM has no instructions.")
- vm.memory = str
- vm.code = code
- sizing := len(code) >> 6 + (1 if len(code) & 0x3F > 0 else 0)
- assert(sizing > 0)
- vm.busy_map = make([]u64, sizing)
- max_possible_threads := max(1, opcode_count(vm.code) - 1)
- vm.threads = make([^]Thread, max_possible_threads)
- vm.next_threads = make([^]Thread, max_possible_threads)
- return
- }
|