nn.h 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. /* RPROP Neural Networks implementation
  2. * See: http://deeplearning.cs.cmu.edu/pdfs/Rprop.pdf
  3. *
  4. * Copyright (c) 2003-2016, Salvatore Sanfilippo <antirez at gmail dot com>
  5. * All rights reserved.
  6. *
  7. * Redistribution and use in source and binary forms, with or without
  8. * modification, are permitted provided that the following conditions are met:
  9. *
  10. * * Redistributions of source code must retain the above copyright notice,
  11. * this list of conditions and the following disclaimer.
  12. * * Redistributions in binary form must reproduce the above copyright
  13. * notice, this list of conditions and the following disclaimer in the
  14. * documentation and/or other materials provided with the distribution.
  15. * * Neither the name of Disque nor the names of its contributors may be used
  16. * to endorse or promote products derived from this software without
  17. * specific prior written permission.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  20. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  21. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  22. * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  23. * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  24. * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  25. * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  26. * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  27. * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  28. * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  29. * POSSIBILITY OF SUCH DAMAGE.
  30. */
  31. #ifndef __SQNN_H
  32. #define __SQNN_H
  33. //#include <assert.h>
  34. typedef float ann_float_t;
  35. typedef ann_float_t (*AnnDerivativeFunc)(ann_float_t v);
  36. /* Data structures.
  37. * Nets are not so 'dynamic', but enough to support
  38. * an arbitrary number of layers, with arbitrary units for layer.
  39. * Only fully connected feed-forward networks are supported. */
  40. typedef struct {
  41. ann_float_t *output; /* output[i], output of i-th unit */
  42. ann_float_t *error; /* error[i], output error of i-th unit*/
  43. ann_float_t *weight; /* weight[(i*units)+j] */
  44. /* weight between unit i-th and next j-th */
  45. ann_float_t *gradient; /* gradient[(i*units)+j] gradient */
  46. ann_float_t *sgradient; /* gradient for the full training set */
  47. /* only used for RPROP */
  48. ann_float_t *pgradient; /* pastgradient[(i*units)+j] t-1 gradient */
  49. /* (t-1 sgradient for resilient BP) */
  50. ann_float_t *delta; /* delta[(i*units)+j] cumulative update */
  51. /* (per-weight delta for RPROP) */
  52. int units; /*moved to last position for alignment purposes*/
  53. int units_aligned; /*units rounded up for alignment*/
  54. } AnnLayer;
  55. /* Feed forward network structure */
  56. typedef struct {
  57. AnnLayer *layer;
  58. int flags;
  59. int layers;
  60. AnnDerivativeFunc node_transf_func;
  61. AnnDerivativeFunc derivative_func;
  62. ann_float_t rprop_nminus;
  63. ann_float_t rprop_nplus;
  64. ann_float_t rprop_maxupdate;
  65. ann_float_t rprop_minupdate;
  66. ann_float_t learn_rate; /* Used for GD training. */
  67. } AnnRprop;
  68. typedef ann_float_t (*AnnTrainAlgoFunc)(AnnRprop *net, ann_float_t *input, ann_float_t *desired, int setlen);
  69. /* Raw interface to data structures */
  70. #define ANN_LAYERS(net) (net)->layers
  71. #define ANN_LAYER(net, l) (net)->layer[/*assert(l >= 0),*/l]
  72. #define ANN_OUTPUT(net,l,i) ANN_LAYER(net, l).output[i]
  73. #define ANN_ERROR(net,l,i) ANN_LAYER(net, l).error[i]
  74. #define ANN_LAYER_IDX(net,l,i,j) (((j)*ANN_LAYER(net, l).units_aligned)+(i))
  75. #define ANN_WEIGHT(net,l,i,j) ANN_LAYER(net, l).weight[ANN_LAYER_IDX(net,l,i,j)]
  76. #define ANN_GRADIENT(net,l,i,j) ANN_LAYER(net, l).gradient[ANN_LAYER_IDX(net,l,i,j)]
  77. #define ANN_SGRADIENT(net,l,i,j) ANN_LAYER(net, l).sgradient[ANN_LAYER_IDX(net,l,i,j)]
  78. #define ANN_PGRADIENT(net,l,i,j) ANN_LAYER(net, l).pgradient[ANN_LAYER_IDX(net,l,i,j)]
  79. #define ANN_DELTA(net,l,i,j) ANN_LAYER(net, l).delta[ANN_LAYER_IDX(net,l,i,j)]
  80. #define ANN_UNITS(net,l) ANN_LAYER(net, l).units
  81. #define ANN_UNITS_ALLOCATED(net,l) ANN_LAYER(net, l).units_aligned
  82. #define ANN_WEIGHTS(net,l) (ANN_UNITS(net,l)*ANN_UNITS(net,l-1))
  83. #define ANN_OUTPUT_NODE(net,i) ANN_OUTPUT(net,0,i)
  84. #define ANN_INPUT_NODE(net,i) ANN_OUTPUT(net,(ANN_LAYERS(net))-1,i)
  85. #define ANN_OUTPUT_UNITS(net) ANN_UNITS(net,0)
  86. #define ANN_INPUT_UNITS(net) (ANN_UNITS(net,(ANN_LAYERS(net))-1)-1)
  87. #define ANN_RPROP_NMINUS(net) (net)->rprop_nminus
  88. #define ANN_RPROP_NPLUS(net) (net)->rprop_nplus
  89. #define ANN_RPROP_MAXUPDATE(net) (net)->rprop_maxupdate
  90. #define ANN_RPROP_MINUPDATE(net) (net)->rprop_minupdate
  91. #define ANN_LEARN_RATE(net) (net)->learn_rate
  92. /* Constants */
  93. #define ANN_DEFAULT_RPROP_NMINUS 0.5
  94. #define ANN_DEFAULT_RPROP_NPLUS 1.2
  95. #define ANN_DEFAULT_RPROP_MAXUPDATE 50
  96. #define ANN_DEFAULT_RPROP_MINUPDATE 0.000001
  97. #define ANN_RPROP_INITIAL_DELTA 0.1
  98. #define ANN_DEFAULT_LEARN_RATE 0.1
  99. #define ANN_ALGO_BPROP 0
  100. #define ANN_ALGO_GD 1
  101. /* Misc */
  102. #define ANN_MAX(a,b) (((a)>(b))?(a):(b))
  103. #define ANN_MIN(a,b) (((a)<(b))?(a):(b))
  104. /* Prototypes */
  105. ann_float_t AnnTransferFunctionSigmoid(ann_float_t x);
  106. ann_float_t AnnTransferFunctionRelu(ann_float_t x);
  107. ann_float_t AnnTransferFunctionTanh(ann_float_t x);
  108. //ann_float_t AnnDerivativeIdentity(ann_float_t x);
  109. ann_float_t AnnDerivativeSigmoid(ann_float_t x);
  110. ann_float_t AnnDerivativeTanh(ann_float_t x);
  111. ann_float_t AnnDerivativeRelu(ann_float_t x);
  112. void AnnResetLayer(AnnLayer *layer);
  113. AnnRprop *AnnAlloc(int layers);
  114. void AnnFreeLayer(AnnLayer *layer);
  115. void AnnFree(AnnRprop *net);
  116. int AnnInitLayer(AnnRprop *net, int i, int units, int bias);
  117. AnnRprop *AnnCreateNet(int layers, int *units);
  118. AnnRprop *AnnCreateNet2(int iunits, int ounits);
  119. AnnRprop *AnnCreateNet3(int iunits, int hunits, int ounits);
  120. AnnRprop *AnnCreateNet4(int iunits, int hunits, int hunits2, int ounits);
  121. AnnRprop *AnnClone(const AnnRprop* net);
  122. size_t AnnCountWeights(AnnRprop *net);
  123. void AnnSimulate(AnnRprop *net);
  124. void Ann2Tcl(const AnnRprop *net);
  125. void Ann2Js(const AnnRprop *net);
  126. void AnnPrint(const AnnRprop *net);
  127. ann_float_t AnnGlobalError(AnnRprop *net, ann_float_t *desidered);
  128. void AnnSetInput(AnnRprop *net, ann_float_t *input);
  129. ann_float_t AnnSimulateError(AnnRprop *net, ann_float_t *input, ann_float_t *desidered);
  130. void AnnCalculateGradientsTrivial(AnnRprop *net, ann_float_t *desidered);
  131. void AnnCalculateGradients(AnnRprop *net, ann_float_t *desidered);
  132. void AnnSetDeltas(AnnRprop *net, ann_float_t val);
  133. void AnnResetDeltas(AnnRprop *net);
  134. void AnnResetSgradient(AnnRprop *net);
  135. void AnnSetRandomWeights(AnnRprop *net);
  136. void AnnScaleWeights(AnnRprop *net, ann_float_t factor);
  137. void AnnUpdateDeltasGD(AnnRprop *net);
  138. void AnnUpdateDeltasGDM(AnnRprop *net);
  139. void AnnUpdateSgradient(AnnRprop *net);
  140. void AnnAdjustWeights(AnnRprop *net, int setlen);
  141. ann_float_t AnnBatchGDEpoch(AnnRprop *net, ann_float_t *input, ann_float_t *desidered, int setlen);
  142. ann_float_t AnnBatchGDMEpoch(AnnRprop *net, ann_float_t *input, ann_float_t *desidered, int setlen);
  143. void AnnAdjustWeightsResilientBP(AnnRprop *net);
  144. ann_float_t AnnResilientBPEpoch(AnnRprop *net, ann_float_t *input, ann_float_t *desidered, int setlen);
  145. ann_float_t AnnTrainWithAlgoFunc(AnnRprop *net, ann_float_t *input, ann_float_t *desidered, ann_float_t maxerr, int maxepochs, int setlen, AnnTrainAlgoFunc algo_func);
  146. ann_float_t AnnTrain(AnnRprop *net, ann_float_t *input, ann_float_t *desidered, ann_float_t maxerr, int maxepochs, int setlen, int algo);
  147. void AnnTestError(AnnRprop *net, ann_float_t *input, ann_float_t *desired, int setlen, ann_float_t *avgerr, ann_float_t *classerr);
  148. #endif /* __SQNN_H */