SparseWeight8Tests.cs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using NUnit.Framework;
  6. namespace SharpGLTF.Transforms
  7. {
  8. [Category("Core.Transforms")]
  9. public class SparseWeight8Tests
  10. {
  11. [TestCase(0)]
  12. [TestCase(1)]
  13. [TestCase(0,0.0001f)]
  14. [TestCase(2, -2, 2, -2)]
  15. [TestCase(0.2f, 0.15f, 0.25f, 0.10f, 0.30f)]
  16. [TestCase(0, 0, 1, 0, 2, 0, 3, 4, 5, 0, 6, 0, 7, 0, 6, 0, 9, 0, 11)]
  17. [TestCase(9, -9, 8, -8, 7, -7, 6, -6, 5, -5, 4, -4, 3, -3, 2, -2, 1, -1)]
  18. [TestCase(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)]
  19. [TestCase(0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)]
  20. public void TestSparseCreation(params float[] array1)
  21. {
  22. var array2 = CreateSparseCompatibleArray(array1);
  23. var array3 = array1
  24. .Select((val, idx) => (idx, val))
  25. .Where(item => item.val != 0)
  26. .Reverse()
  27. .ToArray();
  28. // creation mode 1
  29. var sparse = SparseWeight8.Create(array1);
  30. Assert.AreEqual(array2.Sum(), sparse.WeightSum);
  31. CollectionAssert.AreEqual(array2, sparse.Expand(array2.Length));
  32. // creation mode 2
  33. var indexedSparse = SparseWeight8.Create(array3);
  34. Assert.AreEqual(array2.Sum(), indexedSparse.WeightSum, 0.000001f);
  35. CollectionAssert.AreEqual(array2, indexedSparse.Expand(array2.Length));
  36. Assert.IsTrue(SparseWeight8.AreEqual(sparse, indexedSparse));
  37. // sort by weights
  38. var sByWeights = SparseWeight8.OrderedByWeight(sparse);
  39. Assert.AreEqual(array2.Sum(), sByWeights.WeightSum);
  40. CollectionAssert.AreEqual(array2, sByWeights.Expand(array2.Length));
  41. CheckWeightOrdered(sByWeights);
  42. // sort by indices
  43. var sByIndices = SparseWeight8.OrderedByIndex(sByWeights);
  44. Assert.AreEqual(array2.Sum(), sByIndices.WeightSum);
  45. CollectionAssert.AreEqual(array2, sByIndices.Expand(array2.Length));
  46. CheckIndexOrdered(sByWeights);
  47. // equality
  48. Assert.IsTrue(SparseWeight8.AreEqual(sByIndices, sByWeights));
  49. Assert.AreEqual(sByIndices.GetHashCode(), sByWeights.GetHashCode());
  50. // sum
  51. var sum = SparseWeight8.Add(sByIndices, sByWeights);
  52. Assert.AreEqual(array2.Sum() * 2, sum.WeightSum);
  53. // complement normalization
  54. if (!array2.Any(item => item<0))
  55. {
  56. Assert.GreaterOrEqual(sparse.GetNormalizedWithComplement(int.MaxValue).WeightSum, 1);
  57. }
  58. }
  59. [Test]
  60. public void TestSparseCreation()
  61. {
  62. var sparse = SparseWeight8.Create
  63. (
  64. (9, 9),
  65. (8, 2),
  66. (5, 1), // we set these weights separately
  67. (5, 1), // to check that 5 will pass 8
  68. (5, 1), // in the sorted set.
  69. (7, 1)
  70. );
  71. Assert.AreEqual(3, sparse[5]);
  72. Assert.AreEqual(1, sparse[7]);
  73. Assert.AreEqual(2, sparse[8]);
  74. Assert.AreEqual(9, sparse[9]);
  75. }
  76. [Test]
  77. public void TestCreateSparseFromVectors()
  78. {
  79. CollectionAssert.AreEqual
  80. (
  81. SparseWeight8.Create(new System.Numerics.Vector4(0, 1, 2, 3), new System.Numerics.Vector4(1, 1, 1, 1)).Expand(4),
  82. SparseWeight8.Create(1, 1, 1, 1).Expand(4)
  83. );
  84. CollectionAssert.AreEqual
  85. (
  86. SparseWeight8.Create(new System.Numerics.Vector4(0, 1, 2, 3), new System.Numerics.Vector4(1, 2, 3, 4)).Expand(4),
  87. SparseWeight8.Create(1, 2, 3, 4).Expand(4)
  88. );
  89. CollectionAssert.AreEqual
  90. (
  91. SparseWeight8.Create(new System.Numerics.Vector4(0, 1, 2, 3), new System.Numerics.Vector4(4, 3, 2, 1)).Expand(4),
  92. SparseWeight8.Create(4, 3, 2, 1).Expand(4)
  93. );
  94. CollectionAssert.AreEqual
  95. (
  96. SparseWeight8.Create(new System.Numerics.Vector4(0, 2, 2, 3), new System.Numerics.Vector4(4, 3, 2, 1)).Expand(4),
  97. SparseWeight8.Create(4, 0, 5, 1).Expand(4)
  98. );
  99. CollectionAssert.AreEqual
  100. (
  101. SparseWeight8.Create(new System.Numerics.Vector4(1, 1, 1, 1), new System.Numerics.Vector4(1, 1, 1, 1)).Expand(4),
  102. SparseWeight8.Create(0, 4, 0, 0).Expand(4)
  103. );
  104. }
  105. /// <summary>
  106. /// Creates a new array with only the 8 most relevant weights.
  107. /// </summary>
  108. /// <param name="array"></param>
  109. /// <returns></returns>
  110. static float[] CreateSparseCompatibleArray(params float[] array)
  111. {
  112. const int MAXWEIGHTS = 8;
  113. if (array == null) return null;
  114. var threshold =array
  115. .Select(item => Math.Abs(item))
  116. .OrderByDescending(item => item)
  117. .Take(MAXWEIGHTS)
  118. .Min();
  119. var array2 = new float[array.Length];
  120. var c = 0;
  121. for(int i=0; i < array2.Length; ++i)
  122. {
  123. var v = array[i];
  124. if (v == 0) continue;
  125. if (Math.Abs(v) >= threshold)
  126. {
  127. array2[i] = v;
  128. ++c;
  129. if (c >= MAXWEIGHTS) return array2;
  130. }
  131. }
  132. return array2;
  133. }
  134. static void CheckWeightOrdered(SparseWeight8 sparse)
  135. {
  136. Assert.GreaterOrEqual(Math.Abs(sparse.Weight0), Math.Abs(sparse.Weight1));
  137. Assert.GreaterOrEqual(Math.Abs(sparse.Weight1), Math.Abs(sparse.Weight2));
  138. Assert.GreaterOrEqual(Math.Abs(sparse.Weight2), Math.Abs(sparse.Weight3));
  139. Assert.GreaterOrEqual(Math.Abs(sparse.Weight3), Math.Abs(sparse.Weight4));
  140. Assert.GreaterOrEqual(Math.Abs(sparse.Weight4), Math.Abs(sparse.Weight5));
  141. Assert.GreaterOrEqual(Math.Abs(sparse.Weight5), Math.Abs(sparse.Weight6));
  142. Assert.GreaterOrEqual(Math.Abs(sparse.Weight6), Math.Abs(sparse.Weight7));
  143. }
  144. static void CheckIndexOrdered(SparseWeight8 sparse)
  145. {
  146. Assert.LessOrEqual(sparse.Index0, sparse.Index0);
  147. Assert.LessOrEqual(sparse.Index1, sparse.Index1);
  148. Assert.LessOrEqual(sparse.Index2, sparse.Index2);
  149. Assert.LessOrEqual(sparse.Index3, sparse.Index3);
  150. Assert.LessOrEqual(sparse.Index4, sparse.Index4);
  151. Assert.LessOrEqual(sparse.Index5, sparse.Index5);
  152. Assert.LessOrEqual(sparse.Index6, sparse.Index6);
  153. }
  154. [Test]
  155. public void TestSparseNormalization()
  156. {
  157. var sparse1 = SparseWeight8
  158. .Create(0, 0, 0, 0, 0, 0.1f, 0.7f, 0, 0, 0, 0.1f)
  159. .GetNormalizedWithComplement(int.MaxValue);
  160. Assert.AreEqual(0.1f, sparse1[5]);
  161. Assert.AreEqual(0.7f, sparse1[6]);
  162. Assert.AreEqual(0.1f, sparse1[10]);
  163. Assert.AreEqual(0.1f, sparse1[int.MaxValue], 0.0000001f);
  164. Assert.AreEqual(1, sparse1.WeightSum);
  165. }
  166. [Test]
  167. public void TestSparseEquality()
  168. {
  169. Assert.IsTrue(SparseWeight8.AreEqual(SparseWeight8.Create(0, 1), SparseWeight8.Create(0, 1)));
  170. Assert.IsFalse(SparseWeight8.AreEqual(SparseWeight8.Create(0, 1), SparseWeight8.Create(0, 1, 0.25f)));
  171. Assert.IsFalse(SparseWeight8.AreEqual(SparseWeight8.Create(0, 1), SparseWeight8.Create(1, 0)));
  172. // check if two "half weights" are equal to one "full weight"
  173. //Assert.IsTrue(SparseWeight8.AreWeightsEqual(SparseWeight8.Create((3, 5), (3, 5)), SparseWeight8.Create((3, 10))));
  174. }
  175. [Test]
  176. public void TestSparseWeightsLinearInterpolation1()
  177. {
  178. var x = SparseWeight8.Create(0,0,1,2); CollectionAssert.AreEqual(new[] { 0f, 0f, 1f, 2f }, x.Expand(4));
  179. var y = SparseWeight8.Create(1,2,0,0); CollectionAssert.AreEqual(new[] { 1f, 2f, 0f, 0f }, y.Expand(4));
  180. var z = SparseWeight8.InterpolateLinear(x, y, 0.5f);
  181. Assert.AreEqual(0.5f, z[0]);
  182. Assert.AreEqual(1, z[1]);
  183. Assert.AreEqual(0.5f, z[2]);
  184. Assert.AreEqual(1, z[3]);
  185. }
  186. [Test]
  187. public void TestSparseWeightsLinearInterpolation2()
  188. {
  189. var ax = new float[] { 0, 0, 0, 0, 0, 0.1f, 0.7f, 0, 0, 0, 0.1f };
  190. var ay = new float[] { 0, 0, 0.2f, 0, 0.1f, 0, 0, 0, 0, 0, 0, 0, 0.2f };
  191. var cc = Math.Min(ax.Length, ay.Length);
  192. var x = SparseWeight8.Create(ax); CollectionAssert.AreEqual(ax, x.Expand(ax.Length));
  193. var y = SparseWeight8.Create(ay); CollectionAssert.AreEqual(ay, y.Expand(ay.Length));
  194. var z = SparseWeight8.InterpolateLinear(x, y, 0.5f);
  195. for (int i=0; i < cc; ++i)
  196. {
  197. var w = (ax[i] + ay[i]) / 2;
  198. Assert.AreEqual(w, z[i]);
  199. }
  200. }
  201. [Test]
  202. public void TestSparseWeightsCubicInterpolation()
  203. {
  204. var a = SparseWeight8.Create(0, 0, 0.2f, 0, 0, 0, 1);
  205. var b = SparseWeight8.Create(1, 1, 0.4f, 0, 0, 1, 0);
  206. var t = SparseWeight8.Subtract(b, a);
  207. Assert.AreEqual(1, t[0]);
  208. Assert.AreEqual(1, t[1]);
  209. Assert.AreEqual(0.2f, t[2]);
  210. Assert.AreEqual(0, t[3]);
  211. Assert.AreEqual(0, t[4]);
  212. Assert.AreEqual(1, t[5]);
  213. Assert.AreEqual(-1, t[6]);
  214. var lr = SparseWeight8.InterpolateLinear(a, b, 0.4f);
  215. var cr = SparseWeight8.InterpolateCubic(a, t, b, t, 0.4f);
  216. Assert.AreEqual(lr[0], cr[0], 0.000001f);
  217. Assert.AreEqual(lr[1], cr[1], 0.000001f);
  218. Assert.AreEqual(lr[2], cr[2], 0.000001f);
  219. Assert.AreEqual(lr[3], cr[3], 0.000001f);
  220. Assert.AreEqual(lr[4], cr[4], 0.000001f);
  221. Assert.AreEqual(lr[5], cr[5], 0.000001f);
  222. Assert.AreEqual(lr[6], cr[6], 0.000001f);
  223. Assert.AreEqual(lr[7], cr[7], 0.000001f);
  224. }
  225. [Test]
  226. public void TestSparseWeightReduction()
  227. {
  228. var a = SparseWeight8.Create(5, 3, 2, 4, 0, 4, 2, 6, 3, 6, 1);
  229. var b = a.GetTrimmed(4);
  230. Assert.AreEqual(4, b.GetNonZeroWeights().Count());
  231. Assert.AreEqual(a[0], b[0]);
  232. Assert.AreEqual(a[3], b[3]);
  233. Assert.AreEqual(a[7], b[7]);
  234. Assert.AreEqual(a[9], b[9]);
  235. Assert.AreEqual(0, b.Weight4);
  236. Assert.AreEqual(0, b.Weight5);
  237. Assert.AreEqual(0, b.Weight6);
  238. Assert.AreEqual(0, b.Weight7);
  239. }
  240. }
  241. }