Activations.cpp 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Algorithms/Activations.h>
  9. #include <AzCore/Math/SimdMath.h>
  10. #include <AzCore/Math/MatrixMxN.h>
  11. namespace MachineLearning
  12. {
  13. AZStd::vector<AZ::Edit::EnumConstant<ActivationFunctions>> GetActivationEnumValues()
  14. {
  15. AZStd::vector<AZ::Edit::EnumConstant<ActivationFunctions>> values;
  16. values.emplace_back(ActivationFunctions::ReLU, "ReLU");
  17. values.emplace_back(ActivationFunctions::Sigmoid, "Sigmoid");
  18. values.emplace_back(ActivationFunctions::Softmax, "Softmax");
  19. values.emplace_back(ActivationFunctions::Linear, "Linear");
  20. return values;
  21. }
  22. void OneHotEncode(AZStd::size_t value, AZStd::size_t maxValue, AZ::VectorN& output)
  23. {
  24. AZ_Assert(value <= maxValue, "Requested one-hot encode of an out of range value");
  25. output.Resize(maxValue);
  26. output.SetZero();
  27. output.SetElement(value, 1.0f);
  28. }
  29. AZStd::size_t ArgMaxDecode(const AZ::VectorN& vector)
  30. {
  31. const AZStd::size_t numElements = vector.GetDimensionality();
  32. float maxValue = 0.0f;
  33. AZStd::size_t maxIndex = 0;
  34. for (AZStd::size_t iter = 0; iter < numElements; ++iter)
  35. {
  36. if (vector.GetElement(iter) > maxValue)
  37. {
  38. maxValue = vector.GetElement(iter);
  39. maxIndex = iter;
  40. }
  41. }
  42. return maxIndex;
  43. }
  44. void Activate(ActivationFunctions activationFunction, const AZ::VectorN& sourceVector, AZ::VectorN& output)
  45. {
  46. output.Resize(sourceVector.GetDimensionality());
  47. switch (activationFunction)
  48. {
  49. case ActivationFunctions::ReLU:
  50. ReLU(sourceVector, output);
  51. break;
  52. case ActivationFunctions::Sigmoid:
  53. Sigmoid(sourceVector, output);
  54. break;
  55. case ActivationFunctions::Softmax:
  56. Softmax(sourceVector, output);
  57. break;
  58. case ActivationFunctions::Linear:
  59. Linear(sourceVector, output);
  60. break;
  61. }
  62. }
  63. void ReLU(const AZ::VectorN& sourceVector, AZ::VectorN& output)
  64. {
  65. const AZStd::size_t numElements = sourceVector.GetVectorValues().size();
  66. const AZ::Simd::Vec4::FloatType zero = AZ::Simd::Vec4::ZeroFloat();
  67. output.Resize(sourceVector.GetDimensionality());
  68. for (AZStd::size_t iter = 0; iter < numElements; ++iter)
  69. {
  70. const AZ::Vector4& sourceElement = sourceVector.GetVectorValues()[iter];
  71. const AZ::Simd::Vec4::FloatType mask = AZ::Simd::Vec4::CmpGtEq(sourceElement.GetSimdValue(), zero); // 1's if >= 0, 0's otherwise
  72. AZ::Vector4& outputElement = output.GetVectorValues()[iter];
  73. outputElement.SetSimdValue(AZ::Simd::Vec4::And(sourceElement.GetSimdValue(), mask)); // Zeros out negative elements
  74. }
  75. output.FixLastVectorElement();
  76. }
  77. void Sigmoid(const AZ::VectorN& sourceVector, AZ::VectorN& output)
  78. {
  79. const AZ::Vector4 vecZero = AZ::Vector4::CreateZero();
  80. const AZ::Vector4 vecOne = AZ::Vector4::CreateOne();
  81. const AZ::Vector4 epsilon = AZ::Vector4(AZ::Constants::Tolerance);
  82. const AZStd::size_t numElements = sourceVector.GetVectorValues().size();
  83. output.Resize(sourceVector.GetDimensionality());
  84. for (AZStd::size_t iter = 0; iter < numElements; ++iter)
  85. {
  86. const AZ::Vector4& sourceElement = sourceVector.GetVectorValues()[iter];
  87. AZ::Vector4& outputElement = output.GetVectorValues()[iter];
  88. const AZ::Vector4 divisor = (vecOne + (-sourceElement).GetExpEstimate()).GetMax(epsilon);
  89. outputElement = vecOne / divisor;
  90. outputElement = outputElement.GetClamp(vecZero, vecOne);
  91. }
  92. output.FixLastVectorElement();
  93. }
  94. void Softmax(const AZ::VectorN& sourceVector, AZ::VectorN& output)
  95. {
  96. const AZ::Vector4 vecZero = AZ::Vector4::CreateZero();
  97. const AZ::Vector4 vecOne = AZ::Vector4::CreateOne();
  98. // Naive softmax is simply softmax(source) = exp(source) / sum(exp(source))
  99. // Here we apply the exp-normalization trick to avoid exp overflow
  100. // x = max(source)
  101. // y = exp(source - x)
  102. // softmax(source) = y / sum(y)
  103. const AZStd::size_t numElements = sourceVector.GetVectorValues().size();
  104. output.Resize(sourceVector.GetDimensionality());
  105. AZ::Vector4 max = sourceVector.GetVectorValues()[0];
  106. for (AZStd::size_t iter = 1; iter < numElements; ++iter)
  107. {
  108. max.GetMax(sourceVector.GetVectorValues()[iter]);
  109. }
  110. AZ::Vector4 partialSum = vecZero;
  111. for (AZStd::size_t iter = 0; iter < numElements; ++iter)
  112. {
  113. const AZ::Vector4& sourceElement = sourceVector.GetVectorValues()[iter];
  114. AZ::Vector4& outputElement = output.GetVectorValues()[iter];
  115. outputElement = (sourceElement - max).GetExpEstimate();
  116. outputElement = outputElement.GetClamp(vecZero, vecOne);
  117. partialSum += outputElement;
  118. }
  119. const float divisor = AZ::GetMax(1.0f / partialSum.Dot(vecOne), AZ::Constants::Tolerance);
  120. for (AZ::Vector4& element : output.GetVectorValues())
  121. {
  122. element = element * divisor;
  123. }
  124. output.FixLastVectorElement();
  125. }
  126. void Linear(const AZ::VectorN& sourceVector, AZ::VectorN& output)
  127. {
  128. if (&output != &sourceVector)
  129. {
  130. output = sourceVector;
  131. }
  132. }
  133. void Activate_Derivative(ActivationFunctions activationFunction, const AZ::VectorN& activationOutput, const AZ::VectorN& backGradients, AZ::VectorN& output)
  134. {
  135. output.Resize(activationOutput.GetDimensionality());
  136. switch (activationFunction)
  137. {
  138. case ActivationFunctions::ReLU:
  139. ReLU_Derivative(activationOutput, backGradients, output);
  140. break;
  141. case ActivationFunctions::Sigmoid:
  142. Sigmoid_Derivative(activationOutput, backGradients, output);
  143. break;
  144. case ActivationFunctions::Softmax:
  145. Softmax_Derivative(activationOutput, backGradients, output);
  146. break;
  147. case ActivationFunctions::Linear:
  148. Linear_Derivative(activationOutput, backGradients, output);
  149. break;
  150. }
  151. }
  152. void ReLU_Derivative(const AZ::VectorN& activationOutput, const AZ::VectorN& backGradients, AZ::VectorN& output)
  153. {
  154. const AZStd::size_t numElements = activationOutput.GetVectorValues().size();
  155. const AZ::Simd::Vec4::FloatType zero = AZ::Simd::Vec4::ZeroFloat();
  156. output.Resize(activationOutput.GetDimensionality());
  157. for (AZStd::size_t iter = 0; iter < numElements; ++iter)
  158. {
  159. const AZ::Vector4& activationElement = activationOutput.GetVectorValues()[iter];
  160. const AZ::Vector4& backGradientElement = backGradients.GetVectorValues()[iter];
  161. // 1's if > 0, 0's otherwise
  162. // Strictly greater than is required as any negative inputs in the original source vector will have been clamped to zero by activation
  163. const AZ::Simd::Vec4::FloatType mask = AZ::Simd::Vec4::CmpGt(activationElement.GetSimdValue(), zero);
  164. AZ::Vector4& outputElement = output.GetVectorValues()[iter];
  165. outputElement.SetSimdValue(AZ::Simd::Vec4::And(backGradientElement.GetSimdValue(), mask)); // Returns the backpropagated gradient if mask is non-zero, returns zero otherwise
  166. }
  167. output.FixLastVectorElement();
  168. }
  169. void Sigmoid_Derivative(const AZ::VectorN& activationOutput, const AZ::VectorN& backGradients, AZ::VectorN& output)
  170. {
  171. const AZStd::size_t numElements = activationOutput.GetVectorValues().size();
  172. const AZ::Vector4 vecOne = AZ::Vector4::CreateOne();
  173. output.Resize(activationOutput.GetDimensionality());
  174. for (AZStd::size_t iter = 0; iter < numElements; ++iter)
  175. {
  176. const AZ::Vector4& activationElement = activationOutput.GetVectorValues()[iter];
  177. const AZ::Vector4& backGradientElement = backGradients.GetVectorValues()[iter];
  178. AZ::Vector4& outputElement = output.GetVectorValues()[iter];
  179. outputElement = backGradientElement * activationElement * (vecOne - activationElement);
  180. }
  181. output.FixLastVectorElement();
  182. }
  183. void Softmax_Derivative(const AZ::VectorN& activationOutput, const AZ::VectorN& backGradients, AZ::VectorN& output)
  184. {
  185. // Note that this is completely unvectorized
  186. output.Resize(activationOutput.GetDimensionality());
  187. for (AZStd::size_t i = 0; i < activationOutput.GetDimensionality(); ++i)
  188. {
  189. float gradient = 0.0f;
  190. for (AZStd::size_t j = 0; j < activationOutput.GetDimensionality(); ++j)
  191. {
  192. const float ithElement = activationOutput.GetElement(i);
  193. const float jthElement = activationOutput.GetElement(j) * backGradients.GetElement(j);
  194. gradient += (i == j) ? (1.0f - ithElement) * jthElement : -ithElement * jthElement;
  195. }
  196. output.SetElement(i, gradient);
  197. }
  198. }
  199. void Linear_Derivative([[maybe_unused]] const AZ::VectorN& activationOutput, const AZ::VectorN& backGradients, AZ::VectorN& output)
  200. {
  201. output = backGradients;
  202. }
  203. }