| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- #include <lua.h>
- #include <lualib.h>
- #include <lauxlib.h>
- #include <deepspeech.h>
- #include <stdbool.h>
- #include <stdint.h>
- #include <stdlib.h>
- #define CHECK(c, ...) if (!(c)) { return luaL_error(L, __VA_ARGS__); }
- #ifdef _WIN32
- #define LDS_EXPORT __declspec(dllexport)
- #else
- #define LDS_EXPORT
- #endif
- static struct {
- ModelState* modelState;
- size_t bufferSize;
- short* buffer;
- } state;
- typedef struct {
- StreamingState* handle;
- } lds_Stream;
- static const short* lds_checksamples(lua_State* L, int index, size_t* count) {
- if (lua_istable(L, index)) {
- *count = lua_objlen(L, index);
- if (state.bufferSize < *count) {
- state.bufferSize += !state.bufferSize;
- do { state.bufferSize <<= 1; } while (state.bufferSize < *count);
- state.buffer = realloc(state.buffer, state.bufferSize);
- }
- for (size_t i = 0; i < *count; i++) {
- lua_rawgeti(L, index, i + 1);
- lua_Integer x = lua_tointeger(L, -1);
- lua_pop(L, 1);
- if (x < INT16_MIN || x > INT16_MAX) {
- luaL_error(L, "Sample #%d (%d) is out of range [%d,%d]", i + 1, x, INT16_MIN, INT16_MAX);
- }
- state.buffer[i] = x;
- }
- return state.buffer;
- } else if (lua_type(L, index) == LUA_TLIGHTUSERDATA) {
- return *count = luaL_checkinteger(L, index + 1), lua_touserdata(L, index);
- }
- return NULL;
- }
- static void lds_pushmetadata(lua_State* L, Metadata* metadata) {
- lua_createtable(L, metadata->num_transcripts, 0);
- for (int i = 0; i < metadata->num_transcripts; i++) {
- const CandidateTranscript* transcript = &metadata->transcripts[i];
- lua_createtable(L, 0, 3);
- lua_pushnumber(L, transcript->confidence);
- lua_setfield(L, -2, "confidence");
- lua_createtable(L, transcript->num_tokens, 0);
- for (int j = 0; j < transcript->num_tokens; j++) {
- lua_pushnumber(L, transcript->tokens[j].start_time);
- lua_rawseti(L, -2, j + 1);
- }
- lua_setfield(L, -2, "times");
- lua_createtable(L, transcript->num_tokens, 0);
- for (int j = 0; j < transcript->num_tokens; j++) {
- lua_pushstring(L, transcript->tokens[j].text);
- lua_rawseti(L, -2, j + 1);
- }
- lua_setfield(L, -2, "tokens");
- lua_rawseti(L, -2, i + 1);
- }
- }
- static int lds_init(lua_State* L) {
- luaL_argcheck(L, lua_istable(L, 1), 1, "Expected config to be a table");
- if (state.modelState) {
- DS_FreeModel(state.modelState);
- state.modelState = NULL;
- }
- const char* model = NULL;
- const char* scorer = NULL;
- lua_getfield(L, 1, "model");
- CHECK(lua_type(L, -1) == LUA_TSTRING, "config.model should be a string containing a path to the pbmm file");
- model = lua_tostring(L, -1);
- lua_pop(L, 1);
- lua_getfield(L, 1, "scorer");
- int type = lua_type(L, -1);
- CHECK(type == LUA_TNIL || type == LUA_TSTRING, "config.scorer should be nil or a string");
- scorer = lua_tostring(L, -1);
- lua_pop(L, 1);
- int err = DS_CreateModel(model, &state.modelState);
- if (err) {
- lua_pushboolean(L, false);
- char* message = DS_ErrorCodeToErrorMessage(err);
- lua_pushstring(L, message);
- DS_FreeString(message);
- return 2;
- }
- lua_getfield(L, 1, "beamWidth");
- if (!lua_isnil(L, -1)) {
- DS_SetModelBeamWidth(state.modelState, luaL_checkinteger(L, -1));
- }
- lua_pop(L, 1);
- if (scorer) {
- CHECK(DS_EnableExternalScorer(state.modelState, scorer) == 0, "Failed to set scorer");
- lua_getfield(L, 1, "alpha");
- float alpha = lua_tonumber(L, -1);
- lua_pop(L, 1);
- lua_getfield(L, 1, "beta");
- float beta = lua_tonumber(L, -1);
- lua_pop(L, 1);
- if (alpha != 0.f || beta != 0.f) {
- CHECK(DS_SetScorerAlphaBeta(state.modelState, alpha, beta) == 0, "Failed to set scorer alpha/beta");
- }
- }
- lua_pushboolean(L, true);
- lua_pushinteger(L, DS_GetModelSampleRate(state.modelState));
- return 2;
- }
- static int lds_destroy(lua_State* L) {
- if (state.modelState) {
- DS_FreeModel(state.modelState);
- state.modelState = NULL;
- }
- state.bufferSize = 0;
- free(state.buffer);
- return 0;
- }
- static int lds_decode(lua_State* L) {
- size_t sampleCount;
- CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
- const short* samples = lds_checksamples(L, 1, &sampleCount);
- CHECK(samples != NULL, "Expected a table or lightuserdata pointer for audio sample data");
- char* text = DS_SpeechToText(state.modelState, samples, sampleCount);
- lua_pushstring(L, text);
- DS_FreeString(text);
- return 1;
- }
- static int lds_analyze(lua_State* L) {
- size_t sampleCount;
- CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
- const short* samples = lds_checksamples(L, 1, &sampleCount);
- CHECK(samples != NULL, "Expected a table or lightuserdata pointer for audio sample data");
- uint32_t limit = luaL_optinteger(L, lua_istable(L, 1) ? 2 : 3, 3);
- Metadata* metadata = DS_SpeechToTextWithMetadata(state.modelState, samples, sampleCount, limit);
- lds_pushmetadata(L, metadata);
- DS_FreeMetadata(metadata);
- return 1;
- }
- static int lds_boost(lua_State* L) {
- CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
- const char* word = luaL_checkstring(L, 1);
- float boost = luaL_checknumber(L, 2);
- DS_AddHotWord(state.modelState, word, boost);
- return 0;
- }
- static int lds_unboost(lua_State* L) {
- CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
- const char* word = lua_tostring(L, 1);
- if (word) {
- DS_EraseHotWord(state.modelState, word);
- } else {
- DS_ClearHotWords(state.modelState);
- }
- return 0;
- }
- static int lds_newStream(lua_State* L) {
- CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
- lds_Stream* stream = (lds_Stream*) lua_newuserdata(L, sizeof(lds_Stream));
- CHECK(DS_CreateStream(state.modelState, &stream->handle) == 0, "Could not create stream");
- luaL_getmetatable(L, "lds_Stream");
- lua_setmetatable(L, -2);
- return 1;
- }
- static int lds_stream_feed(lua_State* L) {
- size_t sampleCount;
- lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
- const short* samples = lds_checksamples(L, 2, &sampleCount);
- CHECK(samples != NULL, "Expected a table or lightuserdata pointer for audio sample data");
- DS_FeedAudioContent(stream->handle, samples, sampleCount);
- return 0;
- }
- static int lds_stream_decode(lua_State* L) {
- lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
- char* text = DS_IntermediateDecode(stream->handle);
- lua_pushstring(L, text);
- DS_FreeString(text);
- return 1;
- }
- static int lds_stream_analyze(lua_State* L) {
- lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
- uint32_t limit = luaL_optinteger(L, 2, 3);
- Metadata* metadata = DS_IntermediateDecodeWithMetadata(stream->handle, limit);
- lds_pushmetadata(L, metadata);
- DS_FreeMetadata(metadata);
- return 1;
- }
- static int lds_stream_finish(lua_State* L) {
- lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
- char* text = DS_FinishStream(stream->handle);
- lua_pushstring(L, text);
- DS_FreeString(text);
- DS_CreateStream(state.modelState, &stream->handle);
- return 1;
- }
- static int lds_stream_clear(lua_State* L) {
- lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
- DS_FreeStream(stream->handle);
- DS_CreateStream(state.modelState, &stream->handle);
- return 0;
- }
- static int lds_stream_destroy(lua_State* L) {
- lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
- DS_FreeStream(stream->handle);
- return 0;
- }
- static const luaL_Reg lds_api[] = {
- { "init", lds_init },
- { "decode", lds_decode },
- { "analyze", lds_analyze },
- { "boost", lds_boost },
- { "unboost", lds_unboost },
- { "newStream", lds_newStream },
- { NULL, NULL },
- };
- static const luaL_Reg lds_stream_api[] = {
- { "feed", lds_stream_feed },
- { "decode", lds_stream_decode },
- { "analyze", lds_stream_analyze },
- { "finish", lds_stream_finish },
- { "clear", lds_stream_clear },
- { "__gc", lds_stream_destroy },
- { NULL, NULL }
- };
- LDS_EXPORT int luaopen_deepspeech(lua_State* L) {
- lua_newtable(L);
- luaL_register(L, NULL, lds_api);
- // Add sentinel userdata to free the model state on GC
- lua_newuserdata(L, sizeof(void*));
- lua_createtable(L, 0, 1);
- lua_pushcfunction(L, lds_destroy);
- lua_setfield(L, -2, "__gc");
- lua_setmetatable(L, -2);
- lua_setfield(L, -2, "");
- if (luaL_newmetatable(L, "lds_Stream")) {
- lua_pushvalue(L, -1);
- lua_setfield(L, -2, "__index");
- luaL_register(L, NULL, lds_stream_api);
- lua_pop(L, 1);
- } else {
- return luaL_error(L, "Could not register lds_Stream metatable!");
- }
- return 1;
- }
|