Explorar o código

[spirv] Fix the return type of OpSelect. (#947)

Ehsan %!s(int64=7) %!d(string=hai) anos
pai
achega
6887b565f4

+ 13 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -2070,10 +2070,22 @@ SPIRVEmitter::doConditionalOperator(const ConditionalOperator *expr) {
   // According to HLSL doc, all sides of the ?: expression are always
   // According to HLSL doc, all sides of the ?: expression are always
   // evaluated.
   // evaluated.
   const uint32_t type = typeTranslator.translateType(expr->getType());
   const uint32_t type = typeTranslator.translateType(expr->getType());
-  const uint32_t condition = doExpr(expr->getCond());
+  uint32_t condition = doExpr(expr->getCond());
   const uint32_t trueBranch = doExpr(expr->getTrueExpr());
   const uint32_t trueBranch = doExpr(expr->getTrueExpr());
   const uint32_t falseBranch = doExpr(expr->getFalseExpr());
   const uint32_t falseBranch = doExpr(expr->getFalseExpr());
 
 
+  // The SPIR-V OpSelect instruction must have a selection argument that is the
+  // same size as the return type. If the return type is a vector, the selection
+  // must be a vector of booleans (one per output component).
+  uint32_t count = 0;
+  if (TypeTranslator::isVectorType(expr->getType(), nullptr, &count) &&
+      !TypeTranslator::isVectorType(expr->getCond()->getType())) {
+    const uint32_t condVecType =
+        theBuilder.getVecType(theBuilder.getBoolType(), count);
+    const llvm::SmallVector<uint32_t, 4> components(size_t(count), condition);
+    condition = theBuilder.createCompositeConstruct(condVecType, components);
+  }
+
   auto valueId =
   auto valueId =
       theBuilder.createSelect(type, condition, trueBranch, falseBranch);
       theBuilder.createSelect(type, condition, trueBranch, falseBranch);
   return SpirvEvalInfo(valueId).setRValue();
   return SpirvEvalInfo(valueId).setRValue();

+ 38 - 0
tools/clang/test/CodeGenSPIRV/ternary-op.cond-op.hlsl

@@ -1,5 +1,7 @@
 // Run: %dxc -T ps_6_0 -E main
 // Run: %dxc -T ps_6_0 -E main
 
 
+// CHECK: [[v3i0:%\d+]] = OpConstantComposite %v3int %int_0 %int_0 %int_0
+
 void main() {
 void main() {
 // CHECK-LABEL: %bb_entry = OpLabel
 // CHECK-LABEL: %bb_entry = OpLabel
 
 
@@ -31,4 +33,40 @@ void main() {
 // CHECK-NEXT: [[s2:%\d+]] = OpSelect %v3float [[b3]] [[x0]] [[y0]]
 // CHECK-NEXT: [[s2:%\d+]] = OpSelect %v3float [[b3]] [[x0]] [[y0]]
 // CHECK-NEXT: OpStore %z [[s2]]
 // CHECK-NEXT: OpStore %z [[s2]]
     z = b3 ? x : y;
     z = b3 ? x : y;
+
+    // Try condition with various type.
+    // Note: the SPIR-V OpSelect selection argument must be the same size as the return type.
+    int3 u, v, w;
+    bool  cond;
+    bool3 cond3;
+    float floatCond;
+    int3 int3Cond;
+ 
+// CHECK:      [[cond3:%\d+]] = OpLoad %v3bool %cond3
+// CHECK-NEXT:     [[u:%\d+]] = OpLoad %v3int %u
+// CHECK-NEXT:     [[v:%\d+]] = OpLoad %v3int %v
+// CHECK-NEXT:       {{%\d+}} = OpSelect %v3int [[cond3]] [[u]] [[v]]
+    w = cond3 ? u : v;
+
+// CHECK:       [[cond:%\d+]] = OpLoad %bool %cond
+// CHECK-NEXT:     [[u:%\d+]] = OpLoad %v3int %u
+// CHECK-NEXT:     [[v:%\d+]] = OpLoad %v3int %v
+// CHECK-NEXT: [[splat:%\d+]] = OpCompositeConstruct %v3bool [[cond]] [[cond]] [[cond]]
+// CHECK-NEXT:       {{%\d+}} = OpSelect %v3int [[splat]] [[u]] [[v]]
+    w = cond  ? u : v;
+
+// CHECK:      [[floatCond:%\d+]] = OpLoad %float %floatCond
+// CHECK-NEXT:  [[boolCond:%\d+]] = OpFOrdNotEqual %bool [[floatCond]] %float_0
+// CHECK-NEXT: [[bool3Cond:%\d+]] = OpCompositeConstruct %v3bool [[boolCond]] [[boolCond]] [[boolCond]]
+// CHECK-NEXT:         [[u:%\d+]] = OpLoad %v3int %u
+// CHECK-NEXT:         [[v:%\d+]] = OpLoad %v3int %v
+// CHECK-NEXT:           {{%\d+}} = OpSelect %v3int [[bool3Cond]] [[u]] [[v]]
+    w = floatCond ? u : v;
+
+// CHECK:       [[int3Cond:%\d+]] = OpLoad %v3int %int3Cond
+// CHECK-NEXT: [[bool3Cond:%\d+]] = OpINotEqual %v3bool [[int3Cond]] [[v3i0]]
+// CHECK-NEXT:         [[u:%\d+]] = OpLoad %v3int %u
+// CHECK-NEXT:         [[v:%\d+]] = OpLoad %v3int %v
+// CHECK-NEXT:           {{%\d+}} = OpSelect %v3int [[bool3Cond]] [[u]] [[v]]
+    w = int3Cond ? u : v;
 }
 }