Explorar el Código

Allow const evaluating vector initializer lists in the frontend (#522)

Initializer lists in HLSL are very flexible; the front end right
now doesn't handle their const evaluation to the full. This change
only turns on one specific case: intializing vectors.
Lei Zhang hace 8 años
padre
commit
c5832094a6

+ 37 - 9
tools/clang/lib/AST/ExprConstant.cpp

@@ -76,6 +76,31 @@ static const FunctionDecl *GetCallExprFunction(const CallExpr *CE) {
 
   return FDecl;
 }
+
+// Returns true if the given InitListExpr is for constructing a HLSL vector
+// with the matching number of initializers and each initializer has the
+// matching element type.
+static bool IsHLSLVecInitList(const Expr* expr) {
+  if (const auto* initExpr = dyn_cast<InitListExpr>(expr)) {
+    const QualType vecType = initExpr->getType();
+    if (!hlsl::IsHLSLVecType(vecType))
+      return false;
+
+    const uint32_t size = hlsl::GetHLSLVecSize(vecType);
+    const QualType elemType = hlsl::GetHLSLVecElementType(vecType).getCanonicalType();
+
+    if (initExpr->getNumInits() != size)
+      return false;
+
+    for (uint32_t i = 0; i < size; ++i)
+      if (initExpr->getInit(i)->getType().getCanonicalType() != elemType)
+        return false;
+
+    return true;
+  }
+
+  return false;
+}
 // HLSL Change Ends
 
 
@@ -4254,7 +4279,7 @@ public:
   bool VisitInitListExpr(const InitListExpr *E) {
     if (E->getNumInits() == 0)
       return DerivedZeroInitialization(E);
-    if (Info.getLangOpts().HLSL) return Error(E); // HLSL Change
+    if (Info.getLangOpts().HLSL && !IsHLSLVecInitList(E)) return Error(E); // HLSL Change
     if (E->getNumInits() == 1)
       return StmtVisitorTy::Visit(E->getInit(0));
     return Error(E);
@@ -4295,9 +4320,13 @@ public:
   }
 
   bool VisitCastExpr(const CastExpr *E) {
-    if (Info.getLangOpts().HLSL && E->getSubExpr()->getStmtClass() == Stmt::InitListExprClass) { // HLSL Change
-      return Error(E);
+    // HLSL Change Begins
+    if (Info.getLangOpts().HLSL) {
+      const auto* subExpr = E->getSubExpr();
+      if (subExpr->getStmtClass() == Stmt::InitListExprClass && !IsHLSLVecInitList(subExpr))
+        return Error(E);
     }
+    // HLSL Change Ends
     switch (E->getCastKind()) {
     default:
       break;
@@ -5557,7 +5586,7 @@ public:
     }
   }
   bool VisitInitListExpr(const InitListExpr *E) {
-    if (Info.getLangOpts().HLSL) return Error(E); // HLSL Change
+    if (Info.getLangOpts().HLSL && !IsHLSLVecInitList(E)) return Error(E); // HLSL Change
     return VisitConstructExpr(E);
   }
   bool VisitCXXConstructExpr(const CXXConstructExpr *E) {
@@ -5643,6 +5672,7 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr* E) {
   QualType SETy = SE->getType();
 
   switch (E->getCastKind()) {
+  case CK_HLSLVectorSplat: // HLSL Change
   case CK_VectorSplat: {
     APValue Val = APValue();
     if (SETy->isIntegerType()) {
@@ -8476,12 +8506,10 @@ static bool Evaluate(APValue &Result, EvalInfo &Info, const Expr *E) {
   // In C, function designators are not lvalues, but we evaluate them as if they
   // are.
   // HLSL Change Begins.
-  if (Info.getLangOpts().HLSL && E->getStmtClass() == Stmt::InitListExprClass) { // HLSL Change
-    if (hlsl::IsHLSLVecType(E->getType())) {
-      if (EvaluateVector(E, Result, Info))
+  if (Info.getLangOpts().HLSL) {
+    if (E->isRValue() && hlsl::IsHLSLVecType(E->getType()) && EvaluateVector(E, Result, Info))
         return true;
-    }
-    if (!E->getType()->isScalarType())
+    if (E->getStmtClass() == Stmt::InitListExprClass && !E->getType()->isScalarType())
       return false;
   }
   // HLSL Change Ends.

+ 3 - 2
tools/clang/test/CodeGenSPIRV/cast.vector.splat.hlsl

@@ -6,9 +6,10 @@ void main() {
 // CHECK-LABEL: %bb_entry = OpLabel
 
     // From constant
-// CHECK: OpStore %vf4 [[v4f32c]]
+// CHECK: %vf4 = OpVariable %_ptr_Function_v4float Function [[v4f32c]]
     float4 vf4 = 1;
-// CHECK-NEXT: [[v3f32c:%\d+]] = OpCompositeConstruct %v3float %float_2 %float_2 %float_2
+
+// CHECK: [[v3f32c:%\d+]] = OpCompositeConstruct %v3float %float_2 %float_2 %float_2
 // CHECK-NEXT: OpStore %vf3 [[v3f32c]]
     float3 vf3;
     vf3 = float1(2);

+ 6 - 17
tools/clang/test/CodeGenSPIRV/var.init.hlsl

@@ -16,20 +16,11 @@
 float4 main(float component: COLOR) : SV_TARGET {
 // CHECK-LABEL: %bb_entry = OpLabel
 
-// CHECK-NEXT: %a = OpVariable %_ptr_Function_int Function %int_0
-// CHECK-NEXT: %b = OpVariable %_ptr_Function_int Function
-
-// CHECK-NEXT: %i = OpVariable %_ptr_Function_float Function %float_3
-// CHECK-NEXT: %j = OpVariable %_ptr_Function_float Function
-
-// CHECK-NEXT: %m = OpVariable %_ptr_Function_v4float Function
-// CHECK-NEXT: %n = OpVariable %_ptr_Function_v4float Function
-// CHECK-NEXT: %o = OpVariable %_ptr_Function_v4float Function
-
-// CHECK-NEXT: %p = OpVariable %_ptr_Function_v2int Function [[int2constant]]
-// CHECK-NEXT: %q = OpVariable %_ptr_Function_v3int Function
-
-// CHECK-NEXT: %x = OpVariable %_ptr_Function_uint Function
+// CHECK: %a = OpVariable %_ptr_Function_int Function %int_0
+// CHECK: %i = OpVariable %_ptr_Function_float Function %float_3
+// CHECK: %m = OpVariable %_ptr_Function_v4float Function [[float4constant]]
+// CHECK: %p = OpVariable %_ptr_Function_v2int Function [[int2constant]]
+// CHECK: %x = OpVariable %_ptr_Function_uint Function %uint_1
 
 // Initializer already attached to the var definition
     int a = 0; // From constant
@@ -43,9 +34,8 @@ float4 main(float component: COLOR) : SV_TARGET {
 // CHECK-NEXT: OpStore %j [[component0]]
     float j = component; // From stage variable
 
-// CHECK-NEXT: OpStore %m [[float4constant]]
     float4 m = float4(1.0, 2.0, 3.0, 4.0);  // All components are constants
-// CHECK-NEXT: [[j0:%\d+]] = OpLoad %float %j
+// CHECK: [[j0:%\d+]] = OpLoad %float %j
 // CHECK-NEXT: [[j1:%\d+]] = OpLoad %float %j
 // CHECK-NEXT: [[j2:%\d+]] = OpLoad %float %j
 // CHECK-NEXT: [[j3:%\d+]] = OpLoad %float %j
@@ -65,7 +55,6 @@ float4 main(float component: COLOR) : SV_TARGET {
 // CHECK-NEXT: OpStore %q [[qinit]]
     int3 q = {4, b, a}; // Mixed cases
 
-// CHECK-NEXT: OpStore %x %uint_1
     uint1 x = uint1(1); // Special case: vector of size 1
 
     return o;