Browse Source

Support constant expression truncation (#2870)

This adds implementations for HLSL vector and matrix truncation. Using
Constant classes to construct a constant vector or structure.

Incidental changes to vector and struct splatting to take advantage of
an existing operation for that purpose.

Adds a test for truncations and splats in constant expressions.

Fixes #2832
Greg Roth 5 years ago
parent
commit
b4298af6f5

+ 46 - 4
tools/clang/lib/CodeGen/CGExprConstant.cpp

@@ -753,8 +753,7 @@ public:
       return nullptr;
     case CK_HLSLVectorSplat: {
       unsigned vecSize = hlsl::GetHLSLVecSize(E->getType());
-      std::vector<llvm::Constant*> Elts(vecSize, C);
-      return llvm::ConstantVector::get(Elts);
+      return llvm::ConstantVector::getSplat(vecSize, C);
     }
     case CK_HLSLMatrixSplat: {
       llvm::StructType *ST =
@@ -762,13 +761,56 @@ public:
       unsigned row,col;
       hlsl::GetHLSLMatRowColCount(E->getType(), row, col);
 
-      std::vector<llvm::Constant *> Cols(col, C);
-      llvm::Constant *Row = llvm::ConstantVector::get(Cols);
+      llvm::Constant *Row = llvm::ConstantVector::getSplat(col, C);
       std::vector<llvm::Constant *> Rows(row, Row);
       llvm::Constant *Mat = llvm::ConstantArray::get(
           cast<llvm::ArrayType>(ST->getElementType(0)), Rows);
       return llvm::ConstantStruct::get(ST, Mat);
     }
+    case CK_HLSLVectorTruncationCast: {
+      unsigned vecSize = hlsl::GetHLSLVecSize(E->getType());
+      SmallVector<llvm::Constant*, 4> Elts(vecSize);
+      if (llvm::ConstantDataVector *CDV = dyn_cast<llvm::ConstantDataVector>(C)) {
+        for (unsigned i = 0; i < vecSize; i++)
+          Elts[i] = CDV->getElementAsConstant(i);
+      } else {
+        llvm::ConstantVector *CV = dyn_cast<llvm::ConstantVector>(C);
+        for (unsigned i = 0; i < vecSize; i++)
+          Elts[i] = CV->getOperand(i);
+      }
+      return llvm::ConstantVector::get(Elts);
+    }
+    case CK_HLSLVectorToScalarCast: {
+      if (llvm::ConstantDataVector *CDV = cast<llvm::ConstantDataVector>(C))
+        return CDV->getElementAsConstant(0);
+      llvm::ConstantVector *CV = cast<llvm::ConstantVector>(C);
+      return CV->getOperand(0);
+    }
+    case CK_HLSLMatrixTruncationCast: {
+      llvm::StructType *ST =
+          cast<llvm::StructType>(CGM.getTypes().ConvertType(E->getType()));
+      unsigned rowCt,colCt;
+      hlsl::GetHLSLMatRowColCount(E->getType(), rowCt, colCt);
+      if (llvm::ConstantStruct *CS = dyn_cast<llvm::ConstantStruct>(C)) {
+        llvm::ConstantArray *CA = dyn_cast<llvm::ConstantArray>(CS->getOperand(0));
+        SmallVector<llvm::Constant *, 4> Rows(rowCt);
+        for (unsigned i = 0; i < rowCt; i++) {
+          SmallVector<llvm::Constant*, 4> Elts(colCt);
+          if (llvm::ConstantDataVector *CDV = dyn_cast<llvm::ConstantDataVector>(CA->getOperand(i))) {
+            for (unsigned j = 0; j < colCt; j++)
+              Elts[j] = CDV->getElementAsConstant(j);
+          } else {
+            llvm::ConstantVector *CV = cast<llvm::ConstantVector>(CA->getOperand(i));
+            for (unsigned j = 0; j < colCt; j++)
+              Elts[j] = CV->getOperand(j);
+          }
+          Rows[i] = llvm::ConstantVector::get(Elts);
+        }
+        llvm::Constant *Mat = llvm::ConstantArray::get(
+            cast<llvm::ArrayType>(ST->getElementType(0)), Rows);
+        return llvm::ConstantStruct::get(ST, Mat);
+      }
+    }
     // HLSL Change Ends.
     }
     llvm_unreachable("Invalid CastKind");

+ 19 - 0
tools/clang/test/HLSLFileCheck/hlsl/types/vector/constVecTrunc.hlsl

@@ -0,0 +1,19 @@
+// RUN: %dxc -T ps_6_0 %s  | FileCheck %s
+
+// Test vector/matrix truncation and splats in constant expressions
+// If they remain constant, it should simplify down to just storeOutputs
+
+// CHECK: define void @main
+// CHECK-NEXT: call void @dx.op.storeOutput
+// CHECK-NEXT: call void @dx.op.storeOutput
+// CHECK-NEXT: call void @dx.op.storeOutput
+// CHECK-NEXT: call void @dx.op.storeOutput
+// CHECK: ret void
+float4 main() : SV_Target
+{
+  const float val = float4(0.1F, 1, 0, 1);
+  const float val2 = 0.2F;
+  const float2x2 mat = float3x2(1.1,1.2,2.1,2.2,3.1,3.2);
+  const float2x3 mat2 = 5.0;
+  return float4(val, val2, mat[0][0], mat2[1][1]);
+}