Selaa lähdekoodia

Support matrix cast then subscript. (#2335)

* Support matrix cast then subscript.
Xiang Li 6 vuotta sitten
vanhempi
commit
8c9bcb99b6

+ 42 - 4
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -6303,6 +6303,33 @@ Value *CGMSHLSLRuntime::EmitHLSLLiteralCast(CodeGenFunction &CGF, Value *Src,
   }
 }
 
+// For case like ((float3xfloat3)mat4x4).m21 or ((float3xfloat3)mat4x4)[1], just
+// treat it like mat4x4.m21 or mat4x4[1].
+static Value *GetOriginMatrixOperandAndUpdateMatSize(Value *Ptr, unsigned &row,
+                                                     unsigned &col) {
+  if (CallInst *Mat = dyn_cast<CallInst>(Ptr)) {
+    HLOpcodeGroup OpcodeGroup =
+        GetHLOpcodeGroupByName(Mat->getCalledFunction());
+    if (OpcodeGroup == HLOpcodeGroup::HLCast) {
+      HLCastOpcode castOpcode = static_cast<HLCastOpcode>(GetHLOpcode(Mat));
+      if (castOpcode == HLCastOpcode::DefaultCast) {
+        Ptr = Mat->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
+        // Remove the cast which is useless now.
+        Mat->eraseFromParent();
+        // Update row and col.
+        HLMatrixType matTy =
+            HLMatrixType::cast(Ptr->getType()->getPointerElementType());
+        row = matTy.getNumRows();
+        col = matTy.getNumColumns();
+        // Don't update RetTy and DxilGeneration pass will do the right thing.
+        return Ptr;
+      }
+    }
+  }
+  return nullptr;
+}
+
+
 Value *CGMSHLSLRuntime::EmitHLSLMatrixSubscript(CodeGenFunction &CGF,
                                                 llvm::Type *RetType,
                                                 llvm::Value *Ptr,
@@ -6321,11 +6348,16 @@ Value *CGMSHLSLRuntime::EmitHLSLMatrixSubscript(CodeGenFunction &CGF,
       llvm::PointerType::get(RetType->getPointerElementType(),
                              matBase->getType()->getPointerAddressSpace());
 
+  unsigned row, col;
+  hlsl::GetHLSLMatRowColCount(Ty, row, col);
+  if (Value *OriginPtr = GetOriginMatrixOperandAndUpdateMatSize(Ptr, row, col)) {
+    Ptr = OriginPtr;
+  }
+
   // Lower mat[Idx] into real idx.
   SmallVector<Value *, 8> args;
   args.emplace_back(Ptr);
-  unsigned row, col;
-  hlsl::GetHLSLMatRowColCount(Ty, row, col);
+
   if (isRowMajor) {
     Value *cCol = ConstantInt::get(Idx->getType(), col);
     Value *Base = CGF.Builder.CreateMul(cCol, Idx);
@@ -6375,6 +6407,14 @@ Value *CGMSHLSLRuntime::EmitHLSLMatrixElement(CodeGenFunction &CGF,
   // -1 to avoid opcode param which is added in EmitHLSLMatrixOperationCallImp.
   Value *args[] = {paramList[HLOperandIndex::kMatSubscriptMatOpIdx - 1],
                    paramList[HLOperandIndex::kMatSubscriptSubOpIdx - 1]};
+
+  unsigned row, col;
+  hlsl::GetHLSLMatRowColCount(Ty, row, col);
+  Value *Ptr = paramList[0];
+  if (Value *OriginPtr = GetOriginMatrixOperandAndUpdateMatSize(Ptr, row, col)) {
+    args[0] = OriginPtr;
+  }
+
   // For all zero idx. Still all zero idx.
   if (ConstantAggregateZero *zeros = dyn_cast<ConstantAggregateZero>(idx)) {
     Constant *zero = zeros->getAggregateElement((unsigned)0);
@@ -6383,8 +6423,6 @@ Value *CGMSHLSLRuntime::EmitHLSLMatrixElement(CodeGenFunction &CGF,
   } else {
     ConstantDataSequential *elts = cast<ConstantDataSequential>(idx);
     unsigned count = elts->getNumElements();
-    unsigned row, col;
-    hlsl::GetHLSLMatRowColCount(Ty, row, col);
     std::vector<Constant *> idxs(count >> 1);
     for (unsigned i = 0; i < count; i += 2) {
       unsigned rowIdx = elts->getElementAsInteger(i);

+ 62 - 0
tools/clang/test/CodeGenHLSL/batch/expressions/operators/matrices/mat_cast_sub_write.hlsl

@@ -0,0 +1,62 @@
+// RUN: %dxc /T ps_6_0 /E main %s | FileCheck %s
+
+// Make sure cast then subscript works.
+
+float4x4 a;
+
+struct B3 {
+  uint3 ui;
+};
+
+struct B {
+   uint4  ui;
+};
+
+struct M {
+  struct B base;
+  float4x3 a;
+  float4 b;
+  float4 c;
+  float  d[6];
+};
+
+struct B3 b;
+
+RWStructuredBuffer<M> buf;
+
+float3 main(uint i:I, float3 x:X) :SV_Target {
+  // Make sure match no cast version.
+
+  // CHECK:call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %buf_UAV_structbuf, i32 %{{[0-9]+}}, i32 64, float 3.000000e+00, float undef, float undef, float undef, i8 1)
+
+  ((float)buf[i].b) = 3;
+  // CHECK:call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %buf_UAV_structbuf, i32 %{{[0-9]+}}, i32 80, float 1.000000e+01, float 9.000000e+00, float 8.000000e+00, float undef, i8 7)
+  ((float3)buf[i].c) = float3(10,9,8);
+
+  // CHECK:call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %buf_UAV_structbuf, i32 %{{[0-9]+}}, i32 96, float 5.000000e+00, float undef, float undef, float undef, i8 1)
+  // CHECK:call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %buf_UAV_structbuf, i32 %{{[0-9]+}}, i32 100, float 6.000000e+00, float undef, float undef, float undef, i8 1)
+  // CHECK:call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %buf_UAV_structbuf, i32 %{{[0-9]+}}, i32 104, float 7.000000e+00, float undef, float undef, float undef, i8 1)
+  float td[3] = {5,6,7};
+  ((float[3])buf[i].d) = td;
+
+  //CHECK:call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %buf_UAV_structbuf, i32 %{{[0-9]+}}, i32 0, i32 %{{[0-9]+}}, i32 %{{[0-9]+}}, i32 %{{[0-9]+}}, i32 undef, i8 7)
+  ((B3)buf[i].base) = b;
+
+  // buf[i].a[1].xyz = x;
+  // CHECK:call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %buf_UAV_structbuf, i32 %{{[0-9]+}}, i32 20
+  // CHECK:call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %buf_UAV_structbuf, i32 %{{[0-9]+}}, i32 36
+  // CHECK:call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %buf_UAV_structbuf, i32 %{{[0-9]+}}, i32 52
+
+  ((float3x3)buf[i].a)[1] = x;
+  // a[1].xyz + a._m21_m20_m02;
+  // CHECK:call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 0)  ; CBufferLoadLegacy(handle,regIndex)
+  // CHECK:extractvalue %dx.types.CBufRet.f32 %{{[0-9]+}}, 1
+  // CHECK:call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 1)  ; CBufferLoadLegacy(handle,regIndex)
+  // CHECK:extractvalue %dx.types.CBufRet.f32 %{{[0-9]+}}, 1
+  // CHECK:call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32(i32 59, %dx.types.Handle %"$Globals_cbuffer", i32 2)  ; CBufferLoadLegacy(handle,regIndex)
+  // CHECK:extractvalue %dx.types.CBufRet.f32 %{{[0-9]+}}, 1
+  // CHECK:extractvalue %dx.types.CBufRet.f32 %{{[0-9]+}}, 2
+  // CHECK:extractvalue %dx.types.CBufRet.f32 %{{[0-9]+}}, 2
+  // CHECK:extractvalue %dx.types.CBufRet.f32 %{{[0-9]+}}, 0
+  return ((float3x3)a)[1] + ((float3x3)a)._m21_m20_m02;
+}