Przeglądaj źródła

Fix invalid CmpEQ(%bool,0) for converting bools (#1756)

The matrix lowering code special-cased bools and converted from i32s using CmpEQ(%bool,0) instead of CmpNE(%bool,0) in a few places. This change fixes it, where possible by leveraging MemToReg to remove the special cases.
Tristan Labelle 6 lat temu
rodzic
commit
95984400db

+ 4 - 16
lib/HLSL/HLMatrixLowerPass.cpp

@@ -412,26 +412,14 @@ INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lowe
 static Instruction *CreateTypeCast(HLCastOpcode castOp, Type *toTy, Value *src,
                                    IRBuilder<> Builder) {
   // Cast to bool.
-  if (toTy->getScalarType()->isIntegerTy() &&
-      toTy->getScalarType()->getIntegerBitWidth() == 1) {
+  if (toTy->getScalarType()->isIntegerTy(1)) {
     Type *fromTy = src->getType();
+    Constant *zero = llvm::Constant::getNullValue(src->getType());
     bool isFloat = fromTy->getScalarType()->isFloatingPointTy();
-    Constant *zero;
-    if (isFloat)
-      zero = llvm::ConstantFP::get(fromTy->getScalarType(), 0);
-    else
-      zero = llvm::ConstantInt::get(fromTy->getScalarType(), 0);
-
-    if (toTy->getScalarType() != toTy) {
-      // Create constant vector.
-      unsigned size = toTy->getVectorNumElements();
-      std::vector<Constant *> zeros(size, zero);
-      zero = llvm::ConstantVector::get(zeros);
-    }
     if (isFloat)
-      return cast<Instruction>(Builder.CreateFCmpOEQ(src, zero));
+      return cast<Instruction>(Builder.CreateFCmpONE(src, zero));
     else
-      return cast<Instruction>(Builder.CreateICmpEQ(src, zero));
+      return cast<Instruction>(Builder.CreateICmpNE(src, zero));
   }
 
   Type *eltToTy = toTy->getScalarType();

+ 25 - 46
lib/HLSL/HLOperationLower.cpp

@@ -4970,18 +4970,13 @@ unsigned GetEltTypeByteSizeForConstBuf(Type *EltType, const DataLayout &DL) {
 Value *GenerateCBLoad(Value *handle, Value *offset, Type *EltTy, OP *hlslOP,
                       IRBuilder<> &Builder) {
   Constant *OpArg = hlslOP->GetU32Const((unsigned)OP::OpCode::CBufferLoad);
+
+  DXASSERT(!EltTy->isIntegerTy(1), "Bools should not be loaded as their register representation.");
+
   // Align to 8 bytes for now.
   Constant *align = hlslOP->GetU32Const(8);
-  Type *i1Ty = Type::getInt1Ty(EltTy->getContext());
-  if (EltTy != i1Ty) {
-    Function *CBLoad = hlslOP->GetOpFunc(OP::OpCode::CBufferLoad, EltTy);
-    return Builder.CreateCall(CBLoad, {OpArg, handle, offset, align});
-  } else {
-    Type *i32Ty = Type::getInt32Ty(EltTy->getContext());
-    Function *CBLoad = hlslOP->GetOpFunc(OP::OpCode::CBufferLoad, i32Ty);
-    Value *Result = Builder.CreateCall(CBLoad, {OpArg, handle, offset, align});
-    return Builder.CreateICmpEQ(Result, hlslOP->GetU32Const(0));
-  }
+  Function *CBLoad = hlslOP->GetOpFunc(OP::OpCode::CBufferLoad, EltTy);
+  return Builder.CreateCall(CBLoad, {OpArg, handle, offset, align});
 }
 
 Value *TranslateConstBufMatLd(Type *matType, Value *handle, Value *offset,
@@ -5313,22 +5308,18 @@ Value *GenerateCBLoadLegacy(Value *handle, Value *legacyIdx,
                             IRBuilder<> &Builder) {
   Constant *OpArg = hlslOP->GetU32Const((unsigned)OP::OpCode::CBufferLoadLegacy);
 
-  Type *i1Ty = Type::getInt1Ty(EltTy->getContext());
+  DXASSERT(!EltTy->isIntegerTy(1), "Bools should not be loaded as their register representation.");
+
   Type *doubleTy = Type::getDoubleTy(EltTy->getContext());
   Type *halfTy = Type::getHalfTy(EltTy->getContext());
   Type *i64Ty = Type::getInt64Ty(EltTy->getContext());
   Type *i16Ty = Type::getInt16Ty(EltTy->getContext());
-  bool isBool = EltTy == i1Ty;
+
   bool is64 = (EltTy == doubleTy) | (EltTy == i64Ty);
   bool is16 = (EltTy == halfTy || EltTy == i16Ty) && !hlslOP->UseMinPrecision();
-  bool isNormal = !isBool && !is64;
   DXASSERT_LOCALVAR(is16, (is16 && channelOffset < 8) || channelOffset < 4,
            "legacy cbuffer don't across 16 bytes register.");
-  if (isNormal) {
-    Function *CBLoad = hlslOP->GetOpFunc(OP::OpCode::CBufferLoadLegacy, EltTy);
-    Value *loadLegacy = Builder.CreateCall(CBLoad, {OpArg, handle, legacyIdx});
-    return Builder.CreateExtractValue(loadLegacy, channelOffset);
-  } else if (is64) {
+  if (is64) {
     Function *CBLoad = hlslOP->GetOpFunc(OP::OpCode::CBufferLoadLegacy, EltTy);
     Value *loadLegacy = Builder.CreateCall(CBLoad, {OpArg, handle, legacyIdx});
     DXASSERT((channelOffset&1)==0,"channel offset must be even for double");
@@ -5336,12 +5327,9 @@ Value *GenerateCBLoadLegacy(Value *handle, Value *legacyIdx,
     Value *Result = Builder.CreateExtractValue(loadLegacy, eltIdx);
     return Result;
   } else {
-    DXASSERT(isBool, "bool should be i1");
-    Type *i32Ty = Type::getInt32Ty(EltTy->getContext());
-    Function *CBLoad = hlslOP->GetOpFunc(OP::OpCode::CBufferLoadLegacy, i32Ty);
-    Value *loadLegacy = Builder.CreateCall(CBLoad, {OpArg, handle, legacyIdx});
-    Value *Result = Builder.CreateExtractValue(loadLegacy, channelOffset);
-    return Builder.CreateICmpEQ(Result, hlslOP->GetU32Const(0));
+    Function *CBLoad = hlslOP->GetOpFunc(OP::OpCode::CBufferLoadLegacy, EltTy);
+    Value *loadLegacy = Builder.CreateCall(CBLoad, { OpArg, handle, legacyIdx });
+    return Builder.CreateExtractValue(loadLegacy, channelOffset);
   }
 }
 
@@ -5351,29 +5339,19 @@ Value *GenerateCBLoadLegacy(Value *handle, Value *legacyIdx,
                             IRBuilder<> &Builder) {
   Constant *OpArg = hlslOP->GetU32Const((unsigned)OP::OpCode::CBufferLoadLegacy);
 
-  Type *i1Ty = Type::getInt1Ty(EltTy->getContext());
+  DXASSERT(!EltTy->isIntegerTy(1), "Bools should not be loaded as their register representation.");
+
   Type *doubleTy = Type::getDoubleTy(EltTy->getContext());
   Type *i64Ty = Type::getInt64Ty(EltTy->getContext());
   Type *halfTy = Type::getHalfTy(EltTy->getContext());
   Type *shortTy = Type::getInt16Ty(EltTy->getContext());
 
-  bool isBool = EltTy == i1Ty;
   bool is64 = (EltTy == doubleTy) | (EltTy == i64Ty);
   bool is16 = (EltTy == shortTy || EltTy == halfTy) && !hlslOP->UseMinPrecision();
-  bool isNormal = !isBool && !is64 && !is16;
   DXASSERT((is16 && channelOffset + vecSize <= 8) ||
                (channelOffset + vecSize) <= 4,
            "legacy cbuffer don't across 16 bytes register.");
-  if (isNormal) {
-    Function *CBLoad = hlslOP->GetOpFunc(OP::OpCode::CBufferLoadLegacy, EltTy);
-    Value *loadLegacy = Builder.CreateCall(CBLoad, {OpArg, handle, legacyIdx});
-    Value *Result = UndefValue::get(VectorType::get(EltTy, vecSize));
-    for (unsigned i = 0; i < vecSize; ++i) {
-      Value *NewElt = Builder.CreateExtractValue(loadLegacy, channelOffset+i);
-      Result = Builder.CreateInsertElement(Result, NewElt, i);
-    }
-    return Result;
-  } else if (is16) {
+  if (is16) {
     Function *CBLoad = hlslOP->GetOpFunc(OP::OpCode::CBufferLoadLegacy, EltTy);
     Value *loadLegacy = Builder.CreateCall(CBLoad, {OpArg, handle, legacyIdx});
     Value *Result = UndefValue::get(VectorType::get(EltTy, vecSize));
@@ -5405,16 +5383,14 @@ Value *GenerateCBLoadLegacy(Value *handle, Value *legacyIdx,
     }
     return Result;
   } else {
-    DXASSERT(isBool, "bool should be i1");
-    Type *i32Ty = Type::getInt32Ty(EltTy->getContext());
-    Function *CBLoad = hlslOP->GetOpFunc(OP::OpCode::CBufferLoadLegacy, i32Ty);
-    Value *loadLegacy = Builder.CreateCall(CBLoad, {OpArg, handle, legacyIdx});
-    Value *Result = UndefValue::get(VectorType::get(i32Ty, vecSize));
+    Function *CBLoad = hlslOP->GetOpFunc(OP::OpCode::CBufferLoadLegacy, EltTy);
+    Value *loadLegacy = Builder.CreateCall(CBLoad, { OpArg, handle, legacyIdx });
+    Value *Result = UndefValue::get(VectorType::get(EltTy, vecSize));
     for (unsigned i = 0; i < vecSize; ++i) {
-      Value *NewElt = Builder.CreateExtractValue(loadLegacy, channelOffset+i);
+      Value *NewElt = Builder.CreateExtractValue(loadLegacy, channelOffset + i);
       Result = Builder.CreateInsertElement(Result, NewElt, i);
     }
-    return Builder.CreateICmpEQ(Result, ConstantAggregateZero::get(Result->getType()));
+    return Result;
   }
 }
 
@@ -5424,7 +5400,7 @@ Value *TranslateConstBufMatLdLegacy(Type *matType, Value *handle,
                                     IRBuilder<> &Builder) {
   unsigned col, row;
   HLMatrixLower::GetMatrixInfo(matType, col, row);
-  Type *EltTy = HLMatrixLower::LowerMatrixType(matType, /*forMem*/memElemRepr)->getVectorElementType();
+  Type *EltTy = HLMatrixLower::LowerMatrixType(matType, /*forMem*/true)->getVectorElementType();
 
   unsigned matSize = col * row;
   std::vector<Value *> elts(matSize);
@@ -5458,7 +5434,10 @@ Value *TranslateConstBufMatLdLegacy(Type *matType, Value *handle,
     }
   }
 
-  return HLMatrixLower::BuildVector(EltTy, col * row, elts, Builder);
+  Value *Vec = HLMatrixLower::BuildVector(EltTy, col * row, elts, Builder);
+  if (!memElemRepr)
+    Vec = HLMatrixLower::VecMatrixMemToReg(Vec, matType, Builder);
+  return Vec;
 }
 
 void TranslateCBGepLegacy(GetElementPtrInst *GEP, Value *handle,

+ 10 - 0
tools/clang/test/CodeGenHLSL/quick-test/bool_matrix_cbuflegacy_conversion.hlsl

@@ -0,0 +1,10 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// CHECK: icmp ne i32 {{.*}}, 0
+// CHECK: icmp ne i32 {{.*}}, 0
+// CHECK: icmp ne i32 {{.*}}, 0
+// CHECK: icmp ne i32 {{.*}}, 0
+
+struct Struct { bool2x2 mat; };
+ConstantBuffer<Struct> cb;
+bool2x2 main() : B { return cb.mat; }

+ 30 - 0
tools/clang/test/CodeGenHLSL/quick-test/bool_matrix_conversion.hlsl

@@ -0,0 +1,30 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// CHECK: icmp ne i32 {{.*}}, 0
+// CHECK: icmp ne i32 {{.*}}, 0
+// CHECK: icmp ne i32 {{.*}}, 0
+// CHECK: icmp ne i32 {{.*}}, 0
+// CHECK: fcmp fast one float {{.*}}, 0.000000e+00
+// CHECK: fcmp fast one float {{.*}}, 0.000000e+00
+// CHECK: fcmp fast one float {{.*}}, 0.000000e+00
+// CHECK: fcmp fast one float {{.*}}, 0.000000e+00
+
+struct Input
+{
+    int2x2 i : I;
+    float2x2 f : F;
+};
+
+struct Output
+{
+    bool2x2 i : I;
+    bool2x2 f : F;
+};
+
+Output main(Input input)
+{
+    Output output;
+    output.i = (bool2x2)input.i;
+    output.f = (bool2x2)input.f;
+    return output;
+}