2
0
Эх сурвалжийг харах

Report error on vector if condition. (#113)

Xiang Li 8 жил өмнө
parent
commit
cc4387b7ae

+ 2 - 0
tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

@@ -7554,6 +7554,8 @@ def err_hlsl_vla : Error< // Patterened after err_opencl_vla
   "variable length arrays are not supported in HLSL">;
 def err_hlsl_type_empty_init : Error<
   "%0 cannot have an explicit empty initializer">;
+def err_hlsl_control_flow_cond_not_scalar : Error<
+  "%0 statement conditional expressions must evaluate to a scalar">;
 def err_hlsl_unsupportedvectortype : Error<
   "%0 is declared with type %1, but only primitive scalar values are supported">;
 def err_hlsl_unsupportedvectorsize : Error<

+ 4 - 0
tools/clang/include/clang/Sema/SemaHLSL.h

@@ -73,6 +73,10 @@ void DiagnoseAssignmentResultForHLSL(
   clang::Sema::AssignmentAction Action,
   bool *Complained);
 
+void DiagnoseControlFlowConditionForHLSL(clang::Sema *self,
+                                         clang::Expr *condExpr,
+                                         llvm::StringRef StmtName);
+
 void DiagnosePackingOffset(
   clang::Sema* self,
   clang::SourceLocation loc,

+ 13 - 0
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -8222,6 +8222,19 @@ void hlsl::DiagnoseAssignmentResultForHLSL(Sema* self,
     ->DiagnoseAssignmentResultForHLSL(ConvTy, Loc, DstType, SrcType, SrcExpr, Action, Complained);
 }
 
+void hlsl::DiagnoseControlFlowConditionForHLSL(Sema *self, Expr *condExpr, StringRef StmtName) {
+  while (ImplicitCastExpr *IC = dyn_cast<ImplicitCastExpr>(condExpr)) {
+    if (IC->getCastKind() == CastKind::CK_HLSLMatrixTruncationCast ||
+        IC->getCastKind() == CastKind::CK_HLSLVectorTruncationCast) {
+      self->Diag(condExpr->getLocStart(),
+                 diag::err_hlsl_control_flow_cond_not_scalar)
+          << StmtName;
+      return;
+    }
+    condExpr = IC->getSubExpr();
+  }
+}
+
 static bool ShaderModelsMatch(const StringRef& left, const StringRef& right)
 {
   // TODO: handle shorthand cases.

+ 21 - 3
tools/clang/lib/Sema/SemaStmt.cpp

@@ -37,6 +37,7 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
+#include "clang/Sema/SemaHLSL.h" // HLSL Change
 using namespace clang;
 using namespace sema;
 
@@ -523,7 +524,9 @@ Sema::ActOnIfStmt(SourceLocation IfLoc, FullExprArg CondVal, Decl *CondVar,
   Expr *ConditionExpr = CondResult.getAs<Expr>();
   if (!ConditionExpr)
     return StmtError();
-
+  // HLSL Change Begin.
+  hlsl::DiagnoseControlFlowConditionForHLSL(this, ConditionExpr, "if");
+  // HLSL Change End.
   DiagnoseUnusedExprResult(thenStmt);
 
   if (!elseStmt) {
@@ -1267,6 +1270,11 @@ Sema::ActOnWhileStmt(SourceLocation WhileLoc, FullExprArg Cond,
   Expr *ConditionExpr = CondResult.get();
   if (!ConditionExpr)
     return StmtError();
+
+  // HLSL Change Begin.
+  hlsl::DiagnoseControlFlowConditionForHLSL(this, ConditionExpr, "while");
+  // HLSL Change End.
+
   CheckBreakContinueBinding(ConditionExpr);
 
   DiagnoseUnusedExprResult(Body);
@@ -1294,6 +1302,11 @@ Sema::ActOnDoStmt(SourceLocation DoLoc, Stmt *Body,
   if (CondResult.isInvalid())
     return StmtError();
   Cond = CondResult.get();
+  // HLSL Change Begin.
+  if (Cond) {
+    hlsl::DiagnoseControlFlowConditionForHLSL(this, Cond, "do-while");
+  }
+  // HLSL Change End.
 
   DiagnoseUnusedExprResult(Body);
 
@@ -1670,7 +1683,12 @@ Sema::ActOnForStmt(SourceLocation ForLoc, SourceLocation LParenLoc,
     if (SecondResult.isInvalid())
       return StmtError();
   }
-
+  // HLSL Change Begin.
+  Expr *Cond = SecondResult.get();
+  if (Cond) {
+    hlsl::DiagnoseControlFlowConditionForHLSL(this, Cond, "for");
+  }
+  // HLSL Change End.
   Expr *Third  = third.release().getAs<Expr>();
 
   DiagnoseUnusedExprResult(First);
@@ -1680,7 +1698,7 @@ Sema::ActOnForStmt(SourceLocation ForLoc, SourceLocation LParenLoc,
   if (isa<NullStmt>(Body))
     getCurCompoundScope().setHasEmptyLoopBodies();
 
-  return new (Context) ForStmt(Context, First, SecondResult.get(), ConditionVar,
+  return new (Context) ForStmt(Context, First, Cond, ConditionVar,
                                Third, Body, ForLoc, LParenLoc, RParenLoc);
 }
 

+ 1 - 1
tools/clang/test/CodeGenHLSL/reducible.hlsl

@@ -2,7 +2,7 @@
 
 // CHECK: !"llvm.loop.unroll.disable"
 uint u;
-float main(float2 a : A, int3 b : B) : SV_Target
+float main(float a : A, int3 b : B) : SV_Target
 {
   float s = 0;
   /*

+ 1 - 1
tools/clang/test/CodeGenHLSL/share_mem2Dim.hlsl

@@ -22,7 +22,7 @@ StructuredBuffer<row_major float2x2> mats2;
 [numthreads(8,8,1)]
 void main( uint2 tid : SV_DispatchThreadID, uint2 gid : SV_GroupID, uint2 gtid : SV_GroupThreadID, uint gidx : SV_GroupIndex )
 {
-    if (gtid==0)
+    if (gtid.x==0)
        x = tid.x;
     dataC[tid.x%8][tid.y%8] = mats.Load(gid.x).f2x2 + mats2.Load(gtid.y);
     GroupMemoryBarrierWithGroupSync();

+ 31 - 0
tools/clang/test/CodeGenHLSL/vecCmpCond.hlsl

@@ -0,0 +1,31 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: while statement conditional expressions must evaluate to a scalar
+// CHECK: do-while statement conditional expressions must evaluate to a scalar
+// CHECK: for statement conditional expressions must evaluate to a scalar
+// CHECK: if statement conditional expressions must evaluate to a scalar
+
+float4 m;
+float4 main(float4 a:A) : SV_Target
+{
+    int x=0;
+    while (a>0) {
+      a -= 2;
+      x++;
+    }
+
+    do {
+      a -= 2;
+      x++;
+    } while (a>0);
+
+
+    for (uint i=0; a > 0; a--) {
+       x++;
+    }
+
+    if (m)
+        return x;
+    else
+        return 1;
+}

+ 12 - 0
tools/clang/test/CodeGenHLSL/vecCmpIf.hlsl

@@ -0,0 +1,12 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: if statement conditional expressions must evaluate to a scalar
+
+float4 m;
+float4 main() : SV_Target
+{
+    if (m)
+        return 0;
+    else
+        return 1;
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -534,6 +534,7 @@ public:
   TEST_METHOD(CodeGenUpperCaseRegister1);
   TEST_METHOD(CodeGenVcmp)
   TEST_METHOD(CodeGenVec_Comp_Arg)
+  TEST_METHOD(CodeGenVecCmpCond)
   TEST_METHOD(CodeGenWave)
   TEST_METHOD(CodeGenWriteToInput)
   TEST_METHOD(CodeGenWriteToInput2)
@@ -2797,6 +2798,10 @@ TEST_F(CompilerTest, CodeGenVec_Comp_Arg){
   CodeGenTest(L"..\\CodeGenHLSL\\vec_comp_arg.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenVecCmpCond) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\vecCmpCond.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenWave) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\wave.hlsl");
 }

+ 2 - 2
tools/clang/unittests/HLSL/ValidationTest.cpp

@@ -787,9 +787,9 @@ TEST_F(ValidationTest, ReducibleFail) {
        "  br label %if.end"
       },
       {"%conv\n"
-      "  br i1 %cmp.i0, label %if.else, label %if.end",
+      "  br i1 %cmp, label %if.else, label %if.end",
        "to float\n"
-       "  br i1 %cmp.i0, label %if.then, label %if.end"
+       "  br i1 %cmp, label %if.then, label %if.end"
       },
       "Execution flow must be reducible");
 }