hctdb_test.py 91 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002
  1. # This file is distributed under the University of Illinois Open Source License. See LICENSE.TXT for details
  2. ###############################################################################
  3. # This file contains driver test information for DXIL operations #
  4. ###############################################################################
  5. from hctdb import *
  6. import xml.etree.ElementTree as ET
  7. import argparse
  8. parser = argparse.ArgumentParser(description="contains information about dxil op test cases.")
  9. parser.add_argument('mode', help="'gen-xml' or 'info'")
  10. g_db_dxil = None
  11. def get_db_dxil():
  12. global g_db_dxil
  13. if g_db_dxil is None:
  14. g_db_dxil = db_dxil()
  15. return g_db_dxil
  16. """
  17. This class represents a test case for instructions for driver testings
  18. DXIL instructions and test cases are two disjoint sets where each instruction can have multiple test cases,
  19. and each test case can cover different DXIL instructions. So these two sets form a bipartite graph.
  20. test_name: Test case identifier. Must be unique for each test case.
  21. insts: dxil instructions
  22. validation_type: validation type for test
  23. epsilon: absolute difference check
  24. ulp: units in last place check
  25. relative: relative error check
  26. validation_tolerance: tolerance value for a given test
  27. inputs: testing inputs
  28. outputs: expected outputs for each input
  29. shader_target: target for testing
  30. shader_text: hlsl file that is used for testing dxil op
  31. """
  32. class test_case(object):
  33. def __init__(self, test_name, insts, validation_type, validation_tolerance,
  34. input_lists, output_lists, shader_target, shader_text, **kwargs):
  35. self.test_name = test_name
  36. self.validation_type = validation_type
  37. self.validation_tolerance = validation_tolerance
  38. self.input_lists = input_lists
  39. self.output_lists = output_lists
  40. self.shader_target = shader_target
  41. self.shader_text = shader_text
  42. self.insts = insts # list of instructions each test case cover
  43. self.warp_version = -1 # known warp version that works
  44. self.shader_arguments = ""
  45. for k,v in kwargs.items():
  46. setattr(self, k, v)
  47. # Wrapper for each DXIL instruction
  48. class inst_node(object):
  49. def __init__(self, inst):
  50. self.inst = inst
  51. self.test_cases = [] # list of test_case
  52. def add_test_case(test_name, inst_names, validation_type, validation_tolerance,
  53. input_lists, output_lists, shader_target, shader_text, **kwargs):
  54. insts = []
  55. for inst_name in inst_names:
  56. assert (inst_name in g_instruction_nodes)
  57. insts += [g_instruction_nodes[inst_name].inst]
  58. case = test_case(test_name, insts, validation_type,
  59. validation_tolerance, input_lists, output_lists,
  60. shader_target, shader_text, **kwargs)
  61. g_test_cases[test_name] = case
  62. # update instruction nodes
  63. for inst_name in inst_names:
  64. g_instruction_nodes[inst_name].test_cases += [case]
  65. def add_test_case_int(test_name, inst_names, validation_type, validation_tolerance,
  66. input_lists, output_lists, shader_key, shader_op_name, **kwargs):
  67. add_test_case(test_name, inst_names, validation_type, validation_tolerance,
  68. input_lists, output_lists, "cs_6_0", get_shader_text(shader_key, shader_op_name), **kwargs)
  69. input_lists_16, output_lists_16 = input_lists, output_lists
  70. if "input_16" in kwargs:
  71. input_lists_16 = kwargs["input_16"]
  72. if "output_16" in kwargs:
  73. output_lists_16 = kwargs["output_16"]
  74. add_test_case(test_name + "Bit16", inst_names, validation_type, validation_tolerance,
  75. input_lists_16, output_lists_16, "cs_6_2", get_shader_text(shader_key.replace("int","int16_t"), shader_op_name),
  76. shader_arguments="-enable-16bit-types", **kwargs)
  77. def add_test_case_float_half(test_name, inst_names, validation_type, validation_tolerance,
  78. float_input_lists, float_output_lists, shader_key, shader_op_name, **kwargs):
  79. add_test_case(test_name, inst_names, validation_type, validation_tolerance,
  80. float_input_lists, float_output_lists, "cs_6_0", get_shader_text(shader_key, shader_op_name), **kwargs)
  81. # if half test cases are different from float input lists, use those lists instead for half testings
  82. half_input_lists, half_output_lists, half_validation_type, half_validation_tolerance = float_input_lists, float_output_lists, validation_type, validation_tolerance
  83. if "half_inputs" in kwargs:
  84. half_input_lists = kwargs["half_inputs"]
  85. if "half_outputs" in kwargs:
  86. half_output_lists = kwargs["half_outputs"]
  87. if "half_validation_type" in kwargs:
  88. half_validation_type = kwargs["half_validation_type"]
  89. if "half_validation_tolerance" in kwargs:
  90. half_validation_tolerance = kwargs["half_validation_tolerance"]
  91. # skip relative error test check for half for now
  92. if validation_type != "Relative":
  93. add_test_case(test_name + "Half", inst_names, half_validation_type, half_validation_tolerance,
  94. half_input_lists, half_output_lists, "cs_6_2",
  95. get_shader_text(shader_key.replace("float","half"), shader_op_name), shader_arguments="-enable-16bit-types", **kwargs)
  96. def add_test_case_denorm(test_name, inst_names, validation_type, validation_tolerance, input_lists,
  97. output_lists_ftz, output_lists_preserve, shader_target, shader_text, **kwargs):
  98. add_test_case(test_name + "FTZ", inst_names, validation_type, validation_tolerance, input_lists,
  99. output_lists_ftz, shader_target, shader_text, shader_arguments="-denorm ftz")
  100. add_test_case(test_name + "Preserve", inst_names, validation_type, validation_tolerance, input_lists,
  101. output_lists_preserve, shader_target, shader_text, shader_arguments="-denorm preserve")
  102. # we can expect the same output for "any" and "preserve" mode. We should make sure that for validation zero are accepted outputs for denormal outputs.
  103. add_test_case(test_name + "Any", inst_names, validation_type, validation_tolerance, input_lists,
  104. output_lists_preserve + output_lists_ftz, shader_target, shader_text, shader_arguments="-denorm any")
  105. g_shader_texts = {
  106. "unary int": ''' struct SUnaryIntOp {
  107. int input;
  108. int output;
  109. };
  110. RWStructuredBuffer<SUnaryIntOp> g_buf : register(u0);
  111. [numthreads(8,8,1)]
  112. void main(uint GI : SV_GroupIndex) {
  113. SUnaryIntOp l = g_buf[GI];
  114. l.output = %s(l.input);
  115. g_buf[GI] = l;
  116. };''',
  117. "unary int16_t": ''' struct SUnaryInt16Op {
  118. int16_t input;
  119. int16_t output;
  120. };
  121. RWStructuredBuffer<SUnaryInt16Op> g_buf : register(u0);
  122. [numthreads(8,8,1)]
  123. void main(uint GI : SV_GroupIndex) {
  124. SUnaryInt16Op l = g_buf[GI];
  125. l.output = %s(l.input);
  126. g_buf[GI] = l;
  127. };''',
  128. "unary uint": ''' struct SUnaryUintOp {
  129. uint input;
  130. uint output;
  131. };
  132. RWStructuredBuffer<SUnaryUintOp> g_buf : register(u0);
  133. [numthreads(8,8,1)]
  134. void main(uint GI : SV_GroupIndex) {
  135. SUnaryUintOp l = g_buf[GI];
  136. l.output = %s(l.input);
  137. g_buf[GI] = l;
  138. };''',
  139. "unary uint16_t": ''' struct SUnaryUint16Op {
  140. uint16_t input;
  141. uint16_t output;
  142. };
  143. RWStructuredBuffer<SUnaryUint16Op> g_buf : register(u0);
  144. [numthreads(8,8,1)]
  145. void main(uint GI : SV_GroupIndex) {
  146. SUnaryUint16Op l = g_buf[GI];
  147. l.output = %s(l.input);
  148. g_buf[GI] = l;
  149. };''',
  150. "unary float": ''' struct SUnaryFPOp {
  151. float input;
  152. float output;
  153. };
  154. RWStructuredBuffer<SUnaryFPOp> g_buf : register(u0);
  155. [numthreads(8,8,1)]
  156. void main(uint GI : SV_GroupIndex) {
  157. SUnaryFPOp l = g_buf[GI];
  158. l.output = %s(l.input);
  159. g_buf[GI] = l;
  160. };''',
  161. "unary float bool": ''' struct SUnaryFPOp {
  162. float input;
  163. float output;
  164. };
  165. RWStructuredBuffer<SUnaryFPOp> g_buf : register(u0);
  166. [numthreads(8,8,1)]
  167. void main(uint GI : SV_GroupIndex) {
  168. SUnaryFPOp l = g_buf[GI];
  169. if (%s(l.input))
  170. l.output = 1;
  171. else
  172. l.output = 0;
  173. g_buf[GI] = l;
  174. };''',
  175. "unary half": ''' struct SUnaryFPOp {
  176. float16_t input;
  177. float16_t output;
  178. };
  179. RWStructuredBuffer<SUnaryFPOp> g_buf : register(u0);
  180. [numthreads(8,8,1)]
  181. void main(uint GI : SV_GroupIndex) {
  182. SUnaryFPOp l = g_buf[GI];
  183. l.output = %s(l.input);
  184. g_buf[GI] = l;
  185. };''',
  186. "unary half bool": ''' struct SUnaryFPOp {
  187. float16_t input;
  188. float16_t output;
  189. };
  190. RWStructuredBuffer<SUnaryFPOp> g_buf : register(u0);
  191. [numthreads(8,8,1)]
  192. void main(uint GI : SV_GroupIndex) {
  193. SUnaryFPOp l = g_buf[GI];
  194. if (%s(l.input))
  195. l.output = 1;
  196. else
  197. l.output = 0;
  198. g_buf[GI] = l;
  199. };''',
  200. "binary int": ''' struct SBinaryIntOp {
  201. int input1;
  202. int input2;
  203. int output1;
  204. int output2;
  205. };
  206. RWStructuredBuffer<SBinaryIntOp> g_buf : register(u0);
  207. [numthreads(8,8,1)]
  208. void main(uint GI : SV_GroupIndex) {
  209. SBinaryIntOp l = g_buf[GI];
  210. l.output1 = l.input1 %s l.input2;
  211. g_buf[GI] = l;
  212. };''',
  213. "binary int16_t": ''' struct SBinaryInt16Op {
  214. int16_t input1;
  215. int16_t input2;
  216. int16_t output1;
  217. int16_t output2;
  218. };
  219. RWStructuredBuffer<SBinaryInt16Op> g_buf : register(u0);
  220. [numthreads(8,8,1)]
  221. void main(uint GI : SV_GroupIndex) {
  222. SBinaryInt16Op l = g_buf[GI];
  223. l.output1 = l.input1 %s l.input2;
  224. g_buf[GI] = l;
  225. };''',
  226. "binary int call": ''' struct SBinaryIntOp {
  227. int input1;
  228. int input2;
  229. int output1;
  230. int output2;
  231. };
  232. RWStructuredBuffer<SBinaryIntOp> g_buf : register(u0);
  233. [numthreads(8,8,1)]
  234. void main(uint GI : SV_GroupIndex) {
  235. SBinaryIntOp l = g_buf[GI];
  236. l.output1 = %s(l.input1,l.input2);
  237. g_buf[GI] = l;
  238. };''',
  239. "binary int16_t call": ''' struct SBinaryInt16Op {
  240. int16_t input1;
  241. int16_t input2;
  242. int16_t output1;
  243. int16_t output2;
  244. };
  245. RWStructuredBuffer<SBinaryInt16Op> g_buf : register(u0);
  246. [numthreads(8,8,1)]
  247. void main(uint GI : SV_GroupIndex) {
  248. SBinaryInt16Op l = g_buf[GI];
  249. l.output1 = %s(l.input1,l.input2);
  250. g_buf[GI] = l;
  251. };''',
  252. "binary uint": ''' struct SBinaryUintOp {
  253. uint input1;
  254. uint input2;
  255. uint output1;
  256. uint output2;
  257. };
  258. RWStructuredBuffer<SBinaryUintOp> g_buf : register(u0);
  259. [numthreads(8,8,1)]
  260. void main(uint GI : SV_GroupIndex) {
  261. SBinaryUintOp l = g_buf[GI];
  262. l.output1 = l.input1 %s l.input2;
  263. g_buf[GI] = l;
  264. };''',
  265. "binary uint16_t": ''' struct SBinaryUint16Op {
  266. uint16_t input1;
  267. uint16_t input2;
  268. uint16_t output1;
  269. uint16_t output2;
  270. };
  271. RWStructuredBuffer<SBinaryUint16Op> g_buf : register(u0);
  272. [numthreads(8,8,1)]
  273. void main(uint GI : SV_GroupIndex) {
  274. SBinaryUint16Op l = g_buf[GI];
  275. l.output1 = l.input1 %s l.input2;
  276. g_buf[GI] = l;
  277. };''',
  278. "binary uint call": ''' struct SBinaryUintOp {
  279. uint input1;
  280. uint input2;
  281. uint output1;
  282. uint output2;
  283. };
  284. RWStructuredBuffer<SBinaryUintOp> g_buf : register(u0);
  285. [numthreads(8,8,1)]
  286. void main(uint GI : SV_GroupIndex) {
  287. SBinaryUintOp l = g_buf[GI];
  288. l.output1 = %s(l.input1,l.input2);
  289. g_buf[GI] = l;
  290. };''',
  291. "binary uint16_t call": ''' struct SBinaryUint16Op {
  292. uint16_t input1;
  293. uint16_t input2;
  294. uint16_t output1;
  295. uint16_t output2;
  296. };
  297. RWStructuredBuffer<SBinaryUint16Op> g_buf : register(u0);
  298. [numthreads(8,8,1)]
  299. void main(uint GI : SV_GroupIndex) {
  300. SBinaryUint16Op l = g_buf[GI];
  301. l.output1 = %s(l.input1,l.input2);
  302. g_buf[GI] = l;
  303. };''',
  304. "binary float": ''' struct SBinaryFPOp {
  305. float input1;
  306. float input2;
  307. float output1;
  308. float output2;
  309. };
  310. RWStructuredBuffer<SBinaryFPOp> g_buf : register(u0);
  311. [numthreads(8,8,1)]
  312. void main(uint GI : SV_GroupIndex) {
  313. SBinaryFPOp l = g_buf[GI];
  314. l.output1 = l.input1 %s l.input2;
  315. g_buf[GI] = l;
  316. };''',
  317. "binary float call": ''' struct SBinaryFPOp {
  318. float input1;
  319. float input2;
  320. float output1;
  321. float output2;
  322. };
  323. RWStructuredBuffer<SBinaryFPOp> g_buf : register(u0);
  324. [numthreads(8,8,1)]
  325. void main(uint GI : SV_GroupIndex) {
  326. SBinaryFPOp l = g_buf[GI];
  327. l.output1 = %s(l.input1,l.input2);
  328. g_buf[GI] = l;
  329. };''',
  330. "binary half": ''' struct SBinaryFPOp {
  331. half input1;
  332. half input2;
  333. half output1;
  334. half output2;
  335. };
  336. RWStructuredBuffer<SBinaryFPOp> g_buf : register(u0);
  337. [numthreads(8,8,1)]
  338. void main(uint GI : SV_GroupIndex) {
  339. SBinaryFPOp l = g_buf[GI];
  340. l.output1 = l.input1 %s l.input2;
  341. g_buf[GI] = l;
  342. };''',
  343. "binary half call": ''' struct SBinaryFPOp {
  344. half input1;
  345. half input2;
  346. half output1;
  347. half output2;
  348. };
  349. RWStructuredBuffer<SBinaryFPOp> g_buf : register(u0);
  350. [numthreads(8,8,1)]
  351. void main(uint GI : SV_GroupIndex) {
  352. SBinaryFPOp l = g_buf[GI];
  353. l.output1 = %s(l.input1,l.input2);
  354. g_buf[GI] = l;
  355. };''',
  356. "tertiary int": ''' struct STertiaryIntOp {
  357. int input1;
  358. int input2;
  359. int input3;
  360. int output;
  361. };
  362. RWStructuredBuffer<STertiaryIntOp> g_buf : register(u0);
  363. [numthreads(8,8,1)]
  364. void main(uint GI : SV_GroupIndex) {
  365. STertiaryIntOp l = g_buf[GI];
  366. l.output = %s(l.input1, l.input2, l.input3);
  367. g_buf[GI] = l;
  368. };''',
  369. "tertiary int16_t": ''' struct STertiaryInt16Op {
  370. int16_t input1;
  371. int16_t input2;
  372. int16_t input3;
  373. int16_t output;
  374. };
  375. RWStructuredBuffer<STertiaryInt16Op> g_buf : register(u0);
  376. [numthreads(8,8,1)]
  377. void main(uint GI : SV_GroupIndex) {
  378. STertiaryInt16Op l = g_buf[GI];
  379. l.output = %s(l.input1, l.input2, l.input3);
  380. g_buf[GI] = l;
  381. };''',
  382. "tertiary uint": ''' struct STertiaryUintOp {
  383. uint input1;
  384. uint input2;
  385. uint input3;
  386. uint output;
  387. };
  388. RWStructuredBuffer<STertiaryUintOp> g_buf : register(u0);
  389. [numthreads(8,8,1)]
  390. void main(uint GI : SV_GroupIndex) {
  391. STertiaryUintOp l = g_buf[GI];
  392. l.output = %s(l.input1, l.input2, l.input3);
  393. g_buf[GI] = l;
  394. };''',
  395. "tertiary uint16_t": ''' struct STertiaryUint16Op {
  396. uint16_t input1;
  397. uint16_t input2;
  398. uint16_t input3;
  399. uint16_t output;
  400. };
  401. RWStructuredBuffer<STertiaryUint16Op> g_buf : register(u0);
  402. [numthreads(8,8,1)]
  403. void main(uint GI : SV_GroupIndex) {
  404. STertiaryUint16Op l = g_buf[GI];
  405. l.output = %s(l.input1, l.input2, l.input3);
  406. g_buf[GI] = l;
  407. };''',
  408. "tertiary float": ''' struct STertiaryFloatOp {
  409. float input1;
  410. float input2;
  411. float input3;
  412. float output;
  413. };
  414. RWStructuredBuffer<STertiaryFloatOp> g_buf : register(u0);
  415. [numthreads(8,8,1)]
  416. void main(uint GI : SV_GroupIndex) {
  417. STertiaryFloatOp l = g_buf[GI];
  418. l.output = %s(l.input1, l.input2, l.input3);
  419. g_buf[GI] = l;
  420. };''',
  421. 'tertiary half': ''' struct STertiaryHalfOp {
  422. half input1;
  423. half input2;
  424. half input3;
  425. half output;
  426. };
  427. RWStructuredBuffer<STertiaryHalfOp> g_buf : register(u0);
  428. [numthreads(8,8,1)]
  429. void main(uint GI : SV_GroupIndex) {
  430. STertiaryHalfOp l = g_buf[GI];
  431. l.output = %s(l.input1, l.input2, l.input3);
  432. g_buf[GI] = l;
  433. };''',
  434. "wave op int" :''' struct PerThreadData {
  435. uint firstLaneId;
  436. uint laneIndex;
  437. int mask;
  438. int input;
  439. int output;
  440. };
  441. RWStructuredBuffer<PerThreadData> g_sb : register(u0);
  442. [numthreads(8,12,1)]
  443. void main(uint GI : SV_GroupIndex) {
  444. PerThreadData pts = g_sb[GI];
  445. pts.firstLaneId = WaveReadLaneFirst(GI);
  446. pts.laneIndex = WaveGetLaneIndex();
  447. if (pts.mask != 0) {
  448. pts.output = %s(pts.input);
  449. }
  450. else {
  451. pts.output = %s(pts.input);
  452. }
  453. g_sb[GI] = pts;
  454. };''',
  455. "wave op uint" :''' struct PerThreadData {
  456. uint firstLaneId;
  457. uint laneIndex;
  458. int mask;
  459. uint input;
  460. uint output;
  461. };
  462. RWStructuredBuffer<PerThreadData> g_sb : register(u0);
  463. [numthreads(8,12,1)]
  464. void main(uint GI : SV_GroupIndex) {
  465. PerThreadData pts = g_sb[GI];
  466. pts.firstLaneId = WaveReadLaneFirst(GI);
  467. pts.laneIndex = WaveGetLaneIndex();
  468. if (pts.mask != 0) {
  469. pts.output = %s(pts.input);
  470. }
  471. else {
  472. pts.output = %s(pts.input);
  473. }
  474. g_sb[GI] = pts;
  475. };''',
  476. "wave op int count": ''' struct PerThreadData {
  477. uint firstLaneId;
  478. uint laneIndex;
  479. int mask;
  480. int input;
  481. int output;
  482. };
  483. RWStructuredBuffer<PerThreadData> g_sb : register(u0);
  484. [numthreads(8,12,1)]
  485. void main(uint GI : SV_GroupIndex) {
  486. PerThreadData pts = g_sb[GI];
  487. pts.firstLaneId = WaveReadLaneFirst(GI);
  488. pts.laneIndex = WaveGetLaneIndex();
  489. if (pts.mask != 0) {
  490. pts.output = %s(pts.input > 3);
  491. }
  492. else {
  493. pts.output = %s(pts.input > 3);
  494. }
  495. g_sb[GI] = pts;
  496. };''',
  497. "wave op multi prefix int": ''' struct ThreadData {
  498. uint key;
  499. uint firstLaneId;
  500. uint laneId;
  501. uint mask;
  502. int value;
  503. int result;
  504. };
  505. RWStructuredBuffer<ThreadData> g_buffer : register(u0);
  506. [numthreads(8, 12, 1)]
  507. void main(uint id : SV_GroupIndex) {
  508. ThreadData data = g_buffer[id];
  509. data.firstLaneId = WaveReadLaneFirst(id);
  510. data.laneId = WaveGetLaneIndex();
  511. if (data.mask != 0) {
  512. uint4 mask = WaveMatch(data.key);
  513. data.result = %s(data.value, mask);
  514. } else {
  515. uint4 mask = WaveMatch(data.key);
  516. data.result = %s(data.value, mask);
  517. }
  518. g_buffer[id] = data;
  519. }''',
  520. "wave op multi prefix uint": ''' struct ThreadData {
  521. uint key;
  522. uint firstLaneId;
  523. uint laneId;
  524. uint mask;
  525. uint value;
  526. uint result;
  527. };
  528. RWStructuredBuffer<ThreadData> g_buffer : register(u0);
  529. [numthreads(8, 12, 1)]
  530. void main(uint id : SV_GroupIndex) {
  531. ThreadData data = g_buffer[id];
  532. data.firstLaneId = WaveReadLaneFirst(id);
  533. data.laneId = WaveGetLaneIndex();
  534. if (data.mask != 0) {
  535. uint4 mask = WaveMatch(data.key);
  536. data.result = %s(data.value, mask);
  537. } else {
  538. uint4 mask = WaveMatch(data.key);
  539. data.result = %s(data.value, mask);
  540. }
  541. g_buffer[id] = data;
  542. }'''
  543. }
  544. def get_shader_text(op_type, op_call):
  545. assert(op_type in g_shader_texts)
  546. if op_type.startswith("wave op"):
  547. return g_shader_texts[op_type] % (op_call, op_call)
  548. return g_shader_texts[op_type] % (op_call)
  549. g_denorm_tests = ["FAddDenormAny", "FAddDenormFTZ", "FAddDenormPreserve",
  550. "FSubDenormAny", "FSubDenormFTZ", "FSubDenormPreserve",
  551. "FMulDenormAny", "FMulDenormFTZ", "FMulDenormPreserve",
  552. "FDivDenormAny", "FDivDenormFTZ", "FDivDenormPreserve",
  553. "FMadDenormAny", "FMadDenormFTZ", "FMadDenormPreserve",
  554. "FAbsDenormAny", "FAbsDenormFTZ", "FAbsDenormPreserve",
  555. "FMinDenormAny", "FMinDenormFTZ", "FMinDenormPreserve",
  556. "FMaxDenormAny", "FMaxDenormFTZ", "FMaxDenormPreserve"]
  557. # This is a collection of test case for driver tests per instruction
  558. # Warning: For test cases, when you want to pass in signed 32-bit integer,
  559. # make sure to pass in negative numbers with decimal values instead of hexadecimal representation.
  560. # For some reason, TAEF is not handling them properly.
  561. # For half values, hex is preferable since the test framework will read string as float values
  562. # and convert them to float16, possibly losing precision. The test will read hex values as it is.
  563. def add_test_cases():
  564. nan = float('nan')
  565. p_inf = float('inf')
  566. n_inf = float('-inf')
  567. p_denorm = float('1e-38')
  568. n_denorm = float('-1e-38')
  569. # Unary Float
  570. add_test_case_float_half('Sin', ['Sin'], 'Epsilon', 0.0008, [[
  571. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-314.16',
  572. '314.16'
  573. ]], [[
  574. 'NaN', 'NaN', '-0', '-0', '0', '0', 'NaN', '-0.0007346401',
  575. '0.0007346401'
  576. ]], "unary float", "sin", half_validation_tolerance=0.003, half_inputs=[[
  577. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf',
  578. '0.6279297', '1.255859', '1.884766', '2.511719', '3.140625',
  579. '3.769531', '4.398438', '5.023438', '5.652344', '6.281250'
  580. ]], half_outputs=[[
  581. 'NaN', 'NaN', '-0', '-0', '0', '0', 'NaN',
  582. '0.58747065', '0.95081574', '0.95111507', '0.58904284', '0.00096773',
  583. '-0.58747751', '-0.95112079', '-0.95201313', '-0.58982444', '-0.00193545'
  584. ]])
  585. add_test_case_float_half('Cos', ['Cos'], 'Epsilon', 0.0008, [[
  586. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-314.16',
  587. '314.16'
  588. ]], [[
  589. 'NaN', 'NaN', '1.0', '1.0', '1.0', '1.0', 'NaN', '0.99999973015',
  590. '0.99999973015'
  591. ]], "unary float", "cos", half_validation_tolerance=0.003, half_inputs=[[
  592. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf',
  593. '0.6279297', '1.255859', '1.884766', '2.511719', '3.140625',
  594. '3.769531', '4.398438', '5.023438', '5.652344', '6.281250'
  595. ]], half_outputs=[[
  596. 'NaN', 'NaN', '1.0', '1.0', '1.0', '1.0', 'NaN',
  597. '0.80924553', '0.30975693', '-0.30883664', '-0.80810183', '-0.99999952',
  598. '-0.80924052', '-0.30881903', '0.30605716', '0.80753154', '0.99999809'
  599. ]])
  600. add_test_case_float_half('Tan', ['Tan'], 'Epsilon', 0.0008, [[
  601. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-314.16',
  602. '314.16'
  603. ]], [[
  604. 'NaN', 'NaN', '-0.0', '-0.0', '0.0', '0.0', 'NaN', '-0.000735',
  605. '0.000735'
  606. ]], "unary float", "tan", half_validation_tolerance=0.016, half_inputs=[[
  607. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf',
  608. '0.6279297', '1.255859', '1.884766', '2.511719', '3.140625',
  609. '3.769531', '4.398438', '5.652344', '6.281250'
  610. ]], half_outputs=[[
  611. 'NaN', 'NaN', '-0', '-0', '0', '0', 'NaN',
  612. '0.72594857', '3.06955433', '-3.07967043', '-0.72892153', '-0.00096773',
  613. '0.72596157', '3.07986474', '-0.7304042', '-0.00193546'
  614. ]])
  615. add_test_case_float_half('Hcos', ['Hcos'], 'Epsilon', 0.0008,
  616. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1', '-1']], [[
  617. 'NaN', 'Inf', '1.0', '1.0', '1.0', '1.0', 'Inf', '1.543081',
  618. '1.543081'
  619. ]], "unary float", "cosh", half_validation_type='ulp', half_validation_tolerance=2)
  620. add_test_case_float_half('Hsin', ['Hsin'], 'Epsilon', 0.0008,
  621. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1', '-1']], [[
  622. 'NaN', '-Inf', '0.0', '0.0', '0.0', '0.0', 'Inf', '1.175201',
  623. '-1.175201'
  624. ]], "unary float", "sinh")
  625. add_test_case_float_half('Htan', ['Htan'], 'Epsilon', 0.0008,
  626. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1', '-1']], [[
  627. 'NaN', '-1', '-0.0', '-0.0', '0.0', '0.0', '1', '0.761594',
  628. '-0.761594'
  629. ]], "unary float", "tanh", warp_version=16202)
  630. add_test_case_float_half('Acos', ['Acos'], 'Epsilon', 0.0008, [[
  631. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1', '-1', '1.5',
  632. '-1.5'
  633. ]], [[
  634. 'NaN', 'NaN', '1.570796', '1.570796', '1.570796', '1.570796', 'NaN',
  635. '0', '3.1415926', 'NaN', 'NaN'
  636. ]], "unary float", "acos")
  637. add_test_case_float_half('Asin', ['Asin'], 'Epsilon', 0.0008, [[
  638. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1', '-1', '1.5',
  639. '-1.5'
  640. ]], [[
  641. 'NaN', 'NaN', '0.0', '0.0', '0.0', '0.0', 'NaN', '1.570796',
  642. '-1.570796', 'NaN', 'NaN'
  643. ]], "unary float", "asin")
  644. add_test_case_float_half('Atan', ['Atan'], 'Epsilon', 0.0008,
  645. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1', '-1']], [[
  646. 'NaN', '-1.570796', '0.0', '0.0', '0.0', '0.0', '1.570796',
  647. '0.785398163', '-0.785398163'
  648. ]], "unary float", "atan", warp_version=16202)
  649. add_test_case_float_half('Exp', ['Exp'], 'Relative', 21,
  650. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1', '10']],
  651. [['NaN', '0', '1', '1', '1', '1', 'Inf', '0.367879441', '22026.46579']
  652. ], "unary float", "exp")
  653. add_test_case_float_half('Frc', ['Frc'], 'Epsilon', 0.0008, [[
  654. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1', '2.718280',
  655. '1000.599976', '-7.389'
  656. ]], [[
  657. 'NaN', 'NaN', '0', '0', '0', '0', 'NaN', '0', '0.718280', '0.599976',
  658. '0.611'
  659. ]], "unary float", "frac",
  660. half_inputs=[['NaN', '-Inf', '0x03FF', '-0', '0', 'Inf', '-1', '2.719',
  661. '1000.5', '0xC764']],
  662. half_outputs=[[
  663. 'NaN', 'NaN', '0x03FF', '0', '0', 'NaN', '0', '0.719', '0.5',
  664. '0x38E1']])
  665. add_test_case_float_half('Log', ['Log'], 'Relative', 21, [[
  666. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1',
  667. '2.718281828', '7.389056', '100'
  668. ]], [[
  669. 'NaN', 'NaN', '-Inf', '-Inf', '-Inf', '-Inf', 'Inf', 'NaN', '1.0',
  670. '1.99999998', '4.6051701'
  671. ]],"unary float", "log", half_inputs=[[
  672. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1',
  673. '2.719', '7.39', '100'
  674. ]], half_outputs=[[
  675. 'NaN', 'NaN', '-Inf', '-Inf', '-Inf', '-Inf', 'Inf', 'NaN', '1.0',
  676. '2', '4.605'
  677. ]])
  678. add_test_case_float_half('Sqrt', ['Sqrt'], 'ulp', 1, [[
  679. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1', '2',
  680. '16.0', '256.0'
  681. ]], [[
  682. 'NaN', 'NaN', '-0', '-0', '0', '0', 'Inf', 'NaN', '1.41421356237',
  683. '4.0', '16.0'
  684. ]], "unary float", "sqrt",
  685. half_inputs=[['NaN', '-Inf', '-denorm', '-0', '0', '0x03FF', 'Inf', '-1', '2', '16.0', '256.0']],
  686. half_outputs=[['NaN', 'NaN', 'NaN', '-0', '0', '0x1FFF', 'Inf', 'NaN', '1.41421', '4.0', '16.0']])
  687. add_test_case_float_half('Rsqrt', ['Rsqrt'], 'ulp', 1, [[
  688. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '-1', '16.0',
  689. '256.0', '65536.0'
  690. ]], [[
  691. 'NaN', 'NaN', '-Inf', '-Inf', 'Inf', 'Inf', '0', 'NaN', '0.25',
  692. '0.0625', '0.00390625'
  693. ]], "unary float", "rsqrt", half_inputs=[[
  694. 'NaN', '-Inf', '-denorm', '-0', '0', '0x03FF', 'Inf', '-1', '16.0',
  695. '256.0', '0x7bff'
  696. ]], half_outputs=[[
  697. 'NaN', 'NaN', 'NaN', '-Inf', 'Inf', '0x5801', '0', 'NaN', '0.25',
  698. '0.0625', '0x1C00'
  699. ]])
  700. add_test_case_float_half('Round_ne', ['Round_ne'], 'Epsilon', 0, [[
  701. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '10.0', '10.4',
  702. '10.5', '10.6', '11.5', '-10.0', '-10.4', '-10.5', '-10.6'
  703. ]], [[
  704. 'NaN', '-Inf', '-0', '-0', '0', '0', 'Inf', '10.0', '10.0', '10.0',
  705. '11.0', '12.0', '-10.0', '-10.0', '-10.0', '-11.0'
  706. ]], "unary float", "round")
  707. add_test_case_float_half('Round_ni', ['Round_ni'], 'Epsilon', 0, [[
  708. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '10.0', '10.4',
  709. '10.5', '10.6', '-10.0', '-10.4', '-10.5', '-10.6'
  710. ]], [[
  711. 'NaN', '-Inf', '-0', '-0', '0', '0', 'Inf', '10.0', '10.0', '10.0',
  712. '10.0', '-10.0', '-11.0', '-11.0', '-11.0'
  713. ]], "unary float", "floor", half_inputs=[[
  714. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '10.0', '10.4',
  715. '10.5', '10.6', '-10.0', '-10.4', '-10.5', '-10.6'
  716. ]], half_outputs=[[
  717. 'NaN', '-Inf', '-1', '-0', '0', '0', 'Inf', '10.0', '10.0', '10.0',
  718. '10.0', '-10.0', '-11.0', '-11.0', '-11.0'
  719. ]])
  720. add_test_case_float_half('Round_pi', ['Round_pi'], 'Epsilon', 0,
  721. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '10.0', '10.4',
  722. '10.5', '10.6', '-10.0', '-10.4', '-10.5', '-10.6']],
  723. [['NaN', '-Inf', '-0', '-0', '0', '0', 'Inf', '10.0', '11.0', '11.0',
  724. '11.0', '-10.0', '-10.0', '-10.0', '-10.0']], "unary float", "ceil",
  725. half_inputs=[['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '10.0', '10.4',
  726. '10.5', '10.6', '-10.0', '-10.4', '-10.5', '-10.6']],
  727. half_outputs=[['NaN', '-Inf', '-0', '-0', '0', '1', 'Inf', '10.0', '11.0', '11.0',
  728. '11.0', '-10.0', '-10.0', '-10.0', '-10.0']])
  729. add_test_case_float_half('Round_z', ['Round_z'], 'Epsilon', 0,
  730. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '10.0', '10.4',
  731. '10.5', '10.6', '-10.0', '-10.4', '-10.5', '-10.6']],
  732. [['NaN', '-Inf', '-0', '-0', '0', '0', 'Inf', '10.0', '10.0', '10.0',
  733. '10.0', '-10.0', '-10.0', '-10.0', '-10.0']], "unary float", "trunc")
  734. add_test_case_float_half('IsNaN', ['IsNaN'], 'Epsilon', 0,
  735. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0']
  736. ], [['1', '0', '0', '0', '0', '0', '0', '0', '0']], "unary float bool", "isnan")
  737. add_test_case_float_half('IsInf', ['IsInf'], 'Epsilon', 0,
  738. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0']
  739. ], [['0', '1', '0', '0', '0', '0', '1', '0', '0']], "unary float bool", "isinf")
  740. add_test_case_float_half('IsFinite', ['IsFinite'], 'Epsilon', 0,
  741. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0']
  742. ], [['0', '0', '1', '1', '1', '1', '0', '1', '1']], "unary float bool", "isfinite", warp_version=16202)
  743. add_test_case_float_half('FAbs', ['FAbs'], 'Epsilon', 0,
  744. [['NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0']
  745. ], [['NaN', 'Inf', 'denorm', '0', '0', 'denorm', 'Inf', '1', '1']], "unary float", "abs")
  746. # Binary Float
  747. add_test_case('FMin', ['FMin','FMax'], 'epsilon', 0, [[
  748. '-inf', '-inf', '-inf', '-inf', 'inf', 'inf', 'inf', 'inf', 'NaN',
  749. 'NaN', 'NaN', 'NaN', '1.0', '1.0', '-1.0', '-1.0', '1.0'
  750. ], [
  751. '-inf', 'inf', '1.0', 'NaN', '-inf', 'inf', '1.0', 'NaN', '-inf',
  752. 'inf', '1.0', 'NaN', '-inf', 'inf', '1.0', 'NaN', '-1.0'
  753. ]], [[
  754. '-inf', '-inf', '-inf', '-inf', '-inf', 'inf', '1.0', 'inf', '-inf',
  755. 'inf', '1.0', 'NaN', '-inf', '1.0', '-1.0', '-1.0', '-1.0'
  756. ], [
  757. '-inf', 'inf', '1.0', '-inf', 'inf', 'inf', 'inf', 'inf', '-inf',
  758. 'inf', '1.0', 'NaN', '1.0', 'inf', '1.0', '-1.0', '1.0'
  759. ]], 'cs_6_0', ''' struct SBinaryFPOp {
  760. float input1;
  761. float input2;
  762. float output1;
  763. float output2;
  764. };
  765. RWStructuredBuffer<SBinaryFPOp> g_buf : register(u0);
  766. [numthreads(8,8,1)]
  767. void main(uint GI : SV_GroupIndex) {
  768. SBinaryFPOp l = g_buf[GI];
  769. l.output1 = min(l.input1, l.input2);
  770. l.output2 = max(l.input1, l.input2);
  771. g_buf[GI] = l;
  772. };''')
  773. add_test_case('FMinHalf', ['FMin','FMax'], 'epsilon', 0, [[
  774. '-inf', '-inf', '-inf', '-inf', 'inf', 'inf', 'inf', 'inf', 'NaN',
  775. 'NaN', 'NaN', 'NaN', '1.0', '1.0', '-1.0', '-1.0', '1.0'
  776. ], [
  777. '-inf', 'inf', '1.0', 'NaN', '-inf', 'inf', '1.0', 'NaN', '-inf',
  778. 'inf', '1.0', 'NaN', '-inf', 'inf', '1.0', 'NaN', '-1.0'
  779. ]], [[
  780. '-inf', '-inf', '-inf', '-inf', '-inf', 'inf', '1.0', 'inf', '-inf',
  781. 'inf', '1.0', 'NaN', '-inf', '1.0', '-1.0', '-1.0', '-1.0'
  782. ], [
  783. '-inf', 'inf', '1.0', '-inf', 'inf', 'inf', 'inf', 'inf', '-inf',
  784. 'inf', '1.0', 'NaN', '1.0', 'inf', '1.0', '-1.0', '1.0'
  785. ]], 'cs_6_2', ''' struct SBinaryHalfOp {
  786. half input1;
  787. half input2;
  788. half output1;
  789. half output2;
  790. };
  791. RWStructuredBuffer<SBinaryHalfOp> g_buf : register(u0);
  792. [numthreads(8,8,1)]
  793. void main(uint GI : SV_GroupIndex) {
  794. SBinaryHalfOp l = g_buf[GI];
  795. l.output1 = min(l.input1, l.input2);
  796. l.output2 = max(l.input1, l.input2);
  797. g_buf[GI] = l;
  798. };''', shader_arguments="-enable-16bit-types")
  799. add_test_case_float_half('FAdd', ['FAdd'], 'ulp', 1, [['-1.0', '1.0', '32.5', '1.0000001000'],['4', '5.5', '334.7', '0.5000001000']], [['3.0', '6.5', '367.2', '1.5000002000']],
  800. "binary float", "+")
  801. add_test_case_float_half('FSub', ['FSub'], 'ulp', 1, [['-1.0', '5.5', '32.5', '1.0000001000'],['4', '1.25', '334.7', '0.5000001000']], [['-5', '4.25', '-302.2', '0.5000']],
  802. "binary float", "-")
  803. add_test_case_float_half('FMul', ['FMul'], 'ulp', 1, [['-1.0', '5.5', '1.0000001'],['4', '1.25', '2.0']], [['-4.0', '6.875', '2.0000002']],
  804. "binary float", "*")
  805. add_test_case_float_half('FDiv', ['FDiv'], 'ulp', 1, [['-1.0', '5.5', '1.0000001'],['4', '1.25', '2.0']], [['-0.25', '4.4', '0.50000006']],
  806. "binary float", "/")
  807. # Denorm Binary Float
  808. add_test_case_denorm('FAddDenorm', ['FAdd'], 'ulp', 1,
  809. [['0x007E0000', '0x00200000', '0x007E0000', '0x007E0000'],['0x007E0000','0x00200000', '0x807E0000', '0x800E0000']],
  810. [['0','0', '0', '0']],
  811. [['0x00FC0000','0x00400000', '0', '0x00700000']],
  812. 'cs_6_2', get_shader_text("binary float", "+"))
  813. add_test_case_denorm('FSubDenorm', ['FSub'], 'ulp', 1,
  814. [['0x007E0000', '0x007F0000', '0x00FF0000', '0x007A0000'],['0x007E0000', '0x807F0000', '0x00800000', '0']],
  815. [['0x0', '0', '0', '0']],
  816. [['0x0', '0x00FE0000', '0x007F0000', '0x007A0000']],
  817. 'cs_6_2', get_shader_text("binary float", "-"))
  818. add_test_case_denorm('FDivDenorm', ['FDiv'], 'ulp', 1,
  819. [['0x007F0000', '0x807F0000', '0x20000000', '0x00800000'],['1', '4', '0x607F0000', '0x40000000']],
  820. [['0', '0', '0', '0']],
  821. [['0x007F0000', '0x801FC000', '0x00101010', '0x00400000']],
  822. 'cs_6_2', get_shader_text("binary float", "/"))
  823. add_test_case_denorm('FMulDenorm', ['FMul'], 'ulp', 1,
  824. [['0x00000300', '0x007F0000', '0x007F0000', '0x001E0000', '0x00000300'],['128', '1', '0x007F0000', '20', '0x78000000']],
  825. [['0', '0', '0', '0', '0']],
  826. [['0x00018000','0x007F0000', '0', '0x01960000', '0x32400000']],
  827. 'cs_6_2', get_shader_text("binary float", "*"))
  828. # Tertiary Float
  829. add_test_case_float_half('FMad', ['FMad'], 'ulp', 1, [[
  830. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0',
  831. '0', '1', '1.5'
  832. ], [
  833. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0',
  834. '0', '1', '10'
  835. ], [
  836. 'NaN', '-Inf', '-denorm', '-0', '0', 'denorm', 'Inf', '1.0', '-1.0',
  837. '1', '0', '-5.5'
  838. ]], [['NaN', 'NaN', '0', '0', '0', '0', 'Inf', '2', '0', '1', '1', '9.5']],
  839. "tertiary float", "mad",
  840. half_inputs=[[
  841. 'NaN', '-Inf', '0x03FF', '-0', '0', 'Inf', '1.0', '-1.0',
  842. '0', '1', '1.5'
  843. ], [
  844. 'NaN', '-Inf', '1', '-0', '0', 'Inf', '1.0', '-1.0',
  845. '0', '1', '10'
  846. ], [
  847. 'NaN', '-Inf', '0x03FF', '-0', '0', 'Inf', '1.0', '-1.0',
  848. '1', '0', '-5.5'
  849. ]],
  850. half_outputs=[['NaN', 'NaN', '0x07FE', '0', '0', 'Inf', '2', '0', '1', '1', '9.5']])
  851. # Denorm Tertiary Float
  852. add_test_case_denorm('FMadDenorm', ['FMad'], 'ulp', 1,
  853. [['0x80780000', '0x80780000', '0x00780000'],
  854. ['1', '2', '2'],
  855. ['0x80780000', '0x00800000', '0x00800000']],
  856. [['0', '0x00800000', '0x00800000']],
  857. [['0x80F00000', '0x80700000', '0x01380000']],
  858. 'cs_6_2', get_shader_text("tertiary float", "mad"))
  859. # Unary Int
  860. int8_min, int8_max = '-128', '127'
  861. int16_min, int16_max = '-32768', '32767'
  862. int32_min, int32_max = '-2147483648', '2147483647'
  863. uint16_max = '65535'
  864. uint32_max = '4294967295'
  865. add_test_case_int('Bfrev', ['Bfrev'], 'Epsilon', 0, [[
  866. int32_min, '-65536', '-8', '-1', '0', '1', '8', '65536',
  867. int32_max
  868. ]], [[
  869. '1', '65535', '536870911', '-1', '0', int32_min, '268435456',
  870. '32768', '-2'
  871. ]], "unary int", "reversebits",
  872. input_16=[[int16_min, '-256', '-8', '-1', '0', '1', '8', '256', int16_max]],
  873. output_16=[['1', '255', '8191', '-1', '0', int16_min, '4096', '128', '-2']])
  874. # firstbit_shi (s for signed) returns the
  875. # first 0 from the MSB if the number is negative,
  876. # else the first 1 from the MSB.
  877. # all the variants of the instruction return ~0 if no match was found
  878. add_test_case_int('FirstbitSHi', ['FirstbitSHi'], 'Epsilon', 0, [[
  879. int32_min, '-65536', '-8', '-1', '0', '1', '8', '65536',
  880. int32_max
  881. ]], [['30', '15', '2', '-1', '-1', '0', '3', '16', '30']],
  882. "unary int", "firstbithigh",
  883. input_16=[[int16_min, '-256', '-8', '-1', '0', '1', '8', '256', int16_max]],
  884. output_16=[['14', '7', '2', '-1', '-1', '0', '3', '8', '14']])
  885. add_test_case_int('FirstbitLo', ['FirstbitLo'], 'Epsilon', 0, [[
  886. int32_min, '-65536', '-8', '-1', '0', '1', '8', '65536',
  887. int32_max
  888. ]], [['31', '16', '3', '0', '-1', '0', '3', '16', '0']],
  889. "unary int", "firstbitlow",
  890. input_16=[[int16_min, '-256', '-8', '-1', '0', '1', '8', '256', int16_max]],
  891. output_16=[['15', '8', '3', '0', '-1', '0', '3', '8', '0']])
  892. # TODO: there is a known bug in countbits when passing in immediate values.
  893. # Fix this later
  894. add_test_case('Countbits', ['Countbits'], 'Epsilon', 0, [[
  895. int32_min, '-65536', '-8', '-1', '0', '1', '8', '65536',
  896. int32_max
  897. ]], [['1', '16', '29', '32', '0', '1', '1', '1', '31']],
  898. "cs_6_0", get_shader_text("unary int", "countbits"))
  899. # Unary uint
  900. add_test_case_int('FirstbitHi', ['FirstbitHi'], 'Epsilon', 0,
  901. [['0', '1', '8', '65536', int32_max, uint32_max]],
  902. [['-1', '0', '3', '16', '30', '31']],
  903. "unary uint", "firstbithigh",
  904. input_16=[['0', '1', '8', uint16_max]],
  905. output_16=[['-1', '0', '3', '15']])
  906. # Binary Int
  907. add_test_case_int('IAdd', ['Add'], 'Epsilon', 0,
  908. [[int32_min, '-10', '0', '0', '10', int32_max, '486'],
  909. ['0', '10', '-10', '10', '10', '0', '54238']],
  910. [[int32_min, '0', '-10', '10', '20', int32_max, '54724']],
  911. "binary int", "+",
  912. input_16=[[int16_min, '-10', '0', '0', '10', int16_max],
  913. ['0', '10', '-3114', '272', '15', '0']],
  914. output_16=[[int16_min, '0', '-3114', '272', '25', int16_max]])
  915. add_test_case_int('ISub', ['Sub'], 'Epsilon', 0,
  916. [[int32_min, '-10', '0', '0', '10', int32_max, '486'],
  917. ['0', '10', '-10', '10', '10', '0', '54238']],
  918. [[int32_min, '-20', '10', '-10', '0', int32_max, '-53752']],
  919. "binary int", "-",
  920. input_16=[[int16_min, '-10', '0', '0', '10', int16_max],
  921. ['0', '10', '-3114', '272', '15', '0']],
  922. output_16=[[int16_min, '-20', '3114', '-272', '-5', int16_max]])
  923. add_test_case_int('IMax', ['IMax'], 'Epsilon', 0,
  924. [[int32_min, '-10', '0', '0', '10', int32_max],
  925. ['0', '10', '-10', '10', '10', '0']],
  926. [['0', '10', '0', '10', '10', int32_max]],
  927. "binary int call", "max",
  928. input_16=[[int16_min, '-10', '0', '0', '10', int16_max],
  929. ['0', '10', '-3114', '272', '15', '0']],
  930. output_16=[['0', '10', '0', '272', '15', int16_max]])
  931. add_test_case_int('IMin', ['IMin'], 'Epsilon', 0,
  932. [[int32_min, '-10', '0', '0', '10', int32_max],
  933. ['0', '10', '-10', '10', '10', '0']],
  934. [[int32_min, '-10', '-10', '0', '10', '0']],
  935. "binary int call", "min",
  936. input_16=[[int16_min, '-10', '0', '0', '10', int16_max],
  937. ['0', '10', '-3114', '272', '15', '0']],
  938. output_16=[[int16_min, '-10', '-3114', '0', '10', '0']])
  939. add_test_case_int('IMul', ['Mul'], 'Epsilon', 0, [
  940. [ int32_min, '-10', '-1', '0', '1', '10', '10000', int32_max, int32_max ],
  941. ['-10', '-10', '10', '0', '256', '4', '10001', '0', int32_max]],
  942. [['0', '100', '-10', '0', '256', '40', '100010000', '0', '1']],
  943. "binary int", "*",
  944. input_16=[[ int16_min, '-10', '-1', '0', '1', '10', int16_max],
  945. ['-10', '-10', '10', '0', '256', '4', '0']],
  946. output_16=[['0', '100', '-10', '0', '256', '40', '0']])
  947. add_test_case('IDiv', ['SDiv', 'SRem'], 'Epsilon', 0,
  948. [['1', '1', '10', '10000', int32_max, int32_max, '-1'],
  949. ['1', '256', '4', '10001', '2', int32_max, '1']],
  950. [['1', '0', '2', '0', '1073741823', '1', '-1'],
  951. ['0', '1', '2', '10000', '1', '0', '0']], "cs_6_0",
  952. ''' struct SBinaryIntOp {
  953. int input1;
  954. int input2;
  955. int output1;
  956. int output2;
  957. };
  958. RWStructuredBuffer<SBinaryIntOp> g_buf : register(u0);
  959. [numthreads(8,8,1)]
  960. void main(uint GI : SV_GroupIndex) {
  961. SBinaryIntOp l = g_buf[GI];
  962. l.output1 = l.input1 / l.input2;
  963. l.output2 = l.input1 % l.input2;
  964. g_buf[GI] = l;
  965. };''')
  966. add_test_case_int('Shl', ['Shl'], 'Epsilon', 0,
  967. [['1', '1', '0x1010', '0xa', '-1', '0x12341234', '-1'],
  968. ['0', '259', '4', '2', '0', '15', '3']],
  969. [['0x1', '0x8', '0x10100', '0x28', '-1','0x091a0000', '-8']],
  970. "binary int", "<<",
  971. input_16=[['1', '1', '0x0101', '0xa', '-1', '0x1234', '-1'],
  972. ['0', '259', '4', '2', '0', '13', '3']],
  973. output_16=[['0x1', '0x8', '0x1010', '0x28', '-1','0x8000', '-8']])
  974. add_test_case_int("LShr", ['LShr'], 'Epsilon', 0,
  975. [['1', '1', '0xffff', '0x7fffffff', '0x70001234', '0x12340ab3', '0x7fffffff'],
  976. ['0', '1', '4', '30', '15', '16', '1']],
  977. [['1', '0', '0xfff', '1', '0xe000', '0x1234', '0x3fffffff']],
  978. "binary int", ">>",
  979. input_16=[['1', '1', '0x7fff', '0x7fff'],
  980. ['0', '1', '4', '14']],
  981. output_16=[['1', '0', '0x07ff', '1']]
  982. )
  983. add_test_case_int("And", ['And'], 'Epsilon', 0,
  984. [['0x1', '0x01', '0x7fff0000', '0x33333333', '0x137f', '0x12345678', '0xa341', '-1'],
  985. ['0x1', '0xf0', '0x0000ffff', '0x22222222', '0xec80', '-1', '0x3471', '-1']],
  986. [['0x1', '0x00', '0x0', '0x22222222', '0x0', '0x12345678', '0x2041', '-1']],
  987. "binary int", "&",
  988. input_16=[['0x1', '0x01', '0x7fff', '0x3333', '0x137f', '0x1234', '0xa341', '-1'],
  989. ['0x1', '0xf0', '0x0000', '0x2222', '0xec80', '-1', '0x3471', '-1']],
  990. output_16=[['0x1', '0x00', '0x0', '0x2222', '0x0', '0x1234', '0x2041', '-1']],
  991. )
  992. add_test_case_int("Or", ['Or'], 'Epsilon', 0,
  993. [['0x1', '0x01', '0x7fff0000', '0x11111111', '0x137f', '0x0', '0x12345678', '0xa341', '-1'],
  994. ['0x1', '0xf0', '0x0000ffff', '0x22222222', '0xec80', '0x0', '0x00000000', '0x3471', '-1']],
  995. [['0x1', '0xf1', '0x7fffffff', '0x33333333', '0xffff', '0x0', '0x12345678', '0xb771', '-1']],
  996. "binary int", "|",
  997. input_16=[['0x1', '0x01', '0x7fff', '0x3333', '0x137f', '0x1234', '0xa341', '-1'],
  998. ['0x1', '0xf0', '0x0000', '0x2222', '0xec80', '0xffff', '0x3471', '-1']],
  999. output_16=[['0x1', '0xf1', '0x7fff', '0x3333', '0xffff', '0xffff', '0xb771', '-1']],
  1000. )
  1001. add_test_case_int("Xor", ['Xor'], 'Epsilon', 0,
  1002. [['0x1', '0x01', '0x7fff0000', '0x11111111', '0x137f', '0x0', '0x12345678', '0xa341', '-1'],
  1003. ['0x1', '0xf0', '0x0000ffff', '0x22222222', '0xec80', '0x0', '0x00000000', '0x3471', '-1']],
  1004. [['0x0', '0xf1', '0x7fffffff', '0x33333333', '0xffff', '0x0', '0x12345678', '0x9730', '0x00000000']],
  1005. "binary int", "^",
  1006. input_16=[['0x1', '0x01', '0x7fff', '0x1111', '0x137f', '0x0', '0x1234', '0xa341', '-1'],
  1007. ['0x1', '0xf0', '0x0000', '0x2222', '0xec80', '0x0', '0x0000', '0x3471', '-1']],
  1008. output_16=[['0x0', '0xf1', '0x7fff', '0x3333', '0xffff', '0x0', '0x1234', '0x9730', '0x0000']],
  1009. )
  1010. # Binary Uint
  1011. add_test_case_int('UAdd', ['Add'], 'Epsilon', 0,
  1012. [['2147483648', '4294967285', '0', '0', '10', int32_max, '486'],
  1013. ['0', '10', '0', '10', '10', '0', '54238']],
  1014. [['2147483648', uint32_max, '0', '10', '20', int32_max, '54724']],
  1015. "binary uint", "+",
  1016. input_16=[['323', '0xfff5', '0', '0', '10', uint16_max, '486'],
  1017. ['0', '10', '0', '10', '10', '0', '334']],
  1018. output_16=[['323', uint16_max, '0', '10', '20', uint16_max, '820']])
  1019. add_test_case_int('USub', ['Sub'], 'Epsilon', 0,
  1020. [['2147483648', uint32_max, '0', '0', '30', int32_max, '54724'],
  1021. ['0', '10', '0', '10', '10', '0', '54238']],
  1022. [['2147483648', '4294967285', '0', '4294967286', '20', int32_max, '486']],
  1023. "binary uint", "-",
  1024. input_16=[['323', uint16_max, '0', '0', '10', uint16_max, '486'],
  1025. ['0', '10', '0', '10', '10', '0', '334']],
  1026. output_16=[['323', '0xfff5', '0', '-10', '0', uint16_max, '152']])
  1027. add_test_case_int('UMax', ['UMax'], 'Epsilon', 0,
  1028. [['0', '0', '10', '10000', int32_max, uint32_max],
  1029. ['0', '256', '4', '10001', '0', uint32_max]],
  1030. [['0', '256', '10', '10001', int32_max, uint32_max]],
  1031. "binary uint call", "max",
  1032. input_16=[['0', '0', '10', '10000', int16_max, uint16_max],
  1033. ['0', '256', '4', '10001', '0', uint16_max]],
  1034. output_16=[['0', '256', '10', '10001', int16_max, uint16_max]])
  1035. add_test_case_int('UMin', ['UMin'], 'Epsilon', 0,
  1036. [['0', '0', '10', '10000', int32_max, uint32_max],
  1037. ['0', '256', '4', '10001', '0', uint32_max]],
  1038. [['0', '0', '4', '10000', '0', uint32_max]],
  1039. "binary uint call", "min",
  1040. input_16=[['0', '0', '10', '10000', int16_max, uint16_max],
  1041. ['0', '256', '4', '10001', '0', uint16_max]],
  1042. output_16=[['0', '0', '4', '10000', '0', uint16_max]])
  1043. add_test_case_int('UMul', ['Mul'], 'Epsilon', 0,
  1044. [['0', '1', '10', '10000', int32_max],
  1045. ['0', '256', '4', '10001', '0']],
  1046. [['0', '256', '40', '100010000', '0']],
  1047. "binary uint", "*",
  1048. input_16=[['0', '0', '10', '100', int16_max],
  1049. ['0', '256', '4', '101', '0']],
  1050. output_16=[['0', '0', '40', '10100', '0']])
  1051. add_test_case('UDiv', ['UDiv', 'URem'], 'Epsilon', 0,
  1052. [['1', '1', '10', '10000', int32_max, int32_max, '0xffffffff'],
  1053. ['0', '256', '4', '10001', '0', int32_max, '1']],
  1054. [['0xffffffff', '0', '2', '0', '0xffffffff', '1', '0xffffffff'],
  1055. ['0xffffffff', '1', '2', '10000', '0xffffffff', '0', '0']], 'cs_6_0',
  1056. ''' struct SBinaryUintOp {
  1057. uint input1;
  1058. uint input2;
  1059. uint output1;
  1060. uint output2;
  1061. };
  1062. RWStructuredBuffer<SBinaryUintOp> g_buf : register(u0);
  1063. [numthreads(8,8,1)]
  1064. void main(uint GI : SV_GroupIndex) {
  1065. SBinaryUintOp l = g_buf[GI];
  1066. l.output1 = l.input1 / l.input2;
  1067. l.output2 = l.input1 % l.input2;
  1068. g_buf[GI] = l;
  1069. };''')
  1070. add_test_case('UAddc', ['UAddc'], 'Epsilon', 0,
  1071. [['1', '1', '10000', '0x80000000', '0x7fffffff', '0xffffffff'],
  1072. ['0', '256', '10001', '1', '0x7fffffff', '0x7fffffff']],
  1073. [['2', '2', '20000', '0', '0xfffffffe', '0xfffffffe'],
  1074. ['0', '512', '20002', '3', '0xfffffffe', '0xffffffff']], 'cs_6_0',
  1075. ''' struct SBinaryUintOp {
  1076. uint input1;
  1077. uint input2;
  1078. uint output1;
  1079. uint output2;
  1080. };
  1081. RWStructuredBuffer<SBinaryUintOp> g_buf : register(u0);
  1082. [numthreads(8,8,1)]
  1083. void main(uint GI : SV_GroupIndex) {
  1084. SBinaryUintOp l = g_buf[GI];
  1085. uint2 x = uint2(l.input1, l.input2);
  1086. uint2 y = AddUint64(x, x);
  1087. l.output1 = y.x;
  1088. l.output2 = y.y;
  1089. g_buf[GI] = l;
  1090. };''')
  1091. # Tertiary Int
  1092. add_test_case_int('IMad', ['IMad'], 'epsilon', 0, [[
  1093. '-2147483647', '-256', '-1', '0', '1', '2', '16', int32_max, '1',
  1094. '-1', '1', '10'
  1095. ], ['1', '-256', '-1', '0', '1', '3', '16', '0', '1', '-1', '10', '100'], [
  1096. '0', '0', '0', '0', '1', '3', '1', '255', '2147483646', '-2147483647',
  1097. '-10', '-2000'
  1098. ]], [[
  1099. '-2147483647', '65536', '1', '0', '2', '9', '257', '255', int32_max,
  1100. '-2147483646', '0', '-1000'
  1101. ]], "tertiary int", "mad",
  1102. input_16=[[int16_min, '-256', '-1', '0', '1', '2', '16', int16_max],
  1103. ['1','8','-1', '0', '1', '3', '16','1'],
  1104. ['0', '0', '1', '3', '250', '-30', int16_min, '-50']],
  1105. output_16=[[int16_min, '-2048', '2', '3', '251', '-24', '-32512', '32717']]
  1106. )
  1107. add_test_case_int('UMad', ['UMad'], 'epsilon', 0,
  1108. [['0', '1', '2', '16', int32_max, '0', '10'], [
  1109. '0', '1', '2', '16', '1', '0', '10'
  1110. ], ['0', '0', '1', '15', '0', '10', '10']],
  1111. [['0', '1', '5', '271', int32_max, '10', '110']],
  1112. "tertiary uint", "mad",
  1113. input_16=[['0', '1', '2', '16', int16_max, '0', '10'], [
  1114. '0', '1', '2', '16', '1', '0', '10'
  1115. ], ['0', '0', '1', '15', '0', '10', '10']],
  1116. output_16=[['0', '1', '5', '271', int16_max, '10', '110']],
  1117. )
  1118. # Dot
  1119. add_test_case('Dot', ['Dot2', 'Dot3', 'Dot4'], 'epsilon', 0.008, [[
  1120. 'NaN,NaN,NaN,NaN', '-Inf,-Inf,-Inf,-Inf',
  1121. '-denorm,-denorm,-denorm,-denorm', '-0,-0,-0,-0', '0,0,0,0',
  1122. 'denorm,denorm,denorm,denorm', 'Inf,Inf,Inf,Inf', '1,1,1,1',
  1123. '-10,0,0,10', 'Inf,Inf,Inf,-Inf'
  1124. ], [
  1125. 'NaN,NaN,NaN,NaN', '-Inf,-Inf,-Inf,-Inf',
  1126. '-denorm,-denorm,-denorm,-denorm', '-0,-0,-0,-0', '0,0,0,0',
  1127. 'denorm,denorm,denorm,denorm', 'Inf,Inf,Inf,Inf', '1,1,1,1',
  1128. '10,0,0,10', 'Inf,Inf,Inf,Inf'
  1129. ]], [
  1130. [nan, p_inf, 0, 0, 0, 0, p_inf, 2, -100, p_inf],
  1131. [nan, p_inf, 0, 0, 0, 0, p_inf, 3, -100, p_inf],
  1132. [nan, p_inf, 0, 0, 0, 0, p_inf, 4, 0, nan],
  1133. ], 'cs_6_0', ''' struct SDotOp {
  1134. float4 input1;
  1135. float4 input2;
  1136. float o_dot2;
  1137. float o_dot3;
  1138. float o_dot4;
  1139. };
  1140. RWStructuredBuffer<SDotOp> g_buf : register(u0);
  1141. [numthreads(8,8,1)]
  1142. void main(uint GI : SV_GroupIndex) {
  1143. SDotOp l = g_buf[GI];
  1144. l.o_dot2 = dot(l.input1.xy, l.input2.xy);
  1145. l.o_dot3 = dot(l.input1.xyz, l.input2.xyz);
  1146. l.o_dot4 = dot(l.input1.xyzw, l.input2.xyzw);
  1147. g_buf[GI] = l;
  1148. };''')
  1149. # Dot2AddHalf
  1150. add_test_case('Dot2AddHalf', ['Dot2AddHalf'], 'epsilon', 0.008, [[
  1151. '1,2', '1,-2', '1,2', '-1,2', '1,2', '-1,2', '1,2', '-1,-2',
  1152. '65504,1', '-65504,1', '1,65504', '1,-65504', 'inf,inf',
  1153. 'denorm,denorm', '-denorm,-denorm', 'nan,nan'
  1154. ], [
  1155. '3,4', '-3,4', '3,4', '3,-4', '3,4', '-3,4', '3,4', '-3,-4',
  1156. '1,65504', '1,-65504', '65504,1', '-65504,1', 'inf,inf',
  1157. 'denorm,denorm', '-denorm,-denorm', 'nan,nan'
  1158. ], [
  1159. '0', '0', '10', '10', '-5', '-5', '-30', '-30', '0', '0',
  1160. '10000000', '-10000000', 'inf', 'denorm', '-denorm',
  1161. 'nan'
  1162. ]], [
  1163. [11, -11, 21, -1, 6, 6, -19, -19, 131008, -131008, 10131008,
  1164. -10131008, p_inf, 0, 0, nan],
  1165. ], 'cs_6_4', ''' struct SDot2AddHalfOp {
  1166. half2 input1;
  1167. half2 input2;
  1168. float acc;
  1169. float result;
  1170. };
  1171. RWStructuredBuffer<SDot2AddHalfOp> g_buf : register(u0);
  1172. [numthreads(8,8,1)]
  1173. void main(uint GI : SV_GroupIndex) {
  1174. SDot2AddHalfOp l = g_buf[GI];
  1175. l.result = dot2add(l.input1, l.input2, l.acc);
  1176. g_buf[GI] = l;
  1177. };''', shader_arguments='-enable-16bit-types')
  1178. # Dot4AddI8Packed
  1179. add_test_case('Dot4AddI8Packed', ['Dot4AddI8Packed'], 'epsilon', 0, [[
  1180. '0x00000102', '0x00000102', '0x00000102', '0x00000102',
  1181. '0XFFFFFFFF', '0x80808080', '0x80808080', '0x807F807F',
  1182. '0x7F7F7F7F', '0x80808080'
  1183. ], [
  1184. '0x00000304', '0x00000304', '0x00000304', '0x00000304',
  1185. '0xFFFFFFFF', '0x01010101', '0x7F7F7F7F', '0x807F807F',
  1186. '0x7F7F7F7F', '0x80808080'
  1187. ], [
  1188. '0', '10', '-5', '-30', '0', '0', '0', '0', '0', '0'
  1189. ]], [
  1190. [11, 21, 6, -19, 4, -512, -65024, 65026, 64516, 65536],
  1191. ], 'cs_6_4', ''' struct SDot4AddI8PackedOp {
  1192. dword input1;
  1193. dword input2;
  1194. int acc;
  1195. int result;
  1196. };
  1197. RWStructuredBuffer<SDot4AddI8PackedOp> g_buf : register(u0);
  1198. [numthreads(8,8,1)]
  1199. void main(uint GI : SV_GroupIndex) {
  1200. SDot4AddI8PackedOp l = g_buf[GI];
  1201. l.result = dot4add_i8packed(l.input1, l.input2, l.acc);
  1202. g_buf[GI] = l;
  1203. };''')
  1204. # Dot4AddU8Packed
  1205. add_test_case('Dot4AddU8Packed', ['Dot4AddU8Packed'], 'epsilon', 0, [[
  1206. '0x00000102', '0x00000102', '0x01234567', '0xFFFFFFFF',
  1207. '0xFFFFFFFF'
  1208. ], [
  1209. '0x00000304', '0x00000304', '0x23456789', '0xFFFFFFFF',
  1210. '0xFFFFFFFF'
  1211. ], [
  1212. '0', '10', '10000', '0', '3000000000'
  1213. ]], [
  1214. [11, 21, 33668, 260100, 3000260100],
  1215. ], 'cs_6_4', ''' struct SDot4AddU8PackedOp {
  1216. dword input1;
  1217. dword input2;
  1218. dword acc;
  1219. dword result;
  1220. };
  1221. RWStructuredBuffer<SDot4AddU8PackedOp> g_buf : register(u0);
  1222. [numthreads(8,8,1)]
  1223. void main(uint GI : SV_GroupIndex) {
  1224. SDot4AddU8PackedOp l = g_buf[GI];
  1225. l.result = dot4add_u8packed(l.input1, l.input2, l.acc);
  1226. g_buf[GI] = l;
  1227. };''')
  1228. # Quaternary
  1229. # Msad4 intrinsic calls both Bfi and Msad. Currently this is the only way to call bfi instruction from HLSL
  1230. add_test_case('Bfi', ['Bfi', 'Msad'], 'epsilon', 0,
  1231. [["0xA100B2C3", "0x00000000", "0xFFFF01C1", "0xFFFFFFFF"], [
  1232. "0xD7B0C372, 0x4F57C2A3", "0xFFFFFFFF, 0x00000000",
  1233. "0x38A03AEF, 0x38194DA3", "0xFFFFFFFF, 0x00000000"
  1234. ], ["1,2,3,4", "1,2,3,4", "0,0,0,0", "10,10,10,10"]],
  1235. [['153,6,92,113', '1,2,3,4', '397,585,358,707', '10,265,520,775']],
  1236. 'cs_6_0', ''' struct SMsad4 {
  1237. uint ref;
  1238. uint2 source;
  1239. uint4 accum;
  1240. uint4 result;
  1241. };
  1242. RWStructuredBuffer<SMsad4> g_buf : register(u0);
  1243. [numthreads(8,8,1)]
  1244. void main(uint GI : SV_GroupIndex) {
  1245. SMsad4 l = g_buf[GI];
  1246. l.result = msad4(l.ref, l.source, l.accum);
  1247. g_buf[GI] = l;
  1248. };''')
  1249. # Wave Active Tests
  1250. add_test_case('WaveActiveSum', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1251. [['1', '2', '3', '4'], ['0'], ['2', '4', '8', '-64']], [],
  1252. 'cs_6_0', get_shader_text("wave op int", "WaveActiveSum"))
  1253. add_test_case('WaveActiveProduct', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1254. [['1', '2', '3', '4'], ['0'], ['1', '2', '4', '-64']], [],
  1255. 'cs_6_0', get_shader_text("wave op int", "WaveActiveProduct"))
  1256. add_test_case('WaveActiveCountBits', ['WaveAllBitCount', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1257. [['1', '2', '3', '4'], ['0'], ['1', '10', '-4', '-64'],
  1258. ['-100', '-1000', '300']], [], 'cs_6_0',
  1259. get_shader_text("wave op int count", "WaveActiveCountBits"))
  1260. add_test_case('WaveActiveMax', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1261. [['1', '2', '3', '4'], ['0'], ['1', '10', '-4', '-64'],
  1262. ['-100', '-1000', '300']], [], 'cs_6_0',
  1263. get_shader_text("wave op int", "WaveActiveMax"))
  1264. add_test_case('WaveActiveMin', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1265. [['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'], ['0'],
  1266. ['1', '10', '-4', '-64'], ['-100', '-1000', '300']], [],
  1267. 'cs_6_0', get_shader_text("wave op int", "WaveActiveMin"))
  1268. add_test_case('WaveActiveAllEqual', ['WaveActiveAllEqual'], 'Epsilon', 0,
  1269. [['1', '2', '3', '4', '1', '1', '1', '1'], ['3'], ['-10']],
  1270. [], 'cs_6_0', get_shader_text("wave op int", "WaveActiveAllEqual"))
  1271. add_test_case('WaveActiveAnyTrue', ['WaveAnyTrue', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1272. [['1', '0', '1', '0', '1'], ['1'], ['0']], [], 'cs_6_0',
  1273. get_shader_text("wave op int", "WaveActiveAnyTrue"))
  1274. add_test_case('WaveActiveAllTrue', ['WaveAllTrue', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1275. [['1', '0', '1', '0', '1'], ['1'], ['1']], [], 'cs_6_0',
  1276. get_shader_text("wave op int", "WaveActiveAllTrue"))
  1277. add_test_case('WaveActiveUSum', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1278. [['1', '2', '3', '4'], ['0'], ['2', '4', '8', '64']], [],
  1279. 'cs_6_0', get_shader_text("wave op uint", "WaveActiveSum"))
  1280. add_test_case('WaveActiveUProduct', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1281. [['1', '2', '3', '4'], ['0'], ['1', '2', '4', '64']], [],
  1282. 'cs_6_0', get_shader_text("wave op uint", "WaveActiveProduct"))
  1283. add_test_case('WaveActiveUMax', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1284. [['1', '2', '3', '4'], ['0'], ['1', '10', '4', '64']], [],
  1285. 'cs_6_0', get_shader_text("wave op uint", "WaveActiveMax"))
  1286. add_test_case('WaveActiveUMin', ['WaveActiveOp', 'WaveReadLaneFirst', 'WaveReadLaneAt'], 'Epsilon', 0,
  1287. [['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'], ['0'],
  1288. ['1', '10', '4', '64']], [], 'cs_6_0',
  1289. get_shader_text("wave op uint", "WaveActiveMin"))
  1290. add_test_case('WaveActiveBitOr', ['WaveActiveBit'], 'Epsilon', 0, [[
  1291. '0xe0000000', '0x0d000000', '0x00b00000', '0x00070000', '0x0000e000',
  1292. '0x00000d00', '0x000000b0', '0x00000007'
  1293. ], ['0xedb7edb7', '0xdb7edb7e', '0xb7edb7ed', '0x7edb7edb'], [
  1294. '0x12481248', '0x24812481', '0x48124812', '0x81248124'
  1295. ], ['0x00000000', '0xffffffff']], [], 'cs_6_0', get_shader_text("wave op uint", "WaveActiveBitOr"))
  1296. add_test_case('WaveActiveBitAnd', ['WaveActiveBit'], 'Epsilon', 0, [[
  1297. '0xefffffff', '0xfdffffff', '0xffbfffff', '0xfff7ffff', '0xffffefff',
  1298. '0xfffffdff', '0xffffffbf', '0xfffffff7'
  1299. ], ['0xedb7edb7', '0xdb7edb7e', '0xb7edb7ed', '0x7edb7edb'], [
  1300. '0x12481248', '0x24812481', '0x48124812', '0x81248124'
  1301. ], ['0x00000000', '0xffffffff']], [], 'cs_6_0', get_shader_text("wave op uint", "WaveActiveBitAnd"))
  1302. add_test_case('WaveActiveBitXor', ['WaveActiveBit'], 'Epsilon', 0, [[
  1303. '0xe0000000', '0x0d000000', '0x00b00000', '0x00070000', '0x0000e000',
  1304. '0x00000d00', '0x000000b0', '0x00000007'
  1305. ], ['0xedb7edb7', '0xdb7edb7e', '0xb7edb7ed', '0x7edb7edb'], [
  1306. '0x12481248', '0x24812481', '0x48124812', '0x81248124'
  1307. ], ['0x00000000', '0xffffffff']], [], 'cs_6_0', get_shader_text("wave op uint", "WaveActiveBitXor"))
  1308. add_test_case('WavePrefixCountBits', ['WavePrefixBitCount'], 'Epsilon', 0,
  1309. [['1', '2', '3', '4', '5'], ['0'], ['1', '10', '-4', '-64'],
  1310. ['-100', '-1000', '300']], [], 'cs_6_0',
  1311. get_shader_text("wave op int count", "WavePrefixCountBits"))
  1312. add_test_case('WavePrefixSum', ['WavePrefixOp'], 'Epsilon', 0,
  1313. [['1', '2', '3', '4', '5'], ['0', '1'], ['1', '2', '4', '-64', '128']],
  1314. [], 'cs_6_0', get_shader_text("wave op int", "WavePrefixSum"))
  1315. add_test_case('WavePrefixProduct', ['WavePrefixOp'], 'Epsilon', 0,
  1316. [['1', '2', '3', '4', '5'], ['0', '1'], ['1', '2', '4', '-64', '128']],
  1317. [], 'cs_6_0', get_shader_text("wave op int", "WavePrefixProduct"))
  1318. add_test_case('WavePrefixUSum', ['WavePrefixOp'], 'Epsilon', 0,
  1319. [['1', '2', '3', '4', '5'], ['0', '1'], ['1', '2', '4', '128']], [],
  1320. 'cs_6_0', get_shader_text("wave op uint", "WavePrefixSum"))
  1321. add_test_case('WavePrefixUProduct', ['WavePrefixOp'], 'Epsilon', 0,
  1322. [['1', '2', '3', '4', '5'], ['0', '1'], ['1', '2', '4', '128']], [],
  1323. 'cs_6_0', get_shader_text("wave op uint", "WavePrefixProduct"))
  1324. #Wave Multi Prefix Tests
  1325. add_test_case('WaveMultiPrefixBitAnd', ['WaveMultiPrefixOp'], 'Epsilon', 0,
  1326. [['0', '3', '1', '5', '4'], ['10', '42', '1', '64', '11', '76', '90', '111', '9', '6', '79', '34']], [],
  1327. 'cs_6_5', get_shader_text("wave op multi prefix int", "WaveMultiPrefixBitAnd"))
  1328. add_test_case('WaveMultiPrefixBitOr', ['WaveMultiPrefixOp'], 'Epsilon', 0,
  1329. [['0', '3', '1', '5', '4'], ['10', '42', '1', '64', '11', '76', '90', '111', '9', '6', '79', '34']], [],
  1330. 'cs_6_5', get_shader_text("wave op multi prefix int", "WaveMultiPrefixBitOr"))
  1331. add_test_case('WaveMultiPrefixBitXor', ['WaveMultiPrefixOp'], 'Epsilon', 0,
  1332. [['0', '3', '1', '5', '4'], ['10', '42', '1', '64', '11', '76', '90', '111', '9', '6', '79', '34']], [],
  1333. 'cs_6_5', get_shader_text("wave op multi prefix int", "WaveMultiPrefixBitXor"))
  1334. add_test_case('WaveMultiPrefixSum', ['WaveMultiPrefixOp'], 'Epsilon', 0,
  1335. [['0', '3', '1', '5', '4'], ['10', '42', '1', '64', '11', '76', '90', '111', '9', '6', '79', '34']], [],
  1336. 'cs_6_5', get_shader_text("wave op multi prefix int", "WaveMultiPrefixSum"))
  1337. add_test_case('WaveMultiPrefixProduct', ['WaveMultiPrefixOp'], 'Epsilon', 0,
  1338. [['0', '3', '1', '5', '4'], ['10', '42', '1', '64', '11', '76', '90', '111', '9', '6', '79', '34']], [],
  1339. 'cs_6_5', get_shader_text("wave op multi prefix int", "WaveMultiPrefixProduct"))
  1340. add_test_case('WaveMultiPrefixCountBits', ['WaveMultiPrefixOp'], 'Epsilon', 0,
  1341. [['0', '3', '1', '5', '4'], ['0', '42', '0', '64', '11', '76', '90', '111', '0', '0', '79', '34']], [],
  1342. 'cs_6_5', get_shader_text("wave op multi prefix int", "WaveMultiPrefixCountBits"))
  1343. add_test_case('WaveMultiPrefixUBitAnd', ['WaveMultiPrefixOp'], 'Epsilon', 0,
  1344. [['0', '3', '1', '5', '4'], ['10', '42', '1', '64', '11', '76', '90', '111', '9', '6', '79', '34']], [],
  1345. 'cs_6_5', get_shader_text("wave op multi prefix uint", "WaveMultiPrefixBitAnd"))
  1346. add_test_case('WaveMultiPrefixUBitOr', ['WaveMultiPrefixOp'], 'Epsilon', 0,
  1347. [['0', '3', '1', '5', '4'], ['10', '42', '1', '64', '11', '76', '90', '111', '9', '6', '79', '34']], [],
  1348. 'cs_6_5', get_shader_text("wave op multi prefix uint", "WaveMultiPrefixBitOr"))
  1349. add_test_case('WaveMultiPrefixUBitXor', ['WaveMultiPrefixOp'], 'Epsilon', 0,
  1350. [['0', '3', '1', '5', '4'], ['10', '42', '1', '64', '11', '76', '90', '111', '9', '6', '79', '34']], [],
  1351. 'cs_6_5', get_shader_text("wave op multi prefix uint", "WaveMultiPrefixBitXor"))
  1352. add_test_case('WaveMultiPrefixUSum', ['WaveMultiPrefixOp'], 'Epsilon', 0,
  1353. [['0', '3', '1', '5', '4'], ['10', '42', '1', '64', '11', '76', '90', '111', '9', '6', '79', '34']], [],
  1354. 'cs_6_5', get_shader_text("wave op multi prefix uint", "WaveMultiPrefixSum"))
  1355. add_test_case('WaveMultiPrefixUProduct', ['WaveMultiPrefixOp'], 'Epsilon', 0,
  1356. [['0', '3', '1', '5', '4'], ['10', '42', '1', '64', '11', '76', '90', '111', '9', '6', '79', '34']], [],
  1357. 'cs_6_5', get_shader_text("wave op multi prefix uint", "WaveMultiPrefixProduct"))
  1358. add_test_case('WaveMultiPrefixUCountBits', ['WaveMultiPrefixOp'], 'Epsilon', 0,
  1359. [['0', '3', '1', '5', '4'], ['0', '42', '0', '64', '11', '76', '90', '111', '0', '0', '79', '34']], [],
  1360. 'cs_6_5', get_shader_text("wave op multi prefix uint", "WaveMultiPrefixCountBits"))
  1361. # generating xml file for execution test using data driven method
  1362. # TODO: ElementTree is not generating formatted XML. Currently xml file is checked in after VS Code formatter.
  1363. # Implement xml formatter or import formatter library and use that instead.
  1364. def generate_parameter_types(table, num_inputs, num_outputs, has_known_warp_issue=False):
  1365. param_types = ET.SubElement(table, "ParameterTypes")
  1366. ET.SubElement(
  1367. param_types, "ParameterType", attrib={
  1368. "Name": "ShaderOp.Target"
  1369. }).text = "String"
  1370. ET.SubElement(
  1371. param_types, "ParameterType", attrib={
  1372. "Name": "ShaderOp.Arguments"
  1373. }).text = "String"
  1374. ET.SubElement(
  1375. param_types, "ParameterType", attrib={
  1376. "Name": "ShaderOp.Text"
  1377. }).text = "String"
  1378. ET.SubElement(
  1379. param_types, "ParameterType", attrib={
  1380. "Name": "Validation.Type"
  1381. }).text = "String"
  1382. ET.SubElement(
  1383. param_types, "ParameterType", attrib={
  1384. "Name": "Validation.Tolerance"
  1385. }).text = "double"
  1386. for i in range(0, num_inputs):
  1387. ET.SubElement(
  1388. param_types,
  1389. "ParameterType",
  1390. attrib={
  1391. "Name": 'Validation.Input{}'.format(i + 1),
  1392. 'Array': 'true'
  1393. }).text = "String"
  1394. for i in range(0, num_outputs):
  1395. ET.SubElement(
  1396. param_types,
  1397. "ParameterType",
  1398. attrib={
  1399. "Name": 'Validation.Expected{}'.format(i + 1),
  1400. 'Array': 'true'
  1401. }).text = "String"
  1402. if has_known_warp_issue:
  1403. ET.SubElement(param_types, "ParameterType", attrib={"Name":"Warp.Version"}).text = "unsigned int"
  1404. def generate_parameter_types_wave(table):
  1405. param_types = ET.SubElement(table, "ParameterTypes")
  1406. ET.SubElement(
  1407. param_types, "ParameterType", attrib={
  1408. "Name": "ShaderOp.Target"
  1409. }).text = "String"
  1410. ET.SubElement(
  1411. param_types, "ParameterType", attrib={
  1412. "Name": "ShaderOp.Text"
  1413. }).text = "String"
  1414. ET.SubElement(
  1415. param_types,
  1416. "ParameterType",
  1417. attrib={
  1418. "Name": "Validation.NumInputSet"
  1419. }).text = "String"
  1420. ET.SubElement(
  1421. param_types,
  1422. "ParameterType",
  1423. attrib={
  1424. "Name": "Validation.InputSet1",
  1425. "Array": "true"
  1426. }).text = "String"
  1427. ET.SubElement(
  1428. param_types,
  1429. "ParameterType",
  1430. attrib={
  1431. "Name": "Validation.InputSet2",
  1432. "Array": "true"
  1433. }).text = "String"
  1434. ET.SubElement(
  1435. param_types,
  1436. "ParameterType",
  1437. attrib={
  1438. "Name": "Validation.InputSet3",
  1439. "Array": "true"
  1440. }).text = "String"
  1441. ET.SubElement(
  1442. param_types,
  1443. "ParameterType",
  1444. attrib={
  1445. "Name": "Validation.InputSet4",
  1446. "Array": "true"
  1447. }).text = "String"
  1448. def generate_parameter_types_wave_multi_prefix(table):
  1449. param_types = ET.SubElement(table, "ParameterTypes")
  1450. ET.SubElement(
  1451. param_types, "ParameterType", attrib={
  1452. "Name": "ShaderOp.Target"
  1453. }).text = "String"
  1454. ET.SubElement(
  1455. param_types, "ParameterType", attrib={
  1456. "Name": "ShaderOp.Text"
  1457. }).text = "String"
  1458. ET.SubElement(
  1459. param_types,
  1460. "ParameterType",
  1461. attrib={
  1462. "Name": "Validation.Keys",
  1463. "Array": "true"
  1464. }).text = "String"
  1465. ET.SubElement(
  1466. param_types,
  1467. "ParameterType",
  1468. attrib={
  1469. "Name": "Validation.Values",
  1470. "Array": "true"
  1471. }).text = "String"
  1472. def generate_parameter_types_msad(table):
  1473. param_types = ET.SubElement(table, "ParameterTypes")
  1474. ET.SubElement(
  1475. param_types, "ParameterType", attrib={
  1476. "Name": "ShaderOp.Text"
  1477. }).text = "String"
  1478. ET.SubElement(
  1479. param_types, "ParameterType", attrib={
  1480. "Name": "Validation.Tolerance"
  1481. }).text = "int"
  1482. ET.SubElement(
  1483. param_types,
  1484. "ParameterType",
  1485. attrib={
  1486. "Name": "Validation.Input1",
  1487. "Array": "true"
  1488. }).text = "unsigned int"
  1489. ET.SubElement(
  1490. param_types,
  1491. "ParameterType",
  1492. attrib={
  1493. "Name": "Validation.Input2",
  1494. "Array": "true"
  1495. }).text = "String"
  1496. ET.SubElement(
  1497. param_types,
  1498. "ParameterType",
  1499. attrib={
  1500. "Name": "Validation.Input3",
  1501. "Array": "true"
  1502. }).text = "String"
  1503. ET.SubElement(
  1504. param_types,
  1505. "ParameterType",
  1506. attrib={
  1507. "Name": "Validation.Expected1",
  1508. "Array": "true"
  1509. }).text = "String"
  1510. def generate_row(table, case):
  1511. row = ET.SubElement(table, "Row", {"Name": case.test_name})
  1512. ET.SubElement(row, "Parameter", {
  1513. "Name": "Validation.Type"
  1514. }).text = case.validation_type
  1515. ET.SubElement(row, "Parameter", {
  1516. "Name": "Validation.Tolerance"
  1517. }).text = str(case.validation_tolerance)
  1518. ET.SubElement(row, "Parameter", {
  1519. "Name": "ShaderOp.Text"
  1520. }).text = case.shader_text
  1521. ET.SubElement(row, "Parameter", {
  1522. "Name": "ShaderOp.Target"
  1523. }).text = case.shader_target
  1524. for i in range(len(case.input_lists)):
  1525. inputs = ET.SubElement(row, "Parameter", {
  1526. "Name": "Validation.Input{}".format(i + 1)
  1527. })
  1528. for val in case.input_lists[i]:
  1529. ET.SubElement(inputs, "Value").text = str(val)
  1530. for i in range(len(case.output_lists)):
  1531. outputs = ET.SubElement(row, "Parameter", {
  1532. "Name": "Validation.Expected{}".format(i + 1)
  1533. })
  1534. for val in case.output_lists[i]:
  1535. ET.SubElement(outputs, "Value").text = str(val)
  1536. # Optional parameters
  1537. if case.warp_version > 0:
  1538. ET.SubElement(row, "Parameter", {"Name":"Warp.Version"}).text = str(case.warp_version)
  1539. if case.shader_arguments != "":
  1540. ET.SubElement(row, "Parameter", {"Name":"ShaderOp.Arguments"}).text = case.shader_arguments
  1541. def generate_row_wave(table, case):
  1542. row = ET.SubElement(table, "Row", {"Name": case.test_name})
  1543. ET.SubElement(row, "Parameter", {
  1544. "Name": "ShaderOp.Name"
  1545. }).text = case.test_name
  1546. ET.SubElement(row, "Parameter", {
  1547. "Name": "ShaderOp.Text"
  1548. }).text = case.shader_text
  1549. ET.SubElement(row, "Parameter", {
  1550. "Name": "Validation.NumInputSet"
  1551. }).text = str(len(case.input_lists))
  1552. for i in range(len(case.input_lists)):
  1553. inputs = ET.SubElement(row, "Parameter", {
  1554. "Name": "Validation.InputSet{}".format(i + 1)
  1555. })
  1556. for val in case.input_lists[i]:
  1557. ET.SubElement(inputs, "Value").text = str(val)
  1558. def generate_row_wave_multi(table, case):
  1559. row = ET.SubElement(table, "Row", {"Name": case.test_name})
  1560. ET.SubElement(row, "Parameter", {
  1561. "Name": "ShaderOp.Name"
  1562. }).text = case.test_name
  1563. ET.SubElement(row, "Parameter", {
  1564. "Name": "ShaderOp.Target"
  1565. }).text = case.shader_target
  1566. ET.SubElement(row, "Parameter", {
  1567. "Name": "ShaderOp.Text"
  1568. }).text = case.shader_text
  1569. inputs = ET.SubElement(row, "Parameter", {
  1570. "Name": "Validation.Keys"
  1571. })
  1572. for val in case.input_lists[0]:
  1573. ET.SubElement(inputs, "Value").text = str(val)
  1574. inputs = ET.SubElement(row, "Parameter", {
  1575. "Name": "Validation.Values"
  1576. })
  1577. for val in case.input_lists[1]:
  1578. ET.SubElement(inputs, "Value").text = str(val)
  1579. def generate_table_for_taef():
  1580. with open("..\\..\\tools\\clang\\unittests\\HLSL\\ShaderOpArithTable.xml",
  1581. 'w') as f:
  1582. tree = ET.ElementTree()
  1583. root = ET.Element('Data')
  1584. # Create tables
  1585. generate_parameter_types(
  1586. ET.SubElement(root, "Table", attrib={
  1587. "Id": "UnaryFloatOpTable"
  1588. }), 1, 1, True)
  1589. generate_parameter_types(
  1590. ET.SubElement(root, "Table", attrib={
  1591. "Id": "BinaryFloatOpTable"
  1592. }), 2, 2)
  1593. generate_parameter_types(
  1594. ET.SubElement(root, "Table", attrib={
  1595. "Id": "TertiaryFloatOpTable"
  1596. }), 3, 1)
  1597. generate_parameter_types(
  1598. ET.SubElement(root, "Table", attrib={
  1599. "Id": "UnaryHalfOpTable"
  1600. }), 1, 1, True)
  1601. generate_parameter_types(
  1602. ET.SubElement(root, "Table", attrib={
  1603. "Id": "BinaryHalfOpTable"
  1604. }), 2, 2)
  1605. generate_parameter_types(
  1606. ET.SubElement(root, "Table", attrib={
  1607. "Id": "TertiaryHalfOpTable"
  1608. }), 3, 1)
  1609. generate_parameter_types(
  1610. ET.SubElement(root, "Table", attrib={
  1611. "Id": "UnaryIntOpTable"
  1612. }), 1, 1)
  1613. generate_parameter_types(
  1614. ET.SubElement(root, "Table", attrib={
  1615. "Id": "BinaryIntOpTable"
  1616. }), 2, 2)
  1617. generate_parameter_types(
  1618. ET.SubElement(root, "Table", attrib={
  1619. "Id": "TertiaryIntOpTable"
  1620. }), 3, 1)
  1621. generate_parameter_types(
  1622. ET.SubElement(root, "Table", attrib={
  1623. "Id": "UnaryInt16OpTable"
  1624. }), 1, 1)
  1625. generate_parameter_types(
  1626. ET.SubElement(root, "Table", attrib={
  1627. "Id": "BinaryInt16OpTable"
  1628. }), 2, 2)
  1629. generate_parameter_types(
  1630. ET.SubElement(root, "Table", attrib={
  1631. "Id": "TertiaryInt16OpTable"
  1632. }), 3, 1)
  1633. generate_parameter_types(
  1634. ET.SubElement(root, "Table", attrib={
  1635. "Id": "UnaryUintOpTable"
  1636. }), 1, 1)
  1637. generate_parameter_types(
  1638. ET.SubElement(root, "Table", attrib={
  1639. "Id": "BinaryUintOpTable"
  1640. }), 2, 2)
  1641. generate_parameter_types(
  1642. ET.SubElement(root, "Table", attrib={
  1643. "Id": "TertiaryUintOpTable"
  1644. }), 3, 1)
  1645. generate_parameter_types(
  1646. ET.SubElement(root, "Table", attrib={
  1647. "Id": "UnaryUint16OpTable"
  1648. }), 1, 1)
  1649. generate_parameter_types(
  1650. ET.SubElement(root, "Table", attrib={
  1651. "Id": "BinaryUint16OpTable"
  1652. }), 2, 2)
  1653. generate_parameter_types(
  1654. ET.SubElement(root, "Table", attrib={
  1655. "Id": "TertiaryUint16OpTable"
  1656. }), 3, 1)
  1657. generate_parameter_types(
  1658. ET.SubElement(root, "Table", attrib={
  1659. "Id": "DotOpTable"
  1660. }), 2, 3)
  1661. generate_parameter_types(
  1662. ET.SubElement(root, "Table", attrib={
  1663. "Id": "Dot2AddHalfOpTable"
  1664. }), 3, 1)
  1665. generate_parameter_types(
  1666. ET.SubElement(root, "Table", attrib={
  1667. "Id": "Dot4AddI8PackedOpTable"
  1668. }), 3, 1)
  1669. generate_parameter_types(
  1670. ET.SubElement(root, "Table", attrib={
  1671. "Id": "Dot4AddU8PackedOpTable"
  1672. }), 3, 1)
  1673. generate_parameter_types_msad(
  1674. ET.SubElement(root, "Table", attrib={
  1675. "Id": "Msad4Table"
  1676. }))
  1677. generate_parameter_types_wave(
  1678. ET.SubElement(
  1679. root, "Table", attrib={
  1680. "Id": "WaveIntrinsicsActiveIntTable"
  1681. }))
  1682. generate_parameter_types_wave(
  1683. ET.SubElement(
  1684. root, "Table", attrib={
  1685. "Id": "WaveIntrinsicsActiveUintTable"
  1686. }))
  1687. generate_parameter_types_wave(
  1688. ET.SubElement(
  1689. root, "Table", attrib={
  1690. "Id": "WaveIntrinsicsPrefixIntTable"
  1691. }))
  1692. generate_parameter_types_wave(
  1693. ET.SubElement(
  1694. root, "Table", attrib={
  1695. "Id": "WaveIntrinsicsPrefixUintTable"
  1696. }))
  1697. generate_parameter_types_wave_multi_prefix(
  1698. ET.SubElement(
  1699. root, "Table", attrib={
  1700. "Id": "WaveIntrinsicsMultiPrefixIntTable"
  1701. }))
  1702. generate_parameter_types_wave_multi_prefix(
  1703. ET.SubElement(
  1704. root, "Table", attrib={
  1705. "Id": "WaveIntrinsicsMultiPrefixUintTable"
  1706. }))
  1707. generate_parameter_types(
  1708. ET.SubElement(
  1709. root, "Table", attrib={
  1710. "Id": "DenormBinaryFloatOpTable"
  1711. }), 2, 2) # 2 sets of expected values for any mode
  1712. generate_parameter_types(
  1713. ET.SubElement(
  1714. root, "Table", attrib={
  1715. "Id": "DenormTertiaryFloatOpTable"
  1716. }), 3, 2)
  1717. for case in g_test_cases.values():
  1718. cur_inst = case.insts[0]
  1719. if cur_inst.is_cast or cur_inst.category.startswith("Unary"):
  1720. if "f" in cur_inst.oload_types and not "Half" in case.test_name:
  1721. generate_row(
  1722. root.find("./Table[@Id='UnaryFloatOpTable']"),
  1723. case)
  1724. if "h" in cur_inst.oload_types and "Half" in case.test_name:
  1725. generate_row(root.find("./Table[@Id='UnaryHalfOpTable']"),case)
  1726. if "i" in cur_inst.oload_types and "Bit16" not in case.test_name:
  1727. if cur_inst.category.startswith("Unary int"):
  1728. generate_row(
  1729. root.find("./Table[@Id='UnaryIntOpTable']"),
  1730. case)
  1731. elif cur_inst.category.startswith("Unary uint"):
  1732. generate_row(
  1733. root.find("./Table[@Id='UnaryUintOpTable']"),
  1734. case)
  1735. else:
  1736. print("unknown op: " + cur_inst.name)
  1737. print(cur_inst.dxil_class)
  1738. if "w" in cur_inst.oload_types and "Bit16" in case.test_name:
  1739. if cur_inst.category.startswith("Unary int"):
  1740. generate_row(
  1741. root.find("./Table[@Id='UnaryInt16OpTable']"),
  1742. case)
  1743. elif cur_inst.category.startswith("Unary uint"):
  1744. generate_row(
  1745. root.find("./Table[@Id='UnaryUint16OpTable']"),
  1746. case)
  1747. else:
  1748. print("unknown op: " + cur_inst.name)
  1749. print(cur_inst.dxil_class)
  1750. elif cur_inst.is_binary or cur_inst.category.startswith(
  1751. "Binary"):
  1752. if "f" in cur_inst.oload_types and not "Half" in case.test_name:
  1753. if case.test_name in g_denorm_tests: # for denorm tests
  1754. generate_row(
  1755. root.find("./Table[@Id='DenormBinaryFloatOpTable']"),
  1756. case)
  1757. else:
  1758. generate_row(
  1759. root.find("./Table[@Id='BinaryFloatOpTable']"),
  1760. case)
  1761. if "h" in cur_inst.oload_types and "Half" in case.test_name:
  1762. generate_row(root.find("./Table[@Id='BinaryHalfOpTable']"),case)
  1763. if "i" in cur_inst.oload_types and "Bit16" not in case.test_name:
  1764. if cur_inst.category.startswith("Binary int"):
  1765. if case.test_name in ['UAdd', 'USub', 'UMul']: # Add, Sub, Mul use same operations for int and uint.
  1766. generate_row(
  1767. root.find("./Table[@Id='BinaryUintOpTable']"),
  1768. case)
  1769. else:
  1770. generate_row(
  1771. root.find("./Table[@Id='BinaryIntOpTable']"),
  1772. case)
  1773. elif cur_inst.category.startswith("Binary uint"):
  1774. generate_row(
  1775. root.find("./Table[@Id='BinaryUintOpTable']"),
  1776. case)
  1777. else:
  1778. print("unknown op: " + cur_inst.name)
  1779. print(cur_inst.dxil_class)
  1780. if "w" in cur_inst.oload_types and "Bit16" in case.test_name:
  1781. if cur_inst.category.startswith("Binary int"):
  1782. if case.test_name in ['UAdd', 'USub', 'UMul']: # Add, Sub, Mul use same operations for int and uint.
  1783. generate_row(
  1784. root.find("./Table[@Id='BinaryUint16OpTable']"),
  1785. case)
  1786. else:
  1787. generate_row(
  1788. root.find("./Table[@Id='BinaryInt16OpTable']"),
  1789. case)
  1790. elif cur_inst.category.startswith("Binary uint"):
  1791. generate_row(
  1792. root.find("./Table[@Id='BinaryUint16OpTable']"),
  1793. case)
  1794. else:
  1795. print("unknown op: " + cur_inst.name)
  1796. print(cur_inst.dxil_class)
  1797. elif cur_inst.category.startswith("Tertiary"):
  1798. if "f" in cur_inst.oload_types and not "Half" in case.test_name:
  1799. if case.test_name in g_denorm_tests: # for denorm tests
  1800. generate_row(
  1801. root.find("./Table[@Id='DenormTertiaryFloatOpTable']"),case)
  1802. else:
  1803. generate_row(
  1804. root.find("./Table[@Id='TertiaryFloatOpTable']"),case)
  1805. if "h" in cur_inst.oload_types and "Half" in case.test_name:
  1806. generate_row(root.find("./Table[@Id='TertiaryHalfOpTable']"),case)
  1807. if "i" in cur_inst.oload_types and "Bit16" not in case.test_name:
  1808. if cur_inst.category.startswith("Tertiary int"):
  1809. generate_row(
  1810. root.find("./Table[@Id='TertiaryIntOpTable']"),
  1811. case)
  1812. elif cur_inst.category.startswith("Tertiary uint"):
  1813. generate_row(
  1814. root.find(
  1815. "./Table[@Id='TertiaryUintOpTable']"),
  1816. case)
  1817. else:
  1818. print("unknown op: " + cur_inst.name)
  1819. print(cur_inst.dxil_class)
  1820. if "w" in cur_inst.oload_types and "Bit16" in case.test_name:
  1821. if cur_inst.category.startswith("Tertiary int"):
  1822. generate_row(
  1823. root.find("./Table[@Id='TertiaryInt16OpTable']"),
  1824. case)
  1825. elif cur_inst.category.startswith("Tertiary uint"):
  1826. generate_row(
  1827. root.find(
  1828. "./Table[@Id='TertiaryUint16OpTable']"),
  1829. case)
  1830. else:
  1831. print("unknown op: " + cur_inst.name)
  1832. print(cur_inst.dxil_class)
  1833. elif cur_inst.category.startswith("Quaternary"):
  1834. if cur_inst.name == "Bfi":
  1835. generate_row(
  1836. root.find("./Table[@Id='Msad4Table']"), case)
  1837. else:
  1838. print("unknown op: " + cur_inst.name)
  1839. print(cur_inst.dxil_class)
  1840. elif cur_inst.category == "Dot":
  1841. generate_row(root.find("./Table[@Id='DotOpTable']"), case)
  1842. elif cur_inst.category == "Dot product with accumulate":
  1843. if cur_inst.name == "Dot2AddHalf":
  1844. generate_row(root.find("./Table[@Id='Dot2AddHalfOpTable']"), case)
  1845. elif cur_inst.name == "Dot4AddI8Packed":
  1846. generate_row(root.find("./Table[@Id='Dot4AddI8PackedOpTable']"), case)
  1847. elif cur_inst.name == "Dot4AddU8Packed":
  1848. generate_row(root.find("./Table[@Id='Dot4AddU8PackedOpTable']"), case)
  1849. else:
  1850. print("unknown op: " + cur_inst.name)
  1851. print(cur_inst.dxil_class)
  1852. elif cur_inst.dxil_class in ["WaveActiveOp", "WaveAllOp","WaveActiveAllEqual","WaveAnyTrue","WaveAllTrue"]:
  1853. if case.test_name.startswith("WaveActiveU"):
  1854. generate_row_wave(
  1855. root.find(
  1856. "./Table[@Id='WaveIntrinsicsActiveUintTable']"
  1857. ), case)
  1858. else:
  1859. generate_row_wave(
  1860. root.find(
  1861. "./Table[@Id='WaveIntrinsicsActiveIntTable']"),
  1862. case)
  1863. elif cur_inst.dxil_class == "WaveActiveBit":
  1864. generate_row_wave(
  1865. root.find(
  1866. "./Table[@Id='WaveIntrinsicsActiveUintTable']"),
  1867. case)
  1868. elif cur_inst.dxil_class == "WavePrefixOp":
  1869. if case.test_name.startswith("WavePrefixU"):
  1870. generate_row_wave(
  1871. root.find(
  1872. "./Table[@Id='WaveIntrinsicsPrefixUintTable']"
  1873. ), case)
  1874. else:
  1875. generate_row_wave(
  1876. root.find(
  1877. "./Table[@Id='WaveIntrinsicsPrefixIntTable']"),
  1878. case)
  1879. elif cur_inst.dxil_class == "WaveMultiPrefixOp":
  1880. if case.test_name.startswith("WaveMultiPrefixU"):
  1881. generate_row_wave_multi(
  1882. root.find(
  1883. "./Table[@Id='WaveIntrinsicsMultiPrefixUintTable']"
  1884. ), case)
  1885. else:
  1886. generate_row_wave_multi(
  1887. root.find(
  1888. "./Table[@Id='WaveIntrinsicsMultiPrefixIntTable']"),
  1889. case)
  1890. else:
  1891. print("unknown op: " + cur_inst.name)
  1892. print(cur_inst.dxil_class)
  1893. tree._setroot(root)
  1894. from xml.dom.minidom import parseString
  1895. pretty_xml = parseString(ET.tostring(root)).toprettyxml(indent=" ")
  1896. f.write(pretty_xml)
  1897. print("Saved file at: " + f.name)
  1898. f.close()
  1899. def print_untested_inst():
  1900. lst = []
  1901. for name in [node.inst.name for node in g_instruction_nodes.values() if len(node.test_cases) == 0]:
  1902. lst += [name]
  1903. lst.sort()
  1904. print("Untested dxil ops: ")
  1905. for name in lst:
  1906. print(name)
  1907. print("Total uncovered dxil ops: " + str(len(lst)))
  1908. print("Total covered dxil ops: " + str(len(g_instruction_nodes)-len(lst)))
  1909. # inst name to instruction dict
  1910. g_instruction_nodes = {}
  1911. # test name to test case dict
  1912. g_test_cases = {}
  1913. if __name__ == "__main__":
  1914. db = get_db_dxil()
  1915. for inst in db.instr:
  1916. g_instruction_nodes[inst.name] = inst_node(inst)
  1917. add_test_cases()
  1918. args = vars(parser.parse_args())
  1919. mode = args['mode']
  1920. if mode == "info":
  1921. print_untested_inst()
  1922. elif mode == "gen-xml":
  1923. generate_table_for_taef()
  1924. else:
  1925. print("unknown mode: " + mode)
  1926. exit(1)
  1927. exit(0)