소스 검색

Constant evaluation for HLSLVectorElementExpr (#840)

Ehsan 7 년 전
부모
커밋
84abf7d86c

+ 65 - 0
tools/clang/lib/AST/ExprConstant.cpp

@@ -5645,6 +5645,7 @@ namespace {
       { return Visit(E->getSubExpr()); }
     bool VisitCastExpr(const CastExpr* E);
     bool VisitInitListExpr(const InitListExpr *E);
+    bool VisitHLSLVectorElementExpr(const HLSLVectorElementExpr *E);
     bool VisitUnaryImag(const UnaryOperator *E);
     // FIXME: Missing: unary -, unary ~, binary add/sub/mul/div,
     //                 binary comparisons, binary and/or/xor,
@@ -5657,6 +5658,26 @@ static bool EvaluateVector(const Expr* E, APValue& Result, EvalInfo &Info) {
   return VectorExprEvaluator(Info, Result).Visit(E);
 }
 
+bool VectorExprEvaluator::VisitHLSLVectorElementExpr(
+    const HLSLVectorElementExpr *E) {
+  SmallVector<APValue, 4> Elts;
+  const Expr *baseExpr = E->getBase();
+  // Handling cases where HLSLVectorElement access into constant vector.
+  // For example: float4 a = (0.0).xxxx;
+  if (Evaluate(Result, Info, baseExpr) && !Info.EvalStatus.HasSideEffects &&
+      Result.getKind() == APValue::ValueKind::Vector) {
+    hlsl::VectorMemberAccessPositions accessor = E->getEncodedElementAccess();
+    for (uint32_t i = 0; i < accessor.Count; ++i) {
+      uint32_t selector;
+      accessor.GetPosition(i, &selector);
+      Elts.push_back(Result.getVectorElt(selector));
+    }
+    return Success(Elts, E);
+  }
+  // TODO: Other cases may be added for other APValue::ValueKind.
+  return false;
+}
+
 bool VectorExprEvaluator::VisitCastExpr(const CastExpr* E) {
   // HLSL Change Begins.
   const VectorType *VTy;
@@ -5672,6 +5693,50 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr* E) {
   QualType SETy = SE->getType();
 
   switch (E->getCastKind()) {
+  // HLSL Change Begins.
+  case CK_HLSLCC_FloatingCast: {
+    if (!Visit(SE))
+      return Error(E);
+    SmallVector<APValue, 4> Elts;
+    for (uint32_t i = 0; i < Result.getVectorLength(); ++i) {
+      APValue Elem = Result.getVectorElt(i);
+      HandleFloatToFloatCast(
+          Info, E, hlsl::GetHLSLVecElementType(SE->getType()),
+          hlsl::GetHLSLVecElementType(E->getType()), Elem.getFloat());
+      Elts.push_back(Elem);
+    }
+    const auto EltsSize = Elts.size();
+    return Success(Elts, E);
+  }
+  case CK_HLSLCC_IntegralToFloating: {
+    if (!Visit(SE))
+      return Error(E);
+    SmallVector<APValue, 4> Elts;
+    for (uint32_t i = 0; i < Result.getVectorLength(); ++i) {
+      APFloat ElemFloat(0.0);
+      HandleIntToFloatCast(Info, E, hlsl::GetHLSLVecElementType(SE->getType()),
+                           Result.getVectorElt(i).getInt(),
+                           hlsl::GetHLSLVecElementType(E->getType()),
+                           ElemFloat);
+      Elts.push_back(APValue(ElemFloat));
+    }
+    const auto EltsSize = Elts.size();
+    return Success(Elts, E);
+  }
+  case CK_HLSLCC_FloatingToIntegral: {
+    if (!Visit(SE))
+      return Error(E);
+    SmallVector<APValue, 4> Elts;
+    for (uint32_t i = 0; i < Result.getVectorLength(); ++i) {
+      APSInt ElemInt;
+      HandleFloatToIntCast(Info, E, hlsl::GetHLSLVecElementType(SE->getType()),
+                           Result.getVectorElt(i).getFloat(),
+                           hlsl::GetHLSLVecElementType(E->getType()), ElemInt);
+      Elts.push_back(APValue(ElemInt));
+    }
+    return Success(Elts, E);
+  }
+  // HLSL Change Ends.
   case CK_HLSLVectorSplat: // HLSL Change
   case CK_VectorSplat: {
     APValue Val = APValue();

+ 6 - 0
tools/clang/lib/CodeGen/CGExprConstant.cpp

@@ -743,6 +743,12 @@ public:
     case CK_ZeroToOCLEvent:
       return nullptr;
     // HLSL Change Begins.
+    case CK_HLSLCC_FloatingCast:
+    case CK_HLSLCC_IntegralToFloating:
+    case CK_HLSLCC_FloatingToIntegral:
+      // Since these cast kinds have already been handled in ExprConstant.cpp,
+      // we can reuse the logic there.
+      return CGM.EmitConstantExpr(E, E->getType(), CGF);
     case CK_FlatConversion:
       return nullptr;
     case CK_HLSLVectorSplat: {

+ 16 - 0
tools/clang/test/CodeGenHLSL/vec_elem_const_eval.hlsl

@@ -0,0 +1,16 @@
+// RUN: %dxc -E main -T ps_6_0 -fcgl %s | FileCheck %s
+
+// Note: since we want to check that HLSLVectorElementExpr are constant evaluated,
+// we require -fcgl here to turn off further constant folding and other optimizations.
+
+// CHECK: constant <4 x float> <float 1.500000e+00, float 1.500000e+00, float 1.500000e+00, float 1.500000e+00>
+// CHECK: constant <4 x float> <float 2.000000e+00, float 2.000000e+00, float 2.000000e+00, float 2.000000e+00>
+// CHECK: constant <4 x i32> <i32 3, i32 3, i32 3, i32 3>
+
+static const float4 a = (1.5).xxxx;
+static const float4 b = (2).xxxx;
+static const int4   c = (3.5).xxxx;
+
+float4 main() : SV_Target {
+  return a + b + c;
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -1123,6 +1123,7 @@ public:
   TEST_METHOD(DxilGen_StoreOutput)
   TEST_METHOD(ConstantFolding)
   TEST_METHOD(HoistConstantArray)
+  TEST_METHOD(VecElemConstEval)
   TEST_METHOD(ViewID)
   TEST_METHOD(ShaderCompatSuite)
   TEST_METHOD(QuickTest)
@@ -5727,6 +5728,10 @@ TEST_F(CompilerTest, HoistConstantArray) {
   CodeGenTestCheck(L"hca\\14.hlsl");
 }
 
+TEST_F(CompilerTest, VecElemConstEval) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\vec_elem_const_eval.hlsl");
+}
+
 TEST_F(CompilerTest, PreprocessWhenValidThenOK) {
   CComPtr<IDxcCompiler> pCompiler;
   CComPtr<IDxcOperationResult> pResult;