Ver código fonte

Choose i32 as result type for select used by binary operator. (#1080)

* Choose i32 as result type for case like ((P.x < -P.w) ? 3 : 9) | ((P.y < -P.z) ? 5 : 3).

* Use lowest precision for literal integer for binary operator.
Xiang Li 7 anos atrás
pai
commit
21c251a391

+ 102 - 0
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -5314,6 +5314,83 @@ void CGMSHLSLRuntime::EmitHLSLDiscard(CodeGenFunction &CGF) {
       TheModule);
 }
 
+static llvm::Type *MergeIntType(llvm::IntegerType *T0, llvm::IntegerType *T1) {
+  if (T0->getBitWidth() > T1->getBitWidth())
+    return T0;
+  else
+    return T1;
+}
+
+static Value *CreateExt(CGBuilderTy &Builder, Value *Src, llvm::Type *DstTy,
+                        bool bSigned) {
+  if (bSigned)
+    return Builder.CreateSExt(Src, DstTy);
+  else
+    return Builder.CreateZExt(Src, DstTy);
+}
+// For integer literal, try to get lowest precision.
+static Value *CalcHLSLLiteralToLowestPrecision(CGBuilderTy &Builder, Value *Src,
+                                               bool bSigned) {
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(Src)) {
+    APInt v = CI->getValue();
+    switch (v.getActiveWords()) {
+    case 4:
+      return Builder.getInt32(v.getLimitedValue());
+    case 8:
+      return Builder.getInt64(v.getLimitedValue());
+    case 2:
+      // TODO: use low precision type when support it in dxil.
+      // return Builder.getInt16(v.getLimitedValue());
+      return Builder.getInt32(v.getLimitedValue());
+    case 1:
+      // TODO: use precision type when support it in dxil.
+      // return Builder.getInt8(v.getLimitedValue());
+      return Builder.getInt32(v.getLimitedValue());
+    default:
+      return nullptr;
+    }
+  } else if (SelectInst *SI = dyn_cast<SelectInst>(Src)) {
+    if (SI->getType()->isIntegerTy()) {
+      Value *T = SI->getTrueValue();
+      Value *F = SI->getFalseValue();
+      Value *lowT = CalcHLSLLiteralToLowestPrecision(Builder, T, bSigned);
+      Value *lowF = CalcHLSLLiteralToLowestPrecision(Builder, F, bSigned);
+      if (lowT && lowF && lowT != T && lowF != F) {
+        llvm::IntegerType *TTy = cast<llvm::IntegerType>(lowT->getType());
+        llvm::IntegerType *FTy = cast<llvm::IntegerType>(lowF->getType());
+        llvm::Type *Ty = MergeIntType(TTy, FTy);
+        if (TTy != Ty) {
+          lowT = CreateExt(Builder, lowT, Ty, bSigned);
+        }
+        if (FTy != Ty) {
+          lowF = CreateExt(Builder, lowF, Ty, bSigned);
+        }
+        Value *Cond = SI->getCondition();
+        return Builder.CreateSelect(Cond, lowT, lowF);
+      }
+    }
+  } else if (llvm::BinaryOperator *BO = dyn_cast<llvm::BinaryOperator>(Src)) {
+    Value *Src0 = BO->getOperand(0);
+    Value *Src1 = BO->getOperand(1);
+    Value *CastSrc0 = CalcHLSLLiteralToLowestPrecision(Builder, Src0, bSigned);
+    Value *CastSrc1 = CalcHLSLLiteralToLowestPrecision(Builder, Src1, bSigned);
+    if (Src0 != CastSrc0 && Src1 != CastSrc1 && CastSrc0 && CastSrc1 &&
+        CastSrc0->getType() == CastSrc1->getType()) {
+      llvm::IntegerType *Ty0 = cast<llvm::IntegerType>(CastSrc0->getType());
+      llvm::IntegerType *Ty1 = cast<llvm::IntegerType>(CastSrc0->getType());
+      llvm::Type *Ty = MergeIntType(Ty0, Ty1);
+      if (Ty0 != Ty) {
+        CastSrc0 = CreateExt(Builder, CastSrc0, Ty, bSigned);
+      }
+      if (Ty1 != Ty) {
+        CastSrc1 = CreateExt(Builder, CastSrc1, Ty, bSigned);
+      }
+      return Builder.CreateBinOp(BO->getOpcode(), CastSrc0, CastSrc1);
+    }
+  }
+  return nullptr;
+}
+
 Value *CGMSHLSLRuntime::EmitHLSLLiteralCast(CodeGenFunction &CGF, Value *Src,
                                             QualType SrcType,
                                             QualType DstType) {
@@ -5423,6 +5500,31 @@ Value *CGMSHLSLRuntime::EmitHLSLLiteralCast(CodeGenFunction &CGF, Value *Src,
           return Sel;
         }
       }
+    } else if (llvm::BinaryOperator *BO = dyn_cast<llvm::BinaryOperator>(I)) {
+      // For integer binary operator, do the calc on lowest precision, then cast
+      // to dstTy.
+      if (I->getType()->isIntegerTy()) {
+        bool bSigned = DstType->isSignedIntegerType();
+        Value *CastResult =
+            CalcHLSLLiteralToLowestPrecision(Builder, BO, bSigned);
+        if (!CastResult)
+          return nullptr;
+        if (llvm::IntegerType *IT = dyn_cast<llvm::IntegerType>(DstTy)) {
+          if (DstTy == CastResult->getType()) {
+            return CastResult;
+          } else {
+            if (bSigned)
+              return Builder.CreateSExtOrTrunc(CastResult, DstTy);
+            else
+              return Builder.CreateZExtOrTrunc(CastResult, DstTy);
+          }
+        } else {
+          if (bDstSigned)
+            return Builder.CreateSIToFP(CastResult, DstTy);
+          else
+            return Builder.CreateUIToFP(CastResult, DstTy);
+        }
+      }
     }
     // TODO: support other opcode if need.
     return nullptr;

+ 11 - 0
tools/clang/test/CodeGenHLSL/quick-test/immSel.hlsl

@@ -0,0 +1,11 @@
+// RUN: %dxc -T ps_6_0 -O0 -E main %s | FileCheck %s
+
+// Make sure use literal int is selected into i32 for value fit in i32.
+
+// CHECK-NOT: or i64
+// CHECK: or i32
+
+uint main(const float4 P: A ) : SV_Target
+{
+    return ((P.x < -P.w) ? 3 : 9) | ((P.y < -P.z) ? 5 : 3);
+}