Przeglądaj źródła

Fix bool casting logic when using initializer syntax (#1577)

Vishal Sharma 7 lat temu
rodzic
commit
8948fa423c

+ 28 - 3
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -5163,8 +5163,29 @@ void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Val
   }  
 }
 
+static bool IsBooleanType(llvm::Type *ty) {
+  return (ty->isIntegerTy() && ty->getIntegerBitWidth() == 1);
+}
+
+static Value *CreateCastforBoolDestType(CGBuilderTy &Builder, Value *srcVal) {
+  llvm::Type *srcTy = srcVal->getType();
+  if (srcTy->isFloatingPointTy()) {
+    return Builder.CreateFCmp(FCmpInst::FCMP_UNE, srcVal,
+                              ConstantFP::get(srcTy, 0));
+  } else {
+    // must be an integer type here
+    DXASSERT(srcTy->isIntegerTy() && srcTy->getIntegerBitWidth() > 1,
+             "must be a non-boolean integer type.");
+    return Builder.CreateICmp(ICmpInst::ICMP_NE, srcVal,
+                              ConstantInt::get(srcTy, 0));
+  }
+}
+
 // Cast elements in initlist if not match the target type.
 // idx is current element index in initlist, Ty is target type.
+
+// TODO: Stop handling missing cast here. Handle the casting of non-scalar values
+// to their destination type in init list expressions at AST level.
 static void AddMissingCastOpsInInitList(SmallVector<Value *, 4> &elts, SmallVector<QualType, 4> &eltTys, unsigned &idx, QualType Ty, CodeGenFunction &CGF) {
   if (Ty->isArrayType()) {
     const clang::ArrayType *AT = Ty->getAsArrayTypeUnsafe();
@@ -5218,10 +5239,14 @@ static void AddMissingCastOpsInInitList(SmallVector<Value *, 4> &elts, SmallVect
     llvm::Type *srcTy = val->getType();
     llvm::Type *dstTy = CGF.ConvertType(Ty);
     if (srcTy != dstTy) {
-      Instruction::CastOps castOp =
+      if (IsBooleanType(dstTy)) {
+        elts[idx] = CreateCastforBoolDestType(CGF.Builder, val);
+      } else {
+        Instruction::CastOps castOp =
           static_cast<Instruction::CastOps>(HLModule::FindCastOp(
-              IsUnsigned(eltTys[idx]), IsUnsigned(Ty), srcTy, dstTy));
-      elts[idx] = CGF.Builder.CreateCast(castOp, val, dstTy);
+            IsUnsigned(eltTys[idx]), IsUnsigned(Ty), srcTy, dstTy));
+        elts[idx] = CGF.Builder.CreateCast(castOp, val, dstTy);
+      }
     }
     idx++;
   }

+ 88 - 0
tools/clang/test/CodeGenHLSL/quick-test/bool_cast_initializer_syntax_test.hlsl

@@ -0,0 +1,88 @@
+// RUN: %dxc /T ps_6_0 /E main %s | FileCheck %s
+
+// Ensure that when casting to bool, we never use fptoui or fptosi instruction
+// CHECK-NOT: fptoui
+// CHECK-NOT: fptosi
+
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+
+// CHECK: icmp ne
+// CHECK: icmp ne
+// CHECK: icmp ne
+// CHECK: icmp ne
+
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+  
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+
+// CHECK: icmp ne
+// CHECK: icmp ne
+// CHECK: fcmp fast une
+
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+
+// CHECK: icmp ne
+// CHECK: fcmp fast une
+
+// CHECK: fcmp fast une
+// CHECK: fcmp fast une
+
+bool4 main (float f1 : F1, 
+            float2 f2 : F2,
+            float3 f3 : F3,
+            float4 f4 : F4,
+
+            int i1 : I1,
+            int2 i2 : I2,
+            int3 i3 : I3,
+            int4 i4 : I4,
+
+            min10float mtf1 : M1,
+            min10float2 mtf2 : M2,
+            min10float3 mtf3 : M3,
+            min10float4 mtf4 : M4,
+
+            min16float msf1 : M5,
+            min16float2 msf2 : M6,
+            min16float3 msf3 : M7,
+            min16float4 msf4 : M8,
+
+            half h1 : H1,
+            half2 h2 : H2,
+            half3 h3 : H3,
+            half4 h4 : H4) : SV_Target
+{ 
+    return
+    bool4(f4) &&
+    bool4(i4) &&
+    bool4(mtf4) &&
+    bool4(msf4) &&
+    bool4(h4) &&                
+    bool4(bool3(f3), f1) &&
+    bool4(bool2(i2), f1, h1) &&
+    bool4(f3, f1) &&
+    bool4(bool2(i2), bool2(msf2)) &&
+    bool4(i1, f1, mtf1, h1) &&
+    bool4(i4) &&
+    bool4(mtf2, msf2);
+}