hctdb_test.py 76 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705
  1. # This file is distributed under the University of Illinois Open Source License. See LICENSE.TXT for details
  2. ###############################################################################
  3. # This file contains driver test information for DXIL operations #
  4. ###############################################################################
  5. from hctdb import *
  6. import xml.etree.ElementTree as ET
  7. import argparse
  8. parser = argparse.ArgumentParser(description="contains information about dxil op test cases.")
  9. parser.add_argument('mode', help="'gen-xml' or 'info'")
  10. g_db_dxil = None
  11. def get_db_dxil():
  12. global g_db_dxil
  13. if g_db_dxil is None:
  14. g_db_dxil = db_dxil()
  15. return g_db_dxil
  16. """
  17. This class represents a test case for instructions for driver testings
  18. DXIL instructions and test cases are two disjoint sets where each instruction can have multiple test cases,
  19. and each test case can cover different DXIL instructions. So these two sets form a bipartite graph.
  20. test_name: Test case identifier. Must be unique for each test case.
  21. insts: dxil instructions
  22. validation_type: validation type for test
  23. epsilon: absolute difference check
  24. ulp: units in last place check
  25. relative: relative error check
  26. validation_tolerance: tolerance value for a given test
  27. inputs: testing inputs
  28. outputs: expected outputs for each input
  29. shader_target: target for testing
  30. shader_text: hlsl file that is used for testing dxil op
  31. """
  32. class test_case(object):
  33. def __init__(self, test_name, insts, validation_type, validation_tolerance,
  34. input_lists, output_lists, shader_target, shader_text, **kwargs):
  35. self.test_name = test_name
  36. self.validation_type = validation_type
  37. self.validation_tolerance = validation_tolerance
  38. self.input_lists = input_lists
  39. self.output_lists = output_lists
  40. self.shader_target = shader_target
  41. self.shader_text = shader_text
  42. self.insts = insts # list of instructions each test case cover
  43. self.warp_version = -1 # known warp version that works
  44. self.shader_arguments = ""
  45. for k,v in kwargs.items():
  46. setattr(self, k, v)
  47. # Wrapper for each DXIL instruction
  48. class inst_node(object):
  49. def __init__(self, inst):
  50. self.inst = inst
  51. self.test_cases = [] # list of test_case
  52. def add_test_case(test_name, inst_names, validation_type, validation_tolerance,
  53. input_lists, output_lists, shader_target, shader_text, **kwargs):
  54. insts = []
  55. for inst_name in inst_names:
  56. assert (inst_name in g_instruction_nodes)
  57. insts += [g_instruction_nodes[inst_name].inst]
  58. case = test_case(test_name, insts, validation_type,
  59. validation_tolerance, input_lists, output_lists,
  60. shader_target, shader_text, **kwargs)
  61. g_test_cases[test_name] = case
  62. # update instruction nodes
  63. for inst_name in inst_names:
  64. g_instruction_nodes[inst_name].test_cases += [case]
  65. def add_test_case_int(test_name, inst_names, validation_type, validation_tolerance,
  66. input_lists, output_lists, shader_key, shader_op_name, **kwargs):
  67. add_test_case(test_name, inst_names, validation_type, validation_tolerance,
  68. input_lists, output_lists, "cs_6_0", get_shader_text(shader_key, shader_op_name), **kwargs)
  69. input_lists_16, output_lists_16 = input_lists, output_lists
  70. if "input_16" in kwargs:
  71. input_lists_16 = kwargs["input_16"]
  72. if "output_16" in kwargs:
  73. output_lists_16 = kwargs["output_16"]
  74. add_test_case(test_name + "Bit16", inst_names, validation_type, validation_tolerance,
  75. input_lists_16, output_lists_16, "cs_6_2", get_shader_text(shader_key.replace("int","int16_t"), shader_op_name),
  76. shader_arguments="-enable-16bit-types", **kwargs)
  77. def add_test_case_float_half(test_name, inst_names, validation_type, validation_tolerance,
  78. float_input_lists, float_output_lists, shader_key, shader_op_name, **kwargs):
  79. add_test_case(test_name, inst_names, validation_type, validation_tolerance,
  80. float_input_lists, float_output_lists, "cs_6_0", get_shader_text(shader_key, shader_op_name), **kwargs)
  81. # if half test cases are different from float input lists, use those lists instead for half testings
  82. half_input_lists, half_output_lists = float_input_lists, float_output_lists
  83. if "half_inputs" in kwargs:
  84. half_input_lists = kwargs["half_inputs"]
  85. if "half_outputs" in kwargs:
  86. half_output_lists = kwargs["half_outputs"]
  87. # skip relative error test check for half for now
  88. if validation_type != "Relative":
  89. add_test_case(test_name + "Half", inst_names, validation_type, validation_tolerance,
  90. half_input_lists, half_output_lists, "cs_6_2",
  91. get_shader_text(shader_key.replace("float","half"), shader_op_name), shader_arguments="-enable-16bit-types", **kwargs)
  92. def add_test_case_denorm(test_name, inst_names, validation_type, validation_tolerance, input_lists,
  93. output_lists_ftz, output_lists_preserve, shader_target, shader_text, **kwargs):
  94. add_test_case(test_name + "FTZ", inst_names, validation_type, validation_tolerance, input_lists,
  95. output_lists_ftz, shader_target, shader_text, shader_arguments="-denorm ftz")
  96. add_test_case(test_name + "Preserve", inst_names, validation_type, validation_tolerance, input_lists,
  97. output_lists_preserve, shader_target, shader_text, shader_arguments="-denorm preserve")
  98. # we can expect the same output for "any" and "preserve" mode. We should make sure that for validation zero are accepted outputs for denormal outputs.
  99. add_test_case(test_name + "Any", inst_names, validation_type, validation_tolerance, input_lists,
  100. output_lists_preserve, shader_target, shader_text, shader_arguments="-denorm any")
  101. g_shader_texts = {
  102. "unary int": ''' struct SUnaryIntOp {
  103. int input;
  104. int output;
  105. };
  106. RWStructuredBuffer<SUnaryIntOp> g_buf : register(u0);
  107. [numthreads(8,8,1)]
  108. void main(uint GI : SV_GroupIndex) {
  109. SUnaryIntOp l = g_buf[GI];
  110. l.output = %s(l.input);
  111. g_buf[GI] = l;
  112. };''',
  113. "unary int16_t": ''' struct SUnaryInt16Op {
  114. int16_t input;
  115. int16_t output;
  116. };
  117. RWStructuredBuffer<SUnaryInt16Op> g_buf : register(u0);
  118. [numthreads(8,8,1)]
  119. void main(uint GI : SV_GroupIndex) {
  120. SUnaryInt16Op l = g_buf[GI];
  121. l.output = %s(l.input);
  122. g_buf[GI] = l;
  123. };''',
  124. "unary uint": ''' struct SUnaryUintOp {
  125. uint input;
  126. uint output;
  127. };
  128. RWStructuredBuffer<SUnaryUintOp> g_buf : register(u0);
  129. [numthreads(8,8,1)]
  130. void main(uint GI : SV_GroupIndex) {
  131. SUnaryUintOp l = g_buf[GI];
  132. l.output = %s(l.input);
  133. g_buf[GI] = l;
  134. };''',
  135. "unary uint16_t": ''' struct SUnaryUint16Op {
  136. uint16_t input;
  137. uint16_t output;
  138. };
  139. RWStructuredBuffer<SUnaryUint16Op> g_buf : register(u0);
  140. [numthreads(8,8,1)]
  141. void main(uint GI : SV_GroupIndex) {
  142. SUnaryUint16Op l = g_buf[GI];
  143. l.output = %s(l.input);
  144. g_buf[GI] = l;
  145. };''',
  146. "unary float": ''' struct SUnaryFPOp {
  147. float input;
  148. float output;
  149. };
  150. RWStructuredBuffer<SUnaryFPOp> g_buf : register(u0);
  151. [numthreads(8,8,1)]
  152. void main(uint GI : SV_GroupIndex) {
  153. SUnaryFPOp l = g_buf[GI];
  154. l.output = %s(l.input);
  155. g_buf[GI] = l;
  156. };''',
  157. "unary float bool": ''' struct SUnaryFPOp {
  158. float input;
  159. float output;
  160. };
  161. RWStructuredBuffer<SUnaryFPOp> g_buf : register(u0);
  162. [numthreads(8,8,1)]
  163. void main(uint GI : SV_GroupIndex) {
  164. SUnaryFPOp l = g_buf[GI];
  165. if (%s(l.input))
  166. l.output = 1;
  167. else
  168. l.output = 0;
  169. g_buf[GI] = l;
  170. };''',
  171. "unary half": ''' struct SUnaryFPOp {
  172. float16_t input;
  173. float16_t output;
  174. };
  175. RWStructuredBuffer<SUnaryFPOp> g_buf : register(u0);
  176. [numthreads(8,8,1)]
  177. void main(uint GI : SV_GroupIndex) {
  178. SUnaryFPOp l = g_buf[GI];
  179. l.output = %s(l.input);
  180. g_buf[GI] = l;
  181. };''',
  182. "unary half bool": ''' struct SUnaryFPOp {
  183. float16_t input;
  184. float16_t output;
  185. };
  186. RWStructuredBuffer<SUnaryFPOp> g_buf : register(u0);
  187. [numthreads(8,8,1)]
  188. void main(uint GI : SV_GroupIndex) {
  189. SUnaryFPOp l = g_buf[GI];
  190. if (%s(l.input))
  191. l.output = 1;
  192. else
  193. l.output = 0;
  194. g_buf[GI] = l;
  195. };''',
  196. "binary int": ''' struct SBinaryIntOp {
  197. int input1;
  198. int input2;
  199. int output1;
  200. int output2;
  201. };
  202. RWStructuredBuffer<SBinaryIntOp> g_buf : register(u0);
  203. [numthreads(8,8,1)]
  204. void main(uint GI : SV_GroupIndex) {
  205. SBinaryIntOp l = g_buf[GI];
  206. l.output1 = l.input1 %s l.input2;
  207. g_buf[GI] = l;
  208. };''',
  209. "binary int16_t": ''' struct SBinaryInt16Op {
  210. int16_t input1;
  211. int16_t input2;
  212. int16_t output1;
  213. int16_t output2;
  214. };
  215. RWStructuredBuffer<SBinaryInt16Op> g_buf : register(u0);
  216. [numthreads(8,8,1)]
  217. void main(uint GI : SV_GroupIndex) {
  218. SBinaryInt16Op l = g_buf[GI];
  219. l.output1 = l.input1 %s l.input2;
  220. g_buf[GI] = l;
  221. };''',
  222. "binary int call": ''' struct SBinaryIntOp {
  223. int input1;
  224. int input2;
  225. int output1;
  226. int output2;
  227. };
  228. RWStructuredBuffer<SBinaryIntOp> g_buf : register(u0);
  229. [numthreads(8,8,1)]
  230. void main(uint GI : SV_GroupIndex) {
  231. SBinaryIntOp l = g_buf[GI];
  232. l.output1 = %s(l.input1,l.input2);
  233. g_buf[GI] = l;
  234. };''',
  235. "binary int16_t call": ''' struct SBinaryInt16Op {
  236. int16_t input1;
  237. int16_t input2;
  238. int16_t output1;
  239. int16_t output2;
  240. };
  241. RWStructuredBuffer<SBinaryInt16Op> g_buf : register(u0);
  242. [numthreads(8,8,1)]
  243. void main(uint GI : SV_GroupIndex) {
  244. SBinaryInt16Op l = g_buf[GI];
  245. l.output1 = %s(l.input1,l.input2);
  246. g_buf[GI] = l;
  247. };''',
  248. "binary uint": ''' struct SBinaryUintOp {
  249. uint input1;
  250. uint input2;
  251. uint output1;
  252. uint output2;
  253. };
  254. RWStructuredBuffer<SBinaryUintOp> g_buf : register(u0);
  255. [numthreads(8,8,1)]
  256. void main(uint GI : SV_GroupIndex) {
  257. SBinaryUintOp l = g_buf[GI];
  258. l.output1 = l.input1 %s l.input2;
  259. g_buf[GI] = l;
  260. };''',
  261. "binary uint16_t": ''' struct SBinaryUint16Op {
  262. uint16_t input1;
  263. uint16_t input2;
  264. uint16_t output1;
  265. uint16_t output2;
  266. };
  267. RWStructuredBuffer<SBinaryUint16Op> g_buf : register(u0);
  268. [numthreads(8,8,1)]
  269. void main(uint GI : SV_GroupIndex) {
  270. SBinaryUint16Op l = g_buf[GI];
  271. l.output1 = l.input1 %s l.input2;
  272. g_buf[GI] = l;
  273. };''',
  274. "binary uint call": ''' struct SBinaryUintOp {
  275. uint input1;
  276. uint input2;
  277. uint output1;
  278. uint output2;
  279. };
  280. RWStructuredBuffer<SBinaryUintOp> g_buf : register(u0);
  281. [numthreads(8,8,1)]
  282. void main(uint GI : SV_GroupIndex) {
  283. SBinaryUintOp l = g_buf[GI];
  284. l.output1 = %s(l.input1,l.input2);
  285. g_buf[GI] = l;
  286. };''',
  287. "binary uint16_t call": ''' struct SBinaryUint16Op {
  288. uint16_t input1;
  289. uint16_t input2;
  290. uint16_t output1;
  291. uint16_t output2;
  292. };
  293. RWStructuredBuffer<SBinaryUint16Op> g_buf : register(u0);
  294. [numthreads(8,8,1)]
  295. void main(uint GI : SV_GroupIndex) {
  296. SBinaryUint16Op l = g_buf[GI];
  297. l.output1 = %s(l.input1,l.input2);
  298. g_buf[GI] = l;
  299. };''',
  300. "binary float": ''' struct SBinaryFPOp {
  301. float input1;
  302. float input2;
  303. float output1;
  304. float output2;
  305. };
  306. RWStructuredBuffer<SBinaryFPOp> g_buf : register(u0);
  307. [numthreads(8,8,1)]
  308. void main(uint GI : SV_GroupIndex) {
  309. SBinaryFPOp l = g_buf[GI];
  310. l.output1 = l.input1 %s l.input2;
  311. g_buf[GI] = l;
  312. };''',
  313. "binary float call": ''' struct SBinaryFPOp {
  314. float input1;
  315. float input2;
  316. float output1;
  317. float output2;
  318. };
  319. RWStructuredBuffer<SBinaryFPOp> g_buf : register(u0);
  320. [numthreads(8,8,1)]
  321. void main(uint GI : SV_GroupIndex) {
  322. SBinaryFPOp l = g_buf[GI];
  323. l.output1 = %s(l.input1,l.input2);
  324. g_buf[GI] = l;
  325. };''',
  326. "binary half": ''' struct SBinaryFPOp {
  327. half input1;
  328. half input2;
  329. half output1;
  330. half output2;
  331. };
  332. RWStructuredBuffer<SBinaryFPOp> g_buf : register(u0);
  333. [numthreads(8,8,1)]
  334. void main(uint GI : SV_GroupIndex) {
  335. SBinaryFPOp l = g_buf[GI];
  336. l.output1 = l.input1 %s l.input2;
  337. g_buf[GI] = l;
  338. };''',
  339. "binary half call": ''' struct SBinaryFPOp {
  340. half input1;
  341. half input2;
  342. half output1;
  343. half output2;
  344. };
  345. RWStructuredBuffer<SBinaryFPOp> g_buf : register(u0);
  346. [numthreads(8,8,1)]
  347. void main(uint GI : SV_GroupIndex) {
  348. SBinaryFPOp l = g_buf[GI];
  349. l.output1 = %s(l.input1,l.input2);
  350. g_buf[GI] = l;
  351. };''',
  352. "tertiary int": ''' struct STertiaryIntOp {
  353. int input1;
  354. int input2;
  355. int input3;
  356. int output;
  357. };
  358. RWStructuredBuffer<STertiaryIntOp> g_buf : register(u0);
  359. [numthreads(8,8,1)]
  360. void main(uint GI : SV_GroupIndex) {
  361. STertiaryIntOp l = g_buf[GI];
  362. l.output = %s(l.input1, l.input2, l.input3);
  363. g_buf[GI] = l;
  364. };''',
  365. "tertiary int16_t": ''' struct STertiaryInt16Op {
  366. int16_t input1;
  367. int16_t input2;
  368. int16_t input3;
  369. int16_t output;
  370. };
  371. RWStructuredBuffer<STertiaryInt16Op> g_buf : register(u0);
  372. [numthreads(8,8,1)]
  373. void main(uint GI : SV_GroupIndex) {
  374. STertiaryInt16Op l = g_buf[GI];
  375. l.output = %s(l.input1, l.input2, l.input3);
  376. g_buf[GI] = l;
  377. };''',
  378. "tertiary uint": ''' struct STertiaryUintOp {
  379. uint input1;
  380. uint input2;
  381. uint input3;
  382. uint output;
  383. };
  384. RWStructuredBuffer<STertiaryUintOp> g_buf : register(u0);
  385. [numthreads(8,8,1)]
  386. void main(uint GI : SV_GroupIndex) {
  387. STertiaryUintOp l = g_buf[GI];
  388. l.output = %s(l.input1, l.input2, l.input3);
  389. g_buf[GI] = l;
  390. };''',
  391. "tertiary uint16_t": ''' struct STertiaryUint16Op {
  392. uint16_t input1;
  393. uint16_t input2;
  394. uint16_t input3;
  395. uint16_t output;
  396. };
  397. RWStructuredBuffer<STertiaryUint16Op> g_buf : register(u0);
  398. [numthreads(8,8,1)]
  399. void main(uint GI : SV_GroupIndex) {
  400. STertiaryUint16Op l = g_buf[GI];
  401. l.output = %s(l.input1, l.input2, l.input3);
  402. g_buf[GI] = l;
  403. };''',
  404. "tertiary float": ''' struct STertiaryFloatOp {
  405. float input1;
  406. float input2;
  407. float input3;
  408. float output;
  409. };
  410. RWStructuredBuffer<STertiaryFloatOp> g_buf : register(u0);
  411. [numthreads(8,8,1)]
  412. void main(uint GI : SV_GroupIndex) {
  413. STertiaryFloatOp l = g_buf[GI];
  414. l.output = %s(l.input1, l.input2, l.input3);
  415. g_buf[GI] = l;
  416. };''',
  417. 'tertiary half': ''' struct STertiaryHalfOp {
  418. half input1;
  419. half input2;
  420. half input3;
  421. half output;
  422. };
  423. RWStructuredBuffer<STertiaryHalfOp> g_buf : register(u0);
  424. [numthreads(8,8,1)]
  425. void main(uint GI : SV_GroupIndex) {
  426. STertiaryHalfOp l = g_buf[GI];
  427. l.output = %s(l.input1, l.input2, l.input3);
  428. g_buf[GI] = l;
  429. };''',
  430. "wave op int" :''' struct PerThreadData {
  431. uint firstLaneId;
  432. uint laneIndex;
  433. int mask;
  434. int input;
  435. int output;
  436. };
  437. RWStructuredBuffer<PerThreadData> g_sb : register(u0);
  438. [numthreads(8,12,1)]
  439. void main(uint GI : SV_GroupIndex) {
  440. PerThreadData pts = g_sb[GI];
  441. pts.firstLaneId = WaveReadLaneFirst(GI);
  442. pts.laneIndex = WaveGetLaneIndex();
  443. if (pts.mask != 0) {
  444. pts.output = %s(pts.input);
  445. }
  446. else {
  447. pts.output = %s(pts.input);
  448. }
  449. g_sb[GI] = pts;
  450. };''',
  451. "wave op uint" :''' struct PerThreadData {
  452. uint firstLaneId;
  453. uint laneIndex;
  454. int mask;
  455. uint input;
  456. uint output;
  457. };
  458. RWStructuredBuffer<PerThreadData> g_sb : register(u0);
  459. [numthreads(8,12,1)]
  460. void main(uint GI : SV_GroupIndex) {
  461. PerThreadData pts = g_sb[GI];
  462. pts.firstLaneId = WaveReadLaneFirst(GI);
  463. pts.laneIndex = WaveGetLaneIndex();
  464. if (pts.mask != 0) {
  465. pts.output = %s(pts.input);
  466. }
  467. else {
  468. pts.output = %s(pts.input);
  469. }
  470. g_sb[GI] = pts;
  471. };''',
  472. "wave op int count": ''' struct PerThreadData {
  473. uint firstLaneId;
  474. uint laneIndex;
  475. int mask;
  476. int input;
  477. int output;
  478. };
  479. RWStructuredBuffer<PerThreadData> g_sb : register(u0);
  480. [numthreads(8,12,1)]
  481. void main(uint GI : SV_GroupIndex) {
  482. PerThreadData pts = g_sb[GI];
  483. pts.firstLaneId = WaveReadLaneFirst(GI);
  484. pts.laneIndex = WaveGetLaneIndex();
  485. if (pts.mask != 0) {
  486. pts.output = %s(pts.input > 3);
  487. }
  488. else {
  489. pts.output = %s(pts.input > 3);
  490. }
  491. g_sb[GI] = pts;
  492. };'''
  493. }
  494. def get_shader_text(op_type, op_call):
  495. assert(op_type in g_shader_texts)
  496. if op_type.startswith("wave op"):
  497. return g_shader_texts[op_type] % (op_call, op_call)
  498. return g_shader_texts[op_type] % (op_call)
  499. g_denorm_tests = ["FAddDenormAny", "FAddDenormFTZ", "FAddDenormPreserve",
  500. "FSubDenormAny", "FSubDenormFTZ", "FSubDenormPreserve",
  501. "FMulDenormAny", "FMulDenormFTZ", "FMulDenormPreserve",
  502. "FDivDenormAny", "FDivDenormFTZ", "FDivDenormPreserve",
  503. "FMadDenormAny", "FMadDenormFTZ", "FMadDenormPreserve",
  504. "FAbsDenormAny", "FAbsDenormFTZ", "FAbsDenormPreserve",
  505. "FMinDenormAny", "FMinDenormFTZ", "FMinDenormPreserve",
  506. "FMaxDenormAny", "FMaxDenormFTZ", "FMaxDenormPreserve"]
  507. # This is a collection of test case for driver tests per instruction
  508. # Warning: For test cases, when you want to pass in signed integer,
  509. # make sure to pass in negative numbers with decimal values instead of hexadecimal representation.
  510. # For some reason, TAEF is not handling them properly.
  511. def add_test_cases():
  512. nan = float('nan')
  513. p_inf = float('inf')
  514. n_inf = float('-inf')
  515. p_denorm = float('1e-38')
  516. n_denorm = float('-1e-38')
  517. # Unary Float
  518. add_test_case_float_half('Sin', ['Sin'], 'Epsilon', 0.0008, [[
  519. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-314.16',
  520. '314.16'
  521. ]], [[
  522. 'NaN', 'NaN', '-0', '-0', '0', '0', 'NaN', '-0.0007346401',
  523. '0.0007346401'
  524. ]], "unary float", "sin", half_inputs=[[
  525. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-314',
  526. '314'
  527. ]], half_outputs=[[
  528. 'NaN', 'NaN', '-0', '-0', '0', '0', 'NaN', '-0.1585929',
  529. '0.1585929'
  530. ]])
  531. add_test_case_float_half('Cos', ['Cos'], 'Epsilon', 0.0008, [[
  532. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-314.16',
  533. '314.16'
  534. ]], [[
  535. 'NaN', 'NaN', '1.0', '1.0', '1.0', '1.0', 'NaN', '0.99999973015',
  536. '0.99999973015'
  537. ]], "unary float", "cos", half_inputs=[[
  538. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-314',
  539. '314'
  540. ]], half_outputs=[[
  541. 'NaN', 'NaN', '-0', '-0', '0', '0', 'NaN', '0.987344',
  542. '0.987344'
  543. ]])
  544. add_test_case_float_half('Tan', ['Tan'], 'Epsilon', 0.0008, [[
  545. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-314.16',
  546. '314.16'
  547. ]], [[
  548. 'NaN', 'NaN', '-0.0', '-0.0', '0.0', '0.0', 'NaN', '-0.000735',
  549. '0.000735'
  550. ]], "unary float", "tan", half_inputs=[[
  551. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-314',
  552. '314'
  553. ]], half_outputs=[[
  554. 'NaN', 'NaN', '-0', '-0', '0', '0', 'NaN', '0.1606257',
  555. '-0.1606257'
  556. ]])
  557. add_test_case_float_half('Hcos', ['Hcos'], 'Epsilon', 0.0008,
  558. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1', '-1']], [[
  559. 'NaN', 'Inf', '1.0', '1.0', '1.0', '1.0', 'Inf', '1.543081',
  560. '1.543081'
  561. ]], "unary float", "cosh")
  562. add_test_case_float_half('Hsin', ['Hsin'], 'Epsilon', 0.0008,
  563. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1', '-1']], [[
  564. 'NaN', '-Inf', '0.0', '0.0', '0.0', '0.0', 'Inf', '1.175201',
  565. '-1.175201'
  566. ]], "unary float", "sinh")
  567. add_test_case_float_half('Htan', ['Htan'], 'Epsilon', 0.0008,
  568. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1', '-1']], [[
  569. 'NaN', '-1', '-0.0', '-0.0', '0.0', '0.0', '1', '0.761594',
  570. '-0.761594'
  571. ]], "unary float", "tanh", warp_version=16202)
  572. add_test_case_float_half('Acos', ['Acos'], 'Epsilon', 0.0008, [[
  573. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1', '-1', '1.5',
  574. '-1.5'
  575. ]], [[
  576. 'NaN', 'NaN', '1.570796', '1.570796', '1.570796', '1.570796', 'NaN',
  577. '0', '3.1415926', 'NaN', 'NaN'
  578. ]], "unary float", "acos")
  579. add_test_case_float_half('Asin', ['Asin'], 'Epsilon', 0.0008, [[
  580. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1', '-1', '1.5',
  581. '-1.5'
  582. ]], [[
  583. 'NaN', 'NaN', '0.0', '0.0', '0.0', '0.0', 'NaN', '1.570796',
  584. '-1.570796', 'NaN', 'NaN'
  585. ]], "unary float", "asin")
  586. add_test_case_float_half('Atan', ['Atan'], 'Epsilon', 0.0008,
  587. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1', '-1']], [[
  588. 'NaN', '-1.570796', '0.0', '0.0', '0.0', '0.0', '1.570796',
  589. '0.785398163', '-0.785398163'
  590. ]], "unary float", "atan", warp_version=16202)
  591. add_test_case_float_half('Exp', ['Exp'], 'Relative', 21,
  592. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1', '10']],
  593. [['NaN', '0', '1', '1', '1', '1', 'Inf', '0.367879441', '22026.46579']
  594. ], "unary float", "exp")
  595. add_test_case_float_half('Frc', ['Frc'], 'Epsilon', 0.0008, [[
  596. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1', '2.718280',
  597. '1000.599976', '-7.389'
  598. ]], [[
  599. 'NaN', 'NaN', '0', '0', '0', '0', 'NaN', '0', '0.718280', '0.599976',
  600. '0.611'
  601. ]], "unary float", "frac",
  602. half_inputs=[['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1', '2.719',
  603. '1000.5', '-7.39']],
  604. half_outputs=[[
  605. 'NaN', 'NaN', '0', '0', '0', '0', 'NaN', '0', '0.719', '0.5',
  606. '0.61']])
  607. add_test_case_float_half('Log', ['Log'], 'Relative', 21, [[
  608. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1',
  609. '2.718281828', '7.389056', '100'
  610. ]], [[
  611. 'NaN', 'NaN', '-Inf', '-Inf', '-Inf', '-Inf', 'Inf', 'NaN', '1.0',
  612. '1.99999998', '4.6051701'
  613. ]],"unary float", "log", half_inputs=[[
  614. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1',
  615. '2.719', '7.39', '100'
  616. ]], half_outputs=[[
  617. 'NaN', 'NaN', '-Inf', '-Inf', '-Inf', '-Inf', 'Inf', 'NaN', '1.0',
  618. '2', '4.605'
  619. ]])
  620. add_test_case_float_half('Sqrt', ['Sqrt'], 'ulp', 1, [[
  621. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1', '2',
  622. '16.0', '256.0'
  623. ]], [[
  624. 'NaN', 'NaN', '-0', '-0', '0', '0', 'Inf', 'NaN', '1.41421356237',
  625. '4.0', '16.0'
  626. ]], "unary float", "sqrt")
  627. add_test_case_float_half('Rsqrt', ['Rsqrt'], 'ulp', 1, [[
  628. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1', '16.0',
  629. '256.0', '65536.0'
  630. ]], [[
  631. 'NaN', 'NaN', '-Inf', '-Inf', 'Inf', 'Inf', '0', 'NaN', '0.25',
  632. '0.0625', '0.00390625'
  633. ]], "unary float", "rsqrt", half_inputs=[[
  634. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1', '16.0',
  635. '256.0', '65500'
  636. ]], half_outputs=[[
  637. 'NaN', 'NaN', '-Inf', '-Inf', 'Inf', 'Inf', '0', 'NaN', '0.25',
  638. '0.0625', '0.00001526'
  639. ]])
  640. add_test_case_float_half('Round_ne', ['Round_ne'], 'Epsilon', 0, [[
  641. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '10.0', '10.4',
  642. '10.5', '10.6', '11.5', '-10.0', '-10.4', '-10.5', '-10.6'
  643. ]], [[
  644. 'NaN', '-Inf', '-0', '-0', '0', '0', 'Inf', '10.0', '10.0', '10.0',
  645. '11.0', '12.0', '-10.0', '-10.0', '-10.0', '-11.0'
  646. ]], "unary float", "round")
  647. add_test_case_float_half('Round_ni', ['Round_ni'], 'Epsilon', 0, [[
  648. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '10.0', '10.4',
  649. '10.5', '10.6', '-10.0', '-10.4', '-10.5', '-10.6'
  650. ]], [[
  651. 'NaN', '-Inf', '-0', '-0', '0', '0', 'Inf', '10.0', '10.0', '10.0',
  652. '10.0', '-10.0', '-11.0', '-11.0', '-11.0'
  653. ]], "unary float", "floor")
  654. add_test_case_float_half('Round_pi', ['Round_pi'], 'Epsilon', 0,
  655. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '10.0', '10.4',
  656. '10.5', '10.6', '-10.0', '-10.4', '-10.5', '-10.6']],
  657. [['NaN', '-Inf', '-0', '-0', '0', '0', 'Inf', '10.0', '11.0', '11.0',
  658. '11.0', '-10.0', '-10.0', '-10.0', '-10.0']], "unary float", "ceil")
  659. add_test_case_float_half('Round_z', ['Round_z'], 'Epsilon', 0,
  660. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '10.0', '10.4',
  661. '10.5', '10.6', '-10.0', '-10.4', '-10.5', '-10.6']],
  662. [['NaN', '-Inf', '-0', '-0', '0', '0', 'Inf', '10.0', '10.0', '10.0',
  663. '10.0', '-10.0', '-10.0', '-10.0', '-10.0']], "unary float", "trunc")
  664. add_test_case_float_half('IsNaN', ['IsNaN'], 'Epsilon', 0,
  665. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0']
  666. ], [['1', '0', '0', '0', '0', '0', '0', '0', '0']], "unary float bool", "isnan")
  667. add_test_case_float_half('IsInf', ['IsInf'], 'Epsilon', 0,
  668. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0']
  669. ], [['0', '1', '0', '0', '0', '0', '1', '0', '0']], "unary float bool", "isinf")
  670. add_test_case_float_half('IsFinite', ['IsFinite'], 'Epsilon', 0,
  671. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0']
  672. ], [['0', '0', '1', '1', '1', '1', '0', '1', '1']], "unary float bool", "isfinite", warp_version=16202)
  673. add_test_case_float_half('FAbs', ['FAbs'], 'Epsilon', 0,
  674. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0']
  675. ], [['NaN', 'Inf', 'denorm', '0', '0', 'denorm', 'Inf', '1', '1']], "unary float", "abs")
  676. # Binary Float
  677. add_test_case('FMin', ['FMin','FMax'], 'epsilon', 0, [[
  678. '-inf', '-inf', '-inf', '-inf', 'inf', 'inf', 'inf', 'inf', 'NaN',
  679. 'NaN', 'NaN', 'NaN', '1.0', '1.0', '-1.0', '-1.0', '1.0'
  680. ], [
  681. '-inf', 'inf', '1.0', 'NaN', '-inf', 'inf', '1.0', 'NaN', '-inf',
  682. 'inf', '1.0', 'NaN', '-inf', 'inf', '1.0', 'NaN', '-1.0'
  683. ]], [[
  684. '-inf', '-inf', '-inf', '-inf', '-inf', 'inf', '1.0', 'inf', '-inf',
  685. 'inf', '1.0', 'NaN', '-inf', '1.0', '-1.0', '-1.0', '-1.0'
  686. ], [
  687. '-inf', 'inf', '1.0', '-inf', 'inf', 'inf', 'inf', 'inf', '-inf',
  688. 'inf', '1.0', 'NaN', '1.0', 'inf', '1.0', '-1.0', '1.0'
  689. ]], 'cs_6_0', ''' struct SBinaryFPOp {
  690. float input1;
  691. float input2;
  692. float output1;
  693. float output2;
  694. };
  695. RWStructuredBuffer<SBinaryFPOp> g_buf : register(u0);
  696. [numthreads(8,8,1)]
  697. void main(uint GI : SV_GroupIndex) {
  698. SBinaryFPOp l = g_buf[GI];
  699. l.output1 = min(l.input1, l.input2);
  700. l.output2 = max(l.input1, l.input2);
  701. g_buf[GI] = l;
  702. };''')
  703. add_test_case('FMinHalf', ['FMin','FMax'], 'epsilon', 0, [[
  704. '-inf', '-inf', '-inf', '-inf', 'inf', 'inf', 'inf', 'inf', 'NaN',
  705. 'NaN', 'NaN', 'NaN', '1.0', '1.0', '-1.0', '-1.0', '1.0'
  706. ], [
  707. '-inf', 'inf', '1.0', 'NaN', '-inf', 'inf', '1.0', 'NaN', '-inf',
  708. 'inf', '1.0', 'NaN', '-inf', 'inf', '1.0', 'NaN', '-1.0'
  709. ]], [[
  710. '-inf', '-inf', '-inf', '-inf', '-inf', 'inf', '1.0', 'inf', '-inf',
  711. 'inf', '1.0', 'NaN', '-inf', '1.0', '-1.0', '-1.0', '-1.0'
  712. ], [
  713. '-inf', 'inf', '1.0', '-inf', 'inf', 'inf', 'inf', 'inf', '-inf',
  714. 'inf', '1.0', 'NaN', '1.0', 'inf', '1.0', '-1.0', '1.0'
  715. ]], 'cs_6_0', ''' struct SBinaryHalfOp {
  716. half input1;
  717. half input2;
  718. half output1;
  719. half output2;
  720. };
  721. RWStructuredBuffer<SBinaryHalfOp> g_buf : register(u0);
  722. [numthreads(8,8,1)]
  723. void main(uint GI : SV_GroupIndex) {
  724. SBinaryHalfOp l = g_buf[GI];
  725. l.output1 = min(l.input1, l.input2);
  726. l.output2 = max(l.input1, l.input2);
  727. g_buf[GI] = l;
  728. };''')
  729. add_test_case_float_half('FAdd', ['FAdd'], 'ulp', 1, [['-1.0', '1.0', '32.5', '1.0000001000'],['4', '5.5', '334.7', '0.5000001000']], [['3.0', '6.5', '367.2', '1.5000002000']],
  730. "binary float", "+")
  731. add_test_case_float_half('FSub', ['FSub'], 'ulp', 1, [['-1.0', '5.5', '32.5', '1.0000001000'],['4', '1.25', '334.7', '0.5000001000']], [['-5', '4.25', '-302.2', '0.5000']],
  732. "binary float", "-")
  733. add_test_case_float_half('FMul', ['FMul'], 'ulp', 1, [['-1.0', '5.5', '1.0000001'],['4', '1.25', '2.0']], [['-4.0', '6.875', '2.0000002']],
  734. "binary float", "*")
  735. add_test_case_float_half('FDiv', ['FDiv'], 'ulp', 1, [['-1.0', '5.5', '1.0000001'],['4', '1.25', '2.0']], [['-0.25', '4.4', '0.50000006']],
  736. "binary float", "/")
  737. # Denorm Binary Float
  738. add_test_case_denorm('FAddDenorm', ['FAdd'], 'ulp', 1,
  739. [['0x007E0000', '0x00200000', '0x007E0000', '0x007E0000'],['0x007E0000','0x00200000', '0x807E0000', '0x800E0000']],
  740. [['0x00FC0000','0', '0', '0']],
  741. [['0x00FC0000','0', '0', '0x00700000']],
  742. 'cs_6_2', get_shader_text("binary float", "+"))
  743. add_test_case_denorm('FSubDenorm', ['FSub'], 'ulp', 1,
  744. [['0x007E0000', '0x007F0000', '0x00FF0000', '0x007A0000'],['0x007E0000', '0x807F0000', '0x00800000', '0']],
  745. [['0x0', '0x00FE0000', '0', '0']],
  746. [['0x0', '0x00FE0000', '0x007F0000', '0x007A0000']],
  747. 'cs_6_2', get_shader_text("binary float", "-"))
  748. add_test_case_denorm('FDivDenorm', ['FDiv'], 'ulp', 1,
  749. [['0x007F0000', '0x007F0000', '0x40000000', '0x00800000'],['1', '0x007F0000', '0x7F7F0000', '0x40000000']],
  750. [['0', '1', '0', '0']],
  751. [['0x007F0000', '1', '0x00404040', '0x00400000']],
  752. 'cs_6_2', get_shader_text("binary float", "/"))
  753. add_test_case_denorm('FMulDenorm', ['FMul'], 'ulp', 1,
  754. [['0x00000300', '0x007F0000', '0x007F0000', '0x001E0000', '0x00000300'],['128', '1', '0x007F0000', '20', '0x78000000']],
  755. [['0', '0', '0', '0x01960000', '0x32400000']],
  756. [['0x00018000','0x007F0000', '0', '0x01960000', '0x32400000']],
  757. 'cs_6_2', get_shader_text("binary float", "*"))
  758. # Tertiary Float
  759. add_test_case_float_half('FMad', ['FMad'], 'ulp', 1, [[
  760. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0',
  761. '0', '1', '1.5'
  762. ], [
  763. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0',
  764. '0', '1', '10'
  765. ], [
  766. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0',
  767. '1', '0', '-5.5'
  768. ]], [['NaN', 'NaN', '0', '0', '0', '0', 'Inf', '2', '0', '1', '1', '9.5']],
  769. "tertiary float", "mad")
  770. # Denorm Tertiary Float
  771. add_test_case_denorm('FMadDenorm', ['FMad'], 'ulp', 1,
  772. [['0x80780000', '0x80780000', '0x00780000'],
  773. ['1', '2', '2'],
  774. ['0x80780000', '0x00800000', '0x00800000']],
  775. [['0', '0', '0x01380000']],
  776. [['0x80780000', '0x80700000', '0x01380000']],
  777. 'cs_6_2', get_shader_text("tertiary float", "mad"))
  778. # Unary Int
  779. int8_min, int8_max = '-128', '127'
  780. int16_min, int16_max = '-32768', '32767'
  781. int32_min, int32_max = '-2147483648', '2147483647'
  782. uint16_max = '65535'
  783. uint32_max = '4294967295'
  784. add_test_case_int('Bfrev', ['Bfrev'], 'Epsilon', 0, [[
  785. int32_min, '-65536', '-8', '-1', '0', '1', '8', '65536',
  786. int32_max
  787. ]], [[
  788. '1', '65535', '536870911', '-1', '0', int32_min, '268435456',
  789. '32768', '-2'
  790. ]], "unary int", "reversebits",
  791. input_16=[[int16_min, '-256', '-8', '-1', '0', '1', '8', '256', int16_max]],
  792. output_16=[['1', '255', '8191', '-1', '0', int16_min, '4096', '128', '-2']])
  793. # firstbit_shi (s for signed) returns the
  794. # first 0 from the MSB if the number is negative,
  795. # else the first 1 from the MSB.
  796. # all the variants of the instruction return ~0 if no match was found
  797. add_test_case_int('FirstbitSHi', ['FirstbitSHi'], 'Epsilon', 0, [[
  798. int32_min, '-65536', '-8', '-1', '0', '1', '8', '65536',
  799. int32_max
  800. ]], [['30', '15', '2', '-1', '-1', '0', '3', '16', '30']],
  801. "unary int", "firstbithigh",
  802. input_16=[[int16_min, '-256', '-8', '-1', '0', '1', '8', '256', int16_max]],
  803. output_16=[['14', '7', '2', '-1', '-1', '0', '3', '8', '14']])
  804. add_test_case_int('FirstbitLo', ['FirstbitLo'], 'Epsilon', 0, [[
  805. int32_min, '-65536', '-8', '-1', '0', '1', '8', '65536',
  806. int32_max
  807. ]], [['31', '16', '3', '0', '-1', '0', '3', '16', '0']],
  808. "unary int", "firstbitlow",
  809. input_16=[[int16_min, '-256', '-8', '-1', '0', '1', '8', '256', int16_max]],
  810. output_16=[['15', '8', '3', '0', '-1', '0', '3', '8', '0']])
  811. # TODO: there is a known bug in countbits when passing in immediate values.
  812. # Fix this later
  813. add_test_case('Countbits', ['Countbits'], 'Epsilon', 0, [[
  814. int32_min, '-65536', '-8', '-1', '0', '1', '8', '65536',
  815. int32_max
  816. ]], [['1', '16', '29', '32', '0', '1', '1', '1', '31']],
  817. "cs_6_0", get_shader_text("unary int", "countbits"))
  818. # Unary uint
  819. add_test_case_int('FirstbitHi', ['FirstbitHi'], 'Epsilon', 0,
  820. [['0', '1', '8', '65536', int32_max, uint32_max]],
  821. [['-1', '0', '3', '16', '30', '31']],
  822. "unary uint", "firstbithigh",
  823. input_16=[['0', '1', '8', uint16_max]],
  824. output_16=[['-1', '0', '3', '15']])
  825. # Binary Int
  826. add_test_case_int('IAdd', ['Add'], 'Epsilon', 0,
  827. [[int32_min, '-10', '0', '0', '10', int32_max, '486'],
  828. ['0', '10', '-10', '10', '10', '0', '54238']],
  829. [[int32_min, '0', '-10', '10', '20', int32_max, '54724']],
  830. "binary int", "+",
  831. input_16=[[int16_min, '-10', '0', '0', '10', int16_max],
  832. ['0', '10', '-3114', '272', '15', '0']],
  833. output_16=[[int16_min, '0', '-3114', '272', '25', int16_max]])
  834. add_test_case_int('ISub', ['Sub'], 'Epsilon', 0,
  835. [[int32_min, '-10', '0', '0', '10', int32_max, '486'],
  836. ['0', '10', '-10', '10', '10', '0', '54238']],
  837. [[int32_min, '-20', '10', '-10', '0', int32_max, '-53752']],
  838. "binary int", "-",
  839. input_16=[[int16_min, '-10', '0', '0', '10', int16_max],
  840. ['0', '10', '-3114', '272', '15', '0']],
  841. output_16=[[int16_min, '-20', '-3114', '-272', '-5', int16_max]])
  842. add_test_case_int('IMax', ['IMax'], 'Epsilon', 0,
  843. [[int32_min, '-10', '0', '0', '10', int32_max],
  844. ['0', '10', '-10', '10', '10', '0']],
  845. [['0', '10', '0', '10', '10', int32_max]],
  846. "binary int call", "max",
  847. input_16=[[int16_min, '-10', '0', '0', '10', int16_max],
  848. ['0', '10', '-3114', '272', '15', '0']],
  849. output_16=[['0', '10', '0', '272', '15', int16_max]])
  850. add_test_case_int('IMin', ['IMin'], 'Epsilon', 0,
  851. [[int32_min, '-10', '0', '0', '10', int32_max],
  852. ['0', '10', '-10', '10', '10', '0']],
  853. [[int32_min, '-10', '-10', '0', '10', '0']],
  854. "binary int call", "min",
  855. input_16=[[int16_min, '-10', '0', '0', '10', int16_max],
  856. ['0', '10', '-3114', '272', '15', '0']],
  857. output_16=[[int16_min, '-10', '-3114', '0', '10', '0']])
  858. add_test_case_int('IMul', ['Mul'], 'Epsilon', 0, [
  859. [ int32_min, '-10', '-1', '0', '1', '10', '10000', int32_max, int32_max ],
  860. ['-10', '-10', '10', '0', '256', '4', '10001', '0', int32_max]],
  861. [['0', '100', '-10', '0', '256', '40', '100010000', '0', '1']],
  862. "binary int", "*",
  863. input_16=[[ int16_min, '-10', '-1', '0', '1', '10', int16_max],
  864. ['-10', '-10', '10', '0', '256', '4', '0']],
  865. output_16=[['0', '100', '-10', '0', '256', '40', '0']])
  866. add_test_case('IDiv', ['SDiv', 'SRem'], 'Epsilon', 0,
  867. [['1', '1', '10', '10000', int32_max, int32_max, '-1'],
  868. ['1', '256', '4', '10001', '2', int32_max, '1']],
  869. [['1', '0', '2', '0', '1073741823', '1', '-1'],
  870. ['0', '1', '2', '10000', '1', '0', '0']], "cs_6_0",
  871. ''' struct SBinaryIntOp {
  872. int input1;
  873. int input2;
  874. int output1;
  875. int output2;
  876. };
  877. RWStructuredBuffer<SBinaryIntOp> g_buf : register(u0);
  878. [numthreads(8,8,1)]
  879. void main(uint GI : SV_GroupIndex) {
  880. SBinaryIntOp l = g_buf[GI];
  881. l.output1 = l.input1 / l.input2;
  882. l.output2 = l.input1 % l.input2;
  883. g_buf[GI] = l;
  884. };''')
  885. add_test_case_int('Shl', ['Shl'], 'Epsilon', 0,
  886. [['1', '1', '0x1010', '0xa', '-1', '0x12341234', '-1'],
  887. ['0', '259', '4', '2', '0', '15', '3']],
  888. [['0x1', '0x8', '0x10100', '0x28', '-1','0x091a0000', '-8']],
  889. "binary int", "<<",
  890. input_16=[['1', '1', '0x0101', '0xa', '-1', '0x1234', '-1'],
  891. ['0', '259', '4', '2', '0', '13', '3']],
  892. output_16=[['0x1', '0x8', '0x1010', '0x28', '-1','0x8000', '-8']])
  893. add_test_case_int("LShr", ['LShr'], 'Epsilon', 0,
  894. [['1', '1', '0xffff', '0x7fffffff', '0x70001234', '0x12340ab3', '0x7fffffff'],
  895. ['0', '1', '4', '30', '15', '16', '1']],
  896. [['1', '0', '0xfff', '1', '0xe000', '0x1234', '0x3fffffff']],
  897. "binary int", ">>",
  898. input_16=[['1', '1', '0x7fff', '0x7fff'],
  899. ['0', '1', '4', '14']],
  900. output_16=[['1', '0', '0x07ff', '1']]
  901. )
  902. add_test_case_int("And", ['And'], 'Epsilon', 0,
  903. [['0x1', '0x01', '0x7fff0000', '0x33333333', '0x137f', '0x12345678', '0xa341', '-1'],
  904. ['0x1', '0xf0', '0x0000ffff', '0x22222222', '0xec80', '-1', '0x3471', '-1']],
  905. [['0x1', '0x00', '0x0', '0x22222222', '0x0', '0x12345678', '0x2041', '-1']],
  906. "binary int", "&",
  907. input_16=[['0x1', '0x01', '0x7fff', '0x3333', '0x137f', '0x1234', '0xa341', '-1'],
  908. ['0x1', '0xf0', '0x0000', '0x2222', '0xec80', '-1', '0x3471', '-1']],
  909. output_16=[['0x1', '0x00', '0x0', '0x2222', '0x0', '0x1234', '0x2041', '-1']],
  910. )
  911. add_test_case_int("Or", ['Or'], 'Epsilon', 0,
  912. [['0x1', '0x01', '0x7fff0000', '0x11111111', '0x137f', '0x0', '0x12345678', '0xa341', '-1'],
  913. ['0x1', '0xf0', '0x0000ffff', '0x22222222', '0xec80', '0x0', '0x00000000', '0x3471', '-1']],
  914. [['0x1', '0xf1', '0x7fffffff', '0x33333333', '0xffff', '0x0', '0x12345678', '0xb771', '-1']],
  915. "binary int", "|",
  916. input_16=[['0x1', '0x01', '0x7fff', '0x3333', '0x137f', '0x1234', '0xa341', '-1'],
  917. ['0x1', '0xf0', '0x0000', '0x2222', '0xec80', '0xffff', '0x3471', '-1']],
  918. output_16=[['0x1', '0xf1', '0x7fff', '0x3333', '0xffff', '0xffff', '0xb771', '-1']],
  919. )
  920. add_test_case_int("Xor", ['Xor'], 'Epsilon', 0,
  921. [['0x1', '0x01', '0x7fff0000', '0x11111111', '0x137f', '0x0', '0x12345678', '0xa341', '-1'],
  922. ['0x1', '0xf0', '0x0000ffff', '0x22222222', '0xec80', '0x0', '0x00000000', '0x3471', '-1']],
  923. [['0x0', '0xf1', '0x7fffffff', '0x33333333', '0xffff', '0x0', '0x12345678', '0x9730', '0x00000000']],
  924. "binary int", "^",
  925. input_16=[['0x1', '0x01', '0x7fff', '0x1111', '0x137f', '0x0', '0x1234', '0xa341', '-1'],
  926. ['0x1', '0xf0', '0x0000', '0x2222', '0xec80', '0x0', '0x0000', '0x3471', '-1']],
  927. output_16=[['0x0', '0xf1', '0x7fff', '0x3333', '0xffff', '0x0', '0x1234', '0x9730', '0x0000']],
  928. )
  929. # Binary Uint
  930. add_test_case_int('UAdd', ['Add'], 'Epsilon', 0,
  931. [['2147483648', '4294967285', '0', '0', '10', int32_max, '486'],
  932. ['0', '10', '0', '10', '10', '0', '54238']],
  933. [['2147483648', uint32_max, '0', '10', '20', int32_max, '54724']],
  934. "binary uint", "+",
  935. input_16=[['323', '0xfff5', '0', '0', '10', uint16_max, '486'],
  936. ['0', '10', '0', '10', '10', '0', '334']],
  937. output_16=[['323', uint16_max, '0', '10', '20', uint16_max, '820']])
  938. add_test_case_int('USub', ['Sub'], 'Epsilon', 0,
  939. [['2147483648', uint32_max, '0', '0', '30', int32_max, '54724'],
  940. ['0', '10', '0', '10', '10', '0', '54238']],
  941. [['2147483648', '4294967285', '0', '4294967286', '20', int32_max, '486']],
  942. "binary uint", "-",
  943. input_16=[['323', uint16_max, '0', '0', '10', uint16_max, '486'],
  944. ['0', '10', '0', '10', '10', '0', '334']],
  945. output_16=[['323', '0xfff5', '0', '-10', '0', uint16_max, '152']])
  946. add_test_case_int('UMax', ['UMax'], 'Epsilon', 0,
  947. [['0', '0', '10', '10000', int32_max, uint32_max],
  948. ['0', '256', '4', '10001', '0', uint32_max]],
  949. [['0', '256', '10', '10001', int32_max, uint32_max]],
  950. "binary uint call", "max",
  951. input_16=[['0', '0', '10', '10000', int16_max, uint16_max],
  952. ['0', '256', '4', '10001', '0', uint16_max]],
  953. output_16=[['0', '256', '10', '10001', int16_max, uint16_max]])
  954. add_test_case_int('UMin', ['UMin'], 'Epsilon', 0,
  955. [['0', '0', '10', '10000', int32_max, uint32_max],
  956. ['0', '256', '4', '10001', '0', uint32_max]],
  957. [['0', '0', '4', '10000', '0', uint32_max]],
  958. "binary uint call", "min",
  959. input_16=[['0', '0', '10', '10000', int16_max, uint16_max],
  960. ['0', '256', '4', '10001', '0', uint16_max]],
  961. output_16=[['0', '0', '4', '10000', '0', uint16_max]])
  962. add_test_case_int('UMul', ['Mul'], 'Epsilon', 0,
  963. [['0', '1', '10', '10000', int32_max],
  964. ['0', '256', '4', '10001', '0']],
  965. [['0', '256', '40', '100010000', '0']],
  966. "binary uint", "*",
  967. input_16=[['0', '0', '10', '100', int16_max],
  968. ['0', '256', '4', '101', '0']],
  969. output_16=[['0', '0', '40', '10001', '0']])
  970. add_test_case('UDiv', ['UDiv', 'URem'], 'Epsilon', 0,
  971. [['1', '1', '10', '10000', int32_max, int32_max, '0xffffffff'],
  972. ['0', '256', '4', '10001', '0', int32_max, '1']],
  973. [['0xffffffff', '0', '2', '0', '0xffffffff', '1', '0xffffffff'],
  974. ['0xffffffff', '1', '2', '10000', '0xffffffff', '0', '0']], 'cs_6_0',
  975. ''' struct SBinaryUintOp {
  976. uint input1;
  977. uint input2;
  978. uint output1;
  979. uint output2;
  980. };
  981. RWStructuredBuffer<SBinaryUintOp> g_buf : register(u0);
  982. [numthreads(8,8,1)]
  983. void main(uint GI : SV_GroupIndex) {
  984. SBinaryUintOp l = g_buf[GI];
  985. l.output1 = l.input1 / l.input2;
  986. l.output2 = l.input1 % l.input2;
  987. g_buf[GI] = l;
  988. };''')
  989. add_test_case('UAddc', ['UAddc'], 'Epsilon', 0,
  990. [['1', '1', '10000', '0x80000000', '0x7fffffff', '0xffffffff'],
  991. ['0', '256', '10001', '1', '0x7fffffff', '0x7fffffff']],
  992. [['2', '2', '20000', '0', '0xfffffffe', '0xfffffffe'],
  993. ['0', '512', '20002', '3', '0xfffffffe', '0xffffffff']], 'cs_6_0',
  994. ''' struct SBinaryUintOp {
  995. uint input1;
  996. uint input2;
  997. uint output1;
  998. uint output2;
  999. };
  1000. RWStructuredBuffer<SBinaryUintOp> g_buf : register(u0);
  1001. [numthreads(8,8,1)]
  1002. void main(uint GI : SV_GroupIndex) {
  1003. SBinaryUintOp l = g_buf[GI];
  1004. uint2 x = uint2(l.input1, l.input2);
  1005. uint2 y = AddUint64(x, x);
  1006. l.output1 = y.x;
  1007. l.output2 = y.y;
  1008. g_buf[GI] = l;
  1009. };''')
  1010. # Tertiary Int
  1011. add_test_case_int('IMad', ['IMad'], 'epsilon', 0, [[
  1012. '-2147483647', '-256', '-1', '0', '1', '2', '16', int32_max, '1',
  1013. '-1', '1', '10'
  1014. ], ['1', '-256', '-1', '0', '1', '3', '16', '0', '1', '-1', '10', '100'], [
  1015. '0', '0', '0', '0', '1', '3', '1', '255', '2147483646', '-2147483647',
  1016. '-10', '-2000'
  1017. ]], [[
  1018. '-2147483647', '65536', '1', '0', '2', '9', '257', '255', int32_max,
  1019. '-2147483646', '0', '-1000'
  1020. ]], "tertiary int", "mad",
  1021. input_16=[[int16_min, '-256', '-1', '0', '1', '2', '16', int16_max],
  1022. ['1','8','-1', '0', '1', '3', '16','1'],
  1023. ['0', '0', '1', '3', '250', '-30', int16_min, '-50']],
  1024. output_16=[[int16_min, '-2048', '2', '3', '251', '-24', '-32512', '32717']]
  1025. )
  1026. add_test_case_int('UMad', ['UMad'], 'epsilon', 0,
  1027. [['0', '1', '2', '16', int32_max, '0', '10'], [
  1028. '0', '1', '2', '16', '1', '0', '10'
  1029. ], ['0', '0', '1', '15', '0', '10', '10']],
  1030. [['0', '1', '5', '271', int32_max, '10', '110']],
  1031. "tertiary uint", "mad",
  1032. input_16=[['0', '1', '2', '16', int16_max, '0', '10'], [
  1033. '0', '1', '2', '16', '1', '0', '10'
  1034. ], ['0', '0', '1', '15', '0', '10', '10']],
  1035. output_16=[['0', '1', '5', '271', int16_max, '10', '110']],
  1036. )
  1037. # Dot
  1038. add_test_case('Dot', ['Dot2', 'Dot3', 'Dot4'], 'epsilon', 0.008, [[
  1039. 'NaN,NaN,NaN,NaN', '-Inf,-Inf,-Inf,-Inf',
  1040. '-denorm,-denorm,-denorm,-denorm', '-0,-0,-0,-0', '0,0,0,0',
  1041. 'denorm,denorm,denorm,denorm', 'Inf,Inf,Inf,Inf', '1,1,1,1',
  1042. '-10,0,0,10', 'Inf,Inf,Inf,-Inf'
  1043. ], [
  1044. 'NaN,NaN,NaN,NaN', '-Inf,-Inf,-Inf,-Inf',
  1045. '-denorm,-denorm,-denorm,-denorm', '-0,-0,-0,-0', '0,0,0,0',
  1046. 'denorm,denorm,denorm,denorm', 'Inf,Inf,Inf,Inf', '1,1,1,1',
  1047. '10,0,0,10', 'Inf,Inf,Inf,Inf'
  1048. ]], [
  1049. [nan, p_inf, 0, 0, 0, 0, p_inf, 2, -100, p_inf],
  1050. [nan, p_inf, 0, 0, 0, 0, p_inf, 3, -100, p_inf],
  1051. [nan, p_inf, 0, 0, 0, 0, p_inf, 4, 0, nan],
  1052. ], 'cs_6_0', ''' struct SDotOp {
  1053. float4 input1;
  1054. float4 input2;
  1055. float o_dot2;
  1056. float o_dot3;
  1057. float o_dot4;
  1058. };
  1059. RWStructuredBuffer<SDotOp> g_buf : register(u0);
  1060. [numthreads(8,8,1)]
  1061. void main(uint GI : SV_GroupIndex) {
  1062. SDotOp l = g_buf[GI];
  1063. l.o_dot2 = dot(l.input1.xy, l.input2.xy);
  1064. l.o_dot3 = dot(l.input1.xyz, l.input2.xyz);
  1065. l.o_dot4 = dot(l.input1.xyzw, l.input2.xyzw);
  1066. g_buf[GI] = l;
  1067. };''')
  1068. # Quaternary
  1069. # Msad4 intrinsic calls both Bfi and Msad. Currently this is the only way to call bfi instruction from HLSL
  1070. add_test_case('Bfi', ['Bfi', 'Msad'], 'epsilon', 0,
  1071. [["0xA100B2C3", "0x00000000", "0xFFFF01C1", "0xFFFFFFFF"], [
  1072. "0xD7B0C372, 0x4F57C2A3", "0xFFFFFFFF, 0x00000000",
  1073. "0x38A03AEF, 0x38194DA3", "0xFFFFFFFF, 0x00000000"
  1074. ], ["1,2,3,4", "1,2,3,4", "0,0,0,0", "10,10,10,10"]],
  1075. [['153,6,92,113', '1,2,3,4', '397,585,358,707', '10,265,520,775']],
  1076. 'cs_6_0', ''' struct SMsad4 {
  1077. uint ref;
  1078. uint2 source;
  1079. uint4 accum;
  1080. uint4 result;
  1081. };
  1082. RWStructuredBuffer<SMsad4> g_buf : register(u0);
  1083. [numthreads(8,8,1)]
  1084. void main(uint GI : SV_GroupIndex) {
  1085. SMsad4 l = g_buf[GI];
  1086. l.result = msad4(l.ref, l.source, l.accum);
  1087. g_buf[GI] = l;
  1088. };''')
  1089. # Wave Active Tests
  1090. add_test_case('WaveActiveSum', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1091. [['1', '2', '3', '4'], ['0'], ['2', '4', '8', '-64']], [],
  1092. 'cs_6_0', get_shader_text("wave op int", "WaveActiveSum"))
  1093. add_test_case('WaveActiveProduct', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1094. [['1', '2', '3', '4'], ['0'], ['1', '2', '4', '-64']], [],
  1095. 'cs_6_0', get_shader_text("wave op int", "WaveActiveProduct"))
  1096. add_test_case('WaveActiveCountBits', ['WaveAllBitCount', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1097. [['1', '2', '3', '4'], ['0'], ['1', '10', '-4', '-64'],
  1098. ['-100', '-1000', '300']], [], 'cs_6_0',
  1099. get_shader_text("wave op int count", "WaveActiveCountBits"))
  1100. add_test_case('WaveActiveMax', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1101. [['1', '2', '3', '4'], ['0'], ['1', '10', '-4', '-64'],
  1102. ['-100', '-1000', '300']], [], 'cs_6_0',
  1103. get_shader_text("wave op int", "WaveActiveMax"))
  1104. add_test_case('WaveActiveMin', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1105. [['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'], ['0'],
  1106. ['1', '10', '-4', '-64'], ['-100', '-1000', '300']], [],
  1107. 'cs_6_0', get_shader_text("wave op int", "WaveActiveMin"))
  1108. add_test_case('WaveActiveAllEqual', ['WaveActiveAllEqual'], 'Epsilon', 0,
  1109. [['1', '2', '3', '4', '1', '1', '1', '1'], ['3'], ['-10']],
  1110. [], 'cs_6_0', get_shader_text("wave op int", "WaveActiveAllEqual"))
  1111. add_test_case('WaveActiveAnyTrue', ['WaveAnyTrue', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1112. [['1', '0', '1', '0', '1'], ['1'], ['0']], [], 'cs_6_0',
  1113. get_shader_text("wave op int", "WaveActiveAnyTrue"))
  1114. add_test_case('WaveActiveAllTrue', ['WaveAllTrue', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1115. [['1', '0', '1', '0', '1'], ['1'], ['1']], [], 'cs_6_0',
  1116. get_shader_text("wave op int", "WaveActiveAllTrue"))
  1117. add_test_case('WaveActiveUSum', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1118. [['1', '2', '3', '4'], ['0'], ['2', '4', '8', '64']], [],
  1119. 'cs_6_0', get_shader_text("wave op uint", "WaveActiveSum"))
  1120. add_test_case('WaveActiveUProduct', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1121. [['1', '2', '3', '4'], ['0'], ['1', '2', '4', '64']], [],
  1122. 'cs_6_0', get_shader_text("wave op uint", "WaveActiveProduct"))
  1123. add_test_case('WaveActiveUMax', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1124. [['1', '2', '3', '4'], ['0'], ['1', '10', '4', '64']], [],
  1125. 'cs_6_0', get_shader_text("wave op uint", "WaveActiveMax"))
  1126. add_test_case('WaveActiveUMin', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1127. [['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'], ['0'],
  1128. ['1', '10', '4', '64']], [], 'cs_6_0',
  1129. get_shader_text("wave op uint", "WaveActiveMin"))
  1130. add_test_case('WaveActiveBitOr', ['WaveActiveBit'], 'Epsilon', 0, [[
  1131. '0xe0000000', '0x0d000000', '0x00b00000', '0x00070000', '0x0000e000',
  1132. '0x00000d00', '0x000000b0', '0x00000007'
  1133. ], ['0xedb7edb7', '0xdb7edb7e', '0xb7edb7ed', '0x7edb7edb'], [
  1134. '0x12481248', '0x24812481', '0x48124812', '0x81248124'
  1135. ], ['0x00000000', '0xffffffff']], [], 'cs_6_0', get_shader_text("wave op uint", "WaveActiveBitOr"))
  1136. add_test_case('WaveActiveBitAnd', ['WaveActiveBit'], 'Epsilon', 0, [[
  1137. '0xefffffff', '0xfdffffff', '0xffbfffff', '0xfff7ffff', '0xffffefff',
  1138. '0xfffffdff', '0xffffffbf', '0xfffffff7'
  1139. ], ['0xedb7edb7', '0xdb7edb7e', '0xb7edb7ed', '0x7edb7edb'], [
  1140. '0x12481248', '0x24812481', '0x48124812', '0x81248124'
  1141. ], ['0x00000000', '0xffffffff']], [], 'cs_6_0', get_shader_text("wave op uint", "WaveActiveBitAnd"))
  1142. add_test_case('WaveActiveBitXor', ['WaveActiveBit'], 'Epsilon', 0, [[
  1143. '0xe0000000', '0x0d000000', '0x00b00000', '0x00070000', '0x0000e000',
  1144. '0x00000d00', '0x000000b0', '0x00000007'
  1145. ], ['0xedb7edb7', '0xdb7edb7e', '0xb7edb7ed', '0x7edb7edb'], [
  1146. '0x12481248', '0x24812481', '0x48124812', '0x81248124'
  1147. ], ['0x00000000', '0xffffffff']], [], 'cs_6_0', get_shader_text("wave op uint", "WaveActiveBitXor"))
  1148. add_test_case('WavePrefixCountBits', ['WavePrefixBitCount'], 'Epsilon', 0,
  1149. [['1', '2', '3', '4', '5'], ['0'], ['1', '10', '-4', '-64'],
  1150. ['-100', '-1000', '300']], [], 'cs_6_0',
  1151. get_shader_text("wave op int count", "WavePrefixCountBits"))
  1152. add_test_case('WavePrefixSum', ['WavePrefixOp'], 'Epsilon', 0,
  1153. [['1', '2', '3', '4', '5'], ['0', '1'], ['1', '2', '4', '-64', '128']],
  1154. [], 'cs_6_0', get_shader_text("wave op int", "WavePrefixSum"))
  1155. add_test_case('WavePrefixProduct', ['WavePrefixOp'], 'Epsilon', 0,
  1156. [['1', '2', '3', '4', '5'], ['0', '1'], ['1', '2', '4', '-64', '128']],
  1157. [], 'cs_6_0', get_shader_text("wave op int", "WavePrefixProduct"))
  1158. add_test_case('WavePrefixUSum', ['WavePrefixOp'], 'Epsilon', 0,
  1159. [['1', '2', '3', '4', '5'], ['0', '1'], ['1', '2', '4', '128']], [],
  1160. 'cs_6_0', get_shader_text("wave op uint", "WavePrefixSum"))
  1161. add_test_case('WavePrefixUProduct', ['WavePrefixOp'], 'Epsilon', 0,
  1162. [['1', '2', '3', '4', '5'], ['0', '1'], ['1', '2', '4', '128']], [],
  1163. 'cs_6_0', get_shader_text("wave op uint", "WavePrefixProduct"))
  1164. # generating xml file for execution test using data driven method
  1165. # TODO: ElementTree is not generating formatted XML. Currently xml file is checked in after VS Code formatter.
  1166. # Implement xml formatter or import formatter library and use that instead.
  1167. def generate_parameter_types(table, num_inputs, num_outputs, has_known_warp_issue=False):
  1168. param_types = ET.SubElement(table, "ParameterTypes")
  1169. ET.SubElement(
  1170. param_types, "ParameterType", attrib={
  1171. "Name": "ShaderOp.Target"
  1172. }).text = "String"
  1173. ET.SubElement(
  1174. param_types, "ParameterType", attrib={
  1175. "Name": "ShaderOp.Arguments"
  1176. }).text = "String"
  1177. ET.SubElement(
  1178. param_types, "ParameterType", attrib={
  1179. "Name": "ShaderOp.Text"
  1180. }).text = "String"
  1181. ET.SubElement(
  1182. param_types, "ParameterType", attrib={
  1183. "Name": "Validation.Type"
  1184. }).text = "String"
  1185. ET.SubElement(
  1186. param_types, "ParameterType", attrib={
  1187. "Name": "Validation.Tolerance"
  1188. }).text = "double"
  1189. for i in range(0, num_inputs):
  1190. ET.SubElement(
  1191. param_types,
  1192. "ParameterType",
  1193. attrib={
  1194. "Name": 'Validation.Input{}'.format(i + 1),
  1195. 'Array': 'true'
  1196. }).text = "String"
  1197. for i in range(0, num_outputs):
  1198. ET.SubElement(
  1199. param_types,
  1200. "ParameterType",
  1201. attrib={
  1202. "Name": 'Validation.Expected{}'.format(i + 1),
  1203. 'Array': 'true'
  1204. }).text = "String"
  1205. if has_known_warp_issue:
  1206. ET.SubElement(param_types, "ParameterType", attrib={"Name":"Warp.Version"}).text = "unsigned int"
  1207. def generate_parameter_types_wave(table):
  1208. param_types = ET.SubElement(table, "ParameterTypes")
  1209. ET.SubElement(
  1210. param_types, "ParameterType", attrib={
  1211. "Name": "ShaderOp.Target"
  1212. }).text = "String"
  1213. ET.SubElement(
  1214. param_types, "ParameterType", attrib={
  1215. "Name": "ShaderOp.Text"
  1216. }).text = "String"
  1217. ET.SubElement(
  1218. param_types,
  1219. "ParameterType",
  1220. attrib={
  1221. "Name": "Validation.NumInputSet"
  1222. }).text = "String"
  1223. ET.SubElement(
  1224. param_types,
  1225. "ParameterType",
  1226. attrib={
  1227. "Name": "Validation.InputSet1",
  1228. "Array": "true"
  1229. }).text = "String"
  1230. ET.SubElement(
  1231. param_types,
  1232. "ParameterType",
  1233. attrib={
  1234. "Name": "Validation.InputSet2",
  1235. "Array": "true"
  1236. }).text = "String"
  1237. ET.SubElement(
  1238. param_types,
  1239. "ParameterType",
  1240. attrib={
  1241. "Name": "Validation.InputSet3",
  1242. "Array": "true"
  1243. }).text = "String"
  1244. ET.SubElement(
  1245. param_types,
  1246. "ParameterType",
  1247. attrib={
  1248. "Name": "Validation.InputSet4",
  1249. "Array": "true"
  1250. }).text = "String"
  1251. def generate_parameter_types_msad(table):
  1252. param_types = ET.SubElement(table, "ParameterTypes")
  1253. ET.SubElement(
  1254. param_types, "ParameterType", attrib={
  1255. "Name": "ShaderOp.Text"
  1256. }).text = "String"
  1257. ET.SubElement(
  1258. param_types, "ParameterType", attrib={
  1259. "Name": "Validation.Tolerance"
  1260. }).text = "int"
  1261. ET.SubElement(
  1262. param_types,
  1263. "ParameterType",
  1264. attrib={
  1265. "Name": "Validation.Input1",
  1266. "Array": "true"
  1267. }).text = "unsigned int"
  1268. ET.SubElement(
  1269. param_types,
  1270. "ParameterType",
  1271. attrib={
  1272. "Name": "Validation.Input2",
  1273. "Array": "true"
  1274. }).text = "String"
  1275. ET.SubElement(
  1276. param_types,
  1277. "ParameterType",
  1278. attrib={
  1279. "Name": "Validation.Input3",
  1280. "Array": "true"
  1281. }).text = "String"
  1282. ET.SubElement(
  1283. param_types,
  1284. "ParameterType",
  1285. attrib={
  1286. "Name": "Validation.Expected1",
  1287. "Array": "true"
  1288. }).text = "String"
  1289. def generate_row(table, case):
  1290. row = ET.SubElement(table, "Row", {"Name": case.test_name})
  1291. ET.SubElement(row, "Parameter", {
  1292. "Name": "Validation.Type"
  1293. }).text = case.validation_type
  1294. ET.SubElement(row, "Parameter", {
  1295. "Name": "Validation.Tolerance"
  1296. }).text = str(case.validation_tolerance)
  1297. ET.SubElement(row, "Parameter", {
  1298. "Name": "ShaderOp.Text"
  1299. }).text = case.shader_text
  1300. ET.SubElement(row, "Parameter", {
  1301. "Name": "ShaderOp.Target"
  1302. }).text = case.shader_target
  1303. for i in range(len(case.input_lists)):
  1304. inputs = ET.SubElement(row, "Parameter", {
  1305. "Name": "Validation.Input{}".format(i + 1)
  1306. })
  1307. for val in case.input_lists[i]:
  1308. ET.SubElement(inputs, "Value").text = str(val)
  1309. for i in range(len(case.output_lists)):
  1310. outputs = ET.SubElement(row, "Parameter", {
  1311. "Name": "Validation.Expected{}".format(i + 1)
  1312. })
  1313. for val in case.output_lists[i]:
  1314. ET.SubElement(outputs, "Value").text = str(val)
  1315. # Optional parameters
  1316. if case.warp_version > 0:
  1317. ET.SubElement(row, "Parameter", {"Name":"Warp.Version"}).text = str(case.warp_version)
  1318. if case.shader_arguments != "":
  1319. ET.SubElement(row, "Parameter", {"Name":"ShaderOp.Arguments"}).text = case.shader_arguments
  1320. def generate_row_wave(table, case):
  1321. row = ET.SubElement(table, "Row", {"Name": case.test_name})
  1322. ET.SubElement(row, "Parameter", {
  1323. "Name": "ShaderOp.Name"
  1324. }).text = case.test_name
  1325. ET.SubElement(row, "Parameter", {
  1326. "Name": "ShaderOp.Text"
  1327. }).text = case.shader_text
  1328. ET.SubElement(row, "Parameter", {
  1329. "Name": "Validation.NumInputSet"
  1330. }).text = str(len(case.input_lists))
  1331. for i in range(len(case.input_lists)):
  1332. inputs = ET.SubElement(row, "Parameter", {
  1333. "Name": "Validation.InputSet{}".format(i + 1)
  1334. })
  1335. for val in case.input_lists[i]:
  1336. ET.SubElement(inputs, "Value").text = str(val)
  1337. def generate_table_for_taef():
  1338. with open("..\\..\\tools\\clang\\unittests\\HLSL\\ShaderOpArithTable.xml",
  1339. 'w') as f:
  1340. tree = ET.ElementTree()
  1341. root = ET.Element('Data')
  1342. # Create tables
  1343. generate_parameter_types(
  1344. ET.SubElement(root, "Table", attrib={
  1345. "Id": "UnaryFloatOpTable"
  1346. }), 1, 1, True)
  1347. generate_parameter_types(
  1348. ET.SubElement(root, "Table", attrib={
  1349. "Id": "BinaryFloatOpTable"
  1350. }), 2, 2)
  1351. generate_parameter_types(
  1352. ET.SubElement(root, "Table", attrib={
  1353. "Id": "TertiaryFloatOpTable"
  1354. }), 3, 1)
  1355. generate_parameter_types(
  1356. ET.SubElement(root, "Table", attrib={
  1357. "Id": "UnaryHalfOpTable"
  1358. }), 1, 1, True)
  1359. generate_parameter_types(
  1360. ET.SubElement(root, "Table", attrib={
  1361. "Id": "BinaryHalfOpTable"
  1362. }), 2, 2)
  1363. generate_parameter_types(
  1364. ET.SubElement(root, "Table", attrib={
  1365. "Id": "TertiaryHalfOpTable"
  1366. }), 3, 1)
  1367. generate_parameter_types(
  1368. ET.SubElement(root, "Table", attrib={
  1369. "Id": "UnaryIntOpTable"
  1370. }), 1, 1)
  1371. generate_parameter_types(
  1372. ET.SubElement(root, "Table", attrib={
  1373. "Id": "BinaryIntOpTable"
  1374. }), 2, 2)
  1375. generate_parameter_types(
  1376. ET.SubElement(root, "Table", attrib={
  1377. "Id": "TertiaryIntOpTable"
  1378. }), 3, 1)
  1379. generate_parameter_types(
  1380. ET.SubElement(root, "Table", attrib={
  1381. "Id": "UnaryInt16OpTable"
  1382. }), 1, 1)
  1383. generate_parameter_types(
  1384. ET.SubElement(root, "Table", attrib={
  1385. "Id": "BinaryInt16OpTable"
  1386. }), 2, 2)
  1387. generate_parameter_types(
  1388. ET.SubElement(root, "Table", attrib={
  1389. "Id": "TertiaryInt16OpTable"
  1390. }), 3, 1)
  1391. generate_parameter_types(
  1392. ET.SubElement(root, "Table", attrib={
  1393. "Id": "UnaryUintOpTable"
  1394. }), 1, 1)
  1395. generate_parameter_types(
  1396. ET.SubElement(root, "Table", attrib={
  1397. "Id": "BinaryUintOpTable"
  1398. }), 2, 2)
  1399. generate_parameter_types(
  1400. ET.SubElement(root, "Table", attrib={
  1401. "Id": "TertiaryUintOpTable"
  1402. }), 3, 1)
  1403. generate_parameter_types(
  1404. ET.SubElement(root, "Table", attrib={
  1405. "Id": "UnaryUint16OpTable"
  1406. }), 1, 1)
  1407. generate_parameter_types(
  1408. ET.SubElement(root, "Table", attrib={
  1409. "Id": "BinaryUint16OpTable"
  1410. }), 2, 2)
  1411. generate_parameter_types(
  1412. ET.SubElement(root, "Table", attrib={
  1413. "Id": "TertiaryUint16OpTable"
  1414. }), 3, 1)
  1415. generate_parameter_types(
  1416. ET.SubElement(root, "Table", attrib={
  1417. "Id": "DotOpTable"
  1418. }), 2, 3)
  1419. generate_parameter_types_msad(
  1420. ET.SubElement(root, "Table", attrib={
  1421. "Id": "Msad4Table"
  1422. }))
  1423. generate_parameter_types_wave(
  1424. ET.SubElement(
  1425. root, "Table", attrib={
  1426. "Id": "WaveIntrinsicsActiveIntTable"
  1427. }))
  1428. generate_parameter_types_wave(
  1429. ET.SubElement(
  1430. root, "Table", attrib={
  1431. "Id": "WaveIntrinsicsActiveUintTable"
  1432. }))
  1433. generate_parameter_types_wave(
  1434. ET.SubElement(
  1435. root, "Table", attrib={
  1436. "Id": "WaveIntrinsicsPrefixIntTable"
  1437. }))
  1438. generate_parameter_types_wave(
  1439. ET.SubElement(
  1440. root, "Table", attrib={
  1441. "Id": "WaveIntrinsicsPrefixUintTable"
  1442. }))
  1443. generate_parameter_types(
  1444. ET.SubElement(
  1445. root, "Table", attrib={
  1446. "Id": "DenormBinaryFloatOpTable"
  1447. }), 2, 1)
  1448. generate_parameter_types(
  1449. ET.SubElement(
  1450. root, "Table", attrib={
  1451. "Id": "DenormTertiaryFloatOpTable"
  1452. }), 3, 1)
  1453. for case in g_test_cases.values():
  1454. cur_inst = case.insts[0]
  1455. if cur_inst.is_cast or cur_inst.category.startswith("Unary"):
  1456. if "f" in cur_inst.oload_types and not "Half" in case.test_name:
  1457. generate_row(
  1458. root.find("./Table[@Id='UnaryFloatOpTable']"),
  1459. case)
  1460. if "h" in cur_inst.oload_types and "Half" in case.test_name:
  1461. generate_row(root.find("./Table[@Id='UnaryHalfOpTable']"),case)
  1462. if "i" in cur_inst.oload_types and "Bit16" not in case.test_name:
  1463. if cur_inst.category.startswith("Unary int"):
  1464. generate_row(
  1465. root.find("./Table[@Id='UnaryIntOpTable']"),
  1466. case)
  1467. elif cur_inst.category.startswith("Unary uint"):
  1468. generate_row(
  1469. root.find("./Table[@Id='UnaryUintOpTable']"),
  1470. case)
  1471. else:
  1472. print("unknown op: " + cur_inst.name)
  1473. print(cur_inst.dxil_class)
  1474. if "w" in cur_inst.oload_types and "Bit16" in case.test_name:
  1475. if cur_inst.category.startswith("Unary int"):
  1476. generate_row(
  1477. root.find("./Table[@Id='UnaryInt16OpTable']"),
  1478. case)
  1479. elif cur_inst.category.startswith("Unary uint"):
  1480. generate_row(
  1481. root.find("./Table[@Id='UnaryUint16OpTable']"),
  1482. case)
  1483. else:
  1484. print("unknown op: " + cur_inst.name)
  1485. print(cur_inst.dxil_class)
  1486. elif cur_inst.is_binary or cur_inst.category.startswith(
  1487. "Binary"):
  1488. if "f" in cur_inst.oload_types and not "Half" in case.test_name:
  1489. if case.test_name in g_denorm_tests: # for denorm tests
  1490. generate_row(
  1491. root.find("./Table[@Id='DenormBinaryFloatOpTable']"),
  1492. case)
  1493. else:
  1494. generate_row(
  1495. root.find("./Table[@Id='BinaryFloatOpTable']"),
  1496. case)
  1497. if "h" in cur_inst.oload_types and "Half" in case.test_name:
  1498. generate_row(root.find("./Table[@Id='BinaryHalfOpTable']"),case)
  1499. if "i" in cur_inst.oload_types and "Bit16" not in case.test_name:
  1500. if cur_inst.category.startswith("Binary int"):
  1501. if case.test_name in ['UAdd', 'USub', 'UMul']: # Add, Sub, Mul use same operations for int and uint.
  1502. generate_row(
  1503. root.find("./Table[@Id='BinaryUintOpTable']"),
  1504. case)
  1505. else:
  1506. generate_row(
  1507. root.find("./Table[@Id='BinaryIntOpTable']"),
  1508. case)
  1509. elif cur_inst.category.startswith("Binary uint"):
  1510. generate_row(
  1511. root.find("./Table[@Id='BinaryUintOpTable']"),
  1512. case)
  1513. else:
  1514. print("unknown op: " + cur_inst.name)
  1515. print(cur_inst.dxil_class)
  1516. if "w" in cur_inst.oload_types and "Bit16" in case.test_name:
  1517. if cur_inst.category.startswith("Binary int"):
  1518. if case.test_name in ['UAdd', 'USub', 'UMul']: # Add, Sub, Mul use same operations for int and uint.
  1519. generate_row(
  1520. root.find("./Table[@Id='BinaryUint16OpTable']"),
  1521. case)
  1522. else:
  1523. generate_row(
  1524. root.find("./Table[@Id='BinaryInt16OpTable']"),
  1525. case)
  1526. elif cur_inst.category.startswith("Binary uint"):
  1527. generate_row(
  1528. root.find("./Table[@Id='BinaryUint16OpTable']"),
  1529. case)
  1530. else:
  1531. print("unknown op: " + cur_inst.name)
  1532. print(cur_inst.dxil_class)
  1533. elif cur_inst.category.startswith("Tertiary"):
  1534. if "f" in cur_inst.oload_types and not "Half" in case.test_name:
  1535. if case.test_name in g_denorm_tests: # for denorm tests
  1536. generate_row(
  1537. root.find("./Table[@Id='DenormTertiaryFloatOpTable']"),case)
  1538. else:
  1539. generate_row(
  1540. root.find("./Table[@Id='TertiaryFloatOpTable']"),case)
  1541. if "h" in cur_inst.oload_types and "Half" in case.test_name:
  1542. generate_row(root.find("./Table[@Id='TertiaryHalfOpTable']"),case)
  1543. if "i" in cur_inst.oload_types and "Bit16" not in case.test_name:
  1544. if cur_inst.category.startswith("Tertiary int"):
  1545. generate_row(
  1546. root.find("./Table[@Id='TertiaryIntOpTable']"),
  1547. case)
  1548. elif cur_inst.category.startswith("Tertiary uint"):
  1549. generate_row(
  1550. root.find(
  1551. "./Table[@Id='TertiaryUintOpTable']"),
  1552. case)
  1553. else:
  1554. print("unknown op: " + cur_inst.name)
  1555. print(cur_inst.dxil_class)
  1556. if "w" in cur_inst.oload_types and "Bit16" in case.test_name:
  1557. if cur_inst.category.startswith("Tertiary int"):
  1558. generate_row(
  1559. root.find("./Table[@Id='TertiaryInt16OpTable']"),
  1560. case)
  1561. elif cur_inst.category.startswith("Tertiary uint"):
  1562. generate_row(
  1563. root.find(
  1564. "./Table[@Id='TertiaryUint16OpTable']"),
  1565. case)
  1566. else:
  1567. print("unknown op: " + cur_inst.name)
  1568. print(cur_inst.dxil_class)
  1569. elif cur_inst.category.startswith("Quaternary"):
  1570. if cur_inst.name == "Bfi":
  1571. generate_row(
  1572. root.find("./Table[@Id='Msad4Table']"), case)
  1573. else:
  1574. print("unknown op: " + cur_inst.name)
  1575. print(cur_inst.dxil_class)
  1576. elif cur_inst.category == "Dot":
  1577. generate_row(root.find("./Table[@Id='DotOpTable']"), case)
  1578. elif cur_inst.dxil_class in ["WaveActiveOp", "WaveAllOp","WaveActiveAllEqual","WaveAnyTrue","WaveAllTrue"]:
  1579. if case.test_name.startswith("WaveActiveU"):
  1580. generate_row_wave(
  1581. root.find(
  1582. "./Table[@Id='WaveIntrinsicsActiveUintTable']"
  1583. ), case)
  1584. else:
  1585. generate_row_wave(
  1586. root.find(
  1587. "./Table[@Id='WaveIntrinsicsActiveIntTable']"),
  1588. case)
  1589. elif cur_inst.dxil_class == "WaveActiveBit":
  1590. generate_row_wave(
  1591. root.find(
  1592. "./Table[@Id='WaveIntrinsicsActiveUintTable']"),
  1593. case)
  1594. elif cur_inst.dxil_class == "WavePrefixOp":
  1595. if case.test_name.startswith("WavePrefixU"):
  1596. generate_row_wave(
  1597. root.find(
  1598. "./Table[@Id='WaveIntrinsicsPrefixUintTable']"
  1599. ), case)
  1600. else:
  1601. generate_row_wave(
  1602. root.find(
  1603. "./Table[@Id='WaveIntrinsicsPrefixIntTable']"),
  1604. case)
  1605. else:
  1606. print("unknown op: " + cur_inst.name)
  1607. print(cur_inst.dxil_class)
  1608. tree._setroot(root)
  1609. tree.write(f)
  1610. f.close()
  1611. def print_untested_inst():
  1612. lst = []
  1613. for name in [node.inst.name for node in g_instruction_nodes.values() if len(node.test_cases) == 0]:
  1614. lst += [name]
  1615. lst.sort()
  1616. for name in lst:
  1617. print(name)
  1618. print("Total uncovered dxil ops: " + str(len(lst)))
  1619. print("Total covered dxil ops: " + str(len(g_instruction_nodes)-len(lst)))
  1620. # name to instruction dict
  1621. g_instruction_nodes = {}
  1622. # test name to test case dict
  1623. g_test_cases = {}
  1624. if __name__ == "__main__":
  1625. db = get_db_dxil()
  1626. for inst in db.instr:
  1627. g_instruction_nodes[inst.name] = inst_node(inst)
  1628. add_test_cases()
  1629. args = vars(parser.parse_args())
  1630. mode = args['mode']
  1631. if mode == "info":
  1632. print_untested_inst()
  1633. elif mode == "gen-xml":
  1634. generate_table_for_taef()
  1635. else:
  1636. print("unknown mode: " + mode)
  1637. exit(1)
  1638. exit(0)