vf_train.c 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999
  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <string.h>
  4. #define stop() __debugbreak()
  5. #include <windows.h>
  6. #define int64 __int64
  7. #pragma warning(disable:4127)
  8. #define STBIR__WEIGHT_TABLES
  9. #define STBIR_PROFILE
  10. #define STB_IMAGE_RESIZE_IMPLEMENTATION
  11. #include "stb_image_resize2.h"
  12. static int * file_read( char const * filename )
  13. {
  14. size_t s;
  15. int * m;
  16. FILE * f = fopen( filename, "rb" );
  17. if ( f == 0 ) return 0;
  18. fseek( f, 0, SEEK_END);
  19. s = ftell( f );
  20. fseek( f, 0, SEEK_SET);
  21. m = malloc( s + 4 );
  22. m[0] = (int)s;
  23. fread( m+1, 1, s, f);
  24. fclose(f);
  25. return( m );
  26. }
  27. typedef struct fileinfo
  28. {
  29. int * timings;
  30. int timing_count;
  31. int dimensionx, dimensiony;
  32. int numtypes;
  33. int * types;
  34. int * effective;
  35. int cpu;
  36. int simd;
  37. int numinputrects;
  38. int * inputrects;
  39. int outputscalex, outputscaley;
  40. int milliseconds;
  41. int64 cycles;
  42. double scale_time;
  43. int bitmapx, bitmapy;
  44. char const * filename;
  45. } fileinfo;
  46. int numfileinfo;
  47. fileinfo fi[256];
  48. unsigned char * bitmap;
  49. int bitmapw, bitmaph, bitmapp;
  50. static int use_timing_file( char const * filename, int index )
  51. {
  52. int * base = file_read( filename );
  53. int * file = base;
  54. if ( base == 0 ) return 0;
  55. ++file; // skip file image size;
  56. if ( *file++ != 'VFT1' ) return 0;
  57. fi[index].cpu = *file++;
  58. fi[index].simd = *file++;
  59. fi[index].dimensionx = *file++;
  60. fi[index].dimensiony = *file++;
  61. fi[index].numtypes = *file++;
  62. fi[index].types = file; file += fi[index].numtypes;
  63. fi[index].effective = file; file += fi[index].numtypes;
  64. fi[index].numinputrects = *file++;
  65. fi[index].inputrects = file; file += fi[index].numinputrects * 2;
  66. fi[index].outputscalex = *file++;
  67. fi[index].outputscaley = *file++;
  68. fi[index].milliseconds = *file++;
  69. fi[index].cycles = ((int64*)file)[0]; file += 2;
  70. fi[index].filename = filename;
  71. fi[index].timings = file;
  72. fi[index].timing_count = (int) ( ( base[0] - ( ((char*)file - (char*)base - sizeof(int) ) ) ) / (sizeof(int)*2) );
  73. fi[index].scale_time = (double)fi[index].milliseconds / (double)fi[index].cycles;
  74. return 1;
  75. }
  76. static int vert_first( float weights_table[STBIR_RESIZE_CLASSIFICATIONS][4], int ox, int oy, int ix, int iy, int filter, STBIR__V_FIRST_INFO * v_info )
  77. {
  78. float h_scale=(float)ox/(float)(ix);
  79. float v_scale=(float)oy/(float)(iy);
  80. stbir__support_callback * support = stbir__builtin_supports[filter];
  81. int vertical_filter_width = stbir__get_filter_pixel_width(support,v_scale,0);
  82. int vertical_gather = ( v_scale >= ( 1.0f - stbir__small_float ) ) || ( vertical_filter_width <= STBIR_FORCE_GATHER_FILTER_SCANLINES_AMOUNT );
  83. return stbir__should_do_vertical_first( weights_table, stbir__get_filter_pixel_width(support,h_scale,0), h_scale, ox, vertical_filter_width, v_scale, oy, vertical_gather, v_info );
  84. }
  85. #define STB_IMAGE_WRITE_IMPLEMENTATION
  86. #include "stb_image_write.h"
  87. static void alloc_bitmap()
  88. {
  89. int findex;
  90. int x = 0, y = 0;
  91. int w = 0, h = 0;
  92. for( findex = 0 ; findex < numfileinfo ; findex++ )
  93. {
  94. int nx, ny;
  95. int thisw, thish;
  96. thisw = ( fi[findex].dimensionx * fi[findex].numtypes ) + ( fi[findex].numtypes - 1 );
  97. thish = ( fi[findex].dimensiony * fi[findex].numinputrects ) + ( fi[findex].numinputrects - 1 );
  98. for(;;)
  99. {
  100. nx = x + ((x)?4:0) + thisw;
  101. ny = y + ((y)?4:0) + thish;
  102. if ( ( nx <= 3600 ) || ( x == 0 ) )
  103. {
  104. fi[findex].bitmapx = x + ((x)?4:0);
  105. fi[findex].bitmapy = y + ((y)?4:0);
  106. x = nx;
  107. if ( x > w ) w = x;
  108. if ( ny > h ) h = ny;
  109. break;
  110. }
  111. else
  112. {
  113. x = 0;
  114. y = h;
  115. }
  116. }
  117. }
  118. w = (w+3) & ~3;
  119. bitmapw = w;
  120. bitmaph = h;
  121. bitmapp = w * 3; // RGB
  122. bitmap = malloc( bitmapp * bitmaph );
  123. memset( bitmap, 0, bitmapp * bitmaph );
  124. }
  125. static void build_bitmap( float weights[STBIR_RESIZE_CLASSIFICATIONS][4], int do_channel_count_index, int findex )
  126. {
  127. static int colors[STBIR_RESIZE_CLASSIFICATIONS];
  128. STBIR__V_FIRST_INFO v_info = {0};
  129. int * ts;
  130. int ir;
  131. unsigned char * bitm = bitmap + ( fi[findex].bitmapx*3 ) + ( fi[findex].bitmapy*bitmapp) ;
  132. for( ir = 0; ir < STBIR_RESIZE_CLASSIFICATIONS ; ir++ ) colors[ ir ] = 127*ir/STBIR_RESIZE_CLASSIFICATIONS+128;
  133. ts = fi[findex].timings;
  134. for( ir = 0 ; ir < fi[findex].numinputrects ; ir++ )
  135. {
  136. int ix, iy, chanind;
  137. ix = fi[findex].inputrects[ir*2];
  138. iy = fi[findex].inputrects[ir*2+1];
  139. for( chanind = 0 ; chanind < fi[findex].numtypes ; chanind++ )
  140. {
  141. int ofs, h, hh;
  142. // just do the type that we're on
  143. if ( chanind != do_channel_count_index )
  144. {
  145. ts += 2 * fi[findex].dimensionx * fi[findex].dimensiony;
  146. continue;
  147. }
  148. // bitmap offset
  149. ofs=chanind*(fi[findex].dimensionx+1)*3+ir*(fi[findex].dimensiony+1)*bitmapp;
  150. h = 1;
  151. for( hh = 0 ; hh < fi[findex].dimensiony; hh++ )
  152. {
  153. int ww, w = 1;
  154. for( ww = 0 ; ww < fi[findex].dimensionx; ww++ )
  155. {
  156. int good, v_first, VF, HF;
  157. VF = ts[0];
  158. HF = ts[1];
  159. v_first = vert_first( weights, w, h, ix, iy, STBIR_FILTER_MITCHELL, &v_info );
  160. good = ( ((HF<=VF) && (!v_first)) || ((VF<=HF) && (v_first)));
  161. if ( good )
  162. {
  163. bitm[ofs+2] = 0;
  164. bitm[ofs+1] = (unsigned char)colors[v_info.v_resize_classification];
  165. }
  166. else
  167. {
  168. double r;
  169. if ( HF < VF )
  170. r = (double)(VF-HF)/(double)HF;
  171. else
  172. r = (double)(HF-VF)/(double)VF;
  173. if ( r > 0.4f) r = 0.4;
  174. r *= 1.0f/0.4f;
  175. bitm[ofs+2] = (char)(255.0f*r);
  176. bitm[ofs+1] = (char)(((float)colors[v_info.v_resize_classification])*(1.0f-r));
  177. }
  178. bitm[ofs] = 0;
  179. ofs += 3;
  180. ts += 2;
  181. w += fi[findex].outputscalex;
  182. }
  183. ofs += bitmapp - fi[findex].dimensionx*3;
  184. h += fi[findex].outputscaley;
  185. }
  186. }
  187. }
  188. }
  189. static void build_comp_bitmap( float weights[STBIR_RESIZE_CLASSIFICATIONS][4], int do_channel_count_index )
  190. {
  191. int * ts0;
  192. int * ts1;
  193. int ir;
  194. unsigned char * bitm = bitmap + ( fi[0].bitmapx*3 ) + ( fi[0].bitmapy*bitmapp) ;
  195. ts0 = fi[0].timings;
  196. ts1 = fi[1].timings;
  197. for( ir = 0 ; ir < fi[0].numinputrects ; ir++ )
  198. {
  199. int ix, iy, chanind;
  200. ix = fi[0].inputrects[ir*2];
  201. iy = fi[0].inputrects[ir*2+1];
  202. for( chanind = 0 ; chanind < fi[0].numtypes ; chanind++ )
  203. {
  204. int ofs, h, hh;
  205. // just do the type that we're on
  206. if ( chanind != do_channel_count_index )
  207. {
  208. ts0 += 2 * fi[0].dimensionx * fi[0].dimensiony;
  209. ts1 += 2 * fi[0].dimensionx * fi[0].dimensiony;
  210. continue;
  211. }
  212. // bitmap offset
  213. ofs=chanind*(fi[0].dimensionx+1)*3+ir*(fi[0].dimensiony+1)*bitmapp;
  214. h = 1;
  215. for( hh = 0 ; hh < fi[0].dimensiony; hh++ )
  216. {
  217. int ww, w = 1;
  218. for( ww = 0 ; ww < fi[0].dimensionx; ww++ )
  219. {
  220. int v_first, time0, time1;
  221. v_first = vert_first( weights, w, h, ix, iy, STBIR_FILTER_MITCHELL, 0 );
  222. time0 = ( v_first ) ? ts0[0] : ts0[1];
  223. time1 = ( v_first ) ? ts1[0] : ts1[1];
  224. if ( time0 < time1 )
  225. {
  226. double r = (double)(time1-time0)/(double)time0;
  227. if ( r > 0.4f) r = 0.4;
  228. r *= 1.0f/0.4f;
  229. bitm[ofs+2] = 0;
  230. bitm[ofs+1] = (char)(255.0f*r);
  231. bitm[ofs] = (char)(64.0f*(1.0f-r));
  232. }
  233. else
  234. {
  235. double r = (double)(time0-time1)/(double)time1;
  236. if ( r > 0.4f) r = 0.4;
  237. r *= 1.0f/0.4f;
  238. bitm[ofs+2] = (char)(255.0f*r);
  239. bitm[ofs+1] = 0;
  240. bitm[ofs] = (char)(64.0f*(1.0f-r));
  241. }
  242. ofs += 3;
  243. ts0 += 2;
  244. ts1 += 2;
  245. w += fi[0].outputscalex;
  246. }
  247. ofs += bitmapp - fi[0].dimensionx*3;
  248. h += fi[0].outputscaley;
  249. }
  250. }
  251. }
  252. }
  253. static void write_bitmap()
  254. {
  255. stbi_write_png( "results.png", bitmapp / 3, bitmaph, 3|STB_IMAGE_BGR, bitmap, bitmapp );
  256. }
  257. static void calc_errors( float weights_table[STBIR_RESIZE_CLASSIFICATIONS][4], int * curtot, double * curerr, int do_channel_count_index )
  258. {
  259. int th, findex;
  260. STBIR__V_FIRST_INFO v_info = {0};
  261. for(th=0;th<STBIR_RESIZE_CLASSIFICATIONS;th++)
  262. {
  263. curerr[th]=0;
  264. curtot[th]=0;
  265. }
  266. for( findex = 0 ; findex < numfileinfo ; findex++ )
  267. {
  268. int * ts;
  269. int ir;
  270. ts = fi[findex].timings;
  271. for( ir = 0 ; ir < fi[findex].numinputrects ; ir++ )
  272. {
  273. int ix, iy, chanind;
  274. ix = fi[findex].inputrects[ir*2];
  275. iy = fi[findex].inputrects[ir*2+1];
  276. for( chanind = 0 ; chanind < fi[findex].numtypes ; chanind++ )
  277. {
  278. int h, hh;
  279. // just do the type that we're on
  280. if ( chanind != do_channel_count_index )
  281. {
  282. ts += 2 * fi[findex].dimensionx * fi[findex].dimensiony;
  283. continue;
  284. }
  285. h = 1;
  286. for( hh = 0 ; hh < fi[findex].dimensiony; hh++ )
  287. {
  288. int ww, w = 1;
  289. for( ww = 0 ; ww < fi[findex].dimensionx; ww++ )
  290. {
  291. int good, v_first, VF, HF;
  292. VF = ts[0];
  293. HF = ts[1];
  294. v_first = vert_first( weights_table, w, h, ix, iy, STBIR_FILTER_MITCHELL, &v_info );
  295. good = ( ((HF<=VF) && (!v_first)) || ((VF<=HF) && (v_first)));
  296. if ( !good )
  297. {
  298. double diff;
  299. if ( VF < HF )
  300. diff = ((double)HF-(double)VF) * fi[findex].scale_time;
  301. else
  302. diff = ((double)VF-(double)HF) * fi[findex].scale_time;
  303. curtot[v_info.v_resize_classification] += 1;
  304. curerr[v_info.v_resize_classification] += diff;
  305. }
  306. ts += 2;
  307. w += fi[findex].outputscalex;
  308. }
  309. h += fi[findex].outputscaley;
  310. }
  311. }
  312. }
  313. }
  314. }
  315. #define TRIESPERWEIGHT 32
  316. #define MAXRANGE ((TRIESPERWEIGHT+1) * (TRIESPERWEIGHT+1) * (TRIESPERWEIGHT+1) * (TRIESPERWEIGHT+1) - 1)
  317. static void expand_to_floats( float * weights, int range )
  318. {
  319. weights[0] = (float)( range % (TRIESPERWEIGHT+1) ) / (float)TRIESPERWEIGHT;
  320. weights[1] = (float)( range/(TRIESPERWEIGHT+1) % (TRIESPERWEIGHT+1) ) / (float)TRIESPERWEIGHT;
  321. weights[2] = (float)( range/(TRIESPERWEIGHT+1)/(TRIESPERWEIGHT+1) % (TRIESPERWEIGHT+1) ) / (float)TRIESPERWEIGHT;
  322. weights[3] = (float)( range/(TRIESPERWEIGHT+1)/(TRIESPERWEIGHT+1)/(TRIESPERWEIGHT+1) % (TRIESPERWEIGHT+1) ) / (float)TRIESPERWEIGHT;
  323. }
  324. static char const * expand_to_string( int range )
  325. {
  326. static char str[128];
  327. int w0,w1,w2,w3;
  328. w0 = range % (TRIESPERWEIGHT+1);
  329. w1 = range/(TRIESPERWEIGHT+1) % (TRIESPERWEIGHT+1);
  330. w2 = range/(TRIESPERWEIGHT+1)/(TRIESPERWEIGHT+1) % (TRIESPERWEIGHT+1);
  331. w3 = range/(TRIESPERWEIGHT+1)/(TRIESPERWEIGHT+1)/(TRIESPERWEIGHT+1) % (TRIESPERWEIGHT+1);
  332. sprintf( str, "[ %2d/%d %2d/%d %2d/%d %2d/%d ]",w0,TRIESPERWEIGHT,w1,TRIESPERWEIGHT,w2,TRIESPERWEIGHT,w3,TRIESPERWEIGHT );
  333. return str;
  334. }
  335. static void print_weights( float weights[STBIR_RESIZE_CLASSIFICATIONS][4], int channel_count_index, int * tots, double * errs )
  336. {
  337. int th;
  338. printf("ChInd: %d Weights:\n",channel_count_index);
  339. for(th=0;th<STBIR_RESIZE_CLASSIFICATIONS;th++)
  340. {
  341. float * w = weights[th];
  342. printf(" %d: [%1.5f %1.5f %1.5f %1.5f] (%d %.4f)\n",th, w[0], w[1], w[2], w[3], tots[th], errs[th] );
  343. }
  344. printf("\n");
  345. }
  346. static int windowranges[ 16 ];
  347. static int windowstatus = 0;
  348. static DWORD trainstart = 0;
  349. static void opt_channel( float best_output_weights[STBIR_RESIZE_CLASSIFICATIONS][4], int channel_count_index )
  350. {
  351. int newbest = 0;
  352. float weights[STBIR_RESIZE_CLASSIFICATIONS][4] = {0};
  353. double besterr[STBIR_RESIZE_CLASSIFICATIONS];
  354. int besttot[STBIR_RESIZE_CLASSIFICATIONS];
  355. int best[STBIR_RESIZE_CLASSIFICATIONS]={0};
  356. double curerr[STBIR_RESIZE_CLASSIFICATIONS];
  357. int curtot[STBIR_RESIZE_CLASSIFICATIONS];
  358. int th, range;
  359. DWORD lasttick = 0;
  360. for(th=0;th<STBIR_RESIZE_CLASSIFICATIONS;th++)
  361. {
  362. besterr[th]=1000000000000.0;
  363. besttot[th]=0x7fffffff;
  364. }
  365. newbest = 0;
  366. // try the whole range
  367. range = MAXRANGE;
  368. do
  369. {
  370. for(th=0;th<STBIR_RESIZE_CLASSIFICATIONS;th++)
  371. expand_to_floats( weights[th], range );
  372. calc_errors( weights, curtot, curerr, channel_count_index );
  373. for(th=0;th<STBIR_RESIZE_CLASSIFICATIONS;th++)
  374. {
  375. if ( curerr[th] < besterr[th] )
  376. {
  377. besterr[th] = curerr[th];
  378. besttot[th] = curtot[th];
  379. best[th] = range;
  380. expand_to_floats( best_output_weights[th], best[th] );
  381. newbest = 1;
  382. }
  383. }
  384. {
  385. DWORD t = GetTickCount();
  386. if ( range == 0 )
  387. goto do_bitmap;
  388. if ( newbest )
  389. {
  390. if ( ( GetTickCount() - lasttick ) > 200 )
  391. {
  392. int findex;
  393. do_bitmap:
  394. lasttick = t;
  395. newbest = 0;
  396. for( findex = 0 ; findex < numfileinfo ; findex++ )
  397. build_bitmap( best_output_weights, channel_count_index, findex );
  398. lasttick = GetTickCount();
  399. }
  400. }
  401. }
  402. windowranges[ channel_count_index ] = range;
  403. // advance all the weights and loop
  404. --range;
  405. } while( ( range >= 0 ) && ( !windowstatus ) );
  406. // if we hit here, then we tried all weights for this opt, so save them
  407. }
  408. static void print_struct( float weight[5][STBIR_RESIZE_CLASSIFICATIONS][4], char const * name )
  409. {
  410. printf("\n\nstatic float %s[5][STBIR_RESIZE_CLASSIFICATIONS][4]=\n{", name );
  411. {
  412. int i;
  413. for(i=0;i<5;i++)
  414. {
  415. int th;
  416. for(th=0;th<STBIR_RESIZE_CLASSIFICATIONS;th++)
  417. {
  418. int j;
  419. printf("\n ");
  420. for(j=0;j<4;j++)
  421. printf("%1.5ff, ", weight[i][th][j] );
  422. }
  423. printf("\n");
  424. }
  425. printf("\n};\n");
  426. }
  427. }
  428. static float retrain_weights[5][STBIR_RESIZE_CLASSIFICATIONS][4];
  429. static DWORD __stdcall retrain_shim( LPVOID p )
  430. {
  431. int chanind = (int) (size_t)p;
  432. opt_channel( retrain_weights[chanind], chanind );
  433. return 0;
  434. }
  435. static char const * gettime( int ms )
  436. {
  437. static char time[32];
  438. if (ms > 60000)
  439. sprintf( time, "%dm %ds",ms/60000, (ms/1000)%60 );
  440. else
  441. sprintf( time, "%ds",ms/1000 );
  442. return time;
  443. }
  444. static BITMAPINFOHEADER bmiHeader;
  445. static DWORD extrawindoww, extrawindowh;
  446. static HINSTANCE instance;
  447. static int curzoom = 1;
  448. static LRESULT WINAPI WindowProc( HWND window,
  449. UINT message,
  450. WPARAM wparam,
  451. LPARAM lparam )
  452. {
  453. switch( message )
  454. {
  455. case WM_CHAR:
  456. if ( wparam != 27 )
  457. break;
  458. // falls through
  459. case WM_CLOSE:
  460. {
  461. int i;
  462. int max = 0;
  463. for( i = 0 ; i < fi[0].numtypes ; i++ )
  464. if( windowranges[i] > max ) max = windowranges[i];
  465. if ( ( max == 0 ) || ( MessageBox( window, "Cancel before training is finished?", "Vertical First Training", MB_OKCANCEL|MB_ICONSTOP ) == IDOK ) )
  466. {
  467. for( i = 0 ; i < fi[0].numtypes ; i++ )
  468. if( windowranges[i] > max ) max = windowranges[i];
  469. if ( max )
  470. windowstatus = 1;
  471. DestroyWindow( window );
  472. }
  473. }
  474. return 0;
  475. case WM_PAINT:
  476. {
  477. PAINTSTRUCT ps;
  478. HDC dc;
  479. dc = BeginPaint( window, &ps );
  480. StretchDIBits( dc,
  481. 0, 0, bitmapw*curzoom, bitmaph*curzoom,
  482. 0, 0, bitmapw, bitmaph,
  483. bitmap, (BITMAPINFO*)&bmiHeader, DIB_RGB_COLORS, SRCCOPY );
  484. PatBlt( dc, bitmapw*curzoom, 0, 4096, 4096, WHITENESS );
  485. PatBlt( dc, 0, bitmaph*curzoom, 4096, 4096, WHITENESS );
  486. SetTextColor( dc, RGB(0,0,0) );
  487. SetBkColor( dc, RGB(255,255,255) );
  488. SetBkMode( dc, OPAQUE );
  489. {
  490. int i, l = 0, max = 0;
  491. char buf[1024];
  492. RECT rc;
  493. POINT p;
  494. for( i = 0 ; i < fi[0].numtypes ; i++ )
  495. {
  496. l += sprintf( buf + l, "channels: %d %s\n", fi[0].effective[i], windowranges[i] ? expand_to_string( windowranges[i] ) : "Done." );
  497. if ( windowranges[i] > max ) max = windowranges[i];
  498. }
  499. rc.left = 32; rc.top = bitmaph*curzoom+10;
  500. rc.right = 512; rc.bottom = rc.top + 512;
  501. DrawText( dc, buf, -1, &rc, DT_TOP );
  502. l = 0;
  503. if ( max == 0 )
  504. {
  505. static DWORD traindone = 0;
  506. if ( traindone == 0 ) traindone = GetTickCount();
  507. l = sprintf( buf, "Finished in %s.", gettime( traindone - trainstart ) );
  508. }
  509. else if ( max != MAXRANGE )
  510. l = sprintf( buf, "Done in %s...", gettime( (int) ( ( ( (int64)max * ( (int64)GetTickCount() - (int64)trainstart ) ) ) / (int64) ( MAXRANGE - max ) ) ) );
  511. GetCursorPos( &p );
  512. ScreenToClient( window, &p );
  513. if ( ( p.x >= 0 ) && ( p.y >= 0 ) && ( p.x < (bitmapw*curzoom) ) && ( p.y < (bitmaph*curzoom) ) )
  514. {
  515. int findex;
  516. int x, y, w, h, sx, sy, ix, iy, ox, oy;
  517. int ir, chanind;
  518. int * ts;
  519. char badstr[64];
  520. STBIR__V_FIRST_INFO v_info={0};
  521. p.x /= curzoom;
  522. p.y /= curzoom;
  523. for( findex = 0 ; findex < numfileinfo ; findex++ )
  524. {
  525. x = fi[findex].bitmapx;
  526. y = fi[findex].bitmapy;
  527. w = x + ( fi[findex].dimensionx + 1 ) * fi[findex].numtypes;
  528. h = y + ( fi[findex].dimensiony + 1 ) * fi[findex].numinputrects;
  529. if ( ( p.x >= x ) && ( p.y >= y ) && ( p.x < w ) && ( p.y < h ) )
  530. goto found;
  531. }
  532. goto nope;
  533. found:
  534. ir = ( p.y - y ) / ( fi[findex].dimensiony + 1 );
  535. sy = ( p.y - y ) % ( fi[findex].dimensiony + 1 );
  536. if ( sy >= fi[findex].dimensiony ) goto nope;
  537. chanind = ( p.x - x ) / ( fi[findex].dimensionx + 1 );
  538. sx = ( p.x - x ) % ( fi[findex].dimensionx + 1 );
  539. if ( sx >= fi[findex].dimensionx ) goto nope;
  540. ix = fi[findex].inputrects[ir*2];
  541. iy = fi[findex].inputrects[ir*2+1];
  542. ts = fi[findex].timings + ( ( fi[findex].dimensionx * fi[findex].dimensiony * fi[findex].numtypes * ir ) + ( fi[findex].dimensionx * fi[findex].dimensiony * chanind ) + ( fi[findex].dimensionx * sy ) + sx ) * 2;
  543. ox = 1+fi[findex].outputscalex*sx;
  544. oy = 1+fi[findex].outputscaley*sy;
  545. if ( windowstatus != 2 )
  546. {
  547. int VF, HF, v_first, good;
  548. VF = ts[0];
  549. HF = ts[1];
  550. v_first = vert_first( retrain_weights[chanind], ox, oy, ix, iy, STBIR_FILTER_MITCHELL, &v_info );
  551. good = ( ((HF<=VF) && (!v_first)) || ((VF<=HF) && (v_first)));
  552. if ( good )
  553. badstr[0] = 0;
  554. else
  555. {
  556. double r;
  557. if ( HF < VF )
  558. r = (double)(VF-HF)/(double)HF;
  559. else
  560. r = (double)(HF-VF)/(double)VF;
  561. sprintf( badstr, " %.1f%% off", r*100 );
  562. }
  563. sprintf( buf + l, "\n\n%s\nCh: %d Resize: %dx%d to %dx%d\nV: %d H: %d Order: %c (%s%s)\nClass: %d Scale: %.2f %s", fi[findex].filename,fi[findex].effective[chanind], ix,iy,ox,oy, VF, HF, v_first?'V':'H', good?"Good":"Wrong", badstr, v_info.v_resize_classification, (double)oy/(double)iy, v_info.is_gather ? "Gather" : "Scatter" );
  564. }
  565. else
  566. {
  567. int v_first, time0, time1;
  568. float (* weights)[4] = stbir__compute_weights[chanind];
  569. int * ts1;
  570. char b0[32], b1[32];
  571. ts1 = fi[1].timings + ( ts - fi[0].timings );
  572. v_first = vert_first( weights, ox, oy, ix, iy, STBIR_FILTER_MITCHELL, &v_info );
  573. time0 = ( v_first ) ? ts[0] : ts[1];
  574. time1 = ( v_first ) ? ts1[0] : ts1[1];
  575. b0[0] = b1[0] = 0;
  576. if ( time0 < time1 )
  577. sprintf( b0," (%.f%% better)", ((double)time1-(double)time0)*100.0f/(double)time0);
  578. else
  579. sprintf( b1," (%.f%% better)", ((double)time0-(double)time1)*100.0f/(double)time1);
  580. sprintf( buf + l, "\n\n0: %s\n1: %s\nCh: %d Resize: %dx%d to %dx%d\nClass: %d Scale: %.2f %s\nTime0: %d%s\nTime1: %d%s", fi[0].filename, fi[1].filename, fi[0].effective[chanind], ix,iy,ox,oy, v_info.v_resize_classification, (double)oy/(double)iy, v_info.is_gather ? "Gather" : "Scatter", time0, b0, time1, b1 );
  581. }
  582. }
  583. nope:
  584. rc.left = 32+320; rc.right = 512+320;
  585. SetTextColor( dc, RGB(0,0,128) );
  586. DrawText( dc, buf, -1, &rc, DT_TOP );
  587. }
  588. EndPaint( window, &ps );
  589. return 0;
  590. }
  591. case WM_TIMER:
  592. InvalidateRect( window, 0, 0 );
  593. return 0;
  594. case WM_DESTROY:
  595. PostQuitMessage( 0 );
  596. return 0;
  597. }
  598. return DefWindowProc( window, message, wparam, lparam );
  599. }
  600. static void SetHighDPI(void)
  601. {
  602. typedef HRESULT WINAPI setdpitype(int v);
  603. HMODULE h=LoadLibrary("Shcore.dll");
  604. if (h)
  605. {
  606. setdpitype * sd = (setdpitype*)GetProcAddress(h,"SetProcessDpiAwareness");
  607. if (sd )
  608. sd(1);
  609. }
  610. }
  611. static void draw_window()
  612. {
  613. WNDCLASS wc;
  614. HWND w;
  615. MSG msg;
  616. instance = GetModuleHandle(NULL);
  617. wc.style = 0;
  618. wc.lpfnWndProc = WindowProc;
  619. wc.cbClsExtra = 0;
  620. wc.cbWndExtra = 0;
  621. wc.hInstance = instance;
  622. wc.hIcon = 0;
  623. wc.hCursor = LoadCursor(NULL, IDC_ARROW);
  624. wc.hbrBackground = 0;
  625. wc.lpszMenuName = 0;
  626. wc.lpszClassName = "WHTrain";
  627. if ( !RegisterClass( &wc ) )
  628. exit(1);
  629. SetHighDPI();
  630. bmiHeader.biSize = sizeof(BITMAPINFOHEADER);
  631. bmiHeader.biWidth = bitmapp/3;
  632. bmiHeader.biHeight = -bitmaph;
  633. bmiHeader.biPlanes = 1;
  634. bmiHeader.biBitCount = 24;
  635. bmiHeader.biCompression = BI_RGB;
  636. w = CreateWindow( "WHTrain",
  637. "Vertical First Training",
  638. WS_CAPTION | WS_POPUP| WS_CLIPCHILDREN |
  639. WS_SYSMENU | WS_MINIMIZEBOX | WS_SIZEBOX,
  640. CW_USEDEFAULT,CW_USEDEFAULT,
  641. CW_USEDEFAULT,CW_USEDEFAULT,
  642. 0, 0, instance, 0 );
  643. {
  644. RECT r, c;
  645. GetWindowRect( w, &r );
  646. GetClientRect( w, &c );
  647. extrawindoww = ( r.right - r.left ) - ( c.right - c.left );
  648. extrawindowh = ( r.bottom - r.top ) - ( c.bottom - c.top );
  649. SetWindowPos( w, 0, 0, 0, bitmapw * curzoom + extrawindoww, bitmaph * curzoom + extrawindowh + 164, SWP_NOMOVE );
  650. }
  651. ShowWindow( w, SW_SHOWNORMAL );
  652. SetTimer( w, 1, 250, 0 );
  653. {
  654. BOOL ret;
  655. while( ( ret = GetMessage( &msg, w, 0, 0 ) ) != 0 )
  656. {
  657. if ( ret == -1 )
  658. break;
  659. TranslateMessage( &msg );
  660. DispatchMessage( &msg );
  661. }
  662. }
  663. }
  664. static void retrain()
  665. {
  666. HANDLE threads[ 16 ];
  667. int chanind;
  668. trainstart = GetTickCount();
  669. for( chanind = 0 ; chanind < fi[0].numtypes ; chanind++ )
  670. threads[ chanind ] = CreateThread( 0, 2048*1024, retrain_shim, (LPVOID)(size_t)chanind, 0, 0 );
  671. draw_window();
  672. for( chanind = 0 ; chanind < fi[0].numtypes ; chanind++ )
  673. {
  674. WaitForSingleObject( threads[ chanind ], INFINITE );
  675. CloseHandle( threads[ chanind ] );
  676. }
  677. write_bitmap();
  678. print_struct( retrain_weights, "retained_weights" );
  679. if ( windowstatus ) printf( "CANCELLED!\n" );
  680. }
  681. static void info()
  682. {
  683. int findex;
  684. // display info about each input file
  685. for( findex = 0 ; findex < numfileinfo ; findex++ )
  686. {
  687. int i, h,m,s;
  688. if ( findex ) printf( "\n" );
  689. printf( "Timing file: %s\n", fi[findex].filename );
  690. printf( "CPU type: %d %s\n", fi[findex].cpu, fi[findex].simd?(fi[findex].simd==2?"SIMD8":"SIMD4"):"Scalar" );
  691. h = fi[findex].milliseconds/3600000;
  692. m = (fi[findex].milliseconds-h*3600000)/60000;
  693. s = (fi[findex].milliseconds-h*3600000-m*60000)/1000;
  694. printf( "Total time in test: %dh %dm %ds Cycles/sec: %.f\n", h,m,s, 1000.0/fi[findex].scale_time );
  695. printf( "Each tile of samples is %dx%d, and is scaled by %dx%d.\n", fi[findex].dimensionx,fi[findex].dimensiony, fi[findex].outputscalex,fi[findex].outputscaley );
  696. printf( "So the x coords are: " );
  697. for( i=0; i < fi[findex].dimensionx ; i++ ) printf( "%d ",1+i*fi[findex].outputscalex );
  698. printf( "\n" );
  699. printf( "And the y coords are: " );
  700. for( i=0; i < fi[findex].dimensiony ; i++ ) printf( "%d ",1+i*fi[findex].outputscaley );
  701. printf( "\n" );
  702. printf( "There are %d channel counts and they are: ", fi[findex].numtypes );
  703. for( i=0; i < fi[findex].numtypes ; i++ ) printf( "%d ",fi[findex].effective[i] );
  704. printf( "\n" );
  705. printf( "There are %d input rect sizes and they are: ", fi[findex].numinputrects );
  706. for( i=0; i < fi[findex].numtypes ; i++ ) printf( "%dx%d ",fi[findex].inputrects[i*2],fi[findex].inputrects[i*2+1] );
  707. printf( "\n" );
  708. }
  709. }
  710. static void current( int do_win, int do_bitmap )
  711. {
  712. int i, findex;
  713. trainstart = GetTickCount();
  714. // clear progress
  715. memset( windowranges, 0, sizeof( windowranges ) );
  716. // copy in appropriate weights
  717. memcpy( retrain_weights, stbir__compute_weights, sizeof( retrain_weights ) );
  718. // build and print current errors and build current bitmap
  719. for( i = 0 ; i < fi[0].numtypes ; i++ )
  720. {
  721. double curerr[STBIR_RESIZE_CLASSIFICATIONS];
  722. int curtot[STBIR_RESIZE_CLASSIFICATIONS];
  723. float (* weights)[4] = retrain_weights[i];
  724. calc_errors( weights, curtot, curerr, i );
  725. if ( !do_bitmap )
  726. print_weights( weights, i, curtot, curerr );
  727. for( findex = 0 ; findex < numfileinfo ; findex++ )
  728. build_bitmap( weights, i, findex );
  729. }
  730. if ( do_win )
  731. draw_window();
  732. if ( do_bitmap )
  733. write_bitmap();
  734. }
  735. static void compare()
  736. {
  737. int i;
  738. trainstart = GetTickCount();
  739. windowstatus = 2; // comp mode
  740. // clear progress
  741. memset( windowranges, 0, sizeof( windowranges ) );
  742. if ( ( fi[0].numtypes != fi[1].numtypes ) || ( fi[0].numinputrects != fi[1].numinputrects ) ||
  743. ( fi[0].dimensionx != fi[1].dimensionx ) || ( fi[0].dimensiony != fi[1].dimensiony ) ||
  744. ( fi[0].outputscalex != fi[1].outputscalex ) || ( fi[0].outputscaley != fi[1].outputscaley ) )
  745. {
  746. err:
  747. printf( "Timing files don't match.\n" );
  748. exit(5);
  749. }
  750. for( i=0; i < fi[0].numtypes ; i++ )
  751. {
  752. if ( fi[0].effective[i] != fi[1].effective[i] ) goto err;
  753. if ( fi[0].inputrects[i*2] != fi[1].inputrects[i*2] ) goto err;
  754. if ( fi[0].inputrects[i*2+1] != fi[1].inputrects[i*2+1] ) goto err;
  755. }
  756. alloc_bitmap( 1 );
  757. for( i = 0 ; i < fi[0].numtypes ; i++ )
  758. {
  759. float (* weights)[4] = stbir__compute_weights[i];
  760. build_comp_bitmap( weights, i );
  761. }
  762. draw_window();
  763. }
  764. static void load_files( char ** args, int count )
  765. {
  766. int i;
  767. if ( count == 0 )
  768. {
  769. printf( "No timing files listed!" );
  770. exit(3);
  771. }
  772. for ( i = 0 ; i < count ; i++ )
  773. {
  774. if ( !use_timing_file( args[i], i ) )
  775. {
  776. printf( "Bad timing file %s\n", args[i] );
  777. exit(2);
  778. }
  779. }
  780. numfileinfo = count;
  781. }
  782. int main( int argc, char ** argv )
  783. {
  784. int check;
  785. if ( argc < 3 )
  786. {
  787. err:
  788. printf( "vf_train retrain [timing_filenames....] - recalcs weights for all the files on the command line.\n");
  789. printf( "vf_train info [timing_filenames....] - shows info about each timing file.\n");
  790. printf( "vf_train check [timing_filenames...] - show results for the current weights for all files listed.\n");
  791. printf( "vf_train compare <timing file1> <timing file2> - compare two timing files (must only be two files and same resolution).\n");
  792. printf( "vf_train bitmap [timing_filenames...] - write out results.png, comparing against the current weights for all files listed.\n");
  793. exit(1);
  794. }
  795. check = ( strcmp( argv[1], "check" ) == 0 );
  796. if ( ( check ) || ( strcmp( argv[1], "bitmap" ) == 0 ) )
  797. {
  798. load_files( argv + 2, argc - 2 );
  799. alloc_bitmap( numfileinfo );
  800. current( check, !check );
  801. }
  802. else if ( strcmp( argv[1], "info" ) == 0 )
  803. {
  804. load_files( argv + 2, argc - 2 );
  805. info();
  806. }
  807. else if ( strcmp( argv[1], "compare" ) == 0 )
  808. {
  809. if ( argc != 4 )
  810. {
  811. printf( "You must specify two files to compare.\n" );
  812. exit(4);
  813. }
  814. load_files( argv + 2, argc - 2 );
  815. compare();
  816. }
  817. else if ( strcmp( argv[1], "retrain" ) == 0 )
  818. {
  819. load_files( argv + 2, argc - 2 );
  820. alloc_bitmap( numfileinfo );
  821. retrain();
  822. }
  823. else
  824. {
  825. goto err;
  826. }
  827. return 0;
  828. }