basisu_ssim.cpp 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. // basisu_ssim.cpp
  2. // Copyright (C) 2019-2024 Binomial LLC. All Rights Reserved.
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. //
  8. // http://www.apache.org/licenses/LICENSE-2.0
  9. //
  10. // Unless required by applicable law or agreed to in writing, software
  11. // distributed under the License is distributed on an "AS IS" BASIS,
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. // See the License for the specific language governing permissions and
  14. // limitations under the License.
  15. #include "basisu_ssim.h"
  16. #ifndef M_PI
  17. #define M_PI 3.14159265358979323846
  18. #endif
  19. namespace basisu
  20. {
  21. float gauss(int x, int y, float sigma_sqr)
  22. {
  23. float pow = expf(-((x * x + y * y) / (2.0f * sigma_sqr)));
  24. float g = (1.0f / (sqrtf((float)(2.0f * M_PI * sigma_sqr)))) * pow;
  25. return g;
  26. }
  27. // size_x/y should be odd
  28. void compute_gaussian_kernel(float *pDst, int size_x, int size_y, float sigma_sqr, uint32_t flags)
  29. {
  30. assert(size_x & size_y & 1);
  31. if (!(size_x | size_y))
  32. return;
  33. int mid_x = size_x / 2;
  34. int mid_y = size_y / 2;
  35. double sum = 0;
  36. for (int x = 0; x < size_x; x++)
  37. {
  38. for (int y = 0; y < size_y; y++)
  39. {
  40. float g;
  41. if ((x > mid_x) && (y < mid_y))
  42. g = pDst[(size_x - x - 1) + y * size_x];
  43. else if ((x < mid_x) && (y > mid_y))
  44. g = pDst[x + (size_y - y - 1) * size_x];
  45. else if ((x > mid_x) && (y > mid_y))
  46. g = pDst[(size_x - x - 1) + (size_y - y - 1) * size_x];
  47. else
  48. g = gauss(x - mid_x, y - mid_y, sigma_sqr);
  49. pDst[x + y * size_x] = g;
  50. sum += g;
  51. }
  52. }
  53. if (flags & cComputeGaussianFlagNormalizeCenterToOne)
  54. {
  55. sum = pDst[mid_x + mid_y * size_x];
  56. }
  57. if (flags & (cComputeGaussianFlagNormalizeCenterToOne | cComputeGaussianFlagNormalize))
  58. {
  59. double one_over_sum = 1.0f / sum;
  60. for (int i = 0; i < size_x * size_y; i++)
  61. pDst[i] = static_cast<float>(pDst[i] * one_over_sum);
  62. if (flags & cComputeGaussianFlagNormalizeCenterToOne)
  63. pDst[mid_x + mid_y * size_x] = 1.0f;
  64. }
  65. if (flags & cComputeGaussianFlagPrint)
  66. {
  67. printf("{\n");
  68. for (int y = 0; y < size_y; y++)
  69. {
  70. printf(" ");
  71. for (int x = 0; x < size_x; x++)
  72. {
  73. printf("%f, ", pDst[x + y * size_x]);
  74. }
  75. printf("\n");
  76. }
  77. printf("}");
  78. }
  79. }
  80. void gaussian_filter(imagef &dst, const imagef &orig_img, uint32_t odd_filter_width, float sigma_sqr, bool wrapping, uint32_t width_divisor, uint32_t height_divisor)
  81. {
  82. assert(&dst != &orig_img);
  83. assert(odd_filter_width && (odd_filter_width & 1));
  84. odd_filter_width |= 1;
  85. vector2D<float> kernel(odd_filter_width, odd_filter_width);
  86. compute_gaussian_kernel(kernel.get_ptr(), odd_filter_width, odd_filter_width, sigma_sqr, cComputeGaussianFlagNormalize);
  87. const int dst_width = orig_img.get_width() / width_divisor;
  88. const int dst_height = orig_img.get_height() / height_divisor;
  89. const int H = odd_filter_width / 2;
  90. const int L = -H;
  91. dst.crop(dst_width, dst_height);
  92. //#pragma omp parallel for
  93. for (int oy = 0; oy < dst_height; oy++)
  94. {
  95. for (int ox = 0; ox < dst_width; ox++)
  96. {
  97. vec4F c(0.0f);
  98. for (int yd = L; yd <= H; yd++)
  99. {
  100. int y = oy * height_divisor + (height_divisor >> 1) + yd;
  101. for (int xd = L; xd <= H; xd++)
  102. {
  103. int x = ox * width_divisor + (width_divisor >> 1) + xd;
  104. const vec4F &p = orig_img.get_clamped_or_wrapped(x, y, wrapping, wrapping);
  105. float w = kernel(xd + H, yd + H);
  106. c[0] += p[0] * w;
  107. c[1] += p[1] * w;
  108. c[2] += p[2] * w;
  109. c[3] += p[3] * w;
  110. }
  111. }
  112. dst(ox, oy).set(c[0], c[1], c[2], c[3]);
  113. }
  114. }
  115. }
  116. void pow_image(const imagef &src, imagef &dst, const vec4F &power)
  117. {
  118. dst.resize(src);
  119. //#pragma omp parallel for
  120. for (int y = 0; y < (int)dst.get_height(); y++)
  121. {
  122. for (uint32_t x = 0; x < dst.get_width(); x++)
  123. {
  124. const vec4F &p = src(x, y);
  125. if ((power[0] == 2.0f) && (power[1] == 2.0f) && (power[2] == 2.0f) && (power[3] == 2.0f))
  126. dst(x, y).set(p[0] * p[0], p[1] * p[1], p[2] * p[2], p[3] * p[3]);
  127. else
  128. dst(x, y).set(powf(p[0], power[0]), powf(p[1], power[1]), powf(p[2], power[2]), powf(p[3], power[3]));
  129. }
  130. }
  131. }
  132. void mul_image(const imagef &src, imagef &dst, const vec4F &mul)
  133. {
  134. dst.resize(src);
  135. //#pragma omp parallel for
  136. for (int y = 0; y < (int)dst.get_height(); y++)
  137. {
  138. for (uint32_t x = 0; x < dst.get_width(); x++)
  139. {
  140. const vec4F &p = src(x, y);
  141. dst(x, y).set(p[0] * mul[0], p[1] * mul[1], p[2] * mul[2], p[3] * mul[3]);
  142. }
  143. }
  144. }
  145. void scale_image(const imagef &src, imagef &dst, const vec4F &scale, const vec4F &shift)
  146. {
  147. dst.resize(src);
  148. //#pragma omp parallel for
  149. for (int y = 0; y < (int)dst.get_height(); y++)
  150. {
  151. for (uint32_t x = 0; x < dst.get_width(); x++)
  152. {
  153. const vec4F &p = src(x, y);
  154. vec4F d;
  155. for (uint32_t c = 0; c < 4; c++)
  156. d[c] = scale[c] * p[c] + shift[c];
  157. dst(x, y).set(d[0], d[1], d[2], d[3]);
  158. }
  159. }
  160. }
  161. void add_weighted_image(const imagef &src1, const vec4F &alpha, const imagef &src2, const vec4F &beta, const vec4F &gamma, imagef &dst)
  162. {
  163. dst.resize(src1);
  164. //#pragma omp parallel for
  165. for (int y = 0; y < (int)dst.get_height(); y++)
  166. {
  167. for (uint32_t x = 0; x < dst.get_width(); x++)
  168. {
  169. const vec4F &s1 = src1(x, y);
  170. const vec4F &s2 = src2(x, y);
  171. dst(x, y).set(
  172. s1[0] * alpha[0] + s2[0] * beta[0] + gamma[0],
  173. s1[1] * alpha[1] + s2[1] * beta[1] + gamma[1],
  174. s1[2] * alpha[2] + s2[2] * beta[2] + gamma[2],
  175. s1[3] * alpha[3] + s2[3] * beta[3] + gamma[3]);
  176. }
  177. }
  178. }
  179. void add_image(const imagef &src1, const imagef &src2, imagef &dst)
  180. {
  181. dst.resize(src1);
  182. //#pragma omp parallel for
  183. for (int y = 0; y < (int)dst.get_height(); y++)
  184. {
  185. for (uint32_t x = 0; x < dst.get_width(); x++)
  186. {
  187. const vec4F &s1 = src1(x, y);
  188. const vec4F &s2 = src2(x, y);
  189. dst(x, y).set(s1[0] + s2[0], s1[1] + s2[1], s1[2] + s2[2], s1[3] + s2[3]);
  190. }
  191. }
  192. }
  193. void adds_image(const imagef &src, const vec4F &value, imagef &dst)
  194. {
  195. dst.resize(src);
  196. //#pragma omp parallel for
  197. for (int y = 0; y < (int)dst.get_height(); y++)
  198. {
  199. for (uint32_t x = 0; x < dst.get_width(); x++)
  200. {
  201. const vec4F &p = src(x, y);
  202. dst(x, y).set(p[0] + value[0], p[1] + value[1], p[2] + value[2], p[3] + value[3]);
  203. }
  204. }
  205. }
  206. void mul_image(const imagef &src1, const imagef &src2, imagef &dst, const vec4F &scale)
  207. {
  208. dst.resize(src1);
  209. //#pragma omp parallel for
  210. for (int y = 0; y < (int)dst.get_height(); y++)
  211. {
  212. for (uint32_t x = 0; x < dst.get_width(); x++)
  213. {
  214. const vec4F &s1 = src1(x, y);
  215. const vec4F &s2 = src2(x, y);
  216. vec4F d;
  217. for (uint32_t c = 0; c < 4; c++)
  218. {
  219. float v1 = s1[c];
  220. float v2 = s2[c];
  221. d[c] = v1 * v2 * scale[c];
  222. }
  223. dst(x, y) = d;
  224. }
  225. }
  226. }
  227. void div_image(const imagef &src1, const imagef &src2, imagef &dst, const vec4F &scale)
  228. {
  229. dst.resize(src1);
  230. //#pragma omp parallel for
  231. for (int y = 0; y < (int)dst.get_height(); y++)
  232. {
  233. for (uint32_t x = 0; x < dst.get_width(); x++)
  234. {
  235. const vec4F &s1 = src1(x, y);
  236. const vec4F &s2 = src2(x, y);
  237. vec4F d;
  238. for (uint32_t c = 0; c < 4; c++)
  239. {
  240. float v = s2[c];
  241. if (v == 0.0f)
  242. d[c] = 0.0f;
  243. else
  244. d[c] = (s1[c] * scale[c]) / v;
  245. }
  246. dst(x, y) = d;
  247. }
  248. }
  249. }
  250. vec4F avg_image(const imagef &src)
  251. {
  252. vec4F avg(0.0f);
  253. for (uint32_t y = 0; y < src.get_height(); y++)
  254. {
  255. for (uint32_t x = 0; x < src.get_width(); x++)
  256. {
  257. const vec4F &s = src(x, y);
  258. avg += vec4F(s[0], s[1], s[2], s[3]);
  259. }
  260. }
  261. avg /= static_cast<float>(src.get_total_pixels());
  262. return avg;
  263. }
  264. // Reference: https://ece.uwaterloo.ca/~z70wang/research/ssim/index.html
  265. vec4F compute_ssim(const imagef &a, const imagef &b)
  266. {
  267. imagef axb, a_sq, b_sq, mu1, mu2, mu1_sq, mu2_sq, mu1_mu2, s1_sq, s2_sq, s12, smap, t1, t2, t3;
  268. const float C1 = 6.50250f, C2 = 58.52250f;
  269. pow_image(a, a_sq, vec4F(2));
  270. pow_image(b, b_sq, vec4F(2));
  271. mul_image(a, b, axb, vec4F(1.0f));
  272. gaussian_filter(mu1, a, 11, 1.5f * 1.5f);
  273. gaussian_filter(mu2, b, 11, 1.5f * 1.5f);
  274. pow_image(mu1, mu1_sq, vec4F(2));
  275. pow_image(mu2, mu2_sq, vec4F(2));
  276. mul_image(mu1, mu2, mu1_mu2, vec4F(1.0f));
  277. gaussian_filter(s1_sq, a_sq, 11, 1.5f * 1.5f);
  278. add_weighted_image(s1_sq, vec4F(1), mu1_sq, vec4F(-1), vec4F(0), s1_sq);
  279. gaussian_filter(s2_sq, b_sq, 11, 1.5f * 1.5f);
  280. add_weighted_image(s2_sq, vec4F(1), mu2_sq, vec4F(-1), vec4F(0), s2_sq);
  281. gaussian_filter(s12, axb, 11, 1.5f * 1.5f);
  282. add_weighted_image(s12, vec4F(1), mu1_mu2, vec4F(-1), vec4F(0), s12);
  283. scale_image(mu1_mu2, t1, vec4F(2), vec4F(0));
  284. adds_image(t1, vec4F(C1), t1);
  285. scale_image(s12, t2, vec4F(2), vec4F(0));
  286. adds_image(t2, vec4F(C2), t2);
  287. mul_image(t1, t2, t3, vec4F(1));
  288. add_image(mu1_sq, mu2_sq, t1);
  289. adds_image(t1, vec4F(C1), t1);
  290. add_image(s1_sq, s2_sq, t2);
  291. adds_image(t2, vec4F(C2), t2);
  292. mul_image(t1, t2, t1, vec4F(1));
  293. div_image(t3, t1, smap, vec4F(1));
  294. return avg_image(smap);
  295. }
  296. vec4F compute_ssim(const image &a, const image &b, bool luma, bool luma_601)
  297. {
  298. image ta(a), tb(b);
  299. if ((ta.get_width() != tb.get_width()) || (ta.get_height() != tb.get_height()))
  300. {
  301. debug_printf("compute_ssim: Cropping input images to equal dimensions\n");
  302. const uint32_t w = minimum(a.get_width(), b.get_width());
  303. const uint32_t h = minimum(a.get_height(), b.get_height());
  304. ta.crop(w, h);
  305. tb.crop(w, h);
  306. }
  307. if (!ta.get_width() || !ta.get_height())
  308. {
  309. assert(0);
  310. return vec4F(0);
  311. }
  312. if (luma)
  313. {
  314. for (uint32_t y = 0; y < ta.get_height(); y++)
  315. {
  316. for (uint32_t x = 0; x < ta.get_width(); x++)
  317. {
  318. ta(x, y).set(ta(x, y).get_luma(luma_601), ta(x, y).a);
  319. tb(x, y).set(tb(x, y).get_luma(luma_601), tb(x, y).a);
  320. }
  321. }
  322. }
  323. imagef fta, ftb;
  324. fta.set(ta);
  325. ftb.set(tb);
  326. return compute_ssim(fta, ftb);
  327. }
  328. } // namespace basisu