SparseWeight8Tests.cs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using NUnit.Framework;
  6. namespace SharpGLTF.Transforms
  7. {
  8. [Category("Core.Transforms")]
  9. public class SparseWeight8Tests
  10. {
  11. [TestCase(0)]
  12. [TestCase(1)]
  13. [TestCase(0,0.0001f)]
  14. [TestCase(2, -2, 2, -2)]
  15. [TestCase(0.2f, 0.15f, 0.25f, 0.10f, 0.30f)]
  16. [TestCase(0, 0, 1, 0, 2, 0, 3, 4, 5, 0, 6, 0, 7, 0, 6, 0, 9, 0, 11)]
  17. [TestCase(9, -9, 8, -8, 7, -7, 6, -6, 5, -5, 4, -4, 3, -3, 2, -2, 1, -1)]
  18. [TestCase(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)]
  19. [TestCase(0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)]
  20. public void TestSparseCreation(params float[] array1)
  21. {
  22. var array2 = CreateSparseCompatibleArray(array1);
  23. var array3 = array1
  24. .Select((val, idx) => (idx, val))
  25. .Where(item => item.val != 0)
  26. .Reverse()
  27. .ToArray();
  28. // creation mode 1
  29. var sparse = SparseWeight8.Create(array1);
  30. Assert.That(sparse.WeightSum, Is.EqualTo(array2.Sum()));
  31. Assert.That(sparse.Expand(array2.Length), Is.EqualTo(array2));
  32. // creation mode 2
  33. var indexedSparse = SparseWeight8.Create(array3);
  34. Assert.That(indexedSparse.WeightSum, Is.EqualTo(array2.Sum()).Within(0.000001f));
  35. Assert.That(indexedSparse.Expand(array2.Length), Is.EqualTo(array2));
  36. Assert.That(SparseWeight8.AreEqual(sparse, indexedSparse));
  37. // sort by weights
  38. var sByWeights = SparseWeight8.OrderedByWeight(sparse);
  39. Assert.That(sByWeights.WeightSum, Is.EqualTo(array2.Sum()));
  40. Assert.That(sByWeights.Expand(array2.Length), Is.EqualTo(array2));
  41. CheckWeightOrdered(sByWeights);
  42. // sort by indices
  43. var sByIndices = SparseWeight8.OrderedByIndex(sByWeights);
  44. CheckIndexOrdered(sByIndices);
  45. Assert.That(sByIndices.WeightSum, Is.EqualTo(array2.Sum()));
  46. Assert.That(sByIndices.Expand(array2.Length), Is.EqualTo(array2));
  47. // equality
  48. Assert.That(SparseWeight8.AreEqual(sByIndices, sByWeights), Is.True);
  49. Assert.That(sByWeights.GetHashCode(), Is.EqualTo(sByIndices.GetHashCode()));
  50. // sum
  51. var sum = SparseWeight8.Add(sByIndices, sByWeights);
  52. Assert.That(sum.WeightSum, Is.EqualTo(array2.Sum() * 2));
  53. // complement normalization
  54. if (!array2.Any(item => item<0))
  55. {
  56. Assert.That(sparse.GetNormalizedWithComplement(int.MaxValue).WeightSum, Is.GreaterThanOrEqualTo(1));
  57. }
  58. }
  59. [Test]
  60. public void TestSparseCreation()
  61. {
  62. var sparse = SparseWeight8.Create
  63. (
  64. (9, 9),
  65. (8, 2),
  66. (5, 1), // we set these weights separately
  67. (5, 1), // to check that 5 will pass 8
  68. (5, 1), // in the sorted set.
  69. (7, 1)
  70. );
  71. Assert.That(sparse[5], Is.EqualTo(3));
  72. Assert.That(sparse[7], Is.EqualTo(1));
  73. Assert.That(sparse[8], Is.EqualTo(2));
  74. Assert.That(sparse[9], Is.EqualTo(9));
  75. }
  76. [Test]
  77. public void TestCreateSparseFromVectors()
  78. {
  79. Assert.That
  80. (
  81. SparseWeight8.Create(new System.Numerics.Vector4(0, 1, 2, 3), new System.Numerics.Vector4(1, 1, 1, 1)).Expand(4),
  82. Is.EqualTo(SparseWeight8.Create(1, 1, 1, 1).Expand(4))
  83. );
  84. Assert.That
  85. (
  86. SparseWeight8.Create(new System.Numerics.Vector4(0, 1, 2, 3), new System.Numerics.Vector4(1, 2, 3, 4)).Expand(4),
  87. Is.EqualTo(SparseWeight8.Create(1, 2, 3, 4).Expand(4))
  88. );
  89. Assert.That
  90. (
  91. SparseWeight8.Create(new System.Numerics.Vector4(0, 1, 2, 3), new System.Numerics.Vector4(4, 3, 2, 1)).Expand(4),
  92. Is.EqualTo(SparseWeight8.Create(4, 3, 2, 1).Expand(4))
  93. );
  94. Assert.That
  95. (
  96. SparseWeight8.Create(new System.Numerics.Vector4(0, 2, 2, 3), new System.Numerics.Vector4(4, 3, 2, 1)).Expand(4),
  97. Is.EqualTo(SparseWeight8.Create(4, 0, 5, 1).Expand(4))
  98. );
  99. Assert.That
  100. (
  101. SparseWeight8.Create(new System.Numerics.Vector4(1, 1, 1, 1), new System.Numerics.Vector4(1, 1, 1, 1)).Expand(4),
  102. Is.EqualTo(SparseWeight8.Create(0, 4, 0, 0).Expand(4))
  103. );
  104. }
  105. /// <summary>
  106. /// Creates a new array with only the 8 most relevant weights.
  107. /// </summary>
  108. /// <param name="array"></param>
  109. /// <returns></returns>
  110. static float[] CreateSparseCompatibleArray(params float[] array)
  111. {
  112. const int MAXWEIGHTS = 8;
  113. if (array == null) return null;
  114. var threshold =array
  115. .Select(item => Math.Abs(item))
  116. .OrderByDescending(item => item)
  117. .Take(MAXWEIGHTS)
  118. .Min();
  119. var array2 = new float[array.Length];
  120. var c = 0;
  121. for(int i=0; i < array2.Length; ++i)
  122. {
  123. var v = array[i];
  124. if (v == 0) continue;
  125. if (Math.Abs(v) >= threshold)
  126. {
  127. array2[i] = v;
  128. ++c;
  129. if (c >= MAXWEIGHTS) return array2;
  130. }
  131. }
  132. return array2;
  133. }
  134. static void CheckWeightOrdered(SparseWeight8 sparse)
  135. {
  136. Assert.Multiple(() =>
  137. {
  138. Assert.That(Math.Abs(sparse.Weight0), Is.GreaterThanOrEqualTo(Math.Abs(sparse.Weight1)));
  139. Assert.That(Math.Abs(sparse.Weight1), Is.GreaterThanOrEqualTo(Math.Abs(sparse.Weight2)));
  140. Assert.That(Math.Abs(sparse.Weight2), Is.GreaterThanOrEqualTo(Math.Abs(sparse.Weight3)));
  141. Assert.That(Math.Abs(sparse.Weight3), Is.GreaterThanOrEqualTo(Math.Abs(sparse.Weight4)));
  142. Assert.That(Math.Abs(sparse.Weight4), Is.GreaterThanOrEqualTo(Math.Abs(sparse.Weight5)));
  143. Assert.That(Math.Abs(sparse.Weight5), Is.GreaterThanOrEqualTo(Math.Abs(sparse.Weight6)));
  144. Assert.That(Math.Abs(sparse.Weight6), Is.GreaterThanOrEqualTo(Math.Abs(sparse.Weight7)));
  145. });
  146. }
  147. static void CheckIndexOrdered(SparseWeight8 sparse)
  148. {
  149. var pairs = sparse.GetIndexedWeights();
  150. bool zeroFound = false;
  151. long lastIndex = long.MinValue;
  152. foreach(var (index,weight) in pairs)
  153. {
  154. if (weight == 0) zeroFound = true;
  155. if (zeroFound)
  156. {
  157. Assert.That(index, Is.EqualTo(0));
  158. Assert.That(weight, Is.EqualTo(0));
  159. continue;
  160. }
  161. Assert.That(index, Is.GreaterThan(lastIndex));
  162. lastIndex = index;
  163. }
  164. }
  165. [Test]
  166. public void TestSparseNormalization()
  167. {
  168. var sparse1 = SparseWeight8
  169. .Create(0, 0, 0, 0, 0, 0.1f, 0.7f, 0, 0, 0, 0.1f)
  170. .GetNormalizedWithComplement(int.MaxValue);
  171. Assert.That(sparse1[5], Is.EqualTo(0.1f));
  172. Assert.That(sparse1[6], Is.EqualTo(0.7f));
  173. Assert.That(sparse1[10], Is.EqualTo(0.1f));
  174. Assert.That(sparse1[int.MaxValue], Is.EqualTo(0.1f).Within(0.0000001f));
  175. Assert.That(sparse1.WeightSum, Is.EqualTo(1));
  176. }
  177. [Test]
  178. public void TestSparseEquality()
  179. {
  180. Assert.That(SparseWeight8.AreEqual(SparseWeight8.Create(0, 1), SparseWeight8.Create(0, 1)), Is.True);
  181. Assert.That(SparseWeight8.AreEqual(SparseWeight8.Create(0, 1), SparseWeight8.Create(0, 1, 0.25f)), Is.False);
  182. Assert.That(SparseWeight8.AreEqual(SparseWeight8.Create(0, 1), SparseWeight8.Create(1, 0)), Is.False);
  183. // check if two "half weights" are equal to one "full weight"
  184. //Assert.IsTrue(SparseWeight8.AreWeightsEqual(SparseWeight8.Create((3, 5), (3, 5)), SparseWeight8.Create((3, 10))));
  185. }
  186. [Test]
  187. public void TestSparseWeightsLinearInterpolation1()
  188. {
  189. var x = SparseWeight8.Create(0,0,1,2); Assert.That(x.Expand(4), Is.EqualTo(new[] { 0f, 0f, 1f, 2f }));
  190. var y = SparseWeight8.Create(1,2,0,0); Assert.That(y.Expand(4), Is.EqualTo(new[] { 1f, 2f, 0f, 0f }));
  191. var z = SparseWeight8.InterpolateLinear(x, y, 0.5f);
  192. Assert.That(z[0], Is.EqualTo(0.5f));
  193. Assert.That(z[1], Is.EqualTo(1));
  194. Assert.That(z[2], Is.EqualTo(0.5f));
  195. Assert.That(z[3], Is.EqualTo(1));
  196. }
  197. [Test]
  198. public void TestSparseWeightsLinearInterpolation2()
  199. {
  200. var ax = new float[] { 0, 0, 0, 0, 0, 0.1f, 0.7f, 0, 0, 0, 0.1f };
  201. var ay = new float[] { 0, 0, 0.2f, 0, 0.1f, 0, 0, 0, 0, 0, 0, 0, 0.2f };
  202. var cc = Math.Min(ax.Length, ay.Length);
  203. var x = SparseWeight8.Create(ax); Assert.That(x.Expand(ax.Length), Is.EqualTo(ax));
  204. var y = SparseWeight8.Create(ay); Assert.That(y.Expand(ay.Length), Is.EqualTo(ay));
  205. var z = SparseWeight8.InterpolateLinear(x, y, 0.5f);
  206. for (int i=0; i < cc; ++i)
  207. {
  208. var w = (ax[i] + ay[i]) / 2;
  209. Assert.That(z[i], Is.EqualTo(w));
  210. }
  211. }
  212. [Test]
  213. public void TestSparseWeightsCubicInterpolation()
  214. {
  215. var a = SparseWeight8.Create(0, 0, 0.2f, 0, 0, 0, 1);
  216. var b = SparseWeight8.Create(1, 1, 0.4f, 0, 0, 1, 0);
  217. var t = SparseWeight8.Subtract(b, a);
  218. Assert.That(t[0], Is.EqualTo(1));
  219. Assert.That(t[1], Is.EqualTo(1));
  220. Assert.That(t[2], Is.EqualTo(0.2f));
  221. Assert.That(t[3], Is.EqualTo(0));
  222. Assert.That(t[4], Is.EqualTo(0));
  223. Assert.That(t[5], Is.EqualTo(1));
  224. Assert.That(t[6], Is.EqualTo(-1));
  225. var lr = SparseWeight8.InterpolateLinear(a, b, 0.4f);
  226. var cr = SparseWeight8.InterpolateCubic(a, t, b, t, 0.4f);
  227. Assert.That(cr[0], Is.EqualTo(lr[0]).Within(0.000001f));
  228. Assert.That(cr[1], Is.EqualTo(lr[1]).Within(0.000001f));
  229. Assert.That(cr[2], Is.EqualTo(lr[2]).Within(0.000001f));
  230. Assert.That(cr[3], Is.EqualTo(lr[3]).Within(0.000001f));
  231. Assert.That(cr[4], Is.EqualTo(lr[4]).Within(0.000001f));
  232. Assert.That(cr[5], Is.EqualTo(lr[5]).Within(0.000001f));
  233. Assert.That(cr[6], Is.EqualTo(lr[6]).Within(0.000001f));
  234. Assert.That(cr[7], Is.EqualTo(lr[7]).Within(0.000001f));
  235. }
  236. [Test]
  237. public void TestSparseWeightReduction()
  238. {
  239. var a = SparseWeight8.Create(5, 3, 2, 4, 0, 4, 2, 6, 3, 6, 1);
  240. var b = a.GetTrimmed(4);
  241. Assert.That(b.GetNonZeroWeights().Count(), Is.EqualTo(4));
  242. Assert.That(b[0], Is.EqualTo(a[0]));
  243. Assert.That(b[3], Is.EqualTo(a[3]));
  244. Assert.That(b[7], Is.EqualTo(a[7]));
  245. Assert.That(b[9], Is.EqualTo(a[9]));
  246. Assert.That(b.Weight4, Is.EqualTo(0));
  247. Assert.That(b.Weight5, Is.EqualTo(0));
  248. Assert.That(b.Weight6, Is.EqualTo(0));
  249. Assert.That(b.Weight7, Is.EqualTo(0));
  250. }
  251. }
  252. }