Browse Source

analyze; boost; unboost; beam width; alpha/beta;

bjorn 5 years ago
parent
commit
6f01089d42
2 changed files with 171 additions and 39 deletions
  1. 66 12
      README.md
  2. 105 27
      lua_deepspeech.c

+ 66 - 12
README.md

@@ -34,6 +34,14 @@ DeepSpeech Setup
 `gpu` flavor runs on the GPU with CUDA, and the `tflite` flavor can use the smaller tflite model
 instead of the pbmm one.
 
+### Scorer
+
+You can also optionally create a thing called a "scorer package".  The scorer acts as the grammar
+or vocabulary for the recognition, allowing it to recognize a custom set of words or phrases.  This
+can improve accuracy and speed by a lot, and is useful if you only have a few words or commands that
+need to be detected.  See [here](https://deepspeech.readthedocs.io/en/v0.9.3/Scorer.html) for
+instructions on generating a scorer.
+
 Building
 ---
 
@@ -68,34 +76,62 @@ local speech = require 'lua-deepspeech'
 It returns a table with the library's functionality.
 
 ```lua
-success = speech.init(options)
+success, sampleRate = speech.init(options)
 ```
 
 The library must be initialized with an options table.  The table can contain the following options:
 
 - `options.model` should be a full path to the deepspeech model file (pbmm).  If this file is stored
   in a zip archive fused to the executable it will need to be written to disk first.
-- `options.grammar` TODO
+- `options.scorer` is an optional a path to the scorer package.
+- `options.beamWidth` is an optional beam width number.  A higher beam width increases accuracy at
+  the cost of performance.
+- `options.alpha` and `options.beta` are optional paramters for the scorer.  Usually the defaults
+  are fine.
+
+The function either returns false plus an error message or true and the audio sample rate that the
+model was trained against.  All audio must be provided as **signed 16 bit mono** samples at this
+sample rate.  It's almost always 16000Hz.
 
 ```lua
-sampleRate = speech.getSampleRate()
+text = speech.decode(table)
+text = speech.decode(pointer, count)
 ```
 
-Returns the sample rate the model was trained on, in Hz.  This is usually 16000Hz.  Audio
-information passed to the library should use this sample rate.
+This function performs speech-to-text.  A table of audio samples can be provided, or a lightuserdata
+pointer with a sample count.
+
+In all cases the audio data must be formatted as **signed 16 bit mono** samples at the model's
+sample rate.
+
+Returns a string with the decoded text.
 
 ```lua
-text = speech.decode(table)
-text = speech.decode(pointer, count)
+transcripts = speech.analyze(table, limit)
+transcripts = speech.analyze(pointer, count, limit)
 ```
 
-This functions performs speech-to-text.  A table of audio samples can be provided, or a
-lightuserdata pointer with a sample count.
+This is the same as `decode`, but returns extra metadata about the result.  The return value is a
+list of transcripts.  Each transcript is a table with:
+
+- `confidence` is the confidence level.  May be negative.  Transcripts are sorted by confidence.
+- `tokens` a list of tokens (i.e. letters) that were decoded.
+- `times` a list of timestamps for each token, in seconds.
 
-In all cases the audio data must be formatted as **signed 16 bit mono** samples at the appropriate
-sample rate (usually 16,000Hz, use `speech.getSampleRate` to check).
+`limit` can optionally be used to limit the number of transcripts returned, defaulting to 5.
+
+```lua
+speech.boost(word, amount)
+```
 
-A string is returned with the decoded text.
+Boosts a word.
+
+```lua
+speech.unboost(word)
+speech.unboost()
+```
+
+Unboosts a word, or unboosts all words if no arguments are provided.
 
 ### Streams
 
@@ -122,6 +158,13 @@ text = Stream:decode()
 Performs an intermediate decode on the audio data fed to the Stream, returning the decoded text.
 Additional audio can continue to be fed to the Stream after this function is called.
 
+```lua
+transcripts = Stream:analyze()
+```
+
+Performs an intermediate analysis on the audio data fed to the Stream.  See `speech.analyze`.
+Additional audio can continue to be fed to the Stream after this function is called.
+
 ```lua
 text = Stream:finish()
 ```
@@ -134,6 +177,17 @@ Stream:clear()
 
 Resets the Stream, erasing all audio that has been fed to it.
 
+Tips
+---
+
+- Although DeepSpeech performs at realtime speeds, it's still a good idea to offload the decoding
+  to a separate thread, especially when rendering realtime graphics alongside speech recognition.
+- If you are getting garbage results, ensure you're using the correct sample rate and audio format.
+  DeepSpeech is also somewhat sensitive to background noise and low volume levels.  To improve
+  accuracy further, consider using a custom scorer.
+- When feeding audio to a stream, varying the size of the chunks of audio you feed can be used to
+  trade off latency for performance.
+
 License
 ---
 

+ 105 - 27
lua_deepspeech.c

@@ -24,15 +24,6 @@ typedef struct {
   StreamingState* handle;
 } lds_Stream;
 
-static const char* stringifyError(int e) {
-#define CASE(name, id, desc) case name: return desc;
-  switch (e) {
-    DS_FOR_EACH_ERROR(CASE)
-  }
-  return NULL;
-#undef CASE
-}
-
 static const short* lds_checksamples(lua_State* L, int index, size_t* count) {
   if (lua_istable(L, index)) {
     *count = lua_objlen(L, index);
@@ -63,6 +54,33 @@ static const short* lds_checksamples(lua_State* L, int index, size_t* count) {
   return NULL;
 }
 
+static void lds_pushmetadata(lua_State* L, Metadata* metadata) {
+  lua_createtable(L, metadata->num_transcripts, 0);
+  for (int i = 0; i < metadata->num_transcripts; i++) {
+    const CandidateTranscript* transcript = &metadata->transcripts[i];
+    lua_createtable(L, 0, 3);
+
+    lua_pushnumber(L, transcript->confidence);
+    lua_setfield(L, -2, "confidence");
+
+    lua_createtable(L, transcript->num_tokens, 0);
+    for (int j = 0; j < transcript->num_tokens; j++) {
+      lua_pushnumber(L, transcript->tokens[j].start_time);
+      lua_rawseti(L, -2, j + 1);
+    }
+    lua_setfield(L, -2, "times");
+
+    lua_createtable(L, transcript->num_tokens, 0);
+    for (int j = 0; j < transcript->num_tokens; j++) {
+      lua_pushstring(L, transcript->tokens[j].text);
+      lua_rawseti(L, -2, j + 1);
+    }
+    lua_setfield(L, -2, "tokens");
+
+    lua_rawseti(L, -2, i + 1);
+  }
+}
+
 static int lds_init(lua_State* L) {
   luaL_argcheck(L, lua_istable(L, 1), 1, "Expected config to be a table");
 
@@ -72,32 +90,53 @@ static int lds_init(lua_State* L) {
   }
 
   const char* model = NULL;
-  const char* grammar = NULL;
-
-  int type;
+  const char* scorer = NULL;
 
   lua_getfield(L, 1, "model");
   CHECK(lua_type(L, -1) == LUA_TSTRING, "config.model should be a string containing a path to the pbmm file");
   model = lua_tostring(L, -1);
   lua_pop(L, 1);
 
-  lua_getfield(L, 1, "grammar");
-  type = lua_type(L, -1);
-  CHECK(type == LUA_TNIL || type == LUA_TSTRING, "config.grammar should be nil or a string");
-  grammar = lua_tostring(L, -1);
+  lua_getfield(L, 1, "scorer");
+  int type = lua_type(L, -1);
+  CHECK(type == LUA_TNIL || type == LUA_TSTRING, "config.scorer should be nil or a string");
+  scorer = lua_tostring(L, -1);
   lua_pop(L, 1);
 
   int err = DS_CreateModel(model, &state.modelState);
   if (err) {
-    return luaL_error(L, "Failed to initialize DeepSpeech: %s", stringifyError(err));
+    lua_pushboolean(L, false);
+    char* message = DS_ErrorCodeToErrorMessage(err);
+    lua_pushstring(L, message);
+    DS_FreeString(message);
+    return 2;
   }
 
-  if (grammar) {
-    CHECK(DS_EnableExternalScorer(state.modelState, grammar) == 0, "Failed to set grammar");
+  lua_getfield(L, 1, "beamWidth");
+  if (!lua_isnil(L, -1)) {
+    DS_SetModelBeamWidth(state.modelState, luaL_checkinteger(L, -1));
+  }
+  lua_pop(L, 1);
+
+  if (scorer) {
+    CHECK(DS_EnableExternalScorer(state.modelState, scorer) == 0, "Failed to set scorer");
+
+    lua_getfield(L, 1, "alpha");
+    float alpha = lua_tonumber(L, -1);
+    lua_pop(L, 1);
+
+    lua_getfield(L, 1, "beta");
+    float beta = lua_tonumber(L, -1);
+    lua_pop(L, 1);
+
+    if (alpha != 0.f || beta != 0.f) {
+      CHECK(DS_SetScorerAlphaBeta(state.modelState, alpha, beta) == 0, "Failed to set scorer alpha/beta");
+    }
   }
 
   lua_pushboolean(L, true);
-  return 1;
+  lua_pushinteger(L, DS_GetModelSampleRate(state.modelState));
+  return 2;
 }
 
 static int lds_destroy(lua_State* L) {
@@ -105,15 +144,11 @@ static int lds_destroy(lua_State* L) {
     DS_FreeModel(state.modelState);
     state.modelState = NULL;
   }
+  state.bufferSize = 0;
+  free(state.buffer);
   return 0;
 }
 
-static int lds_getSampleRate(lua_State* L) {
-  CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
-  lua_pushinteger(L, DS_GetModelSampleRate(state.modelState));
-  return 1;
-}
-
 static int lds_decode(lua_State* L) {
   size_t sampleCount;
   CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
@@ -125,6 +160,37 @@ static int lds_decode(lua_State* L) {
   return 1;
 }
 
+static int lds_analyze(lua_State* L) {
+  size_t sampleCount;
+  CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
+  const short* samples = lds_checksamples(L, 1, &sampleCount);
+  CHECK(samples != NULL, "Expected a table or lightuserdata pointer for audio sample data");
+  uint32_t limit = luaL_optinteger(L, lua_istable(L, 1) ? 2 : 3, 3);
+  Metadata* metadata = DS_SpeechToTextWithMetadata(state.modelState, samples, sampleCount, limit);
+  lds_pushmetadata(L, metadata);
+  DS_FreeMetadata(metadata);
+  return 1;
+}
+
+static int lds_boost(lua_State* L) {
+  CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
+  const char* word = luaL_checkstring(L, 1);
+  float boost = luaL_checknumber(L, 2);
+  DS_AddHotWord(state.modelState, word, boost);
+  return 0;
+}
+
+static int lds_unboost(lua_State* L) {
+  CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
+  const char* word = lua_tostring(L, 1);
+  if (word) {
+    DS_EraseHotWord(state.modelState, word);
+  } else {
+    DS_ClearHotWords(state.modelState);
+  }
+  return 0;
+}
+
 static int lds_newStream(lua_State* L) {
   CHECK(state.modelState != NULL, "DeepSpeech is not initialized");
   lds_Stream* stream = (lds_Stream*) lua_newuserdata(L, sizeof(lds_Stream));
@@ -151,6 +217,15 @@ static int lds_stream_decode(lua_State* L) {
   return 1;
 }
 
+static int lds_stream_analyze(lua_State* L) {
+  lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
+  uint32_t limit = luaL_optinteger(L, 2, 3);
+  Metadata* metadata = DS_IntermediateDecodeWithMetadata(stream->handle, limit);
+  lds_pushmetadata(L, metadata);
+  DS_FreeMetadata(metadata);
+  return 1;
+}
+
 static int lds_stream_finish(lua_State* L) {
   lds_Stream* stream = (lds_Stream*) luaL_checkudata(L, 1, "lds_Stream");
   char* text = DS_FinishStream(stream->handle);
@@ -175,8 +250,10 @@ static int lds_stream_destroy(lua_State* L) {
 
 static const luaL_Reg lds_api[] = {
   { "init", lds_init },
-  { "getSampleRate", lds_getSampleRate },
   { "decode", lds_decode },
+  { "analyze", lds_analyze },
+  { "boost", lds_boost },
+  { "unboost", lds_unboost },
   { "newStream", lds_newStream },
   { NULL, NULL },
 };
@@ -184,6 +261,7 @@ static const luaL_Reg lds_api[] = {
 static const luaL_Reg lds_stream_api[] = {
   { "feed", lds_stream_feed },
   { "decode", lds_stream_decode },
+  { "analyze", lds_stream_analyze },
   { "finish", lds_stream_finish },
   { "clear", lds_stream_clear },
   { "__gc", lds_stream_destroy },