2
0
Эх сурвалжийг харах

Fix broken ternary on matrix return type (#4434) (#4460)

HLSL ternary operators that result in vector or matrix types need
special handling even if the condition is not a vector. This change
allows vector and matrix result types for ternary operators even if
matrix and vector conditions are not allowed.

The change works by generating an alloca before ternary blocks, and
terminating ternary blocks by writing the resulting matrices to the
alloca. The result of the ternary is then a load of the alloca'd matrix.

Fixes #4434
Chris B 3 жил өмнө
parent
commit
81ef064a47

+ 31 - 1
tools/clang/lib/CodeGen/CGExprScalar.cpp

@@ -20,10 +20,12 @@
 #include "TargetInfo.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/DeclObjC.h"
+#include "clang/AST/HlslTypes.h"
 #include "clang/AST/RecordLayout.h"
 #include "clang/AST/StmtVisitor.h"
 #include "clang/Basic/TargetInfo.h"
 #include "clang/Frontend/CodeGenOptions.h"
+#include "dxc/DXIL/DxilUtil.h" // HLSL Change
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
@@ -3701,7 +3703,7 @@ VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
   }
   // HLSL Change Starts
   if (CGF.getLangOpts().HLSL && !CGF.getLangOpts().EnableShortCircuit) {
-    // HLSL does not short circuit by default.
+    // HLSL does not short circuit by default before HLSL 2021
     if (hlsl::IsHLSLVecType(E->getType()) || E->getType()->isArithmeticType()) {
       llvm::Value *CondV = CGF.EmitScalarExpr(condExpr);
       llvm::Value *LHS = Visit(lhsExpr);
@@ -3729,6 +3731,7 @@ VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
           CGF, E, LHS->getType(), {Cond, LHS, RHS});
     }
   }
+
   // HLSL Change Ends
 
   // If this is a really simple expression (like x ? 4 : 5), emit this as a
@@ -3749,6 +3752,17 @@ VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
     return Builder.CreateSelect(CondV, LHS, RHS, "cond");
   }
 
+  // HLSL Change Begins
+  llvm::Instruction *ResultAlloca = nullptr;
+  if (CGF.getLangOpts().HLSL && CGF.getLangOpts().EnableShortCircuit &&
+      hlsl::IsHLSLMatType(E->getType())) {
+    llvm::Type *MatTy = CGF.ConvertTypeForMem(E->getType());
+    ResultAlloca = CGF.CreateTempAlloca(MatTy);
+    ResultAlloca->moveBefore(hlsl::dxilutil::FindAllocaInsertionPt(
+        Builder.GetInsertBlock()->getParent()));
+  }
+  // HLSL Change Ends
+
   llvm::BasicBlock *LHSBlock = CGF.createBasicBlock("cond.true");
   llvm::BasicBlock *RHSBlock = CGF.createBasicBlock("cond.false");
   llvm::BasicBlock *ContBlock = CGF.createBasicBlock("cond.end");
@@ -3761,6 +3775,11 @@ VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
   CGF.incrementProfileCounter(E);
   eval.begin(CGF);
   Value *LHS = Visit(lhsExpr);
+  // HLSL Change Begin - Handle matrix ternary
+  if (ResultAlloca)
+    CGF.CGM.getHLSLRuntime().EmitHLSLMatrixStore(CGF, LHS, ResultAlloca,
+                                                 E->getType());
+  // HLSL Change End
   eval.end(CGF);
 
   LHSBlock = Builder.GetInsertBlock();
@@ -3769,11 +3788,22 @@ VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
   CGF.EmitBlock(RHSBlock);
   eval.begin(CGF);
   Value *RHS = Visit(rhsExpr);
+  // HLSL Change Begin - Handle matrix ternary
+  if (ResultAlloca)
+    CGF.CGM.getHLSLRuntime().EmitHLSLMatrixStore(CGF, RHS, ResultAlloca,
+                                                 E->getType());
+  // HLSL Change End
   eval.end(CGF);
 
   RHSBlock = Builder.GetInsertBlock();
   CGF.EmitBlock(ContBlock);
 
+  // HLSL Change Begin - Handle matrix ternary
+  if (ResultAlloca)
+    return CGF.CGM.getHLSLRuntime().EmitHLSLMatrixLoad(CGF, ResultAlloca,
+                                                       E->getType());
+  // HLSL Change End
+
   // If the LHS or RHS is a throw expression, it will be legitimately null.
   if (!LHS)
     return RHS;

+ 22 - 0
tools/clang/test/HLSLFileCheck/hlsl/operators/ternary-return-matrix.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -T cs_6_0 -E CSMain -HV 2021 %s -fcgl | FileCheck %s
+
+float2x2 crashingFunction(bool b) {
+  float2x2 x = {0.0, 0.0, 0.0, 0.0};
+  float2x2 y = {0.0, 0.0, 0.0, 0.0};
+  return b ? x : y; // <-- this is the issue
+}
+
+[numthreads(1, 1, 1)] void CSMain() {
+  if (crashingFunction(true)[0][0] > 0)
+    return;
+}
+
+// CHECK: define internal %class.matrix.float.2.2 @"\01?crashingFunction@@YA?AV?$matrix@M$01$01@@_N@Z"
+// CHECK: [[ALLOCA:%[0-9a-z]+]] = alloca %class.matrix.float.2.2
+// CHECK: preds = {{%[0-9a-z]+}}
+// CHECK: call %class.matrix.float.2.2 @"dx.hl.matldst.colStore.%class.matrix.float.2.2 (i32, %class.matrix.float.2.2*, %class.matrix.float.2.2)"(i32 1, %class.matrix.float.2.2* [[ALLOCA]], %class.matrix.float.2.2 %{{[0-9]+}})
+// CHECK: preds = {{%[0-9a-z]+}}
+// CHECK: call %class.matrix.float.2.2 @"dx.hl.matldst.colStore.%class.matrix.float.2.2 (i32, %class.matrix.float.2.2*, %class.matrix.float.2.2)"(i32 1, %class.matrix.float.2.2* [[ALLOCA]], %class.matrix.float.2.2 %{{[0-9]+}})
+// CHECK: preds = {{%[0-9a-z.]+}}, {{%[0-9a-z.]+}}
+// CHECK: call %class.matrix.float.2.2 @"dx.hl.matldst.colLoad.%class.matrix.float.2.2 (i32, %class.matrix.float.2.2*)"(i32 0, %class.matrix.float.2.2* [[ALLOCA]])
+