lua_deepspeech.c 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. #include <lua.h>
  2. #include <lualib.h>
  3. #include <lauxlib.h>
  4. #include <deepspeech.h>
  5. #include <stdbool.h>
  6. #include <stdint.h>
  7. #include <stdlib.h>
  8. #define CHECK(c, ...) if (!(c)) { return luaL_error(L, __VA_ARGS__); }
  9. #ifdef _WIN32
  10. #define LDS_EXPORT __declspec(dllexport)
  11. #else
  12. #define LDS_EXPORT
  13. #endif
  14. static struct {
  15. ModelState* modelState;
  16. size_t bufferSize;
  17. short* buffer;
  18. } state;
  19. typedef struct {
  20. StreamingState* handle;
  21. } lds_Stream;
  22. static const short* lds_checksamples(lua_State* L, int index, size_t* count) {
  23. if (lua_istable(L, index)) {
  24. *count = lua_objlen(L, index);
  25. if (state.bufferSize < *count) {
  26. state.bufferSize += !state.bufferSize;
  27. do { state.bufferSize <<= 1; } while (state.bufferSize < *count);
  28. state.buffer = realloc(state.buffer, state.bufferSize);
  29. }
  30. for (size_t i = 0; i < *count; i++) {
  31. lua_rawgeti(L, index, i + 1);
  32. lua_Integer x = lua_tointeger(L, -1);
  33. lua_pop(L, 1);
  34. if (x < INT16_MIN || x > INT16_MAX) {
  35. luaL_error(L, "Sample #%d (%d) is out of range [%d,%d]", i + 1, x, INT16_MIN, INT16_MAX);
  36. }
  37. state.buffer[i] = x;
  38. }
  39. return state.buffer;
  40. } else if (lua_type(L, index) == LUA_TLIGHTUSERDATA) {
  41. return *count = luaL_checkinteger(L, index + 1), lua_touserdata(L, index);
  42. }
  43. return NULL;
  44. }
  45. static void lds_pushmetadata(lua_State* L, Metadata* metadata) {
  46. lua_createtable(L, metadata->num_transcripts, 0);
  47. for (int i = 0; i < metadata->num_transcripts; i++) {
  48. const CandidateTranscript* transcript = &metadata->transcripts[i];
  49. lua_createtable(L, 0, 3);
  50. lua_pushnumber(L, transcript->confidence);
  51. lua_setfield(L, -2, "confidence");
  52. lua_createtable(L, transcript->num_tokens, 0);
  53. for (int j = 0; j < transcript->num_tokens; j++) {
  54. lua_pushnumber(L, transcript->tokens[j].start_time);
  55. lua_rawseti(L, -2, j + 1);
  56. }
  57. lua_setfield(L, -2, "times");
  58. lua_createtable(L, transcript->num_tokens, 0);
  59. for (int j = 0; j < transcript->num_tokens; j++) {
  60. lua_pushstring(L, transcript->tokens[j].text);
  61. lua_rawseti(L, -2, j + 1);
  62. }
  63. lua_setfield(L, -2, "tokens");
  64. lua_rawseti(L, -2, i + 1);
  65. }
  66. }
  67. static int lds_init(lua_State* L) {
  68. luaL_argcheck(L, lua_istable(L, 1), 1, "Expected config to be a table");
  69. if (state.modelState) {
  70. DS_FreeModel(state.modelState);
  71. state.modelState = NULL;
  72. }
  73. const char* model = NULL;
  74. const char* scorer = NULL;
  75. lua_getfield(L, 1, "model");
  76. CHECK(lua_type(L, -1) == LUA_TSTRING, "config.model should be a string containing a path to the pbmm file");
  77. model = lua_tostring(L, -1);
  78. lua_pop(L, 1);
  79. lua_getfield(L, 1, "scorer");
  80. int type = lua_type(L, -1);
  81. CHECK(type == LUA_TNIL || type == LUA_TSTRING, "config.scorer should be nil or a string");
  82. scorer = lua_tostring(L, -1);
  83. lua_pop(L, 1);
  84. int err = DS_CreateModel(model, &state.modelState);
  85. if (err) {
  86. lua_pushboolean(L, false);
  87. char* message = DS_ErrorCodeToErrorMessage(err);
  88. lua_pushstring(L, message);
  89. DS_FreeString(message);
  90. return 2;
  91. }
  92. lua_getfield(L, 1, "beamWidth");
  93. if (!lua_isnil(L, -1)) {
  94. DS_SetModelBeamWidth(state.modelState, luaL_checkinteger(L, -1));
  95. }
  96. lua_pop(L, 1);
  97. if (scorer) {
  98. CHECK(DS_EnableExternalScorer(state.modelState, scorer) == 0, "Failed to set scorer");
  99. lua_getfield(L, 1, "alpha");
  100. float alpha = lua_tonumber(L, -1);
  101. lua_pop(L, 1);
  102. lua_getfield(L, 1, "beta");
  103. float beta = lua_tonumber(L, -1);
  104. lua_pop(L, 1);
  105. if (alpha != 0.f || beta != 0.f) {
  106. CHECK(DS_SetScorerAlphaBeta(state.modelState, alpha, beta) == 0, "Failed to set scorer alpha/beta");
  107. }
  108. }
  109. lua_pushboolean(L, true);
  110. lua_pushinteger(L, DS_GetModelSampleRate(state.modelState));
  111. return 2;
  112. }
  113. static int lds_destroy(lua_State* L) {
  114. if (state.modelState) {
  115. DS_FreeModel(state.modelState);
  116. state.modelState = NULL;
  117. }
  118. state.bufferSize = 0;
  119. free(state.buffer);
  120. return 0;
  121. }
  122. static int lds_decode(lua_State* L) {
  123. size_t sampleCount;
  124. CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
  125. const short* samples = lds_checksamples(L, 1, &sampleCount);
  126. CHECK(samples != NULL, "Expected a table or lightuserdata pointer for audio sample data");
  127. char* text = DS_SpeechToText(state.modelState, samples, sampleCount);
  128. lua_pushstring(L, text);
  129. DS_FreeString(text);
  130. return 1;
  131. }
  132. static int lds_analyze(lua_State* L) {
  133. size_t sampleCount;
  134. CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
  135. const short* samples = lds_checksamples(L, 1, &sampleCount);
  136. CHECK(samples != NULL, "Expected a table or lightuserdata pointer for audio sample data");
  137. uint32_t limit = luaL_optinteger(L, lua_istable(L, 1) ? 2 : 3, 3);
  138. Metadata* metadata = DS_SpeechToTextWithMetadata(state.modelState, samples, sampleCount, limit);
  139. lds_pushmetadata(L, metadata);
  140. DS_FreeMetadata(metadata);
  141. return 1;
  142. }
  143. static int lds_boost(lua_State* L) {
  144. CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
  145. const char* word = luaL_checkstring(L, 1);
  146. float boost = luaL_checknumber(L, 2);
  147. DS_AddHotWord(state.modelState, word, boost);
  148. return 0;
  149. }
  150. static int lds_unboost(lua_State* L) {
  151. CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
  152. const char* word = lua_tostring(L, 1);
  153. if (word) {
  154. DS_EraseHotWord(state.modelState, word);
  155. } else {
  156. DS_ClearHotWords(state.modelState);
  157. }
  158. return 0;
  159. }
  160. static int lds_newStream(lua_State* L) {
  161. CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
  162. lds_Stream* stream = (lds_Stream*) lua_newuserdata(L, sizeof(lds_Stream));
  163. CHECK(DS_CreateStream(state.modelState, &stream->handle) == 0, "Could not create stream");
  164. luaL_getmetatable(L, "lds_Stream");
  165. lua_setmetatable(L, -2);
  166. return 1;
  167. }
  168. static int lds_stream_feed(lua_State* L) {
  169. size_t sampleCount;
  170. lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
  171. const short* samples = lds_checksamples(L, 2, &sampleCount);
  172. CHECK(samples != NULL, "Expected a table or lightuserdata pointer for audio sample data");
  173. DS_FeedAudioContent(stream->handle, samples, sampleCount);
  174. return 0;
  175. }
  176. static int lds_stream_decode(lua_State* L) {
  177. lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
  178. char* text = DS_IntermediateDecode(stream->handle);
  179. lua_pushstring(L, text);
  180. DS_FreeString(text);
  181. return 1;
  182. }
  183. static int lds_stream_analyze(lua_State* L) {
  184. lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
  185. uint32_t limit = luaL_optinteger(L, 2, 3);
  186. Metadata* metadata = DS_IntermediateDecodeWithMetadata(stream->handle, limit);
  187. lds_pushmetadata(L, metadata);
  188. DS_FreeMetadata(metadata);
  189. return 1;
  190. }
  191. static int lds_stream_finish(lua_State* L) {
  192. lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
  193. char* text = DS_FinishStream(stream->handle);
  194. lua_pushstring(L, text);
  195. DS_FreeString(text);
  196. DS_CreateStream(state.modelState, &stream->handle);
  197. return 1;
  198. }
  199. static int lds_stream_clear(lua_State* L) {
  200. lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
  201. DS_FreeStream(stream->handle);
  202. DS_CreateStream(state.modelState, &stream->handle);
  203. return 0;
  204. }
  205. static int lds_stream_destroy(lua_State* L) {
  206. lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
  207. DS_FreeStream(stream->handle);
  208. return 0;
  209. }
  210. static const luaL_Reg lds_api[] = {
  211. { "init", lds_init },
  212. { "decode", lds_decode },
  213. { "analyze", lds_analyze },
  214. { "boost", lds_boost },
  215. { "unboost", lds_unboost },
  216. { "newStream", lds_newStream },
  217. { NULL, NULL },
  218. };
  219. static const luaL_Reg lds_stream_api[] = {
  220. { "feed", lds_stream_feed },
  221. { "decode", lds_stream_decode },
  222. { "analyze", lds_stream_analyze },
  223. { "finish", lds_stream_finish },
  224. { "clear", lds_stream_clear },
  225. { "__gc", lds_stream_destroy },
  226. { NULL, NULL }
  227. };
  228. LDS_EXPORT int luaopen_deepspeech(lua_State* L) {
  229. lua_newtable(L);
  230. luaL_register(L, NULL, lds_api);
  231. // Add sentinel userdata to free the model state on GC
  232. lua_newuserdata(L, sizeof(void*));
  233. lua_createtable(L, 0, 1);
  234. lua_pushcfunction(L, lds_destroy);
  235. lua_setfield(L, -2, "__gc");
  236. lua_setmetatable(L, -2);
  237. lua_setfield(L, -2, "");
  238. if (luaL_newmetatable(L, "lds_Stream")) {
  239. lua_pushvalue(L, -1);
  240. lua_setfield(L, -2, "__index");
  241. luaL_register(L, NULL, lds_stream_api);
  242. lua_pop(L, 1);
  243. } else {
  244. return luaL_error(L, "Could not register lds_Stream metatable!");
  245. }
  246. return 1;
  247. }