|
|
@@ -24,15 +24,6 @@ typedef struct {
|
|
|
StreamingState* handle;
|
|
|
} lds_Stream;
|
|
|
|
|
|
-static const char* stringifyError(int e) {
|
|
|
-#define CASE(name, id, desc) case name: return desc;
|
|
|
- switch (e) {
|
|
|
- DS_FOR_EACH_ERROR(CASE)
|
|
|
- }
|
|
|
- return NULL;
|
|
|
-#undef CASE
|
|
|
-}
|
|
|
-
|
|
|
static const short* lds_checksamples(lua_State* L, int index, size_t* count) {
|
|
|
if (lua_istable(L, index)) {
|
|
|
*count = lua_objlen(L, index);
|
|
|
@@ -63,6 +54,33 @@ static const short* lds_checksamples(lua_State* L, int index, size_t* count) {
|
|
|
return NULL;
|
|
|
}
|
|
|
|
|
|
+static void lds_pushmetadata(lua_State* L, Metadata* metadata) {
|
|
|
+ lua_createtable(L, metadata->num_transcripts, 0);
|
|
|
+ for (int i = 0; i < metadata->num_transcripts; i++) {
|
|
|
+ const CandidateTranscript* transcript = &metadata->transcripts[i];
|
|
|
+ lua_createtable(L, 0, 3);
|
|
|
+
|
|
|
+ lua_pushnumber(L, transcript->confidence);
|
|
|
+ lua_setfield(L, -2, "confidence");
|
|
|
+
|
|
|
+ lua_createtable(L, transcript->num_tokens, 0);
|
|
|
+ for (int j = 0; j < transcript->num_tokens; j++) {
|
|
|
+ lua_pushnumber(L, transcript->tokens[j].start_time);
|
|
|
+ lua_rawseti(L, -2, j + 1);
|
|
|
+ }
|
|
|
+ lua_setfield(L, -2, "times");
|
|
|
+
|
|
|
+ lua_createtable(L, transcript->num_tokens, 0);
|
|
|
+ for (int j = 0; j < transcript->num_tokens; j++) {
|
|
|
+ lua_pushstring(L, transcript->tokens[j].text);
|
|
|
+ lua_rawseti(L, -2, j + 1);
|
|
|
+ }
|
|
|
+ lua_setfield(L, -2, "tokens");
|
|
|
+
|
|
|
+ lua_rawseti(L, -2, i + 1);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
static int lds_init(lua_State* L) {
|
|
|
luaL_argcheck(L, lua_istable(L, 1), 1, "Expected config to be a table");
|
|
|
|
|
|
@@ -72,32 +90,53 @@ static int lds_init(lua_State* L) {
|
|
|
}
|
|
|
|
|
|
const char* model = NULL;
|
|
|
- const char* grammar = NULL;
|
|
|
-
|
|
|
- int type;
|
|
|
+ const char* scorer = NULL;
|
|
|
|
|
|
lua_getfield(L, 1, "model");
|
|
|
CHECK(lua_type(L, -1) == LUA_TSTRING, "config.model should be a string containing a path to the pbmm file");
|
|
|
model = lua_tostring(L, -1);
|
|
|
lua_pop(L, 1);
|
|
|
|
|
|
- lua_getfield(L, 1, "grammar");
|
|
|
- type = lua_type(L, -1);
|
|
|
- CHECK(type == LUA_TNIL || type == LUA_TSTRING, "config.grammar should be nil or a string");
|
|
|
- grammar = lua_tostring(L, -1);
|
|
|
+ lua_getfield(L, 1, "scorer");
|
|
|
+ int type = lua_type(L, -1);
|
|
|
+ CHECK(type == LUA_TNIL || type == LUA_TSTRING, "config.scorer should be nil or a string");
|
|
|
+ scorer = lua_tostring(L, -1);
|
|
|
lua_pop(L, 1);
|
|
|
|
|
|
int err = DS_CreateModel(model, &state.modelState);
|
|
|
if (err) {
|
|
|
- return luaL_error(L, "Failed to initialize DeepSpeech: %s", stringifyError(err));
|
|
|
+ lua_pushboolean(L, false);
|
|
|
+ char* message = DS_ErrorCodeToErrorMessage(err);
|
|
|
+ lua_pushstring(L, message);
|
|
|
+ DS_FreeString(message);
|
|
|
+ return 2;
|
|
|
}
|
|
|
|
|
|
- if (grammar) {
|
|
|
- CHECK(DS_EnableExternalScorer(state.modelState, grammar) == 0, "Failed to set grammar");
|
|
|
+ lua_getfield(L, 1, "beamWidth");
|
|
|
+ if (!lua_isnil(L, -1)) {
|
|
|
+ DS_SetModelBeamWidth(state.modelState, luaL_checkinteger(L, -1));
|
|
|
+ }
|
|
|
+ lua_pop(L, 1);
|
|
|
+
|
|
|
+ if (scorer) {
|
|
|
+ CHECK(DS_EnableExternalScorer(state.modelState, scorer) == 0, "Failed to set scorer");
|
|
|
+
|
|
|
+ lua_getfield(L, 1, "alpha");
|
|
|
+ float alpha = lua_tonumber(L, -1);
|
|
|
+ lua_pop(L, 1);
|
|
|
+
|
|
|
+ lua_getfield(L, 1, "beta");
|
|
|
+ float beta = lua_tonumber(L, -1);
|
|
|
+ lua_pop(L, 1);
|
|
|
+
|
|
|
+ if (alpha != 0.f || beta != 0.f) {
|
|
|
+ CHECK(DS_SetScorerAlphaBeta(state.modelState, alpha, beta) == 0, "Failed to set scorer alpha/beta");
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
lua_pushboolean(L, true);
|
|
|
- return 1;
|
|
|
+ lua_pushinteger(L, DS_GetModelSampleRate(state.modelState));
|
|
|
+ return 2;
|
|
|
}
|
|
|
|
|
|
static int lds_destroy(lua_State* L) {
|
|
|
@@ -105,15 +144,11 @@ static int lds_destroy(lua_State* L) {
|
|
|
DS_FreeModel(state.modelState);
|
|
|
state.modelState = NULL;
|
|
|
}
|
|
|
+ state.bufferSize = 0;
|
|
|
+ free(state.buffer);
|
|
|
return 0;
|
|
|
}
|
|
|
|
|
|
-static int lds_getSampleRate(lua_State* L) {
|
|
|
- CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
|
|
|
- lua_pushinteger(L, DS_GetModelSampleRate(state.modelState));
|
|
|
- return 1;
|
|
|
-}
|
|
|
-
|
|
|
static int lds_decode(lua_State* L) {
|
|
|
size_t sampleCount;
|
|
|
CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
|
|
|
@@ -125,6 +160,37 @@ static int lds_decode(lua_State* L) {
|
|
|
return 1;
|
|
|
}
|
|
|
|
|
|
+static int lds_analyze(lua_State* L) {
|
|
|
+ size_t sampleCount;
|
|
|
+ CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
|
|
|
+ const short* samples = lds_checksamples(L, 1, &sampleCount);
|
|
|
+ CHECK(samples != NULL, "Expected a table or lightuserdata pointer for audio sample data");
|
|
|
+ uint32_t limit = luaL_optinteger(L, lua_istable(L, 1) ? 2 : 3, 3);
|
|
|
+ Metadata* metadata = DS_SpeechToTextWithMetadata(state.modelState, samples, sampleCount, limit);
|
|
|
+ lds_pushmetadata(L, metadata);
|
|
|
+ DS_FreeMetadata(metadata);
|
|
|
+ return 1;
|
|
|
+}
|
|
|
+
|
|
|
+static int lds_boost(lua_State* L) {
|
|
|
+ CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
|
|
|
+ const char* word = luaL_checkstring(L, 1);
|
|
|
+ float boost = luaL_checknumber(L, 2);
|
|
|
+ DS_AddHotWord(state.modelState, word, boost);
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+static int lds_unboost(lua_State* L) {
|
|
|
+ CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
|
|
|
+ const char* word = lua_tostring(L, 1);
|
|
|
+ if (word) {
|
|
|
+ DS_EraseHotWord(state.modelState, word);
|
|
|
+ } else {
|
|
|
+ DS_ClearHotWords(state.modelState);
|
|
|
+ }
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
static int lds_newStream(lua_State* L) {
|
|
|
CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
|
|
|
lds_Stream* stream = (lds_Stream*) lua_newuserdata(L, sizeof(lds_Stream));
|
|
|
@@ -151,6 +217,15 @@ static int lds_stream_decode(lua_State* L) {
|
|
|
return 1;
|
|
|
}
|
|
|
|
|
|
+static int lds_stream_analyze(lua_State* L) {
|
|
|
+ lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
|
|
|
+ uint32_t limit = luaL_optinteger(L, 2, 3);
|
|
|
+ Metadata* metadata = DS_IntermediateDecodeWithMetadata(stream->handle, limit);
|
|
|
+ lds_pushmetadata(L, metadata);
|
|
|
+ DS_FreeMetadata(metadata);
|
|
|
+ return 1;
|
|
|
+}
|
|
|
+
|
|
|
static int lds_stream_finish(lua_State* L) {
|
|
|
lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
|
|
|
char* text = DS_FinishStream(stream->handle);
|
|
|
@@ -175,8 +250,10 @@ static int lds_stream_destroy(lua_State* L) {
|
|
|
|
|
|
static const luaL_Reg lds_api[] = {
|
|
|
{ "init", lds_init },
|
|
|
- { "getSampleRate", lds_getSampleRate },
|
|
|
{ "decode", lds_decode },
|
|
|
+ { "analyze", lds_analyze },
|
|
|
+ { "boost", lds_boost },
|
|
|
+ { "unboost", lds_unboost },
|
|
|
{ "newStream", lds_newStream },
|
|
|
{ NULL, NULL },
|
|
|
};
|
|
|
@@ -184,6 +261,7 @@ static const luaL_Reg lds_api[] = {
|
|
|
static const luaL_Reg lds_stream_api[] = {
|
|
|
{ "feed", lds_stream_feed },
|
|
|
{ "decode", lds_stream_decode },
|
|
|
+ { "analyze", lds_stream_analyze },
|
|
|
{ "finish", lds_stream_finish },
|
|
|
{ "clear", lds_stream_clear },
|
|
|
{ "__gc", lds_stream_destroy },
|