val_arithmetics_test.cpp 65 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Tests for unique type declaration rules validator.
  15. #include <string>
  16. #include "gmock/gmock.h"
  17. #include "test/unit_spirv.h"
  18. #include "test/val/val_fixtures.h"
  19. namespace spvtools {
  20. namespace val {
  21. namespace {
  22. using ::testing::HasSubstr;
  23. using ::testing::Not;
  24. using ValidateArithmetics = spvtest::ValidateBase<bool>;
  25. std::string GenerateCode(const std::string& main_body) {
  26. const std::string prefix =
  27. R"(
  28. OpCapability Shader
  29. OpCapability Int64
  30. OpCapability Float64
  31. OpCapability Matrix
  32. %ext_inst = OpExtInstImport "GLSL.std.450"
  33. OpMemoryModel Logical GLSL450
  34. OpEntryPoint Fragment %main "main"
  35. OpExecutionMode %main OriginUpperLeft
  36. %void = OpTypeVoid
  37. %func = OpTypeFunction %void
  38. %bool = OpTypeBool
  39. %f32 = OpTypeFloat 32
  40. %u32 = OpTypeInt 32 0
  41. %s32 = OpTypeInt 32 1
  42. %f64 = OpTypeFloat 64
  43. %u64 = OpTypeInt 64 0
  44. %s64 = OpTypeInt 64 1
  45. %boolvec2 = OpTypeVector %bool 2
  46. %s32vec2 = OpTypeVector %s32 2
  47. %u32vec2 = OpTypeVector %u32 2
  48. %u64vec2 = OpTypeVector %u64 2
  49. %f32vec2 = OpTypeVector %f32 2
  50. %f64vec2 = OpTypeVector %f64 2
  51. %boolvec3 = OpTypeVector %bool 3
  52. %u32vec3 = OpTypeVector %u32 3
  53. %u64vec3 = OpTypeVector %u64 3
  54. %s32vec3 = OpTypeVector %s32 3
  55. %f32vec3 = OpTypeVector %f32 3
  56. %f64vec3 = OpTypeVector %f64 3
  57. %boolvec4 = OpTypeVector %bool 4
  58. %u32vec4 = OpTypeVector %u32 4
  59. %u64vec4 = OpTypeVector %u64 4
  60. %s32vec4 = OpTypeVector %s32 4
  61. %f32vec4 = OpTypeVector %f32 4
  62. %f64vec4 = OpTypeVector %f64 4
  63. %f32mat22 = OpTypeMatrix %f32vec2 2
  64. %f32mat23 = OpTypeMatrix %f32vec2 3
  65. %f32mat32 = OpTypeMatrix %f32vec3 2
  66. %f32mat33 = OpTypeMatrix %f32vec3 3
  67. %f64mat22 = OpTypeMatrix %f64vec2 2
  68. %struct_f32_f32 = OpTypeStruct %f32 %f32
  69. %struct_u32_u32 = OpTypeStruct %u32 %u32
  70. %struct_u32_u32_u32 = OpTypeStruct %u32 %u32 %u32
  71. %struct_s32_s32 = OpTypeStruct %s32 %s32
  72. %struct_s32_u32 = OpTypeStruct %s32 %u32
  73. %struct_u32vec2_u32vec2 = OpTypeStruct %u32vec2 %u32vec2
  74. %struct_s32vec2_s32vec2 = OpTypeStruct %s32vec2 %s32vec2
  75. %f32_0 = OpConstant %f32 0
  76. %f32_1 = OpConstant %f32 1
  77. %f32_2 = OpConstant %f32 2
  78. %f32_3 = OpConstant %f32 3
  79. %f32_4 = OpConstant %f32 4
  80. %f32_pi = OpConstant %f32 3.14159
  81. %s32_0 = OpConstant %s32 0
  82. %s32_1 = OpConstant %s32 1
  83. %s32_2 = OpConstant %s32 2
  84. %s32_3 = OpConstant %s32 3
  85. %s32_4 = OpConstant %s32 4
  86. %s32_m1 = OpConstant %s32 -1
  87. %u32_0 = OpConstant %u32 0
  88. %u32_1 = OpConstant %u32 1
  89. %u32_2 = OpConstant %u32 2
  90. %u32_3 = OpConstant %u32 3
  91. %u32_4 = OpConstant %u32 4
  92. %f64_0 = OpConstant %f64 0
  93. %f64_1 = OpConstant %f64 1
  94. %f64_2 = OpConstant %f64 2
  95. %f64_3 = OpConstant %f64 3
  96. %f64_4 = OpConstant %f64 4
  97. %s64_0 = OpConstant %s64 0
  98. %s64_1 = OpConstant %s64 1
  99. %s64_2 = OpConstant %s64 2
  100. %s64_3 = OpConstant %s64 3
  101. %s64_4 = OpConstant %s64 4
  102. %s64_m1 = OpConstant %s64 -1
  103. %u64_0 = OpConstant %u64 0
  104. %u64_1 = OpConstant %u64 1
  105. %u64_2 = OpConstant %u64 2
  106. %u64_3 = OpConstant %u64 3
  107. %u64_4 = OpConstant %u64 4
  108. %u32vec2_01 = OpConstantComposite %u32vec2 %u32_0 %u32_1
  109. %u32vec2_12 = OpConstantComposite %u32vec2 %u32_1 %u32_2
  110. %u32vec3_012 = OpConstantComposite %u32vec3 %u32_0 %u32_1 %u32_2
  111. %u32vec3_123 = OpConstantComposite %u32vec3 %u32_1 %u32_2 %u32_3
  112. %u32vec4_0123 = OpConstantComposite %u32vec4 %u32_0 %u32_1 %u32_2 %u32_3
  113. %u32vec4_1234 = OpConstantComposite %u32vec4 %u32_1 %u32_2 %u32_3 %u32_4
  114. %s32vec2_01 = OpConstantComposite %s32vec2 %s32_0 %s32_1
  115. %s32vec2_12 = OpConstantComposite %s32vec2 %s32_1 %s32_2
  116. %s32vec3_012 = OpConstantComposite %s32vec3 %s32_0 %s32_1 %s32_2
  117. %s32vec3_123 = OpConstantComposite %s32vec3 %s32_1 %s32_2 %s32_3
  118. %s32vec4_0123 = OpConstantComposite %s32vec4 %s32_0 %s32_1 %s32_2 %s32_3
  119. %s32vec4_1234 = OpConstantComposite %s32vec4 %s32_1 %s32_2 %s32_3 %s32_4
  120. %f32vec2_01 = OpConstantComposite %f32vec2 %f32_0 %f32_1
  121. %f32vec2_12 = OpConstantComposite %f32vec2 %f32_1 %f32_2
  122. %f32vec3_012 = OpConstantComposite %f32vec3 %f32_0 %f32_1 %f32_2
  123. %f32vec3_123 = OpConstantComposite %f32vec3 %f32_1 %f32_2 %f32_3
  124. %f32vec4_0123 = OpConstantComposite %f32vec4 %f32_0 %f32_1 %f32_2 %f32_3
  125. %f32vec4_1234 = OpConstantComposite %f32vec4 %f32_1 %f32_2 %f32_3 %f32_4
  126. %f64vec2_01 = OpConstantComposite %f64vec2 %f64_0 %f64_1
  127. %f64vec2_12 = OpConstantComposite %f64vec2 %f64_1 %f64_2
  128. %f64vec3_012 = OpConstantComposite %f64vec3 %f64_0 %f64_1 %f64_2
  129. %f64vec3_123 = OpConstantComposite %f64vec3 %f64_1 %f64_2 %f64_3
  130. %f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3
  131. %f64vec4_1234 = OpConstantComposite %f64vec4 %f64_1 %f64_2 %f64_3 %f64_4
  132. %f32mat22_1212 = OpConstantComposite %f32mat22 %f32vec2_12 %f32vec2_12
  133. %f32mat23_121212 = OpConstantComposite %f32mat23 %f32vec2_12 %f32vec2_12 %f32vec2_12
  134. %f32mat32_123123 = OpConstantComposite %f32mat32 %f32vec3_123 %f32vec3_123
  135. %f32mat33_123123123 = OpConstantComposite %f32mat33 %f32vec3_123 %f32vec3_123 %f32vec3_123
  136. %f64mat22_1212 = OpConstantComposite %f64mat22 %f64vec2_12 %f64vec2_12
  137. %main = OpFunction %void None %func
  138. %main_entry = OpLabel)";
  139. const std::string suffix =
  140. R"(
  141. OpReturn
  142. OpFunctionEnd)";
  143. return prefix + main_body + suffix;
  144. }
  145. TEST_F(ValidateArithmetics, F32Success) {
  146. const std::string body = R"(
  147. %val1 = OpFMul %f32 %f32_0 %f32_1
  148. %val2 = OpFSub %f32 %f32_2 %f32_0
  149. %val3 = OpFAdd %f32 %val1 %val2
  150. %val4 = OpFNegate %f32 %val3
  151. %val5 = OpFDiv %f32 %val4 %val1
  152. %val6 = OpFRem %f32 %val4 %f32_2
  153. %val7 = OpFMod %f32 %val4 %f32_2
  154. )";
  155. CompileSuccessfully(GenerateCode(body).c_str());
  156. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  157. }
  158. TEST_F(ValidateArithmetics, F64Success) {
  159. const std::string body = R"(
  160. %val1 = OpFMul %f64 %f64_0 %f64_1
  161. %val2 = OpFSub %f64 %f64_2 %f64_0
  162. %val3 = OpFAdd %f64 %val1 %val2
  163. %val4 = OpFNegate %f64 %val3
  164. %val5 = OpFDiv %f64 %val4 %val1
  165. %val6 = OpFRem %f64 %val4 %f64_2
  166. %val7 = OpFMod %f64 %val4 %f64_2
  167. )";
  168. CompileSuccessfully(GenerateCode(body).c_str());
  169. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  170. }
  171. TEST_F(ValidateArithmetics, Int32Success) {
  172. const std::string body = R"(
  173. %val1 = OpIMul %u32 %s32_0 %u32_1
  174. %val2 = OpIMul %s32 %s32_2 %u32_1
  175. %val3 = OpIAdd %u32 %val1 %val2
  176. %val4 = OpIAdd %s32 %val1 %val2
  177. %val5 = OpISub %u32 %val3 %val4
  178. %val6 = OpISub %s32 %val4 %val3
  179. %val7 = OpSDiv %s32 %val4 %val3
  180. %val8 = OpSNegate %s32 %val7
  181. %val9 = OpSRem %s32 %val4 %val3
  182. %val10 = OpSMod %s32 %val4 %val3
  183. )";
  184. CompileSuccessfully(GenerateCode(body).c_str());
  185. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  186. }
  187. TEST_F(ValidateArithmetics, Int64Success) {
  188. const std::string body = R"(
  189. %val1 = OpIMul %u64 %s64_0 %u64_1
  190. %val2 = OpIMul %s64 %s64_2 %u64_1
  191. %val3 = OpIAdd %u64 %val1 %val2
  192. %val4 = OpIAdd %s64 %val1 %val2
  193. %val5 = OpISub %u64 %val3 %val4
  194. %val6 = OpISub %s64 %val4 %val3
  195. %val7 = OpSDiv %s64 %val4 %val3
  196. %val8 = OpSNegate %s64 %val7
  197. %val9 = OpSRem %s64 %val4 %val3
  198. %val10 = OpSMod %s64 %val4 %val3
  199. )";
  200. CompileSuccessfully(GenerateCode(body).c_str());
  201. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  202. }
  203. TEST_F(ValidateArithmetics, F32Vec2Success) {
  204. const std::string body = R"(
  205. %val1 = OpFMul %f32vec2 %f32vec2_01 %f32vec2_12
  206. %val2 = OpFSub %f32vec2 %f32vec2_12 %f32vec2_01
  207. %val3 = OpFAdd %f32vec2 %val1 %val2
  208. %val4 = OpFNegate %f32vec2 %val3
  209. %val5 = OpFDiv %f32vec2 %val4 %val1
  210. %val6 = OpFRem %f32vec2 %val4 %f32vec2_12
  211. %val7 = OpFMod %f32vec2 %val4 %f32vec2_12
  212. )";
  213. CompileSuccessfully(GenerateCode(body).c_str());
  214. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  215. }
  216. TEST_F(ValidateArithmetics, F64Vec2Success) {
  217. const std::string body = R"(
  218. %val1 = OpFMul %f64vec2 %f64vec2_01 %f64vec2_12
  219. %val2 = OpFSub %f64vec2 %f64vec2_12 %f64vec2_01
  220. %val3 = OpFAdd %f64vec2 %val1 %val2
  221. %val4 = OpFNegate %f64vec2 %val3
  222. %val5 = OpFDiv %f64vec2 %val4 %val1
  223. %val6 = OpFRem %f64vec2 %val4 %f64vec2_12
  224. %val7 = OpFMod %f64vec2 %val4 %f64vec2_12
  225. )";
  226. CompileSuccessfully(GenerateCode(body).c_str());
  227. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  228. }
  229. TEST_F(ValidateArithmetics, U32Vec2Success) {
  230. const std::string body = R"(
  231. %val1 = OpIMul %u32vec2 %u32vec2_01 %u32vec2_12
  232. %val2 = OpISub %u32vec2 %u32vec2_12 %u32vec2_01
  233. %val3 = OpIAdd %u32vec2 %val1 %val2
  234. %val4 = OpSNegate %u32vec2 %val3
  235. %val5 = OpSDiv %u32vec2 %val4 %val1
  236. %val6 = OpSRem %u32vec2 %val4 %u32vec2_12
  237. %val7 = OpSMod %u32vec2 %val4 %u32vec2_12
  238. )";
  239. CompileSuccessfully(GenerateCode(body).c_str());
  240. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  241. }
  242. TEST_F(ValidateArithmetics, FNegateTypeIdU32) {
  243. const std::string body = R"(
  244. %val = OpFNegate %u32 %u32_0
  245. )";
  246. CompileSuccessfully(GenerateCode(body).c_str());
  247. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  248. EXPECT_THAT(
  249. getDiagnosticString(),
  250. HasSubstr(
  251. "Expected floating scalar or vector type as Result Type: FNegate"));
  252. }
  253. TEST_F(ValidateArithmetics, FNegateTypeIdVec2U32) {
  254. const std::string body = R"(
  255. %val = OpFNegate %u32vec2 %u32vec2_01
  256. )";
  257. CompileSuccessfully(GenerateCode(body).c_str());
  258. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  259. EXPECT_THAT(
  260. getDiagnosticString(),
  261. HasSubstr(
  262. "Expected floating scalar or vector type as Result Type: FNegate"));
  263. }
  264. TEST_F(ValidateArithmetics, FNegateWrongOperand) {
  265. const std::string body = R"(
  266. %val = OpFNegate %f32 %u32_0
  267. )";
  268. CompileSuccessfully(GenerateCode(body).c_str());
  269. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  270. EXPECT_THAT(getDiagnosticString(),
  271. HasSubstr("Expected arithmetic operands to be of Result Type: "
  272. "FNegate operand index 2"));
  273. }
  274. TEST_F(ValidateArithmetics, FMulTypeIdU32) {
  275. const std::string body = R"(
  276. %val = OpFMul %u32 %u32_0 %u32_1
  277. )";
  278. CompileSuccessfully(GenerateCode(body).c_str());
  279. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  280. EXPECT_THAT(
  281. getDiagnosticString(),
  282. HasSubstr(
  283. "Expected floating scalar or vector type as Result Type: FMul"));
  284. }
  285. TEST_F(ValidateArithmetics, FMulTypeIdVec2U32) {
  286. const std::string body = R"(
  287. %val = OpFMul %u32vec2 %u32vec2_01 %u32vec2_12
  288. )";
  289. CompileSuccessfully(GenerateCode(body).c_str());
  290. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  291. EXPECT_THAT(
  292. getDiagnosticString(),
  293. HasSubstr(
  294. "Expected floating scalar or vector type as Result Type: FMul"));
  295. }
  296. TEST_F(ValidateArithmetics, FMulWrongOperand1) {
  297. const std::string body = R"(
  298. %val = OpFMul %f32 %u32_0 %f32_1
  299. )";
  300. CompileSuccessfully(GenerateCode(body).c_str());
  301. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  302. EXPECT_THAT(getDiagnosticString(),
  303. HasSubstr("Expected arithmetic operands to be of Result Type: "
  304. "FMul operand index 2"));
  305. }
  306. TEST_F(ValidateArithmetics, FMulWrongOperand2) {
  307. const std::string body = R"(
  308. %val = OpFMul %f32 %f32_0 %u32_1
  309. )";
  310. CompileSuccessfully(GenerateCode(body).c_str());
  311. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  312. EXPECT_THAT(getDiagnosticString(),
  313. HasSubstr("Expected arithmetic operands to be of Result Type: "
  314. "FMul operand index 3"));
  315. }
  316. TEST_F(ValidateArithmetics, FMulWrongVectorOperand1) {
  317. const std::string body = R"(
  318. %val = OpFMul %f64vec3 %f32vec3_123 %f64vec3_012
  319. )";
  320. CompileSuccessfully(GenerateCode(body).c_str());
  321. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  322. EXPECT_THAT(getDiagnosticString(),
  323. HasSubstr("Expected arithmetic operands to be of Result Type: "
  324. "FMul operand index 2"));
  325. }
  326. TEST_F(ValidateArithmetics, FMulWrongVectorOperand2) {
  327. const std::string body = R"(
  328. %val = OpFMul %f32vec3 %f32vec3_123 %f64vec3_012
  329. )";
  330. CompileSuccessfully(GenerateCode(body).c_str());
  331. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  332. EXPECT_THAT(getDiagnosticString(),
  333. HasSubstr("Expected arithmetic operands to be of Result Type: "
  334. "FMul operand index 3"));
  335. }
  336. TEST_F(ValidateArithmetics, IMulFloatTypeId) {
  337. const std::string body = R"(
  338. %val = OpIMul %f32 %u32_0 %s32_1
  339. )";
  340. CompileSuccessfully(GenerateCode(body).c_str());
  341. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  342. EXPECT_THAT(
  343. getDiagnosticString(),
  344. HasSubstr("Expected int scalar or vector type as Result Type: IMul"));
  345. }
  346. TEST_F(ValidateArithmetics, IMulFloatOperand1) {
  347. const std::string body = R"(
  348. %val = OpIMul %u32 %f32_0 %s32_1
  349. )";
  350. CompileSuccessfully(GenerateCode(body).c_str());
  351. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  352. EXPECT_THAT(getDiagnosticString(),
  353. HasSubstr("Expected int scalar or vector type as operand: "
  354. "IMul operand index 2"));
  355. }
  356. TEST_F(ValidateArithmetics, IMulFloatOperand2) {
  357. const std::string body = R"(
  358. %val = OpIMul %u32 %s32_0 %f32_1
  359. )";
  360. CompileSuccessfully(GenerateCode(body).c_str());
  361. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  362. EXPECT_THAT(getDiagnosticString(),
  363. HasSubstr("Expected int scalar or vector type as operand: "
  364. "IMul operand index 3"));
  365. }
  366. TEST_F(ValidateArithmetics, IMulWrongBitWidthOperand1) {
  367. const std::string body = R"(
  368. %val = OpIMul %u64 %u32_0 %s64_1
  369. )";
  370. CompileSuccessfully(GenerateCode(body).c_str());
  371. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  372. EXPECT_THAT(
  373. getDiagnosticString(),
  374. HasSubstr("Expected arithmetic operands to have the same bit width "
  375. "as Result Type: IMul operand index 2"));
  376. }
  377. TEST_F(ValidateArithmetics, IMulWrongBitWidthOperand2) {
  378. const std::string body = R"(
  379. %val = OpIMul %u32 %u32_0 %s64_1
  380. )";
  381. CompileSuccessfully(GenerateCode(body).c_str());
  382. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  383. EXPECT_THAT(
  384. getDiagnosticString(),
  385. HasSubstr("Expected arithmetic operands to have the same bit width "
  386. "as Result Type: IMul operand index 3"));
  387. }
  388. TEST_F(ValidateArithmetics, IMulWrongBitWidthVector) {
  389. const std::string body = R"(
  390. %val = OpIMul %u64vec3 %u32vec3_012 %u32vec3_123
  391. )";
  392. CompileSuccessfully(GenerateCode(body).c_str());
  393. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  394. EXPECT_THAT(
  395. getDiagnosticString(),
  396. HasSubstr("Expected arithmetic operands to have the same bit width "
  397. "as Result Type: IMul operand index 2"));
  398. }
  399. TEST_F(ValidateArithmetics, IMulVectorScalarOperand1) {
  400. const std::string body = R"(
  401. %val = OpIMul %u32vec2 %u32_0 %u32vec2_01
  402. )";
  403. CompileSuccessfully(GenerateCode(body).c_str());
  404. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  405. EXPECT_THAT(
  406. getDiagnosticString(),
  407. HasSubstr("Expected arithmetic operands to have the same dimension "
  408. "as Result Type: IMul operand index 2"));
  409. }
  410. TEST_F(ValidateArithmetics, IMulVectorScalarOperand2) {
  411. const std::string body = R"(
  412. %val = OpIMul %u32vec2 %u32vec2_01 %u32_0
  413. )";
  414. CompileSuccessfully(GenerateCode(body).c_str());
  415. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  416. EXPECT_THAT(
  417. getDiagnosticString(),
  418. HasSubstr("Expected arithmetic operands to have the same dimension "
  419. "as Result Type: IMul operand index 3"));
  420. }
  421. TEST_F(ValidateArithmetics, IMulScalarVectorOperand1) {
  422. const std::string body = R"(
  423. %val = OpIMul %s32 %u32vec2_01 %u32_0
  424. )";
  425. CompileSuccessfully(GenerateCode(body).c_str());
  426. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  427. EXPECT_THAT(
  428. getDiagnosticString(),
  429. HasSubstr("Expected arithmetic operands to have the same dimension "
  430. "as Result Type: IMul operand index 2"));
  431. }
  432. TEST_F(ValidateArithmetics, IMulScalarVectorOperand2) {
  433. const std::string body = R"(
  434. %val = OpIMul %u32 %u32_0 %s32vec2_01
  435. )";
  436. CompileSuccessfully(GenerateCode(body).c_str());
  437. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  438. EXPECT_THAT(
  439. getDiagnosticString(),
  440. HasSubstr("Expected arithmetic operands to have the same dimension "
  441. "as Result Type: IMul operand index 3"));
  442. }
  443. TEST_F(ValidateArithmetics, SNegateFloat) {
  444. const std::string body = R"(
  445. %val = OpSNegate %s32 %f32_1
  446. )";
  447. CompileSuccessfully(GenerateCode(body).c_str());
  448. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  449. EXPECT_THAT(getDiagnosticString(),
  450. HasSubstr("Expected int scalar or vector type as operand: "
  451. "SNegate operand index 2"));
  452. }
  453. TEST_F(ValidateArithmetics, UDivFloatType) {
  454. const std::string body = R"(
  455. %val = OpUDiv %f32 %u32_2 %u32_1
  456. )";
  457. CompileSuccessfully(GenerateCode(body).c_str());
  458. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  459. EXPECT_THAT(
  460. getDiagnosticString(),
  461. HasSubstr(
  462. "Expected unsigned int scalar or vector type as Result Type: UDiv"));
  463. }
  464. TEST_F(ValidateArithmetics, UDivSignedIntType) {
  465. const std::string body = R"(
  466. %val = OpUDiv %s32 %u32_2 %u32_1
  467. )";
  468. CompileSuccessfully(GenerateCode(body).c_str());
  469. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  470. EXPECT_THAT(
  471. getDiagnosticString(),
  472. HasSubstr(
  473. "Expected unsigned int scalar or vector type as Result Type: UDiv"));
  474. }
  475. TEST_F(ValidateArithmetics, UDivWrongOperand1) {
  476. const std::string body = R"(
  477. %val = OpUDiv %u64 %f64_2 %u64_1
  478. )";
  479. CompileSuccessfully(GenerateCode(body).c_str());
  480. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  481. EXPECT_THAT(getDiagnosticString(),
  482. HasSubstr("Expected arithmetic operands to be of Result Type: "
  483. "UDiv operand index 2"));
  484. }
  485. TEST_F(ValidateArithmetics, UDivWrongOperand2) {
  486. const std::string body = R"(
  487. %val = OpUDiv %u64 %u64_2 %u32_1
  488. )";
  489. CompileSuccessfully(GenerateCode(body).c_str());
  490. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  491. EXPECT_THAT(getDiagnosticString(),
  492. HasSubstr("Expected arithmetic operands to be of Result Type: "
  493. "UDiv operand index 3"));
  494. }
  495. TEST_F(ValidateArithmetics, DotSuccess) {
  496. const std::string body = R"(
  497. %val = OpDot %f32 %f32vec2_01 %f32vec2_12
  498. )";
  499. CompileSuccessfully(GenerateCode(body).c_str());
  500. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  501. }
  502. TEST_F(ValidateArithmetics, DotWrongTypeId) {
  503. const std::string body = R"(
  504. %val = OpDot %u32 %u32vec2_01 %u32vec2_12
  505. )";
  506. CompileSuccessfully(GenerateCode(body).c_str());
  507. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  508. EXPECT_THAT(getDiagnosticString(),
  509. HasSubstr("Expected float scalar type as Result Type: Dot"));
  510. }
  511. TEST_F(ValidateArithmetics, DotNotVectorTypeOperand1) {
  512. const std::string body = R"(
  513. %val = OpDot %f32 %f32 %f32vec2_12
  514. )";
  515. CompileSuccessfully(GenerateCode(body).c_str());
  516. ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
  517. EXPECT_THAT(getDiagnosticString(),
  518. HasSubstr("Operand '6[%float]' cannot be a "
  519. "type"));
  520. }
  521. TEST_F(ValidateArithmetics, DotNotVectorTypeOperand2) {
  522. const std::string body = R"(
  523. %val = OpDot %f32 %f32vec3_012 %f32_1
  524. )";
  525. CompileSuccessfully(GenerateCode(body).c_str());
  526. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  527. EXPECT_THAT(
  528. getDiagnosticString(),
  529. HasSubstr("Expected float vector as operand: Dot operand index 3"));
  530. }
  531. TEST_F(ValidateArithmetics, DotWrongComponentOperand1) {
  532. const std::string body = R"(
  533. %val = OpDot %f64 %f32vec2_01 %f64vec2_12
  534. )";
  535. CompileSuccessfully(GenerateCode(body).c_str());
  536. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  537. EXPECT_THAT(getDiagnosticString(),
  538. HasSubstr("Expected component type to be equal to Result Type: "
  539. "Dot operand index 2"));
  540. }
  541. TEST_F(ValidateArithmetics, DotWrongComponentOperand2) {
  542. const std::string body = R"(
  543. %val = OpDot %f32 %f32vec2_01 %f64vec2_12
  544. )";
  545. CompileSuccessfully(GenerateCode(body).c_str());
  546. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  547. EXPECT_THAT(getDiagnosticString(),
  548. HasSubstr("Expected component type to be equal to Result Type: "
  549. "Dot operand index 3"));
  550. }
  551. TEST_F(ValidateArithmetics, DotDifferentVectorSize) {
  552. const std::string body = R"(
  553. %val = OpDot %f32 %f32vec2_01 %f32vec3_123
  554. )";
  555. CompileSuccessfully(GenerateCode(body).c_str());
  556. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  557. EXPECT_THAT(
  558. getDiagnosticString(),
  559. HasSubstr(
  560. "Expected operands to have the same number of components: Dot"));
  561. }
  562. TEST_F(ValidateArithmetics, VectorTimesScalarSuccess) {
  563. const std::string body = R"(
  564. %val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f32_2
  565. )";
  566. CompileSuccessfully(GenerateCode(body).c_str());
  567. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  568. }
  569. TEST_F(ValidateArithmetics, VectorTimesScalarWrongTypeId) {
  570. const std::string body = R"(
  571. %val = OpVectorTimesScalar %u32vec2 %f32vec2_01 %f32_2
  572. )";
  573. CompileSuccessfully(GenerateCode(body).c_str());
  574. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  575. EXPECT_THAT(getDiagnosticString(),
  576. HasSubstr("Expected float vector type as Result Type: "
  577. "VectorTimesScalar"));
  578. }
  579. TEST_F(ValidateArithmetics, VectorTimesScalarWrongVector) {
  580. const std::string body = R"(
  581. %val = OpVectorTimesScalar %f32vec2 %f32vec3_012 %f32_2
  582. )";
  583. CompileSuccessfully(GenerateCode(body).c_str());
  584. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  585. EXPECT_THAT(
  586. getDiagnosticString(),
  587. HasSubstr("Expected vector operand type to be equal to Result Type: "
  588. "VectorTimesScalar"));
  589. }
  590. TEST_F(ValidateArithmetics, VectorTimesScalarWrongScalar) {
  591. const std::string body = R"(
  592. %val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f64_2
  593. )";
  594. CompileSuccessfully(GenerateCode(body).c_str());
  595. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  596. EXPECT_THAT(
  597. getDiagnosticString(),
  598. HasSubstr("Expected scalar operand type to be equal to the component "
  599. "type of the vector operand: VectorTimesScalar"));
  600. }
  601. TEST_F(ValidateArithmetics, MatrixTimesScalarSuccess) {
  602. const std::string body = R"(
  603. %val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f32_2
  604. )";
  605. CompileSuccessfully(GenerateCode(body).c_str());
  606. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  607. }
  608. TEST_F(ValidateArithmetics, MatrixTimesScalarWrongTypeId) {
  609. const std::string body = R"(
  610. %val = OpMatrixTimesScalar %f32vec2 %f32mat22_1212 %f32_2
  611. )";
  612. CompileSuccessfully(GenerateCode(body).c_str());
  613. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  614. EXPECT_THAT(getDiagnosticString(),
  615. HasSubstr("Expected float matrix type as Result Type: "
  616. "MatrixTimesScalar"));
  617. }
  618. TEST_F(ValidateArithmetics, MatrixTimesScalarWrongMatrix) {
  619. const std::string body = R"(
  620. %val = OpMatrixTimesScalar %f32mat22 %f32vec2_01 %f32_2
  621. )";
  622. CompileSuccessfully(GenerateCode(body).c_str());
  623. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  624. EXPECT_THAT(
  625. getDiagnosticString(),
  626. HasSubstr("Expected matrix operand type to be equal to Result Type: "
  627. "MatrixTimesScalar"));
  628. }
  629. TEST_F(ValidateArithmetics, MatrixTimesScalarWrongScalar) {
  630. const std::string body = R"(
  631. %val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f64_2
  632. )";
  633. CompileSuccessfully(GenerateCode(body).c_str());
  634. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  635. EXPECT_THAT(
  636. getDiagnosticString(),
  637. HasSubstr("Expected scalar operand type to be equal to the component "
  638. "type of the matrix operand: MatrixTimesScalar"));
  639. }
  640. TEST_F(ValidateArithmetics, VectorTimesMatrix2x22Success) {
  641. const std::string body = R"(
  642. %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat22_1212
  643. )";
  644. CompileSuccessfully(GenerateCode(body).c_str());
  645. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  646. }
  647. TEST_F(ValidateArithmetics, VectorTimesMatrix3x32Success) {
  648. const std::string body = R"(
  649. %val = OpVectorTimesMatrix %f32vec2 %f32vec3_123 %f32mat32_123123
  650. )";
  651. CompileSuccessfully(GenerateCode(body).c_str());
  652. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  653. }
  654. TEST_F(ValidateArithmetics, VectorTimesMatrixWrongTypeId) {
  655. const std::string body = R"(
  656. %val = OpVectorTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
  657. )";
  658. CompileSuccessfully(GenerateCode(body).c_str());
  659. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  660. EXPECT_THAT(getDiagnosticString(),
  661. HasSubstr("Expected float vector type as Result Type: "
  662. "VectorTimesMatrix"));
  663. }
  664. TEST_F(ValidateArithmetics, VectorTimesMatrixNotFloatVector) {
  665. const std::string body = R"(
  666. %val = OpVectorTimesMatrix %f32vec2 %u32vec2_12 %f32mat22_1212
  667. )";
  668. CompileSuccessfully(GenerateCode(body).c_str());
  669. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  670. EXPECT_THAT(getDiagnosticString(),
  671. HasSubstr("Expected float vector type as left operand: "
  672. "VectorTimesMatrix"));
  673. }
  674. TEST_F(ValidateArithmetics, VectorTimesMatrixWrongVectorComponent) {
  675. const std::string body = R"(
  676. %val = OpVectorTimesMatrix %f32vec2 %f64vec2_12 %f32mat22_1212
  677. )";
  678. CompileSuccessfully(GenerateCode(body).c_str());
  679. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  680. EXPECT_THAT(
  681. getDiagnosticString(),
  682. HasSubstr(
  683. "Expected component types of Result Type and vector to be equal: "
  684. "VectorTimesMatrix"));
  685. }
  686. TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrix) {
  687. const std::string body = R"(
  688. %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32vec2_12
  689. )";
  690. CompileSuccessfully(GenerateCode(body).c_str());
  691. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  692. EXPECT_THAT(getDiagnosticString(),
  693. HasSubstr("Expected float matrix type as right operand: "
  694. "VectorTimesMatrix"));
  695. }
  696. TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrixComponent) {
  697. const std::string body = R"(
  698. %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f64mat22_1212
  699. )";
  700. CompileSuccessfully(GenerateCode(body).c_str());
  701. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  702. EXPECT_THAT(
  703. getDiagnosticString(),
  704. HasSubstr(
  705. "Expected component types of Result Type and matrix to be equal: "
  706. "VectorTimesMatrix"));
  707. }
  708. TEST_F(ValidateArithmetics, VectorTimesMatrix2eq2x23Fail) {
  709. const std::string body = R"(
  710. %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat23_121212
  711. )";
  712. CompileSuccessfully(GenerateCode(body).c_str());
  713. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  714. EXPECT_THAT(
  715. getDiagnosticString(),
  716. HasSubstr(
  717. "Expected number of columns of the matrix to be equal to Result Type "
  718. "vector size: VectorTimesMatrix"));
  719. }
  720. TEST_F(ValidateArithmetics, VectorTimesMatrix2x32Fail) {
  721. const std::string body = R"(
  722. %val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat32_123123
  723. )";
  724. CompileSuccessfully(GenerateCode(body).c_str());
  725. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  726. EXPECT_THAT(
  727. getDiagnosticString(),
  728. HasSubstr(
  729. "Expected number of rows of the matrix to be equal to the vector "
  730. "operand size: VectorTimesMatrix"));
  731. }
  732. TEST_F(ValidateArithmetics, MatrixTimesVector22x2Success) {
  733. const std::string body = R"(
  734. %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec2_12
  735. )";
  736. CompileSuccessfully(GenerateCode(body).c_str());
  737. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  738. }
  739. TEST_F(ValidateArithmetics, MatrixTimesVector23x3Success) {
  740. const std::string body = R"(
  741. %val = OpMatrixTimesVector %f32vec2 %f32mat23_121212 %f32vec3_123
  742. )";
  743. CompileSuccessfully(GenerateCode(body).c_str());
  744. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  745. }
  746. TEST_F(ValidateArithmetics, MatrixTimesVectorWrongTypeId) {
  747. const std::string body = R"(
  748. %val = OpMatrixTimesVector %f32mat22 %f32mat22_1212 %f32vec2_12
  749. )";
  750. CompileSuccessfully(GenerateCode(body).c_str());
  751. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  752. EXPECT_THAT(getDiagnosticString(),
  753. HasSubstr("Expected float vector type as Result Type: "
  754. "MatrixTimesVector"));
  755. }
  756. TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrix) {
  757. const std::string body = R"(
  758. %val = OpMatrixTimesVector %f32vec3 %f32vec3_123 %f32vec3_123
  759. )";
  760. CompileSuccessfully(GenerateCode(body).c_str());
  761. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  762. EXPECT_THAT(getDiagnosticString(),
  763. HasSubstr("Expected float matrix type as left operand: "
  764. "MatrixTimesVector"));
  765. }
  766. TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrixCol) {
  767. const std::string body = R"(
  768. %val = OpMatrixTimesVector %f32vec3 %f32mat23_121212 %f32vec3_123
  769. )";
  770. CompileSuccessfully(GenerateCode(body).c_str());
  771. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  772. EXPECT_THAT(
  773. getDiagnosticString(),
  774. HasSubstr(
  775. "Expected column type of the matrix to be equal to Result Type: "
  776. "MatrixTimesVector"));
  777. }
  778. TEST_F(ValidateArithmetics, MatrixTimesVectorWrongVector) {
  779. const std::string body = R"(
  780. %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %u32vec2_12
  781. )";
  782. CompileSuccessfully(GenerateCode(body).c_str());
  783. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  784. EXPECT_THAT(getDiagnosticString(),
  785. HasSubstr("Expected float vector type as right operand: "
  786. "MatrixTimesVector"));
  787. }
  788. TEST_F(ValidateArithmetics, MatrixTimesVectorDifferentComponents) {
  789. const std::string body = R"(
  790. %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f64vec2_12
  791. )";
  792. CompileSuccessfully(GenerateCode(body).c_str());
  793. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  794. EXPECT_THAT(getDiagnosticString(),
  795. HasSubstr("Expected component types of the operands to be equal: "
  796. "MatrixTimesVector"));
  797. }
  798. TEST_F(ValidateArithmetics, MatrixTimesVector22x3Fail) {
  799. const std::string body = R"(
  800. %val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec3_123
  801. )";
  802. CompileSuccessfully(GenerateCode(body).c_str());
  803. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  804. EXPECT_THAT(
  805. getDiagnosticString(),
  806. HasSubstr(
  807. "Expected number of columns of the matrix to be equal to the vector "
  808. "size: MatrixTimesVector"));
  809. }
  810. TEST_F(ValidateArithmetics, MatrixTimesMatrix22x22Success) {
  811. const std::string body = R"(
  812. %val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32mat22_1212
  813. )";
  814. CompileSuccessfully(GenerateCode(body).c_str());
  815. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  816. }
  817. TEST_F(ValidateArithmetics, MatrixTimesMatrix23x32Success) {
  818. const std::string body = R"(
  819. %val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat32_123123
  820. )";
  821. CompileSuccessfully(GenerateCode(body).c_str());
  822. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  823. }
  824. TEST_F(ValidateArithmetics, MatrixTimesMatrix33x33Success) {
  825. const std::string body = R"(
  826. %val = OpMatrixTimesMatrix %f32mat33 %f32mat33_123123123 %f32mat33_123123123
  827. )";
  828. CompileSuccessfully(GenerateCode(body).c_str());
  829. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  830. }
  831. TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongTypeId) {
  832. const std::string body = R"(
  833. %val = OpMatrixTimesMatrix %f32vec2 %f32mat22_1212 %f32mat22_1212
  834. )";
  835. CompileSuccessfully(GenerateCode(body).c_str());
  836. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  837. EXPECT_THAT(
  838. getDiagnosticString(),
  839. HasSubstr(
  840. "Expected float matrix type as Result Type: MatrixTimesMatrix"));
  841. }
  842. TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongLeftOperand) {
  843. const std::string body = R"(
  844. %val = OpMatrixTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
  845. )";
  846. CompileSuccessfully(GenerateCode(body).c_str());
  847. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  848. EXPECT_THAT(
  849. getDiagnosticString(),
  850. HasSubstr(
  851. "Expected float matrix type as left operand: MatrixTimesMatrix"));
  852. }
  853. TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongRightOperand) {
  854. const std::string body = R"(
  855. %val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32vec2_12
  856. )";
  857. CompileSuccessfully(GenerateCode(body).c_str());
  858. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  859. EXPECT_THAT(
  860. getDiagnosticString(),
  861. HasSubstr(
  862. "Expected float matrix type as right operand: MatrixTimesMatrix"));
  863. }
  864. TEST_F(ValidateArithmetics, MatrixTimesMatrix32x23Fail) {
  865. const std::string body = R"(
  866. %val = OpMatrixTimesMatrix %f32mat22 %f32mat32_123123 %f32mat23_121212
  867. )";
  868. CompileSuccessfully(GenerateCode(body).c_str());
  869. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  870. EXPECT_THAT(
  871. getDiagnosticString(),
  872. HasSubstr(
  873. "Expected column types of Result Type and left matrix to be equal: "
  874. "MatrixTimesMatrix"));
  875. }
  876. TEST_F(ValidateArithmetics, MatrixTimesMatrixDifferentComponents) {
  877. const std::string body = R"(
  878. %val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f64mat22_1212
  879. )";
  880. CompileSuccessfully(GenerateCode(body).c_str());
  881. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  882. EXPECT_THAT(getDiagnosticString(),
  883. HasSubstr("Expected component types of Result Type and right "
  884. "matrix to be equal: "
  885. "MatrixTimesMatrix"));
  886. }
  887. TEST_F(ValidateArithmetics, MatrixTimesMatrix23x23Fail) {
  888. const std::string body = R"(
  889. %val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat23_121212
  890. )";
  891. CompileSuccessfully(GenerateCode(body).c_str());
  892. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  893. EXPECT_THAT(getDiagnosticString(),
  894. HasSubstr("Expected number of columns of Result Type and right "
  895. "matrix to be equal: "
  896. "MatrixTimesMatrix"));
  897. }
  898. TEST_F(ValidateArithmetics, MatrixTimesMatrix23x22Fail) {
  899. const std::string body = R"(
  900. %val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat22_1212
  901. )";
  902. CompileSuccessfully(GenerateCode(body).c_str());
  903. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  904. EXPECT_THAT(getDiagnosticString(),
  905. HasSubstr("Expected number of columns of left matrix and number "
  906. "of rows of right "
  907. "matrix to be equal: MatrixTimesMatrix"));
  908. }
  909. TEST_F(ValidateArithmetics, OuterProduct2x2Success) {
  910. const std::string body = R"(
  911. %val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec2_01
  912. )";
  913. CompileSuccessfully(GenerateCode(body).c_str());
  914. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  915. }
  916. TEST_F(ValidateArithmetics, OuterProduct3x2Success) {
  917. const std::string body = R"(
  918. %val = OpOuterProduct %f32mat32 %f32vec3_123 %f32vec2_01
  919. )";
  920. CompileSuccessfully(GenerateCode(body).c_str());
  921. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  922. }
  923. TEST_F(ValidateArithmetics, OuterProduct2x3Success) {
  924. const std::string body = R"(
  925. %val = OpOuterProduct %f32mat23 %f32vec2_01 %f32vec3_123
  926. )";
  927. CompileSuccessfully(GenerateCode(body).c_str());
  928. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  929. }
  930. TEST_F(ValidateArithmetics, OuterProductWrongTypeId) {
  931. const std::string body = R"(
  932. %val = OpOuterProduct %f32vec2 %f32vec2_01 %f32vec3_123
  933. )";
  934. CompileSuccessfully(GenerateCode(body).c_str());
  935. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  936. EXPECT_THAT(getDiagnosticString(),
  937. HasSubstr("Expected float matrix type as Result Type: "
  938. "OuterProduct"));
  939. }
  940. TEST_F(ValidateArithmetics, OuterProductWrongLeftOperand) {
  941. const std::string body = R"(
  942. %val = OpOuterProduct %f32mat22 %f32vec3_123 %f32vec2_01
  943. )";
  944. CompileSuccessfully(GenerateCode(body).c_str());
  945. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  946. EXPECT_THAT(
  947. getDiagnosticString(),
  948. HasSubstr("Expected column type of Result Type to be equal to the type "
  949. "of the left operand: OuterProduct"));
  950. }
  951. TEST_F(ValidateArithmetics, OuterProductRightOperandNotFloatVector) {
  952. const std::string body = R"(
  953. %val = OpOuterProduct %f32mat22 %f32vec2_12 %u32vec2_01
  954. )";
  955. CompileSuccessfully(GenerateCode(body).c_str());
  956. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  957. EXPECT_THAT(
  958. getDiagnosticString(),
  959. HasSubstr("Expected float vector type as right operand: OuterProduct"));
  960. }
  961. TEST_F(ValidateArithmetics, OuterProductRightOperandWrongComponent) {
  962. const std::string body = R"(
  963. %val = OpOuterProduct %f32mat22 %f32vec2_12 %f64vec2_01
  964. )";
  965. CompileSuccessfully(GenerateCode(body).c_str());
  966. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  967. EXPECT_THAT(getDiagnosticString(),
  968. HasSubstr("Expected component types of the operands to be equal: "
  969. "OuterProduct"));
  970. }
  971. TEST_F(ValidateArithmetics, OuterProductRightOperandWrongDimension) {
  972. const std::string body = R"(
  973. %val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec3_123
  974. )";
  975. CompileSuccessfully(GenerateCode(body).c_str());
  976. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  977. EXPECT_THAT(
  978. getDiagnosticString(),
  979. HasSubstr("Expected number of columns of the matrix to be equal to the "
  980. "vector size of the right operand: OuterProduct"));
  981. }
  982. std::string GenerateCoopMatCode(const std::string& extra_types,
  983. const std::string& main_body) {
  984. const std::string prefix =
  985. R"(
  986. OpCapability Shader
  987. OpCapability Float16
  988. OpCapability CooperativeMatrixNV
  989. OpExtension "SPV_NV_cooperative_matrix"
  990. OpMemoryModel Logical GLSL450
  991. OpEntryPoint GLCompute %main "main"
  992. %void = OpTypeVoid
  993. %func = OpTypeFunction %void
  994. %bool = OpTypeBool
  995. %f16 = OpTypeFloat 16
  996. %f32 = OpTypeFloat 32
  997. %u32 = OpTypeInt 32 0
  998. %s32 = OpTypeInt 32 1
  999. %u32_8 = OpConstant %u32 8
  1000. %u32_16 = OpConstant %u32 16
  1001. %u32_4 = OpConstant %u32 4
  1002. %subgroup = OpConstant %u32 3
  1003. %f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
  1004. %u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8
  1005. %s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8
  1006. %f16_1 = OpConstant %f16 1
  1007. %f32_1 = OpConstant %f32 1
  1008. %u32_1 = OpConstant %u32 1
  1009. %s32_1 = OpConstant %s32 1
  1010. %f16mat_1 = OpConstantComposite %f16mat %f16_1
  1011. %u32mat_1 = OpConstantComposite %u32mat %u32_1
  1012. %s32mat_1 = OpConstantComposite %s32mat %s32_1
  1013. %u32_c1 = OpSpecConstant %u32 1
  1014. %u32_c2 = OpSpecConstant %u32 2
  1015. %f16matc = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_c1 %u32_c2
  1016. %f16matc_1 = OpConstantComposite %f16matc %f16_1
  1017. %mat16x4 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_4
  1018. %mat4x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_4 %u32_16
  1019. %mat16x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_16
  1020. %f16mat_16x4_1 = OpConstantComposite %mat16x4 %f16_1
  1021. %f16mat_4x16_1 = OpConstantComposite %mat4x16 %f16_1
  1022. %f16mat_16x16_1 = OpConstantComposite %mat16x16 %f16_1)";
  1023. const std::string func_begin =
  1024. R"(
  1025. %main = OpFunction %void None %func
  1026. %main_entry = OpLabel)";
  1027. const std::string suffix =
  1028. R"(
  1029. OpReturn
  1030. OpFunctionEnd)";
  1031. return prefix + extra_types + func_begin + main_body + suffix;
  1032. }
  1033. TEST_F(ValidateArithmetics, CoopMatSuccess) {
  1034. const std::string body = R"(
  1035. %val1 = OpFAdd %f16mat %f16mat_1 %f16mat_1
  1036. %val2 = OpFSub %f16mat %f16mat_1 %f16mat_1
  1037. %val3 = OpFDiv %f16mat %f16mat_1 %f16mat_1
  1038. %val4 = OpFNegate %f16mat %f16mat_1
  1039. %val5 = OpIAdd %u32mat %u32mat_1 %u32mat_1
  1040. %val6 = OpISub %u32mat %u32mat_1 %u32mat_1
  1041. %val7 = OpUDiv %u32mat %u32mat_1 %u32mat_1
  1042. %val8 = OpIAdd %s32mat %s32mat_1 %s32mat_1
  1043. %val9 = OpISub %s32mat %s32mat_1 %s32mat_1
  1044. %val10 = OpSDiv %s32mat %s32mat_1 %s32mat_1
  1045. %val11 = OpSNegate %s32mat %s32mat_1
  1046. %val12 = OpMatrixTimesScalar %f16mat %f16mat_1 %f16_1
  1047. %val13 = OpMatrixTimesScalar %u32mat %u32mat_1 %u32_1
  1048. %val14 = OpMatrixTimesScalar %s32mat %s32mat_1 %s32_1
  1049. %val15 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16mat_16x16_1
  1050. %val16 = OpCooperativeMatrixMulAddNV %f16matc %f16matc_1 %f16matc_1 %f16matc_1
  1051. )";
  1052. CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
  1053. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  1054. }
  1055. TEST_F(ValidateArithmetics, CoopMatFMulFail) {
  1056. const std::string body = R"(
  1057. %val1 = OpFMul %f16mat %f16mat_1 %f16mat_1
  1058. )";
  1059. CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
  1060. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1061. EXPECT_THAT(
  1062. getDiagnosticString(),
  1063. HasSubstr(
  1064. "Expected floating scalar or vector type as Result Type: FMul"));
  1065. }
  1066. TEST_F(ValidateArithmetics, CoopMatMatrixTimesScalarMismatchFail) {
  1067. const std::string body = R"(
  1068. %val1 = OpMatrixTimesScalar %f16mat %f16mat_1 %f32_1
  1069. )";
  1070. CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
  1071. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1072. EXPECT_THAT(
  1073. getDiagnosticString(),
  1074. HasSubstr("Expected scalar operand type to be equal to the component "
  1075. "type of the matrix operand: MatrixTimesScalar"));
  1076. }
  1077. TEST_F(ValidateArithmetics, CoopMatScopeFail) {
  1078. const std::string types = R"(
  1079. %device = OpConstant %u32 1
  1080. %mat16x16_dv = OpTypeCooperativeMatrixNV %f16 %device %u32_16 %u32_16
  1081. %f16matdv_16x16_1 = OpConstantComposite %mat16x16_dv %f16_1
  1082. )";
  1083. const std::string body = R"(
  1084. %val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16matdv_16x16_1
  1085. )";
  1086. CompileSuccessfully(GenerateCoopMatCode(types, body).c_str());
  1087. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1088. EXPECT_THAT(
  1089. getDiagnosticString(),
  1090. HasSubstr(
  1091. "Cooperative matrix scopes must match: CooperativeMatrixMulAddNV"));
  1092. }
  1093. TEST_F(ValidateArithmetics, CoopMatDimFail) {
  1094. const std::string body = R"(
  1095. %val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_4x16_1 %f16mat_16x4_1 %f16mat_16x16_1
  1096. )";
  1097. CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
  1098. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1099. EXPECT_THAT(
  1100. getDiagnosticString(),
  1101. HasSubstr("Cooperative matrix 'M' mismatch: CooperativeMatrixMulAddNV"));
  1102. }
  1103. TEST_F(ValidateArithmetics, CoopMatComponentTypeNotScalarNumeric) {
  1104. const std::string types = R"(
  1105. %bad = OpTypeCooperativeMatrixNV %bool %subgroup %u32_8 %u32_8
  1106. )";
  1107. CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
  1108. EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
  1109. EXPECT_THAT(getDiagnosticString(),
  1110. HasSubstr("OpTypeCooperativeMatrix Component Type <id> "
  1111. "'4[%bool]' is not a scalar numerical type."));
  1112. }
  1113. TEST_F(ValidateArithmetics, CoopMatScopeNotConstantInt) {
  1114. const std::string types = R"(
  1115. %bad = OpTypeCooperativeMatrixNV %f16 %f32_1 %u32_8 %u32_8
  1116. )";
  1117. CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
  1118. EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
  1119. EXPECT_THAT(
  1120. getDiagnosticString(),
  1121. HasSubstr("OpTypeCooperativeMatrix Scope <id> '17[%float_1]' is not a "
  1122. "constant instruction with scalar integer type."));
  1123. }
  1124. TEST_F(ValidateArithmetics, CoopMatRowsNotConstantInt) {
  1125. const std::string types = R"(
  1126. %bad = OpTypeCooperativeMatrixNV %f16 %subgroup %f32_1 %u32_8
  1127. )";
  1128. CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
  1129. EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
  1130. EXPECT_THAT(
  1131. getDiagnosticString(),
  1132. HasSubstr("OpTypeCooperativeMatrix Rows <id> '17[%float_1]' is not a "
  1133. "constant instruction with scalar integer type."));
  1134. }
  1135. TEST_F(ValidateArithmetics, CoopMatColumnsNotConstantInt) {
  1136. const std::string types = R"(
  1137. %bad = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %f32_1
  1138. )";
  1139. CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
  1140. EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
  1141. EXPECT_THAT(
  1142. getDiagnosticString(),
  1143. HasSubstr("OpTypeCooperativeMatrix Cols <id> '17[%float_1]' is not a "
  1144. "constant instruction with scalar integer type."));
  1145. }
  1146. TEST_F(ValidateArithmetics, IAddCarrySuccess) {
  1147. const std::string body = R"(
  1148. %val1 = OpIAddCarry %struct_u32_u32 %u32_0 %u32_1
  1149. %val2 = OpIAddCarry %struct_u32vec2_u32vec2 %u32vec2_01 %u32vec2_12
  1150. )";
  1151. CompileSuccessfully(GenerateCode(body).c_str());
  1152. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  1153. }
  1154. TEST_F(ValidateArithmetics, IAddCarryResultTypeNotStruct) {
  1155. const std::string body = R"(
  1156. %val = OpIAddCarry %u32 %u32_0 %u32_1
  1157. )";
  1158. CompileSuccessfully(GenerateCode(body).c_str());
  1159. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1160. EXPECT_THAT(getDiagnosticString(),
  1161. HasSubstr("Expected a struct as Result Type: IAddCarry"));
  1162. }
  1163. TEST_F(ValidateArithmetics, IAddCarryResultTypeNotTwoMembers) {
  1164. const std::string body = R"(
  1165. %val = OpIAddCarry %struct_u32_u32_u32 %u32_0 %u32_1
  1166. )";
  1167. CompileSuccessfully(GenerateCode(body).c_str());
  1168. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1169. EXPECT_THAT(
  1170. getDiagnosticString(),
  1171. HasSubstr("Expected Result Type struct to have two members: IAddCarry"));
  1172. }
  1173. TEST_F(ValidateArithmetics, IAddCarryResultTypeMemberNotUnsignedInt) {
  1174. const std::string body = R"(
  1175. %val = OpIAddCarry %struct_s32_s32 %s32_0 %s32_1
  1176. )";
  1177. CompileSuccessfully(GenerateCode(body).c_str());
  1178. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1179. EXPECT_THAT(getDiagnosticString(),
  1180. HasSubstr("Expected Result Type struct member types to be "
  1181. "unsigned integer scalar "
  1182. "or vector: IAddCarry"));
  1183. }
  1184. TEST_F(ValidateArithmetics, IAddCarryWrongLeftOperand) {
  1185. const std::string body = R"(
  1186. %val = OpIAddCarry %struct_u32_u32 %s32_0 %u32_1
  1187. )";
  1188. CompileSuccessfully(GenerateCode(body).c_str());
  1189. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1190. EXPECT_THAT(getDiagnosticString(),
  1191. HasSubstr("Expected both operands to be of Result Type member "
  1192. "type: IAddCarry"));
  1193. }
  1194. TEST_F(ValidateArithmetics, IAddCarryWrongRightOperand) {
  1195. const std::string body = R"(
  1196. %val = OpIAddCarry %struct_u32_u32 %u32_0 %s32_1
  1197. )";
  1198. CompileSuccessfully(GenerateCode(body).c_str());
  1199. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1200. EXPECT_THAT(getDiagnosticString(),
  1201. HasSubstr("Expected both operands to be of Result Type member "
  1202. "type: IAddCarry"));
  1203. }
  1204. TEST_F(ValidateArithmetics, OpSMulExtendedSuccess) {
  1205. const std::string body = R"(
  1206. %val1 = OpSMulExtended %struct_u32_u32 %u32_0 %u32_1
  1207. %val2 = OpSMulExtended %struct_s32_s32 %s32_0 %s32_1
  1208. %val3 = OpSMulExtended %struct_u32vec2_u32vec2 %u32vec2_01 %u32vec2_12
  1209. %val4 = OpSMulExtended %struct_s32vec2_s32vec2 %s32vec2_01 %s32vec2_12
  1210. )";
  1211. CompileSuccessfully(GenerateCode(body).c_str());
  1212. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  1213. }
  1214. TEST_F(ValidateArithmetics, SMulExtendedResultTypeMemberNotInt) {
  1215. const std::string body = R"(
  1216. %val = OpSMulExtended %struct_f32_f32 %f32_0 %f32_1
  1217. )";
  1218. CompileSuccessfully(GenerateCode(body).c_str());
  1219. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1220. EXPECT_THAT(
  1221. getDiagnosticString(),
  1222. HasSubstr("Expected Result Type struct member types to be integer scalar "
  1223. "or vector: SMulExtended"));
  1224. }
  1225. TEST_F(ValidateArithmetics, SMulExtendedResultTypeMembersNotIdentical) {
  1226. const std::string body = R"(
  1227. %val = OpSMulExtended %struct_s32_u32 %s32_0 %s32_1
  1228. )";
  1229. CompileSuccessfully(GenerateCode(body).c_str());
  1230. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1231. EXPECT_THAT(
  1232. getDiagnosticString(),
  1233. HasSubstr("Expected Result Type struct member types to be identical: "
  1234. "SMulExtended"));
  1235. }
  1236. std::string GenerateCoopMatKHRCode(const std::string& extra_types,
  1237. const std::string& main_body) {
  1238. const std::string prefix = R"(
  1239. OpCapability Shader
  1240. OpCapability Float16
  1241. OpCapability CooperativeMatrixKHR
  1242. OpCapability CooperativeMatrixReductionsNV
  1243. OpCapability CooperativeMatrixPerElementOperationsNV
  1244. OpExtension "SPV_KHR_cooperative_matrix"
  1245. OpExtension "SPV_NV_cooperative_matrix2"
  1246. OpExtension "SPV_KHR_vulkan_memory_model"
  1247. OpMemoryModel Logical GLSL450
  1248. OpEntryPoint GLCompute %main "main"
  1249. %void = OpTypeVoid
  1250. %func = OpTypeFunction %void
  1251. %bool = OpTypeBool
  1252. %f16 = OpTypeFloat 16
  1253. %f32 = OpTypeFloat 32
  1254. %u32 = OpTypeInt 32 0
  1255. %s32 = OpTypeInt 32 1
  1256. %u32_8 = OpConstant %u32 8
  1257. %u32_16 = OpConstant %u32 16
  1258. %u32_4 = OpConstant %u32 4
  1259. %subgroup = OpConstant %u32 3
  1260. %useA = OpConstant %u32 0
  1261. %useB = OpConstant %u32 1
  1262. %useC = OpConstant %u32 2
  1263. %f16matA = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useA
  1264. %u32matA = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useA
  1265. %s32matA = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useA
  1266. %f16matB = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useB
  1267. %u32matB = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useB
  1268. %s32matB = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useB
  1269. %f16matC = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useC
  1270. %f32matC = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_16 %u32_16 %useC
  1271. %u32matC = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useC
  1272. %s32matC = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useC
  1273. %f16_1 = OpConstant %f16 1
  1274. %f32_1 = OpConstant %f32 1
  1275. %u32_1 = OpConstant %u32 1
  1276. %s32_1 = OpConstant %s32 1
  1277. %f16mat_A_1 = OpConstantComposite %f16matA %f16_1
  1278. %u32mat_A_1 = OpConstantComposite %u32matA %u32_1
  1279. %s32mat_A_1 = OpConstantComposite %s32matA %s32_1
  1280. %f16mat_B_1 = OpConstantComposite %f16matB %f16_1
  1281. %u32mat_B_1 = OpConstantComposite %u32matB %u32_1
  1282. %s32mat_B_1 = OpConstantComposite %s32matB %s32_1
  1283. %f16mat_C_1 = OpConstantComposite %f16matC %f16_1
  1284. %u32mat_C_1 = OpConstantComposite %u32matC %u32_1
  1285. %s32mat_C_1 = OpConstantComposite %s32matC %s32_1
  1286. )";
  1287. const std::string func_begin = R"(
  1288. %main = OpFunction %void None %func
  1289. %main_entry = OpLabel)";
  1290. const std::string suffix = R"(
  1291. OpReturn
  1292. OpFunctionEnd)";
  1293. return prefix + extra_types + func_begin + main_body + suffix;
  1294. }
  1295. TEST_F(ValidateArithmetics, CoopMatKHRSuccess) {
  1296. const std::string body = R"(
  1297. %val1 = OpFAdd %f16matA %f16mat_A_1 %f16mat_A_1
  1298. %val2 = OpFSub %f16matA %f16mat_A_1 %f16mat_A_1
  1299. %val3 = OpFMul %f16matA %f16mat_A_1 %f16mat_A_1
  1300. %val4 = OpFDiv %f16matA %f16mat_A_1 %f16mat_A_1
  1301. %val5 = OpFNegate %f16matA %f16mat_A_1
  1302. %val6 = OpIAdd %u32matA %u32mat_A_1 %u32mat_A_1
  1303. %val7 = OpISub %u32matA %u32mat_A_1 %u32mat_A_1
  1304. %val8 = OpUDiv %u32matA %u32mat_A_1 %u32mat_A_1
  1305. %val9 = OpIAdd %s32matA %s32mat_A_1 %s32mat_A_1
  1306. %val10 = OpISub %s32matA %s32mat_A_1 %s32mat_A_1
  1307. %val11 = OpSDiv %s32matA %s32mat_A_1 %s32mat_A_1
  1308. %val12 = OpSNegate %s32matA %s32mat_A_1
  1309. %val13 = OpMatrixTimesScalar %f16matA %f16mat_A_1 %f16_1
  1310. %val14 = OpMatrixTimesScalar %u32matA %u32mat_A_1 %u32_1
  1311. %val15 = OpMatrixTimesScalar %s32matA %s32mat_A_1 %s32_1
  1312. %val16 = OpCooperativeMatrixMulAddKHR %f32matC %f16mat_A_1 %f16mat_B_1 %f16mat_C_1
  1313. %val17 = OpCooperativeMatrixMulAddKHR %s32matC %s32mat_A_1 %s32mat_B_1 %s32mat_C_1
  1314. MatrixASignedComponentsKHR|MatrixBSignedComponentsKHR|MatrixCSignedComponentsKHR|MatrixResultSignedComponentsKHR
  1315. %val18 = OpCooperativeMatrixMulAddKHR %u32matC %u32mat_A_1 %u32mat_B_1 %u32mat_C_1
  1316. )";
  1317. CompileSuccessfully(GenerateCoopMatKHRCode("", body).c_str());
  1318. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  1319. }
  1320. TEST_F(ValidateArithmetics, CoopMatMatrixKHRTimesScalarMismatchFail) {
  1321. const std::string body = R"(
  1322. %val1 = OpMatrixTimesScalar %f16matA %f16mat_A_1 %f32_1
  1323. )";
  1324. CompileSuccessfully(GenerateCoopMatKHRCode("", body).c_str());
  1325. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1326. EXPECT_THAT(
  1327. getDiagnosticString(),
  1328. HasSubstr("Expected scalar operand type to be equal to the component "
  1329. "type of the matrix operand: MatrixTimesScalar"));
  1330. }
  1331. TEST_F(ValidateArithmetics, CoopMatKHRScopeFail) {
  1332. const std::string types = R"(
  1333. %device = OpConstant %u32 1
  1334. %mat16x16_dv = OpTypeCooperativeMatrixKHR %f16 %device %u32_16 %u32_16 %useC
  1335. %f16matdv_16x16_1 = OpConstantComposite %mat16x16_dv %f16_1
  1336. )";
  1337. const std::string body = R"(
  1338. %val1 = OpFAdd %f16matA %f16matdv_16x16_1 %f16mat_A_1
  1339. )";
  1340. CompileSuccessfully(GenerateCoopMatKHRCode(types, body).c_str());
  1341. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1342. EXPECT_THAT(
  1343. getDiagnosticString(),
  1344. HasSubstr("Expected scopes of Matrix and Result Type to be identical"));
  1345. }
  1346. TEST_F(ValidateArithmetics, CoopMatKHRDimFail) {
  1347. const std::string types = R"(
  1348. %mat16x4 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_4 %useC
  1349. %mat16x4_C_1 = OpConstantComposite %mat16x4 %f16_1
  1350. )";
  1351. const std::string body = R"(
  1352. %val1 = OpCooperativeMatrixMulAddKHR %mat16x4 %f16mat_A_1 %f16mat_B_1 %mat16x4_C_1
  1353. )";
  1354. CompileSuccessfully(GenerateCoopMatKHRCode(types, body).c_str());
  1355. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1356. EXPECT_THAT(
  1357. getDiagnosticString(),
  1358. HasSubstr("Cooperative matrix 'N' mismatch: CooperativeMatrixMulAddKHR"));
  1359. }
  1360. TEST_F(ValidateArithmetics, CoopMat2ReduceSuccess) {
  1361. const std::string extra_types = R"(
  1362. %f16matC8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %useC
  1363. %f16matC16x8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_8 %useC
  1364. %f16matC8x16 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_16 %useC
  1365. %functy = OpTypeFunction %f16 %f16 %f16
  1366. %reducefunc = OpFunction %f16 None %functy
  1367. %x = OpFunctionParameter %f16
  1368. %y = OpFunctionParameter %f16
  1369. %entry2 = OpLabel
  1370. %sum = OpFAdd %f16 %x %y
  1371. OpReturnValue %sum
  1372. OpFunctionEnd
  1373. )";
  1374. const std::string body = R"(
  1375. %val1 = OpCooperativeMatrixReduceNV %f16matC8 %f16mat_C_1 2x2 %reducefunc
  1376. %val2 = OpCooperativeMatrixReduceNV %f16matC16x8 %f16mat_C_1 Row %reducefunc
  1377. %val3 = OpCooperativeMatrixReduceNV %f16matC8x16 %f16mat_C_1 Column %reducefunc
  1378. %val4 = OpCooperativeMatrixReduceNV %f16matC %f16mat_C_1 Row|Column %reducefunc
  1379. %val5 = OpCooperativeMatrixReduceNV %f16matC8 %f16mat_C_1 Row|Column %reducefunc
  1380. )";
  1381. CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
  1382. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  1383. }
  1384. TEST_F(ValidateArithmetics, CoopMat2Reduce2x2DimFail) {
  1385. const std::string extra_types = R"(
  1386. %functy = OpTypeFunction %f16 %f16 %f16
  1387. %reducefunc = OpFunction %f16 None %functy
  1388. %x = OpFunctionParameter %f16
  1389. %y = OpFunctionParameter %f16
  1390. %entry2 = OpLabel
  1391. %sum = OpFAdd %f16 %x %y
  1392. OpReturnValue %sum
  1393. OpFunctionEnd
  1394. )";
  1395. const std::string body = R"(
  1396. %val1 = OpCooperativeMatrixReduceNV %f16matC %f16mat_C_1 2x2 %reducefunc
  1397. )";
  1398. CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
  1399. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1400. EXPECT_THAT(getDiagnosticString(),
  1401. HasSubstr("For Reduce2x2, result rows/cols must be half of "
  1402. "matrix rows/cols: CooperativeMatrixReduceNV"));
  1403. }
  1404. TEST_F(ValidateArithmetics, CoopMat2ReduceRowDimFail) {
  1405. const std::string extra_types = R"(
  1406. %f16matC8x16 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_16 %useC
  1407. %functy = OpTypeFunction %f16 %f16 %f16
  1408. %reducefunc = OpFunction %f16 None %functy
  1409. %x = OpFunctionParameter %f16
  1410. %y = OpFunctionParameter %f16
  1411. %entry2 = OpLabel
  1412. %sum = OpFAdd %f16 %x %y
  1413. OpReturnValue %sum
  1414. OpFunctionEnd
  1415. )";
  1416. const std::string body = R"(
  1417. %val1 = OpCooperativeMatrixReduceNV %f16matC8x16 %f16mat_C_1 Row %reducefunc
  1418. )";
  1419. CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
  1420. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1421. EXPECT_THAT(getDiagnosticString(),
  1422. HasSubstr("For ReduceRow, result rows must match matrix rows: "
  1423. "CooperativeMatrixReduceNV"));
  1424. }
  1425. TEST_F(ValidateArithmetics, CoopMat2ReduceColDimFail) {
  1426. const std::string extra_types = R"(
  1427. %f16matC16x8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_8 %useC
  1428. %functy = OpTypeFunction %f16 %f16 %f16
  1429. %reducefunc = OpFunction %f16 None %functy
  1430. %x = OpFunctionParameter %f16
  1431. %y = OpFunctionParameter %f16
  1432. %entry2 = OpLabel
  1433. %sum = OpFAdd %f16 %x %y
  1434. OpReturnValue %sum
  1435. OpFunctionEnd
  1436. )";
  1437. const std::string body = R"(
  1438. %val1 = OpCooperativeMatrixReduceNV %f16matC16x8 %f16mat_C_1 Column %reducefunc
  1439. )";
  1440. CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
  1441. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1442. EXPECT_THAT(getDiagnosticString(),
  1443. HasSubstr("For ReduceColumn, result cols must match matrix cols: "
  1444. "CooperativeMatrixReduceNV"));
  1445. }
  1446. TEST_F(ValidateArithmetics, CoopMat2ReduceMaskFail) {
  1447. const std::string extra_types = R"(
  1448. %f16matC8 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %useC
  1449. %functy = OpTypeFunction %f16 %f16 %f16
  1450. %reducefunc = OpFunction %f16 None %functy
  1451. %x = OpFunctionParameter %f16
  1452. %y = OpFunctionParameter %f16
  1453. %entry2 = OpLabel
  1454. %sum = OpFAdd %f16 %x %y
  1455. OpReturnValue %sum
  1456. OpFunctionEnd
  1457. )";
  1458. const std::string body = R"(
  1459. %val1 = OpCooperativeMatrixReduceNV %f16matC8 %f16mat_C_1 Row|Column|2x2 %reducefunc
  1460. )";
  1461. CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
  1462. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1463. EXPECT_THAT(getDiagnosticString(),
  1464. HasSubstr("Reduce 2x2 must not be used with Row/Column: "
  1465. "CooperativeMatrixReduceNV"));
  1466. }
  1467. TEST_F(ValidateArithmetics, CoopMat2ReduceFuncTypeFail) {
  1468. const std::string extra_types = R"(
  1469. %functy = OpTypeFunction %f32 %f32 %f32
  1470. %reducefunc = OpFunction %f32 None %functy
  1471. %x = OpFunctionParameter %f32
  1472. %y = OpFunctionParameter %f32
  1473. %entry2 = OpLabel
  1474. %sum = OpFAdd %f32 %x %y
  1475. OpReturnValue %sum
  1476. OpFunctionEnd
  1477. )";
  1478. const std::string body = R"(
  1479. %val1 = OpCooperativeMatrixReduceNV %f16matC %f16mat_C_1 Row|Column %reducefunc
  1480. )";
  1481. CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
  1482. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1483. EXPECT_THAT(getDiagnosticString(),
  1484. HasSubstr("CombineFunc return type and parameters must match "
  1485. "matrix component type: CooperativeMatrixReduceNV"));
  1486. }
  1487. TEST_F(ValidateArithmetics, CoopMat2PerElementOpSuccess) {
  1488. const std::string extra_types = R"(
  1489. %functy = OpTypeFunction %f16 %u32 %u32 %f16
  1490. %functy2 = OpTypeFunction %f16 %u32 %u32 %f16 %u32
  1491. %elemfunc = OpFunction %f16 None %functy
  1492. %row = OpFunctionParameter %u32
  1493. %col = OpFunctionParameter %u32
  1494. %el = OpFunctionParameter %f16
  1495. %entry2 = OpLabel
  1496. OpReturnValue %el
  1497. OpFunctionEnd
  1498. %elemfunc2 = OpFunction %f16 None %functy2
  1499. %row2 = OpFunctionParameter %u32
  1500. %col2 = OpFunctionParameter %u32
  1501. %el2 = OpFunctionParameter %f16
  1502. %x = OpFunctionParameter %u32
  1503. %entry3 = OpLabel
  1504. OpReturnValue %el2
  1505. OpFunctionEnd
  1506. )";
  1507. const std::string body = R"(
  1508. %val1 = OpCooperativeMatrixPerElementOpNV %f16matC %f16mat_C_1 %elemfunc
  1509. %val2 = OpCooperativeMatrixPerElementOpNV %f16matC %f16mat_C_1 %elemfunc2 %f16_1
  1510. )";
  1511. CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
  1512. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  1513. }
  1514. TEST_F(ValidateArithmetics, CoopMat2PerElementOpElemTyFail) {
  1515. const std::string extra_types = R"(
  1516. %functy = OpTypeFunction %f32 %u32 %u32 %f32
  1517. %elemfunc = OpFunction %f32 None %functy
  1518. %row = OpFunctionParameter %u32
  1519. %col = OpFunctionParameter %u32
  1520. %el = OpFunctionParameter %f32
  1521. %entry2 = OpLabel
  1522. OpReturnValue %el
  1523. OpFunctionEnd
  1524. )";
  1525. const std::string body = R"(
  1526. %val1 = OpCooperativeMatrixPerElementOpNV %f16matC %f16mat_C_1 %elemfunc
  1527. )";
  1528. CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
  1529. ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
  1530. EXPECT_THAT(getDiagnosticString(),
  1531. HasSubstr("must match matrix component type"));
  1532. }
  1533. TEST_F(ValidateArithmetics, CoopMat2PerElementOpRowTyFail) {
  1534. const std::string extra_types = R"(
  1535. %functy = OpTypeFunction %f16 %f16 %u32 %f16
  1536. %elemfunc = OpFunction %f16 None %functy
  1537. %row = OpFunctionParameter %f16
  1538. %col = OpFunctionParameter %u32
  1539. %el = OpFunctionParameter %f16
  1540. %entry2 = OpLabel
  1541. OpReturnValue %el
  1542. OpFunctionEnd
  1543. )";
  1544. const std::string body = R"(
  1545. %val1 = OpCooperativeMatrixPerElementOpNV %f16matC %f16mat_C_1 %elemfunc
  1546. )";
  1547. CompileSuccessfully(GenerateCoopMatKHRCode(extra_types, body).c_str());
  1548. ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
  1549. EXPECT_THAT(getDiagnosticString(), HasSubstr("must be a 32-bit integer"));
  1550. }
  1551. std::string GenerateCoopVecCode(const std::string& extra_types,
  1552. const std::string& main_body) {
  1553. const std::string prefix =
  1554. R"(
  1555. OpCapability Shader
  1556. OpCapability Float16
  1557. OpCapability CooperativeVectorNV
  1558. OpCapability ReplicatedCompositesEXT
  1559. OpExtension "SPV_NV_cooperative_vector"
  1560. OpExtension "SPV_EXT_replicated_composites"
  1561. %ext_inst = OpExtInstImport "GLSL.std.450"
  1562. OpMemoryModel Logical GLSL450
  1563. OpEntryPoint GLCompute %main "main"
  1564. %void = OpTypeVoid
  1565. %func = OpTypeFunction %void
  1566. %bool = OpTypeBool
  1567. %f16 = OpTypeFloat 16
  1568. %f32 = OpTypeFloat 32
  1569. %u32 = OpTypeInt 32 0
  1570. %s32 = OpTypeInt 32 1
  1571. %u32_8 = OpConstant %u32 8
  1572. %u32_16 = OpConstant %u32 16
  1573. %u32_4 = OpConstant %u32 4
  1574. %subgroup = OpConstant %u32 3
  1575. %f16vec = OpTypeCooperativeVectorNV %f16 %u32_8
  1576. %f16vec4 = OpTypeCooperativeVectorNV %f16 %u32_4
  1577. %u32vec = OpTypeCooperativeVectorNV %u32 %u32_8
  1578. %s32vec = OpTypeCooperativeVectorNV %s32 %u32_8
  1579. %f16_1 = OpConstant %f16 1
  1580. %f32_1 = OpConstant %f32 1
  1581. %u32_1 = OpConstant %u32 1
  1582. %s32_1 = OpConstant %s32 1
  1583. %f16vec4_1 = OpConstantComposite %f16vec4 %f16_1 %f16_1 %f16_1 %f16_1
  1584. %f16vec_1 = OpConstantComposite %f16vec %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1 %f16_1
  1585. %u32vec_1 = OpConstantComposite %u32vec %u32_1 %u32_1 %u32_1 %u32_1 %u32_1 %u32_1 %u32_1 %u32_1
  1586. %s32vec_1 = OpConstantComposite %s32vec %s32_1 %s32_1 %s32_1 %s32_1 %s32_1 %s32_1 %s32_1 %s32_1
  1587. %u32_c1 = OpSpecConstant %u32 1
  1588. %u32_c2 = OpSpecConstant %u32 2
  1589. %f16vecc = OpTypeCooperativeVectorNV %f16 %u32_c1
  1590. %f16vecc_1 = OpConstantCompositeReplicateEXT %f16vecc %f16_1
  1591. )";
  1592. const std::string func_begin =
  1593. R"(
  1594. %main = OpFunction %void None %func
  1595. %main_entry = OpLabel)";
  1596. const std::string suffix =
  1597. R"(
  1598. OpReturn
  1599. OpFunctionEnd)";
  1600. return prefix + extra_types + func_begin + main_body + suffix;
  1601. }
  1602. TEST_F(ValidateArithmetics, CoopVecSuccess) {
  1603. const std::string body = R"(
  1604. %val1 = OpFAdd %f16vec %f16vec_1 %f16vec_1
  1605. %val2 = OpFSub %f16vec %f16vec_1 %f16vec_1
  1606. %val3 = OpFDiv %f16vec %f16vec_1 %f16vec_1
  1607. %val4 = OpFNegate %f16vec %f16vec_1
  1608. %val5 = OpIAdd %u32vec %u32vec_1 %u32vec_1
  1609. %val6 = OpISub %u32vec %u32vec_1 %u32vec_1
  1610. %val7 = OpUDiv %u32vec %u32vec_1 %u32vec_1
  1611. %val8 = OpIAdd %s32vec %s32vec_1 %s32vec_1
  1612. %val9 = OpISub %s32vec %s32vec_1 %s32vec_1
  1613. %val10 = OpSDiv %s32vec %s32vec_1 %s32vec_1
  1614. %val11 = OpSNegate %s32vec %s32vec_1
  1615. %val12 = OpVectorTimesScalar %f16vec %f16vec_1 %f16_1
  1616. %val13 = OpExtInst %f16vec %ext_inst FMin %f16vec_1 %f16vec_1
  1617. %val14 = OpExtInst %f16vec %ext_inst FMax %f16vec_1 %f16vec_1
  1618. %val15 = OpExtInst %f16vec %ext_inst FClamp %f16vec_1 %f16vec_1 %f16vec_1
  1619. %val16 = OpExtInst %f16vec %ext_inst NClamp %f16vec_1 %f16vec_1 %f16vec_1
  1620. %val17 = OpExtInst %f16vec %ext_inst Step %f16vec_1 %f16vec_1
  1621. %val18 = OpExtInst %f16vec %ext_inst Exp %f16vec_1
  1622. %val19 = OpExtInst %f16vec %ext_inst Log %f16vec_1
  1623. %val20 = OpExtInst %f16vec %ext_inst Tanh %f16vec_1
  1624. %val21 = OpExtInst %f16vec %ext_inst Atan %f16vec_1
  1625. %val22 = OpExtInst %f16vec %ext_inst Fma %f16vec_1 %f16vec_1 %f16vec_1
  1626. %val23 = OpExtInst %u32vec %ext_inst UMin %u32vec_1 %u32vec_1
  1627. %val24 = OpExtInst %u32vec %ext_inst UMax %u32vec_1 %u32vec_1
  1628. %val25 = OpExtInst %u32vec %ext_inst UClamp %u32vec_1 %u32vec_1 %u32vec_1
  1629. %val26 = OpExtInst %s32vec %ext_inst SMin %s32vec_1 %s32vec_1
  1630. %val27 = OpExtInst %s32vec %ext_inst SMax %s32vec_1 %s32vec_1
  1631. %val28 = OpExtInst %s32vec %ext_inst SClamp %s32vec_1 %s32vec_1 %s32vec_1
  1632. %val29 = OpShiftRightLogical %u32vec %u32vec_1 %u32vec_1
  1633. %val30 = OpShiftRightArithmetic %u32vec %u32vec_1 %u32vec_1
  1634. %val31 = OpShiftLeftLogical %u32vec %u32vec_1 %u32vec_1
  1635. %val32 = OpBitwiseOr %u32vec %u32vec_1 %u32vec_1
  1636. %val33 = OpBitwiseXor %u32vec %u32vec_1 %u32vec_1
  1637. %val34 = OpBitwiseAnd %u32vec %u32vec_1 %u32vec_1
  1638. %val35 = OpNot %u32vec %u32vec_1
  1639. )";
  1640. CompileSuccessfully(GenerateCoopVecCode("", body).c_str());
  1641. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  1642. }
  1643. TEST_F(ValidateArithmetics, CoopVecFMulPass) {
  1644. const std::string body = R"(
  1645. %val1 = OpFMul %f16vec %f16vec_1 %f16vec_1
  1646. )";
  1647. CompileSuccessfully(GenerateCoopVecCode("", body).c_str());
  1648. ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
  1649. }
  1650. TEST_F(ValidateArithmetics, CoopVecVectorTimesScalarMismatchFail) {
  1651. const std::string body = R"(
  1652. %val1 = OpVectorTimesScalar %f16vec %f16vec_1 %f32_1
  1653. )";
  1654. CompileSuccessfully(GenerateCoopVecCode("", body).c_str());
  1655. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1656. EXPECT_THAT(
  1657. getDiagnosticString(),
  1658. HasSubstr("Expected scalar operand type to be equal to the component "
  1659. "type of the vector operand: VectorTimesScalar"));
  1660. }
  1661. TEST_F(ValidateArithmetics, CoopVecDimFail) {
  1662. const std::string body = R"(
  1663. %val1 = OpFMul %f16vec %f16vec_1 %f16vec4_1
  1664. )";
  1665. CompileSuccessfully(GenerateCoopVecCode("", body).c_str());
  1666. ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
  1667. EXPECT_THAT(getDiagnosticString(),
  1668. HasSubstr("Expected number of components to be identical"));
  1669. }
  1670. TEST_F(ValidateArithmetics, CoopVecComponentTypeNotScalarNumeric) {
  1671. const std::string types = R"(
  1672. %bad = OpTypeCooperativeVectorNV %bool %u32_8
  1673. )";
  1674. CompileSuccessfully(GenerateCoopVecCode(types, "").c_str());
  1675. EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
  1676. EXPECT_THAT(getDiagnosticString(),
  1677. HasSubstr("OpTypeCooperativeVectorNV Component Type <id> "
  1678. "'5[%bool]' is not a scalar numerical type."));
  1679. }
  1680. TEST_F(ValidateArithmetics, CoopVecDimNotConstantInt) {
  1681. const std::string types = R"(
  1682. %bad = OpTypeCooperativeVectorNV %f16 %f32_1
  1683. )";
  1684. CompileSuccessfully(GenerateCoopVecCode(types, "").c_str());
  1685. EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
  1686. EXPECT_THAT(getDiagnosticString(),
  1687. HasSubstr("OpTypeCooperativeVectorNV component count <id> "
  1688. "'19[%float_1]' is not a constant integer type"));
  1689. }
  1690. } // namespace
  1691. } // namespace val
  1692. } // namespace spvtools