sq_nn.cpp 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474
  1. #include <stdio.h>
  2. #include "squirrel.h"
  3. #include <string.h>
  4. #include <inttypes.h>
  5. #include <math.h>
  6. #include <stdlib.h>
  7. #include <sys/time.h>
  8. //#include <pthread.h>
  9. SQ_OPT_STRING_STRLEN();
  10. extern "C" {
  11. #include "nn.h"
  12. void *ann_malloc(size_t sz)
  13. {
  14. return sq_malloc(sz);
  15. }
  16. void ann_free(void *p)
  17. {
  18. sq_free(p, 0);
  19. }
  20. }
  21. #define NR_FLAG_NONE 0
  22. #define NR_FLAG_TRAINING (1<<0) /* NN is training in a thread. */
  23. #define NR_FLAG_REGRESSOR (1<<1) /* NN will be used for regression. */
  24. #define NR_FLAG_CLASSIFIER (1<<2) /* NN will be used for classification.*/
  25. #define NR_FLAG_NORMALIZE (1<<3) /* Perform input/output normalization.*/
  26. #define NR_FLAG_AUTO_STOP (1<<4) /* Auto stop on training. */
  27. #define NR_FLAG_OF_DETECTED (1<<5) /* Auto stopped on overfitting. */
  28. #define NR_FLAG_BACKTRACK (1<<6) /* Auto stop with backtracking. */
  29. /* Flags to persist when saving the NN. */
  30. #define NR_FLAG_TO_PERSIST (NR_FLAG_REGRESSOR| \
  31. NR_FLAG_CLASSIFIER| \
  32. NR_FLAG_NORMALIZE| \
  33. NR_FLAG_OF_DETECTED)
  34. /* Flags to transfer after training. */
  35. #define NR_FLAG_TO_TRANSFER (NR_FLAG_OF_DETECTED)
  36. #define NR_MAX_LAYERS 32
  37. #define NR_RDB_ENC_VER 2
  38. typedef struct {
  39. uint32_t len, maxlen;
  40. float *inputs, *outputs;
  41. } NRDataset;
  42. typedef struct {
  43. uint64_t id; /* Neural network unique ID. */
  44. uint64_t training_total_steps; /* How many steps of trainig the network
  45. received. A step is a single input/output
  46. pattern presented to the net (counting
  47. the same pattern multiple times) */
  48. uint64_t training_total_ms; /* Total milliseconds time of training. */
  49. uint64_t training_max_cycles; /* Max cycles of a single training. */
  50. uint64_t training_max_ms; /* Max time of a single training. */
  51. uint32_t flags; /* NR_FLAG_... */
  52. uint32_t epochs; /* Number of training epochs so far. */
  53. AnnRprop *nn; /* Neural network structure. */
  54. NRDataset dataset; /* Training dataset. */
  55. NRDataset test; /* Testing dataset. */
  56. float dataset_error; /* Average error in the training dataset. */
  57. float test_error; /* Average error in the test dataset. */
  58. float test_class_error; /* Percentage of wrong classifications in test
  59. dataset. Only applicable to nets flagged with
  60. NR_FLAG_CLASSIFIER. */
  61. /* For normalized (NR_FLAG_NORMALIZE) networks. */
  62. float *inorm; /* Inputs normalization factors. */
  63. float *onorm; /* Outputs normalization factors. */
  64. } NRTypeObject;
  65. #if 0
  66. typedef struct {
  67. //RedisModuleString *key; /* Key name of the NN we are training.
  68. // Set to NULL for unused slots. */
  69. int db_id; /* DB ID where the key is. */
  70. pthread_t tid; /* Thread ID of the trainer. */
  71. int in_progress; /* 0 if training terminated. */
  72. NRTypeObject *nr; /* A copy of the NN we are training. */
  73. float dataset_error; /* Dataset error in the last cycle. */
  74. float test_error; /* Test error in the last cycle. */
  75. float class_error; /* Percentage of wrong classifications. */
  76. int curcycle; /* Current cycle. */
  77. } NRPendingTraining;
  78. #endif
  79. /* We take an array with NNs currently training in other threads.
  80. * Every time an NN command is called, we try to see if there are
  81. * finished trainings, in order to udpate weights of the original
  82. * NN stored into the key (we work on a copy on the other thread).*/
  83. #define NR_PENDING_TRAINING_MAX_LEN 32
  84. #if 0
  85. #define REDISMODULE_ERR -1
  86. #define REDISMODULE_OK 0
  87. #define RedisModuleCtx void
  88. #define RedisModule_Log(ctx, log_level, msg)
  89. #define UNUSED(V) ((void) V)
  90. typedef SQString RedisModuleString;
  91. static uint64_t NRNextId = 1; /* Next neural network unique ID. */
  92. #define RedisModule_Alloc(x) sq_malloc(x)
  93. static void *RedisModule_Calloc(size_t nelm, size_t sz)
  94. {
  95. size_t malloc_size = nelm * sz;
  96. void *ptr = sq_malloc(malloc_size);
  97. if(ptr) memset(ptr, 0, malloc_size);
  98. return ptr;
  99. }
  100. static void *RedisModule_Realloc(void *oldPtr, size_t sz)
  101. {
  102. void *ptr = sq_realloc(oldPtr, 0, sz);
  103. return ptr;
  104. }
  105. #define RedisModule_Free(x) sq_free(x, 0)
  106. static pthread_mutex_t NRPendingTrainingMutex = PTHREAD_MUTEX_INITIALIZER;
  107. /* All the followings must be accessed after acquiring the mutex. */
  108. static NRPendingTraining NRTrainings[NR_PENDING_TRAINING_MAX_LEN];
  109. static int NRPendingTrainingCount = 0; /* Number of pending trainings. */
  110. /* ========================== Low level object API ========================== */
  111. long long NRMilliseconds(void) {
  112. struct timeval tv;
  113. long long ust;
  114. gettimeofday(&tv, NULL);
  115. ust = ((long long)tv.tv_sec)*1000000;
  116. ust += tv.tv_usec;
  117. return ust/1000;
  118. }
  119. /* Create a network with the specified parameters. Note that the layers
  120. * must be specified from the output layer[0] to the input
  121. * layer[N]. Each element in the integer array 'layer' specify how many
  122. * units there are in the corresponding layer. */
  123. static NRTypeObject *createNRTypeObject(int flags, int *layers, int numlayers, int dset_len, int test_len) {
  124. NRTypeObject *o;
  125. o = (NRTypeObject*)RedisModule_Calloc(1,sizeof(*o));
  126. o->id = NRNextId++;
  127. o->flags = flags;
  128. o->nn = AnnCreateNet(numlayers,layers);
  129. o->dataset.maxlen = dset_len;
  130. o->test.maxlen = test_len;
  131. int ilen = ANN_INPUT_UNITS(o->nn);
  132. int olen = ANN_OUTPUT_UNITS(o->nn);
  133. o->inorm = (float*)RedisModule_Calloc(1,sizeof(float)*ilen);
  134. o->onorm = (float*)RedisModule_Calloc(1,sizeof(float)*olen);
  135. for (int j = 0; j < ilen; j++) o->inorm[j] = 1;
  136. for (int j = 0; j < olen; j++) o->onorm[j] = 1;
  137. return o;
  138. }
  139. /* Insert data (observations needed to train and test the NN) into the
  140. * NN object. While the learning and testing datasets are yet not full
  141. * the observed pattern is inserted evenly in one or the other side in
  142. * order to make sure the two datasets are populated evenly. When both
  143. * are already full, a random elmenet from one or the other (doing
  144. * a random weighted choice depending on the length) is substituted with
  145. * the new item. */
  146. #define NR_INSERT_NO_TARGET 0 /* Auto select where to insert. */
  147. #define NR_INSERT_TRAIN 1 /* Insert in training dataset. */
  148. #define NR_INSERT_TEST 2 /* Insert in testing dataset. */
  149. static void NRTypeInsertData(NRTypeObject *o, float *inputs, float *outputs,
  150. int target_ds) {
  151. NRDataset *target = NULL;
  152. /* Check if there is no dataset at all. This may be a valid setup
  153. * with online learning, sample by sample. */
  154. if (o->dataset.maxlen == 0 && o->test.maxlen == 0) return;
  155. /* If the user specified a target, select it. */
  156. if (target_ds == NR_INSERT_TRAIN) target = &o->dataset;
  157. else if (target_ds == NR_INSERT_TEST) target = &o->test;
  158. /* If no target is specified, but there is only one possible
  159. * target, select it ASAP. */
  160. if (o->dataset.maxlen == 0) {
  161. target = &o->test;
  162. } else if (o->test.maxlen == 0) {
  163. target = &o->dataset;
  164. }
  165. /* Otherwise choose as the target to populate the one with less data
  166. * relatively to its size. */
  167. if (target == NULL) {
  168. /* If one of the two datasets are still not full, pick
  169. * based on fill percentage. Otherwise pick a random
  170. * target relatively to their size. */
  171. if (o->dataset.len != o->dataset.maxlen ||
  172. o->test.len != o->dataset.len)
  173. {
  174. float fill_a = (float)o->dataset.len / o->dataset.maxlen;
  175. float fill_b = (float)o->test.len / o->test.maxlen;
  176. target = (fill_a <= fill_b) ? &o->dataset : &o->test;
  177. } else {
  178. double r = rand()/RAND_MAX;
  179. double sumlen = o->dataset.maxlen + o->test.maxlen;
  180. if (r < (double)o->dataset.maxlen/sumlen) {
  181. target = &o->dataset;
  182. } else {
  183. target = &o->test;
  184. }
  185. }
  186. }
  187. /* Append if there is room or substitute with a random entry. */
  188. size_t idx;
  189. int j, numin = ANN_INPUT_UNITS(o->nn),
  190. numout = ANN_OUTPUT_UNITS(o->nn);
  191. if (target->maxlen == target->len) {
  192. idx = rand() % target->maxlen;
  193. } else {
  194. idx = target->len;
  195. target->len++;
  196. target->inputs = (float*)RedisModule_Realloc(target->inputs,
  197. sizeof(float)*numin*target->len);
  198. target->outputs = (float*)RedisModule_Realloc(target->outputs,
  199. sizeof(float)*numout*target->len);
  200. }
  201. /* Finally store the values at position. */
  202. for (j = 0; j < numin; j++)
  203. target->inputs[idx*numin+j] = inputs[j];
  204. for (j = 0; j < numout; j++)
  205. target->outputs[idx*numout+j] = outputs[j];
  206. }
  207. /* Free the specified dataset. */
  208. void NRDatasetFree(NRDataset *dset) {
  209. RedisModule_Free(dset->inputs);
  210. RedisModule_Free(dset->outputs);
  211. }
  212. /* Free a whole NN object. */
  213. void NRTypeReleaseObject(NRTypeObject *o) {
  214. AnnFree(o->nn);
  215. NRDatasetFree(&o->dataset);
  216. NRDatasetFree(&o->test);
  217. RedisModule_Free(o->inorm);
  218. RedisModule_Free(o->onorm);
  219. RedisModule_Free(o);
  220. }
  221. /* ================================ Training =============================== */
  222. /* Clone a neural network object, including the training and test dataset.
  223. * We use cloning in order to train in a different thread, and later
  224. * copy the weights back into the original NN.
  225. *
  226. * Note when 'newid' is 0, the copied object NN unique ID is the same as the
  227. * original as normally this is what we want, in order to later match the
  228. * trained network with the object stored at the specified key
  229. * in the pending traning structure.
  230. *
  231. * However if the copy is performed with other goals, 'newid' should
  232. * be set to non-zero in order to create a net with a different ID. */
  233. NRTypeObject *NRClone(NRTypeObject *o, int newid) {
  234. NRTypeObject *copy;
  235. copy = (NRTypeObject*)RedisModule_Calloc(1,sizeof(*o));
  236. *copy = *o;
  237. if (newid) copy->id = NRNextId++;
  238. copy->nn = AnnClone(o->nn);
  239. copy->dataset = o->dataset;
  240. copy->test = o->test;
  241. int ilen = ANN_INPUT_UNITS(o->nn);
  242. int olen = ANN_OUTPUT_UNITS(o->nn);
  243. copy->dataset.inputs = (float*)RedisModule_Alloc(sizeof(float)*ilen*o->dataset.len);
  244. copy->dataset.outputs = (float*)RedisModule_Alloc(sizeof(float)*olen*o->dataset.len);
  245. copy->test.inputs = (float*)RedisModule_Alloc(sizeof(float)*ilen*o->test.len);
  246. copy->test.outputs = (float*)RedisModule_Alloc(sizeof(float)*olen*o->test.len);
  247. memcpy(copy->dataset.inputs,o->dataset.inputs,sizeof(float)*ilen*o->dataset.len);
  248. memcpy(copy->dataset.outputs,o->dataset.outputs,sizeof(float)*olen*o->dataset.len);
  249. memcpy(copy->test.inputs,o->test.inputs,sizeof(float)*ilen*o->test.len);
  250. memcpy(copy->test.outputs,o->test.outputs,sizeof(float)*olen*o->test.len);
  251. copy->inorm = (float*)RedisModule_Alloc(sizeof(float)*ilen);
  252. copy->onorm = (float*)RedisModule_Alloc(sizeof(float)*olen);
  253. memcpy(copy->inorm,o->inorm,sizeof(float)*ilen);
  254. memcpy(copy->onorm,o->onorm,sizeof(float)*olen);
  255. return copy;
  256. }
  257. /* Transfer the weights from the source to the destination NN.
  258. * This is used after the learning process finished in a different
  259. * thread in order to transfer the learning back to the orignal
  260. * NN. */
  261. static void NRTransferWeights(RedisModuleCtx *ctx, NRTypeObject *dst, NRTypeObject *src) {
  262. if (dst->id != src->id) {
  263. RedisModule_Log(ctx,"warning",
  264. "NSTransferWeight(): source and destination neural network IDs "
  265. "don't match. This is unexpected, probably a bug inside the "
  266. "module. Weights not transferred back to the origina NN.");
  267. return;
  268. }
  269. /* It would be faster to memcpy just the weight array for each layer,
  270. * however this way we access the NN in a more abstract way, and should
  271. * be fast enough in most cases. We can always optimized it later. */
  272. AnnFree(dst->nn);
  273. dst->nn = AnnClone(src->nn);
  274. dst->training_total_steps = src->training_total_steps;
  275. dst->training_total_ms = src->training_total_ms;
  276. dst->dataset_error = src->dataset_error;
  277. dst->test_error = src->test_error;
  278. dst->test_class_error = src->test_class_error;
  279. dst->flags |= src->flags & NR_FLAG_TO_TRANSFER;
  280. int ilen = ANN_INPUT_UNITS(src->nn);
  281. int olen = ANN_OUTPUT_UNITS(src->nn);
  282. memcpy(dst->inorm,src->inorm,sizeof(float)*ilen);
  283. memcpy(dst->onorm,src->onorm,sizeof(float)*olen);
  284. }
  285. /* Threaded training entry point.
  286. *
  287. * To get some clue about overfitting algorithm behavior:
  288. * #define NR_TRAINING_DEBUG 1
  289. */
  290. void *NRTrainingThreadMain(void *arg) {
  291. NRPendingTraining *pt = (NRPendingTraining*)arg;
  292. NRTypeObject *nr = pt->nr;
  293. int training_iterations = 1;
  294. float train_error = 0;
  295. float test_error = 0;
  296. float class_error = 0;
  297. float past_train_error = 1.0/0.0;
  298. float past_test_error = 1.0/0.0;
  299. int auto_stop = nr->flags & NR_FLAG_AUTO_STOP;
  300. int backtrack = nr->flags & NR_FLAG_BACKTRACK;
  301. uint64_t cycles = 0;
  302. long long start = NRMilliseconds();
  303. long long cycle_time;
  304. int overfitting_count = 0;
  305. int overfitting_limit = 5;
  306. float best_test_error = 1.0/0.0;
  307. nr->flags &= ~NR_FLAG_TO_TRANSFER;
  308. /* If the network is auto normalized, we need to trasnform the inputs
  309. * in a way that's acceptable for the NN. We just find the maximum
  310. * absolute value, and divide for it, to get a -1,1 range. There
  311. * are more advanced transformations that are usually performed that
  312. * could be implemented in the future.
  313. *
  314. * Note that we compute the normalization vectors for all the inputs
  315. * and outputs, however if the network is a classifier, flagged with
  316. * (NR_FLAG_CLASSIFIER), no output normalization will be done since
  317. * the data is already in 0/1 format. */
  318. if ((nr->flags & NR_FLAG_NORMALIZE) && nr->dataset.len) {
  319. int ilen = ANN_INPUT_UNITS(nr->nn);
  320. int olen = ANN_OUTPUT_UNITS(nr->nn);
  321. float *imax = nr->inorm;
  322. float *omax = nr->onorm;
  323. float *inputs = nr->dataset.inputs;
  324. float *outputs = nr->dataset.outputs;
  325. for (int i = 0; i < ilen; i++) imax[i] = 1;
  326. for (int i = 0; i < olen; i++) omax[i] = 1;
  327. /* Compute the max values vectors. */
  328. for (uint32_t j = 0; j < nr->dataset.len; j++) {
  329. for (int i = 0; i < ilen; i++)
  330. if (fabs(inputs[i]) > imax[i]) imax[i] = fabs(inputs[i]);
  331. for (int i = 0; i < olen; i++)
  332. if (fabs(outputs[i]) > omax[i]) omax[i] = fabs(outputs[i]);
  333. inputs += ilen;
  334. outputs += olen;
  335. }
  336. /* Likely we are not seeing what will really be the true input/output
  337. * maximum value, so we multiply the maximum values found by a constant.
  338. * However if the max is exactly "1" we assume it's a classification
  339. * input and don't alter it. */
  340. for (int i = 0; i < ilen; i++) if (imax[i] != 1) imax[i] *= 1.2;
  341. for (int i = 0; i < olen; i++) if (omax[i] != 1) omax[i] *= 1.2;
  342. /* We can normalize the dataset directly: after the training it will
  343. * be discarded anyway. */
  344. inputs = nr->dataset.inputs;
  345. outputs = nr->dataset.outputs;
  346. for (uint32_t j = 0; j < nr->dataset.len; j++) {
  347. for (int i = 0; i < ilen; i++) inputs[i] /= nr->inorm[i];
  348. if (!(nr->flags & NR_FLAG_CLASSIFIER))
  349. for (int i = 0; i < olen; i++) outputs[i] /= nr->onorm[i];
  350. inputs += ilen;
  351. outputs += olen;
  352. }
  353. inputs = nr->test.inputs;
  354. outputs = nr->test.outputs;
  355. for (uint32_t j = 0; j < nr->test.len; j++) {
  356. for (int i = 0; i < ilen; i++) inputs[i] /= nr->inorm[i];
  357. if (!(nr->flags & NR_FLAG_CLASSIFIER))
  358. for (int i = 0; i < olen; i++) outputs[i] /= nr->onorm[i];
  359. inputs += ilen;
  360. outputs += olen;
  361. }
  362. }
  363. AnnRprop *saved = NULL; /* Saved to recover on overfitting. */
  364. float saved_error; /* The test error of the saved NN. */
  365. float saved_train_error; /* The training dataset error of the saved NN */
  366. float saved_class_error; /* The % of classification errors of saved NN */
  367. while(1) {
  368. long long cycle_start = NRMilliseconds();
  369. train_error = AnnTrain(nr->nn,
  370. nr->dataset.inputs,
  371. nr->dataset.outputs,
  372. 0,
  373. training_iterations,
  374. nr->dataset.len,
  375. ANN_ALGO_BPROP);
  376. cycle_time = NRMilliseconds() - cycle_start;
  377. nr->training_total_steps += nr->dataset.len*training_iterations;
  378. /* Evaluate the error in the case of auto training, stop it
  379. * once we see that the error in the traning set is decreasing
  380. * while the one in the test set is not. */
  381. if (auto_stop) {
  382. AnnTestError(nr->nn,
  383. nr->test.inputs,
  384. nr->test.outputs,
  385. nr->test.len, &test_error, &class_error);
  386. if (train_error < past_train_error &&
  387. test_error > past_test_error)
  388. {
  389. overfitting_count++;
  390. #ifdef NR_TRAINING_DEBUG
  391. printf("+YCLE %lld: [%d] %f VS %f\n", (long long)cycles,
  392. overfitting_count, train_error, test_error);
  393. #endif
  394. if (overfitting_count == overfitting_limit) {
  395. nr->flags |= NR_FLAG_OF_DETECTED;
  396. break;
  397. }
  398. } else if (overfitting_count > 0) {
  399. #ifdef NR_TRAINING_DEBUG
  400. printf("-YCLE %lld: [%d] %f VS %f\n", (long long)cycles,
  401. overfitting_count, train_error, test_error);
  402. #endif
  403. overfitting_count--;
  404. }
  405. /* Save all the networks with a score better than the currently
  406. * saved network. This can be a bit costly, but is safe: one
  407. * cycle of training more and overfitting can ruin it all. */
  408. if (backtrack && (saved == NULL || test_error < saved_error)) {
  409. #ifdef NR_TRAINING_DEBUG
  410. printf("SAVED! %f < %f\n", test_error, saved_error);
  411. #endif
  412. saved_error = test_error;
  413. saved_train_error = train_error;
  414. saved_class_error = class_error;
  415. if (saved) AnnFree(saved);
  416. saved = AnnClone(nr->nn);
  417. }
  418. /* Best network found? Reset the overfitting hints counter. */
  419. if (test_error < best_test_error) {
  420. overfitting_count = 0;
  421. best_test_error = test_error;
  422. #ifdef NR_TRAINING_DEBUG
  423. printf("BEST! %lld: <%d> %f VS %f\n", (long long)cycles,
  424. overfitting_limit,train_error, test_error);
  425. #endif
  426. }
  427. /* Also stop if the loss is zero in both datasets. */
  428. if (train_error < 0.000000000000001 &&
  429. test_error < 0.000000000000001) break;
  430. }
  431. cycles++;
  432. long long total_time = NRMilliseconds()-start;
  433. /* Cycles and milliseconds stop conditions. */
  434. if (nr->training_max_cycles && cycles == nr->training_max_cycles)
  435. break;
  436. if (nr->training_max_ms && total_time > (long long)nr->training_max_ms)
  437. break;
  438. /* If this is a long training, to do just a single training iteration
  439. * for each cycle is not optimal: tune the number of iterations to
  440. * at least take 100 milliseconds. */
  441. if (total_time > 10000 && cycle_time < 100) training_iterations++;
  442. past_train_error = train_error;
  443. past_test_error = test_error;
  444. /* Update stats for NR.THREADS to show progresses. */
  445. pthread_mutex_lock(&NRPendingTrainingMutex);
  446. pt->dataset_error = train_error;
  447. pt->test_error = test_error;
  448. if (nr->flags & NR_FLAG_CLASSIFIER) pt->class_error = class_error;
  449. pt->curcycle = cycles;
  450. pthread_mutex_unlock(&NRPendingTrainingMutex);
  451. }
  452. /* If auto stop is disabled, we still need to compute the test error
  453. * in order to return this information to the main thread. */
  454. if (!auto_stop) {
  455. AnnTestError(nr->nn,
  456. nr->test.inputs,
  457. nr->test.outputs,
  458. nr->test.len, &test_error, &class_error);
  459. }
  460. /* If both autostop and backtracking are enabled, we may have
  461. * a better network saved! */
  462. if (auto_stop && backtrack) {
  463. if (saved && saved_error < test_error) {
  464. #ifdef NR_TRAINING_DEBUG
  465. printf("BACKTRACK: Saved network used!\n");
  466. #endif
  467. AnnFree(nr->nn);
  468. nr->nn = saved;
  469. test_error = saved_error;
  470. train_error = saved_train_error;
  471. class_error = saved_class_error;
  472. } else if (saved) {
  473. AnnFree(saved);
  474. }
  475. }
  476. if (nr->flags & NR_FLAG_CLASSIFIER) nr->test_class_error = class_error;
  477. nr->dataset_error = train_error;
  478. nr->test_error = test_error;
  479. nr->training_total_ms += NRMilliseconds()-start;
  480. /* Signal that the training process has finished, it's up to the main
  481. * thread to cleanup this training slot, copying the weights to the
  482. * original neural network and reclaiming memory for the copy we
  483. * used to work. */
  484. pthread_mutex_lock(&NRPendingTrainingMutex);
  485. pt->in_progress = 0;
  486. pthread_mutex_unlock(&NRPendingTrainingMutex);
  487. return NULL;
  488. }
  489. /* Start a background training in another thread. Return REDISMODULE_ERR if
  490. * there is no free slot for training, as we already reached the maximum of
  491. * networks we can train in parallel.
  492. *
  493. * The 'flags' argument specifies the additional NN flags to pass to the
  494. * training ruotine:
  495. *
  496. * NR_FLAG_AUTO_STOP -- Automatically stop training on overtraining.
  497. * NR_FLAG_BACKTRACK -- Save current NN state when overfitting is likely.
  498. */
  499. int NRStartTraining(RedisModuleCtx *ctx, RedisModuleString *key, int dbid, NRTypeObject *nr) {
  500. pthread_mutex_lock(&NRPendingTrainingMutex);
  501. if (NRPendingTrainingCount == NR_PENDING_TRAINING_MAX_LEN) {
  502. pthread_mutex_unlock(&NRPendingTrainingMutex);
  503. return REDISMODULE_ERR;
  504. }
  505. /* Setup our trainig data. */
  506. NRPendingTraining *pt = &NRTrainings[NRPendingTrainingCount];
  507. //pt->key = RedisModule_CreateStringFromString(ctx,key);
  508. //RedisModule_RetainString(ctx,pt->key);
  509. pt->db_id = dbid;
  510. pt->in_progress = 1;
  511. pt->nr = NRClone(nr,0);
  512. pt->dataset_error = 0;
  513. pt->test_error = 0;
  514. pt->class_error = 0;
  515. pt->curcycle = 0;
  516. if (pthread_create(&pt->tid,NULL,NRTrainingThreadMain,pt) != 0) {
  517. RedisModule_Log(ctx,"warning","Unable to create a new pthread in NRStartTraining()");
  518. //RedisModule_FreeString(ctx,pt->key);
  519. pt->key = NULL;
  520. NRTypeReleaseObject(pt->nr);
  521. pthread_mutex_unlock(&NRPendingTrainingMutex);
  522. return REDISMODULE_ERR;
  523. }
  524. NRPendingTrainingCount++;
  525. nr->flags |= NR_FLAG_TRAINING;
  526. nr->flags &= ~NR_FLAG_TO_TRANSFER;
  527. pthread_mutex_unlock(&NRPendingTrainingMutex);
  528. return REDISMODULE_OK;
  529. }
  530. /* Check if there are threads that terminated the NN training, and
  531. * collect the info they computed (that is the new NN). */
  532. int NRCollectThreads(RedisModuleCtx *ctx) {
  533. int collected = 0;
  534. pthread_mutex_lock(&NRPendingTrainingMutex);
  535. for (int j = 0; j < NRPendingTrainingCount; j++) {
  536. NRPendingTraining *pt = &NRTrainings[j];
  537. if (pt->in_progress == 0) {
  538. /* Training terminated. Let's see if the key
  539. * is still there and NN ID matches. */
  540. int orig_id = RedisModule_GetSelectedDb(ctx);
  541. if (orig_id != pt->db_id) RedisModule_SelectDb(ctx,pt->db_id);
  542. RedisModuleKey *key = RedisModule_OpenKey(ctx,pt->key,
  543. REDISMODULE_READ|REDISMODULE_WRITE);
  544. if (RedisModule_ModuleTypeGetType(key) == NRType) {
  545. NRTypeObject *nr = RedisModule_ModuleTypeGetValue(key);
  546. if (nr->id == pt->nr->id) {
  547. NRTransferWeights(ctx,nr,pt->nr);
  548. nr->flags &= ~NR_FLAG_TRAINING;
  549. }
  550. RedisModule_FreeString(ctx,pt->key);
  551. pt->key = NULL;
  552. NRTypeReleaseObject(pt->nr);
  553. NRPendingTrainingCount--;
  554. memcpy(&NRTrainings[j],&NRTrainings[j+1],
  555. (NRPendingTrainingCount-j)*sizeof(NRTrainings[0]));
  556. }
  557. if (orig_id != pt->db_id) RedisModule_SelectDb(ctx,orig_id);
  558. collected++;
  559. }
  560. }
  561. pthread_mutex_unlock(&NRPendingTrainingMutex);
  562. return collected;
  563. }
  564. #endif // 0
  565. #define RedisModule_Free(x) sq_free(x, 0)
  566. static void *RedisModule_Calloc(size_t nelm, size_t sz)
  567. {
  568. size_t malloc_size = nelm * sz;
  569. void *ptr = sq_malloc(malloc_size);
  570. if(ptr) memset(ptr, 0, malloc_size);
  571. return ptr;
  572. }
  573. static void *RedisModule_Realloc(void *oldPtr, size_t sz)
  574. {
  575. void *ptr = sq_realloc(oldPtr, 0, sz);
  576. return ptr;
  577. }
  578. static uint64_t NRNextId = 1; /* Next neural network unique ID. */
  579. long long NRMilliseconds(void) {
  580. struct timeval tv;
  581. long long ust;
  582. gettimeofday(&tv, NULL);
  583. ust = ((long long)tv.tv_sec)*1000000;
  584. ust += tv.tv_usec;
  585. return ust/1000;
  586. }
  587. /* Create a network with the specified parameters. Note that the layers
  588. * must be specified from the output layer[0] to the input
  589. * layer[N]. Each element in the integer array 'layer' specify how many
  590. * units there are in the corresponding layer. */
  591. static NRTypeObject *createNRTypeObject(int flags, int *layers, int numlayers, int dset_len, int test_len) {
  592. NRTypeObject *o;
  593. o = (NRTypeObject*)RedisModule_Calloc(1,sizeof(*o));
  594. o->id = NRNextId++;
  595. o->flags = flags;
  596. o->nn = AnnCreateNet(numlayers,layers);
  597. o->dataset.maxlen = dset_len;
  598. o->test.maxlen = test_len;
  599. int ilen = ANN_INPUT_UNITS(o->nn);
  600. int olen = ANN_OUTPUT_UNITS(o->nn);
  601. o->inorm = (float*)RedisModule_Calloc(1,sizeof(float)*ilen);
  602. o->onorm = (float*)RedisModule_Calloc(1,sizeof(float)*olen);
  603. for (int j = 0; j < ilen; j++) o->inorm[j] = 1;
  604. for (int j = 0; j < olen; j++) o->onorm[j] = 1;
  605. return o;
  606. }
  607. /* Insert data (observations needed to train and test the NN) into the
  608. * NN object. While the learning and testing datasets are yet not full
  609. * the observed pattern is inserted evenly in one or the other side in
  610. * order to make sure the two datasets are populated evenly. When both
  611. * are already full, a random elmenet from one or the other (doing
  612. * a random weighted choice depending on the length) is substituted with
  613. * the new item. */
  614. #define NR_INSERT_NO_TARGET 0 /* Auto select where to insert. */
  615. #define NR_INSERT_TRAIN 1 /* Insert in training dataset. */
  616. #define NR_INSERT_TEST 2 /* Insert in testing dataset. */
  617. static void NRTypeInsertData(NRTypeObject *o, float *inputs, float *outputs,
  618. int target_ds) {
  619. NRDataset *target = NULL;
  620. /* Check if there is no dataset at all. This may be a valid setup
  621. * with online learning, sample by sample. */
  622. if (o->dataset.maxlen == 0 && o->test.maxlen == 0) return;
  623. /* If the user specified a target, select it. */
  624. if (target_ds == NR_INSERT_TRAIN) target = &o->dataset;
  625. else if (target_ds == NR_INSERT_TEST) target = &o->test;
  626. /* If no target is specified, but there is only one possible
  627. * target, select it ASAP. */
  628. if (o->dataset.maxlen == 0) {
  629. target = &o->test;
  630. } else if (o->test.maxlen == 0) {
  631. target = &o->dataset;
  632. }
  633. /* Otherwise choose as the target to populate the one with less data
  634. * relatively to its size. */
  635. if (target == NULL) {
  636. /* If one of the two datasets are still not full, pick
  637. * based on fill percentage. Otherwise pick a random
  638. * target relatively to their size. */
  639. if (o->dataset.len != o->dataset.maxlen ||
  640. o->test.len != o->dataset.len)
  641. {
  642. float fill_a = (float)o->dataset.len / o->dataset.maxlen;
  643. float fill_b = (float)o->test.len / o->test.maxlen;
  644. target = (fill_a <= fill_b) ? &o->dataset : &o->test;
  645. } else {
  646. double r = rand()/RAND_MAX;
  647. double sumlen = o->dataset.maxlen + o->test.maxlen;
  648. if (r < (double)o->dataset.maxlen/sumlen) {
  649. target = &o->dataset;
  650. } else {
  651. target = &o->test;
  652. }
  653. }
  654. }
  655. /* Append if there is room or substitute with a random entry. */
  656. size_t idx;
  657. int j, numin = ANN_INPUT_UNITS(o->nn),
  658. numout = ANN_OUTPUT_UNITS(o->nn);
  659. if (target->maxlen == target->len) {
  660. idx = rand() % target->maxlen;
  661. } else {
  662. idx = target->len;
  663. target->len++;
  664. target->inputs = (float*)RedisModule_Realloc(target->inputs,
  665. sizeof(float)*numin*target->len);
  666. target->outputs = (float*)RedisModule_Realloc(target->outputs,
  667. sizeof(float)*numout*target->len);
  668. }
  669. /* Finally store the values at position. */
  670. for (j = 0; j < numin; j++)
  671. target->inputs[idx*numin+j] = inputs[j];
  672. for (j = 0; j < numout; j++)
  673. target->outputs[idx*numout+j] = outputs[j];
  674. }
  675. /* Free the specified dataset. */
  676. void NRDatasetFree(NRDataset *dset) {
  677. RedisModule_Free(dset->inputs);
  678. RedisModule_Free(dset->outputs);
  679. }
  680. /* Free a whole NN object. */
  681. void NRTypeReleaseObject(NRTypeObject *o) {
  682. AnnFree(o->nn);
  683. NRDatasetFree(&o->dataset);
  684. NRDatasetFree(&o->test);
  685. RedisModule_Free(o->inorm);
  686. RedisModule_Free(o->onorm);
  687. RedisModule_Free(o);
  688. }
  689. static const SQChar sq_nn_TAG[] = _SC("AnnRprop");
  690. static SQRESULT sq_nn_release_hook(SQUserPointer p, SQInteger size, void */*ep*/) {
  691. NRTypeObject *self = (NRTypeObject *)p;
  692. if(self) NRTypeReleaseObject(self);
  693. return 0;
  694. }
  695. /*
  696. ** Creates a new AnnRprop.
  697. */
  698. static SQRESULT sq_nn_constructor (HSQUIRRELVM v) {
  699. SQ_FUNC_VARS(v);
  700. SQ_GET_INTEGER(v, 2, flags);
  701. SQ_GET_INTEGER(v, 3, ninputs);
  702. const SQInteger nhidden_pos = 4;
  703. SQ_GET_INTEGER(v, 5, noutputs);
  704. SQ_OPT_INTEGER(v, 6, ndata, 0);
  705. SQ_OPT_INTEGER(v, 7, ntest, 0);
  706. if(!(
  707. ((flags & NR_FLAG_CLASSIFIER) && !(flags & NR_FLAG_REGRESSOR))
  708. || (!(flags & NR_FLAG_CLASSIFIER) && (flags & NR_FLAG_REGRESSOR))
  709. )
  710. )
  711. return sq_throwerror(v, _SC("invalid neural network type. Must be "
  712. "CLASSIFIER or REGRESSOR"));
  713. int layers[NR_MAX_LAYERS], num_layers=0;
  714. layers[num_layers++] = noutputs;
  715. /* Our NN library takes the definition of layers in the opposite
  716. * order, swap the layers array. */
  717. SQInteger asize = sq_getsize(v, nhidden_pos);
  718. for(int i=asize-1; i >= 0; --i)
  719. {
  720. sq_pushinteger(v, i);
  721. sq_get(v, nhidden_pos);
  722. SQInteger nhidden;
  723. SQRESULT rc = sq_getinteger(v, -1, &nhidden);
  724. if(rc != SQ_OK) return sq_throwerror(v, _SC("only integers expected on hidden layers array"));
  725. layers[num_layers++] = nhidden;
  726. sq_poptop(v);
  727. }
  728. layers[num_layers++] = ninputs;
  729. //for(int i=0; i < num_layers; ++i) printf("layers %d : %d\n", i, layers[i]);
  730. NRTypeObject *self = createNRTypeObject(flags, layers, num_layers, ndata, ntest);
  731. if(self){
  732. self->flags = flags;
  733. sq_setinstanceup(v, 1, self);
  734. sq_setreleasehook(v, 1, sq_nn_release_hook);
  735. return 1;
  736. }
  737. delete self;
  738. return sq_throwerror(v, _SC("failed to create AnnRprop"));
  739. }
  740. #define SQ_GET_NN_INSTANCE(v, at) SQ_GET_INSTANCE_VAR(v, at, NRTypeObject, self, sq_nn_TAG)
  741. static SQRESULT sq_nn_observe(HSQUIRRELVM v)
  742. {
  743. SQ_FUNC_VARS(v);
  744. SQ_GET_NN_INSTANCE(v, 1);
  745. SQ_OPT_INTEGER(v, 4, target, NR_INSERT_NO_TARGET);
  746. SQInteger ilen = ANN_INPUT_UNITS(self->nn);
  747. SQInteger olen = ANN_OUTPUT_UNITS(self->nn);
  748. SQInteger oargs = (self->flags & NR_FLAG_CLASSIFIER) ? 1 : olen;
  749. const SQInteger inputs_pos = 2;
  750. const SQInteger outputs_pos = 3;
  751. SQInteger asize_inputs = sq_getsize(v, inputs_pos);
  752. SQInteger asize_outputs = sq_getsize(v, outputs_pos);
  753. if((ilen != asize_inputs) || (oargs != asize_outputs))
  754. return sq_throwerror(v, _SC( "number of arguments does not "
  755. "match the number of " _PRINT_INT_FMT " inputs and " _PRINT_INT_FMT " outputs in the neural network"),
  756. ilen, oargs);
  757. const SQInteger inputs_alloc_size = sizeof(float)*ilen;
  758. const SQInteger outputs_alloc_size = sizeof(float)*olen;
  759. float *inputs = (float*)sq_malloc(inputs_alloc_size);
  760. for(SQInteger i=0; i < ilen; ++i)
  761. {
  762. sq_pushinteger(v, i);
  763. sq_get(v, inputs_pos);
  764. SQFloat fnum;
  765. SQRESULT rc = sq_getfloat(v, -1, &fnum);
  766. if(rc != SQ_OK)
  767. {
  768. sq_free(inputs, inputs_alloc_size);
  769. return sq_throwerror(v, _SC("only numbers expected on input array"));
  770. }
  771. inputs[i] = fnum;
  772. sq_poptop(v);
  773. }
  774. float *outputs = (float*)sq_malloc(outputs_alloc_size);
  775. for(SQInteger i=0; i < oargs; ++i)
  776. {
  777. sq_pushinteger(v, i);
  778. sq_get(v, outputs_pos);
  779. SQFloat fnum;
  780. SQRESULT rc = sq_getfloat(v, -1, &fnum);
  781. if(rc != SQ_OK)
  782. {
  783. sq_free(inputs, inputs_alloc_size);
  784. sq_free(outputs, outputs_alloc_size);
  785. return sq_throwerror(v, _SC("only numbers expected on output array"));
  786. }
  787. if (self->flags & NR_FLAG_CLASSIFIER) {
  788. int classid = fnum;
  789. if (classid != fnum || fnum >= olen || fnum < 0) {
  790. sq_free(inputs, inputs_alloc_size);
  791. sq_free(outputs, outputs_alloc_size);
  792. return sq_throwerror(v, _SC("classifier network output must be an integer "
  793. "in the range from 0 to outputs-1."));
  794. }
  795. memset(outputs,0, outputs_alloc_size);
  796. outputs[classid] = 1;
  797. } else {
  798. outputs[i] = fnum;
  799. }
  800. sq_poptop(v);
  801. }
  802. NRTypeInsertData(self,inputs,outputs,target);
  803. sq_free(inputs, inputs_alloc_size);
  804. sq_free(outputs, outputs_alloc_size);
  805. return 0;
  806. }
  807. static SQRESULT sq_nn_train(HSQUIRRELVM v)
  808. {
  809. SQ_FUNC_VARS(v);
  810. SQ_GET_NN_INSTANCE(v, 1);
  811. SQ_OPT_INTEGER(v, 2, opt_max_cycles, 0);
  812. SQ_OPT_INTEGER(v, 3, opt_max_ms, 10000);
  813. SQ_OPT_INTEGER(v, 4, opt_flags, 0);
  814. NRTypeObject *nr = self;
  815. nr->training_max_cycles = opt_max_cycles;
  816. nr->training_max_ms = opt_max_ms;
  817. if(opt_flags & NR_FLAG_AUTO_STOP) nr->flags |= NR_FLAG_AUTO_STOP;
  818. if(opt_flags & NR_FLAG_BACKTRACK) nr->flags |= NR_FLAG_BACKTRACK;
  819. /* Overfitting detection compares error rate in testing/training data,
  820. * so does not work without entries in the testing dataset. */
  821. if (nr->flags & NR_FLAG_AUTO_STOP && nr->test.len == 0) {
  822. return sq_throwerror(v, _SC("Can't start training with AUTOSTOP option: "
  823. "overfitting detection requires a non zero length testing dataset"));
  824. }
  825. int training_iterations = 1;
  826. float train_error = 0;
  827. float test_error = 0;
  828. float class_error = 0;
  829. float past_train_error = 1.0/0.0;
  830. float past_test_error = 1.0/0.0;
  831. int auto_stop = nr->flags & NR_FLAG_AUTO_STOP;
  832. int backtrack = nr->flags & NR_FLAG_BACKTRACK;
  833. uint64_t cycles = 0;
  834. long long start = NRMilliseconds();
  835. long long cycle_time;
  836. int overfitting_count = 0;
  837. int overfitting_limit = 5;
  838. float best_test_error = 1.0/0.0;
  839. nr->flags &= ~NR_FLAG_TO_TRANSFER;
  840. /* If the network is auto normalized, we need to trasnform the inputs
  841. * in a way that's acceptable for the NN. We just find the maximum
  842. * absolute value, and divide for it, to get a -1,1 range. There
  843. * are more advanced transformations that are usually performed that
  844. * could be implemented in the future.
  845. *
  846. * Note that we compute the normalization vectors for all the inputs
  847. * and outputs, however if the network is a classifier, flagged with
  848. * (NR_FLAG_CLASSIFIER), no output normalization will be done since
  849. * the data is already in 0/1 format. */
  850. if ((nr->flags & NR_FLAG_NORMALIZE) && nr->dataset.len) {
  851. int ilen = ANN_INPUT_UNITS(nr->nn);
  852. int olen = ANN_OUTPUT_UNITS(nr->nn);
  853. float *imax = nr->inorm;
  854. float *omax = nr->onorm;
  855. float *inputs = nr->dataset.inputs;
  856. float *outputs = nr->dataset.outputs;
  857. for (int i = 0; i < ilen; i++) imax[i] = 1;
  858. for (int i = 0; i < olen; i++) omax[i] = 1;
  859. /* Compute the max values vectors. */
  860. for (uint32_t j = 0; j < nr->dataset.len; j++) {
  861. for (int i = 0; i < ilen; i++)
  862. if (fabs(inputs[i]) > imax[i]) imax[i] = fabs(inputs[i]);
  863. for (int i = 0; i < olen; i++)
  864. if (fabs(outputs[i]) > omax[i]) omax[i] = fabs(outputs[i]);
  865. inputs += ilen;
  866. outputs += olen;
  867. }
  868. /* Likely we are not seeing what will really be the true input/output
  869. * maximum value, so we multiply the maximum values found by a constant.
  870. * However if the max is exactly "1" we assume it's a classification
  871. * input and don't alter it. */
  872. for (int i = 0; i < ilen; i++) if (imax[i] != 1) imax[i] *= 1.2;
  873. for (int i = 0; i < olen; i++) if (omax[i] != 1) omax[i] *= 1.2;
  874. /* We can normalize the dataset directly: after the training it will
  875. * be discarded anyway. */
  876. inputs = nr->dataset.inputs;
  877. outputs = nr->dataset.outputs;
  878. for (uint32_t j = 0; j < nr->dataset.len; j++) {
  879. for (int i = 0; i < ilen; i++) inputs[i] /= nr->inorm[i];
  880. if (!(nr->flags & NR_FLAG_CLASSIFIER))
  881. for (int i = 0; i < olen; i++) outputs[i] /= nr->onorm[i];
  882. inputs += ilen;
  883. outputs += olen;
  884. }
  885. inputs = nr->test.inputs;
  886. outputs = nr->test.outputs;
  887. for (uint32_t j = 0; j < nr->test.len; j++) {
  888. for (int i = 0; i < ilen; i++) inputs[i] /= nr->inorm[i];
  889. if (!(nr->flags & NR_FLAG_CLASSIFIER))
  890. for (int i = 0; i < olen; i++) outputs[i] /= nr->onorm[i];
  891. inputs += ilen;
  892. outputs += olen;
  893. }
  894. }
  895. AnnRprop *saved = NULL; /* Saved to recover on overfitting. */
  896. float saved_error; /* The test error of the saved NN. */
  897. float saved_train_error; /* The training dataset error of the saved NN */
  898. float saved_class_error; /* The % of classification errors of saved NN */
  899. while(1) {
  900. long long cycle_start = NRMilliseconds();
  901. train_error = AnnTrain(nr->nn,
  902. nr->dataset.inputs,
  903. nr->dataset.outputs,
  904. 0,
  905. training_iterations,
  906. nr->dataset.len,
  907. ANN_ALGO_BPROP);
  908. cycle_time = NRMilliseconds() - cycle_start;
  909. nr->training_total_steps += nr->dataset.len*training_iterations;
  910. /* Evaluate the error in the case of auto training, stop it
  911. * once we see that the error in the traning set is decreasing
  912. * while the one in the test set is not. */
  913. if (auto_stop) {
  914. AnnTestError(nr->nn,
  915. nr->test.inputs,
  916. nr->test.outputs,
  917. nr->test.len, &test_error, &class_error);
  918. if (train_error < past_train_error &&
  919. test_error > past_test_error)
  920. {
  921. overfitting_count++;
  922. #ifdef NR_TRAINING_DEBUG
  923. printf("+YCLE %lld: [%d] %f VS %f\n", (long long)cycles,
  924. overfitting_count, train_error, test_error);
  925. #endif
  926. if (overfitting_count == overfitting_limit) {
  927. nr->flags |= NR_FLAG_OF_DETECTED;
  928. break;
  929. }
  930. } else if (overfitting_count > 0) {
  931. #ifdef NR_TRAINING_DEBUG
  932. printf("-YCLE %lld: [%d] %f VS %f\n", (long long)cycles,
  933. overfitting_count, train_error, test_error);
  934. #endif
  935. overfitting_count--;
  936. }
  937. /* Save all the networks with a score better than the currently
  938. * saved network. This can be a bit costly, but is safe: one
  939. * cycle of training more and overfitting can ruin it all. */
  940. if (backtrack && (saved == NULL || test_error < saved_error)) {
  941. #ifdef NR_TRAINING_DEBUG
  942. printf("SAVED! %f < %f\n", test_error, saved_error);
  943. #endif
  944. saved_error = test_error;
  945. saved_train_error = train_error;
  946. saved_class_error = class_error;
  947. if (saved) AnnFree(saved);
  948. saved = AnnClone(nr->nn);
  949. }
  950. /* Best network found? Reset the overfitting hints counter. */
  951. if (test_error < best_test_error) {
  952. overfitting_count = 0;
  953. best_test_error = test_error;
  954. #ifdef NR_TRAINING_DEBUG
  955. printf("BEST! %lld: <%d> %f VS %f\n", (long long)cycles,
  956. overfitting_limit,train_error, test_error);
  957. #endif
  958. }
  959. /* Also stop if the loss is zero in both datasets. */
  960. if (train_error < 0.000000000000001 &&
  961. test_error < 0.000000000000001) break;
  962. }
  963. cycles++;
  964. long long total_time = NRMilliseconds()-start;
  965. /* Cycles and milliseconds stop conditions. */
  966. if (nr->training_max_cycles && cycles == nr->training_max_cycles)
  967. break;
  968. if (nr->training_max_ms && total_time > (long long)nr->training_max_ms)
  969. break;
  970. /* If this is a long training, to do just a single training iteration
  971. * for each cycle is not optimal: tune the number of iterations to
  972. * at least take 100 milliseconds. */
  973. if (total_time > 10000 && cycle_time < 100) training_iterations++;
  974. past_train_error = train_error;
  975. past_test_error = test_error;
  976. }
  977. /* If auto stop is disabled, we still need to compute the test error
  978. * in order to return this information to the main thread. */
  979. if (!auto_stop) {
  980. AnnTestError(nr->nn,
  981. nr->test.inputs,
  982. nr->test.outputs,
  983. nr->test.len, &test_error, &class_error);
  984. }
  985. /* If both autostop and backtracking are enabled, we may have
  986. * a better network saved! */
  987. if (auto_stop && backtrack) {
  988. if (saved && saved_error < test_error) {
  989. #ifdef NR_TRAINING_DEBUG
  990. printf("BACKTRACK: Saved network used!\n");
  991. #endif
  992. AnnFree(nr->nn);
  993. nr->nn = saved;
  994. test_error = saved_error;
  995. train_error = saved_train_error;
  996. class_error = saved_class_error;
  997. } else if (saved) {
  998. AnnFree(saved);
  999. }
  1000. }
  1001. if (nr->flags & NR_FLAG_CLASSIFIER) nr->test_class_error = class_error;
  1002. nr->dataset_error = train_error;
  1003. nr->test_error = test_error;
  1004. nr->training_total_ms += NRMilliseconds()-start;
  1005. return 0;
  1006. }
  1007. static SQRESULT sq_nn_run(HSQUIRRELVM v)
  1008. {
  1009. SQ_FUNC_VARS_NO_TOP(v);
  1010. SQ_GET_NN_INSTANCE(v, 1);
  1011. SQInteger asize_inputs = sq_getsize(v, 2);
  1012. SQInteger ilen = ANN_INPUT_UNITS(self->nn);
  1013. if(ilen != asize_inputs)
  1014. return sq_throwerror(v, _SC("wrong number of inputs " _PRINT_INT_FMT " for expected " _PRINT_INT_FMT), asize_inputs, ilen);
  1015. for(SQInteger i=0; i < ilen; ++i)
  1016. {
  1017. sq_pushinteger(v, i);
  1018. sq_get(v, 2);
  1019. SQFloat fnum;
  1020. SQRESULT rc = sq_getfloat(v, -1, &fnum);
  1021. if(rc != SQ_OK)
  1022. {
  1023. return sq_throwerror(v, _SC("only numbers expected on input array"));
  1024. }
  1025. if (self->flags & NR_FLAG_NORMALIZE) fnum /= self->inorm[i];
  1026. ANN_INPUT_NODE(self->nn,i) = fnum;
  1027. sq_poptop(v);
  1028. }
  1029. AnnSimulate(self->nn);
  1030. /* Output the raw net output or the class ID if the network
  1031. * is a classifier and the command invoked was NR.CLASS. */
  1032. int olen = ANN_OUTPUT_UNITS(self->nn);
  1033. sq_newarray(v, olen);
  1034. for(int j = 0; j < olen; j++) {
  1035. float output = ANN_OUTPUT_NODE(self->nn,j);
  1036. if (!(self->flags & NR_FLAG_CLASSIFIER) &&
  1037. (self->flags & NR_FLAG_NORMALIZE))
  1038. {
  1039. output *= self->onorm[j];
  1040. }
  1041. sq_pushfloat(v, output);
  1042. sq_arrayset(v, -2, j);
  1043. }
  1044. return 1;
  1045. }
  1046. static SQRESULT sq_nn_classify(HSQUIRRELVM v)
  1047. {
  1048. SQ_FUNC_VARS_NO_TOP(v);
  1049. SQ_GET_NN_INSTANCE(v, 1);
  1050. if (!(self->flags & NR_FLAG_CLASSIFIER))
  1051. return sq_throwerror(v, _SC("you can't call classify with a regressor network."));
  1052. SQInteger asize_inputs = sq_getsize(v, 2);
  1053. SQInteger ilen = ANN_INPUT_UNITS(self->nn);
  1054. if(ilen != asize_inputs)
  1055. return sq_throwerror(v, _SC("wrong number of inputs %d for expected %d"), (int)asize_inputs, (int)ilen);
  1056. for(SQInteger i=0; i < ilen; ++i)
  1057. {
  1058. sq_pushinteger(v, i);
  1059. sq_get(v, 2);
  1060. SQFloat fnum;
  1061. SQRESULT rc = sq_getfloat(v, -1, &fnum);
  1062. if(rc != SQ_OK)
  1063. {
  1064. return sq_throwerror(v, _SC("only numbers expected on input array"));
  1065. }
  1066. if (self->flags & NR_FLAG_NORMALIZE) fnum /= self->inorm[i];
  1067. ANN_INPUT_NODE(self->nn,i) = fnum;
  1068. sq_poptop(v);
  1069. }
  1070. AnnSimulate(self->nn);
  1071. /* Output the raw net output or the class ID if the network
  1072. * is a classifier and the command invoked was NR.CLASS. */
  1073. int olen = ANN_OUTPUT_UNITS(self->nn);
  1074. float fmax = ANN_OUTPUT_NODE(self->nn,0);
  1075. int max_class = 0;
  1076. for(int j = 1; j < olen; j++) {
  1077. float output = ANN_OUTPUT_NODE(self->nn,j);
  1078. if (output > fmax) {
  1079. fmax = output;
  1080. max_class = j;
  1081. }
  1082. }
  1083. sq_pushinteger(v, max_class);
  1084. return 1;
  1085. }
  1086. #define ADD_T_TABLE_STR(sk, sv) \
  1087. sq_pushstring(v, sk, -1); \
  1088. sq_pushstring(v, sv, -1); \
  1089. sq_rawset(v, -3);
  1090. #define ADD_T_TABLE_INT(sk, sv) \
  1091. sq_pushstring(v, sk, -1); \
  1092. sq_pushinteger(v, sv); \
  1093. sq_rawset(v, -3);
  1094. #define ADD_T_TABLE_FLOAT(sk, sv) \
  1095. sq_pushstring(v, sk, -1); \
  1096. sq_pushfloat(v, sv); \
  1097. sq_rawset(v, -3);
  1098. static SQRESULT sq_nn_info(HSQUIRRELVM v)
  1099. {
  1100. SQ_FUNC_VARS_NO_TOP(v);
  1101. SQ_GET_NN_INSTANCE(v, 1);
  1102. sq_newtable(v);
  1103. ADD_T_TABLE_INT("id", self->id);
  1104. ADD_T_TABLE_STR("type", (self->flags & NR_FLAG_CLASSIFIER) ? "classifier" : "regressor");
  1105. ADD_T_TABLE_INT("auto-normalization", !!(self->flags & NR_FLAG_NORMALIZE));
  1106. ADD_T_TABLE_INT("training", !!(self->flags & NR_FLAG_TRAINING));
  1107. sq_pushliteral(v, _SC("layout"));
  1108. sq_newarray(v, ANN_LAYERS(self->nn));
  1109. for (int ai=0, i = ANN_LAYERS(self->nn)-1; i >= 0; i--, ++ai) {
  1110. int units = ANN_UNITS(self->nn,i);
  1111. if (i != 0) units--; /* Don't count the bias unit. */
  1112. sq_pushinteger(v, units);
  1113. sq_arrayset(v, -2, ai);
  1114. }
  1115. sq_rawset(v, -3);
  1116. ADD_T_TABLE_INT("training-dataset-maxlen", self->dataset.maxlen);
  1117. ADD_T_TABLE_INT("training-dataset-len", self->dataset.len);
  1118. ADD_T_TABLE_INT("test-dataset-maxlen", self->test.maxlen);
  1119. ADD_T_TABLE_INT("test-dataset-len", self->test.len);
  1120. ADD_T_TABLE_INT("training-total-steps", self->training_total_steps);
  1121. ADD_T_TABLE_INT("training-total-cycles", self->dataset.len ?
  1122. (self->training_total_steps / self->dataset.len) : 0);
  1123. float tms = (float)self->training_total_ms/1000;
  1124. ADD_T_TABLE_FLOAT("training-total-seconds", tms);
  1125. ADD_T_TABLE_FLOAT("dataset-error", self->dataset_error);
  1126. ADD_T_TABLE_FLOAT("test-error", self->test_error);
  1127. if (self->flags & NR_FLAG_CLASSIFIER) {
  1128. ADD_T_TABLE_FLOAT("classification-errors-perc", self->test_class_error);
  1129. }
  1130. ADD_T_TABLE_STR("overfitting-detected", (self->flags & NR_FLAG_OF_DETECTED) ? "yes" : "no");
  1131. return 1;
  1132. }
  1133. static SQRESULT sq_nn_clone(HSQUIRRELVM v)
  1134. {
  1135. SQ_FUNC_VARS_NO_TOP(v);
  1136. SQ_GET_NN_INSTANCE(v, 1);
  1137. AnnRprop *clone = AnnClone(self->nn);
  1138. if(clone)
  1139. {
  1140. sq_pushstring(v, sq_nn_TAG, -1);
  1141. if(sq_getonregistrytable(v) == SQ_ERROR) return SQ_ERROR;
  1142. sq_createinstance(v, -1);
  1143. sq_setinstanceup(v, -1, clone);
  1144. sq_setreleasehook(v, -1, sq_nn_release_hook);
  1145. }
  1146. else sq_pushnull(v);
  1147. return 1;
  1148. }
  1149. #define SQ_NN_GET_SET_FLOAT(func_name) \
  1150. static SQRESULT sq_nn_##func_name(HSQUIRRELVM v)\
  1151. {\
  1152. SQ_FUNC_VARS(v);\
  1153. SQ_GET_NN_INSTANCE(v, 1);\
  1154. if(_top_ == 1)\
  1155. {\
  1156. sq_pushfloat(v, self->nn->func_name);\
  1157. return 1;\
  1158. }\
  1159. SQ_GET_FLOAT(v, 2, func_name);\
  1160. self->nn->func_name = func_name;\
  1161. return 0;\
  1162. }
  1163. SQ_NN_GET_SET_FLOAT(learn_rate);
  1164. SQ_NN_GET_SET_FLOAT(rprop_nminus);
  1165. SQ_NN_GET_SET_FLOAT(rprop_nplus);
  1166. SQ_NN_GET_SET_FLOAT(rprop_maxupdate);
  1167. SQ_NN_GET_SET_FLOAT(rprop_minupdate);
  1168. static SQRESULT sq_nn_flags(HSQUIRRELVM v)
  1169. {
  1170. SQ_FUNC_VARS(v);
  1171. SQ_GET_NN_INSTANCE(v, 1);
  1172. if(_top_ == 1)
  1173. {
  1174. sq_pushinteger(v, self->nn->flags);
  1175. return 1;
  1176. }
  1177. SQ_GET_INTEGER(v, 2, flags);
  1178. self->nn->flags = flags;
  1179. return 0;
  1180. }
  1181. static SQRESULT sq_nn_weights(HSQUIRRELVM v)
  1182. {
  1183. SQ_FUNC_VARS_NO_TOP(v);
  1184. SQ_GET_NN_INSTANCE(v, 1);
  1185. sq_pushfloat(v, AnnCountWeights(self->nn));
  1186. return 1;
  1187. }
  1188. static SQRESULT sq_nn_weight(HSQUIRRELVM v)
  1189. {
  1190. SQ_FUNC_VARS(v);
  1191. SQ_GET_NN_INSTANCE(v, 1);
  1192. SQ_GET_INTEGER(v, 2, layer);
  1193. SQ_GET_INTEGER(v, 3, i);
  1194. SQ_GET_INTEGER(v, 4, j);
  1195. if(layer < 0 && layer >= self->nn->layers) return sq_throwerror(v, _SC("layer out of range"));
  1196. //if(i < 0 && i >= self->layer[layer]) return sq_throwerror(v, _("layer out of range"));
  1197. float *weight = &ANN_WEIGHT(self->nn, layer, i, j);
  1198. if(_top_ == 4)
  1199. {
  1200. sq_pushfloat(v, *weight);
  1201. return 1;
  1202. }
  1203. SQ_GET_FLOAT(v, 5, new_weight);
  1204. *weight = new_weight;
  1205. return 0;
  1206. }
  1207. static SQRESULT sq_nn_Ann2Tcl(HSQUIRRELVM v)
  1208. {
  1209. SQ_FUNC_VARS_NO_TOP(v);
  1210. SQ_GET_NN_INSTANCE(v, 1);
  1211. Ann2Tcl(self->nn);
  1212. return 0;
  1213. }
  1214. static SQRESULT sq_nn_Ann2Js(HSQUIRRELVM v)
  1215. {
  1216. SQ_FUNC_VARS_NO_TOP(v);
  1217. SQ_GET_NN_INSTANCE(v, 1);
  1218. Ann2Js(self->nn);
  1219. return 0;
  1220. }
  1221. static SQRESULT sq_nn_AnnPrint(HSQUIRRELVM v)
  1222. {
  1223. SQ_FUNC_VARS_NO_TOP(v);
  1224. SQ_GET_NN_INSTANCE(v, 1);
  1225. AnnPrint(self->nn);
  1226. return 0;
  1227. }
  1228. #define _DECL_FUNC(name,nparams,tycheck) {_SC(#name),sq_nn_##name,nparams,tycheck}
  1229. static SQRegFunction sq_nn_methods[] =
  1230. {
  1231. _DECL_FUNC(constructor, -5,_SC("xiiaiii")),
  1232. _DECL_FUNC(clone, 1,_SC("x")),
  1233. _DECL_FUNC(Ann2Tcl, 1,_SC("x")),
  1234. _DECL_FUNC(Ann2Js, 1,_SC("x")),
  1235. _DECL_FUNC(AnnPrint, 1,_SC("x")),
  1236. _DECL_FUNC(flags, -1,_SC("xi")),
  1237. _DECL_FUNC(learn_rate, -1,_SC("xf")),
  1238. _DECL_FUNC(rprop_nminus, -1,_SC("xf")),
  1239. _DECL_FUNC(rprop_nplus, -1,_SC("xf")),
  1240. _DECL_FUNC(rprop_maxupdate, -1,_SC("xf")),
  1241. _DECL_FUNC(rprop_minupdate, -1,_SC("xf")),
  1242. _DECL_FUNC(weights, 1,_SC("x")),
  1243. _DECL_FUNC(weight, -4,_SC("xiiif")),
  1244. _DECL_FUNC(observe, -3,_SC("xaai")),
  1245. _DECL_FUNC(train, -1,_SC("xiii")),
  1246. _DECL_FUNC(run, 2,_SC("xa")),
  1247. _DECL_FUNC(classify, 2,_SC("xa")),
  1248. _DECL_FUNC(info, 1,_SC("x")),
  1249. {0,0}
  1250. };
  1251. #undef _DECL_FUNC
  1252. typedef struct {
  1253. const SQChar *Str;
  1254. SQInteger Val;
  1255. } KeyIntType, * KeyIntPtrType;
  1256. static KeyIntType sqpcre2_constants[] = {
  1257. #define MK_CONST(c) {_SC(#c), NR_##c}
  1258. #define MK_CONST_FLAG(c) {_SC(#c), NR_FLAG_##c}
  1259. MK_CONST_FLAG(NONE),
  1260. MK_CONST_FLAG(TRAINING),
  1261. MK_CONST_FLAG(REGRESSOR),
  1262. MK_CONST_FLAG(CLASSIFIER),
  1263. MK_CONST_FLAG(NORMALIZE),
  1264. MK_CONST_FLAG(AUTO_STOP),
  1265. MK_CONST_FLAG(OF_DETECTED),
  1266. MK_CONST_FLAG(BACKTRACK),
  1267. MK_CONST_FLAG(TO_PERSIST),
  1268. MK_CONST_FLAG(TO_TRANSFER),
  1269. MK_CONST(MAX_LAYERS),
  1270. MK_CONST(RDB_ENC_VER),
  1271. MK_CONST(INSERT_TRAIN),
  1272. MK_CONST(INSERT_TEST),
  1273. {0,0}
  1274. };
  1275. #ifdef __cplusplus
  1276. extern "C" {
  1277. #endif
  1278. /* This defines a function that opens up your library. */
  1279. SQRESULT sqext_register_nn (HSQUIRRELVM v) {
  1280. sq_pushstring(v,sq_nn_TAG,-1);
  1281. sq_newclass(v,SQFalse);
  1282. sq_settypetag(v,-1,(void*)sq_nn_TAG);
  1283. sq_insert_reg_funcs(v, sq_nn_methods);
  1284. //add constants
  1285. KeyIntPtrType KeyIntPtr;
  1286. for (KeyIntPtr = sqpcre2_constants; KeyIntPtr->Str; KeyIntPtr++) {
  1287. sq_pushstring(v, KeyIntPtr->Str, -1); //first the key
  1288. sq_pushinteger(v, KeyIntPtr->Val); //then the value
  1289. sq_newslot(v, -3, SQFalse); //store then
  1290. }
  1291. sq_newslot(v,-3,SQTrue);
  1292. return SQ_OK;
  1293. }
  1294. #ifdef __cplusplus
  1295. }
  1296. #endif