Browse Source

Handle matrix type when expanding store intrinsics (#3418)

Vishal Sharma 4 years ago
parent
commit
23a94db422

+ 37 - 7
lib/HLSL/HLExpandStoreIntrinsics.cpp

@@ -10,6 +10,8 @@
 #include "dxc/Support/Global.h"
 #include "dxc/HLSL/HLOperations.h"
 #include "dxc/HLSL/HLMatrixType.h"
+#include "dxc/HLSL/HLModule.h"
+#include "dxc/DXIL/DxilTypeSystem.h"
 #include "dxc/HlslIntrinsicOp.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instruction.h"
@@ -40,10 +42,11 @@ public:
   bool runOnFunction(Function& Func) override;
 
 private:
+  DxilTypeSystem *m_typeSys;
   bool expand(CallInst *StoreCall);
   void emitElementStores(CallInst &OriginalCall,
     SmallVectorImpl<Value*>& GEPIndicesStack, Type *StackTopTy,
-    unsigned OffsetFromBase);
+    unsigned OffsetFromBase, DxilFieldAnnotation *fieldAnnotation);
 };
 
 char HLExpandStoreIntrinsics::ID = 0;
@@ -51,6 +54,7 @@ char HLExpandStoreIntrinsics::ID = 0;
 bool HLExpandStoreIntrinsics::runOnFunction(Function& Func) {
   bool changed = false;
 
+  m_typeSys = &(Func.getParent()->GetHLModule().GetTypeSystem());
   for (auto InstIt = inst_begin(Func), InstEnd = inst_end(Func); InstIt != InstEnd;) {
     CallInst *Call = dyn_cast<CallInst>(&*(InstIt++));
     if (Call == nullptr
@@ -73,7 +77,7 @@ bool HLExpandStoreIntrinsics::expand(CallInst* StoreCall) {
   IRBuilder<> Builder(StoreCall);
   SmallVector<Value*, 4> GEPIndicesStack;
   GEPIndicesStack.emplace_back(Builder.getInt32(0));
-  emitElementStores(*StoreCall, GEPIndicesStack, OldStoreValueArgTy->getPointerElementType(), /* OffsetFromBase */ 0);
+  emitElementStores(*StoreCall, GEPIndicesStack, OldStoreValueArgTy->getPointerElementType(), /* OffsetFromBase */ 0, nullptr);
   DXASSERT(StoreCall->getType()->isVoidTy() && StoreCall->use_empty(),
     "Buffer store intrinsic is expected to return void and hence not have uses.");
   StoreCall->eraseFromParent();
@@ -82,18 +86,20 @@ bool HLExpandStoreIntrinsics::expand(CallInst* StoreCall) {
 
 void HLExpandStoreIntrinsics::emitElementStores(CallInst &OriginalCall,
     SmallVectorImpl<Value*>& GEPIndicesStack, Type *StackTopTy,
-    unsigned OffsetFromBase) {
+    unsigned OffsetFromBase, DxilFieldAnnotation* fieldAnnotation) {
   llvm::Module &Module = *OriginalCall.getModule();
   IRBuilder<> Builder(&OriginalCall);
 
   StructType* StructTy = dyn_cast<StructType>(StackTopTy);
   if (StructTy != nullptr && !HLMatrixType::isa(StructTy)) {
     const StructLayout* Layout = Module.getDataLayout().getStructLayout(StructTy);
+    DxilStructAnnotation *SA = m_typeSys->GetStructAnnotation(StructTy);
     for (unsigned i = 0; i < StructTy->getNumElements(); ++i) {
       Type *ElemTy = StructTy->getElementType(i);
       unsigned ElemOffsetFromBase = OffsetFromBase + Layout->getElementOffset(i);
       GEPIndicesStack.emplace_back(Builder.getInt32(i));
-      emitElementStores(OriginalCall, GEPIndicesStack, ElemTy, ElemOffsetFromBase);
+      DxilFieldAnnotation* FA = SA != nullptr ? &(SA->GetFieldAnnotation(i)) : nullptr;
+      emitElementStores(OriginalCall, GEPIndicesStack, ElemTy, ElemOffsetFromBase, FA);
       GEPIndicesStack.pop_back();
     }
   }
@@ -102,7 +108,7 @@ void HLExpandStoreIntrinsics::emitElementStores(CallInst &OriginalCall,
     for (int i = 0; i < (int)ArrayTy->getNumElements(); ++i) {
       unsigned ElemOffsetFromBase = OffsetFromBase + ElemSize * i;
       GEPIndicesStack.emplace_back(Builder.getInt32(i));
-      emitElementStores(OriginalCall, GEPIndicesStack, ArrayTy->getElementType(), ElemOffsetFromBase);
+      emitElementStores(OriginalCall, GEPIndicesStack, ArrayTy->getElementType(), ElemOffsetFromBase, fieldAnnotation);
       GEPIndicesStack.pop_back();
     }
   }
@@ -118,8 +124,32 @@ void HLExpandStoreIntrinsics::emitElementStores(CallInst &OriginalCall,
 
     Value* AggPtr = OriginalCall.getArgOperand(HLOperandIndex::kStoreValOpIdx);
     Value *ElemPtr = Builder.CreateGEP(AggPtr, GEPIndicesStack);
-    Value *ElemVal = Builder.CreateLoad(ElemPtr); // We go from memory to memory so no special bool handling needed
-    
+    Value* ElemVal = nullptr;
+
+    if (HLMatrixType::isa(StackTopTy) && fieldAnnotation &&
+        fieldAnnotation->HasMatrixAnnotation()) {
+
+      // For matrix load, we generate HL intrinsic matldst.colLoad/matldst.rowLoad
+      // instead of LLVM LoadInst to ensure that it gets lowered properly later
+      // in HLMatrixLowerPass
+      bool isRowMajor = fieldAnnotation->GetMatrixAnnotation().Orientation ==
+                        hlsl::MatrixOrientation::RowMajor;
+      unsigned matLdOpcode =
+          isRowMajor ? static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad)
+                     : static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad);
+      // Generate matrix load
+      FunctionType *MatLdFnType = FunctionType::get(
+          StackTopTy, {Builder.getInt32Ty(), ElemPtr->getType()},
+          /* isVarArg */ false);
+
+      Function *MatLdFn = GetOrCreateHLFunction(
+          Module, MatLdFnType, HLOpcodeGroup::HLMatLoadStore, matLdOpcode);
+      Value *MatLdOpCode = ConstantInt::get(Builder.getInt32Ty(), matLdOpcode);
+      ElemVal = Builder.CreateCall(MatLdFn, {MatLdOpCode, ElemPtr});
+    } else {
+      ElemVal = Builder.CreateLoad(ElemPtr); // We go from memory to memory so no special bool handling needed
+    }
+
     FunctionType *NewCalleeType = FunctionType::get(Builder.getVoidTy(),
       { OpcodeVal->getType(), BufHandle->getType(), OffsetVal->getType(), ElemVal->getType() },
       /* isVarArg */ false);

+ 46 - 0
tools/clang/test/HLSLFileCheck/hlsl/objects/ByteAddressBuffer/rwbab_store_struct_with_mat_const_init_zpc.hlsl

@@ -0,0 +1,46 @@
+// RUN: %dxc -E main -Zpc -T vs_6_5 /DTEST1=1 %s | FileCheck %s -check-prefix=CHK_TEST1
+// RUN: %dxc -E main -Zpc -T vs_6_5 %s | FileCheck %s -check-prefix=CHK_TEST2
+
+// Regression test for github issue# #3226
+
+
+#ifdef TEST1
+struct S
+{
+  float1x1 mat;    
+};
+#else
+struct S
+{
+  float2x2 mat;
+  float f1; 
+  int i1; 
+  struct S1
+  {
+      float1x1 mat1; // nested struct
+  } s1;
+};
+#endif
+
+RWByteAddressBuffer buf;
+
+void main()
+{
+
+#ifdef TEST1
+    S t = {{1}};
+#else
+    S t = {{1, 2, 3, 4}, 5.f, 6, {7}};
+#endif
+    // CHK_TEST1: dx.op.rawBufferStore.f32
+    // CHK_TEST1: float 1.000000e+00
+    
+    // CHK_TEST2: dx.op.rawBufferStore.f32
+    // CHK_TEST2: float 1.000000e+00, float 3.000000e+00, float 2.000000e+00, float 4.000000e+00
+    // CHK_TEST2: dx.op.rawBufferStore.f32
+    // CHK_TEST2: float 5.000000e+00
+    // CHK_TEST2: dx.op.rawBufferStore.i32
+    // CHK_TEST2: dx.op.rawBufferStore.f32
+    // CHK_TEST2: float 7.000000e+00
+    buf.Store(0, t);
+}

+ 46 - 0
tools/clang/test/HLSLFileCheck/hlsl/objects/ByteAddressBuffer/rwbab_store_struct_with_mat_const_init_zpr.hlsl

@@ -0,0 +1,46 @@
+// RUN: %dxc -E main -Zpr -T vs_6_5 /DTEST1=1 %s | FileCheck %s -check-prefix=CHK_TEST1
+// RUN: %dxc -E main -Zpr -T vs_6_5 %s | FileCheck %s -check-prefix=CHK_TEST2
+
+// Regression test for github issue# #3226
+
+
+#ifdef TEST1
+struct S
+{
+  float1x1 mat;    
+};
+#else
+struct S
+{
+  float2x2 mat;
+  float f1; 
+  int i1; 
+  struct S1
+  {
+      float1x1 mat1; // nested struct
+  } s1;
+};
+#endif
+
+RWByteAddressBuffer buf;
+
+void main()
+{
+
+#ifdef TEST1
+    S t = {{1}};
+#else
+    S t = {{1, 2, 3, 4}, 5.f, 6, {7}};
+#endif
+    // CHK_TEST1: dx.op.rawBufferStore.f32
+    // CHK_TEST1: float 1.000000e+00
+    
+    // CHK_TEST2: dx.op.rawBufferStore.f32
+    // CHK_TEST2: float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 4.000000e+00
+    // CHK_TEST2: dx.op.rawBufferStore.f32
+    // CHK_TEST2: float 5.000000e+00
+    // CHK_TEST2: dx.op.rawBufferStore.i32
+    // CHK_TEST2: dx.op.rawBufferStore.f32
+    // CHK_TEST2: float 7.000000e+00
+    buf.Store(0, t);
+}

+ 37 - 0
tools/clang/test/HLSLFileCheck/hlsl/objects/ByteAddressBuffer/rwbab_store_struct_with_mat_zpr.hlsl

@@ -0,0 +1,37 @@
+// RUN: %dxc -E main -Zpr -T vs_6_5 /DTEST1=1 %s | FileCheck %s -check-prefix=CHK_TEST1
+// RUN: %dxc -E main -Zpr -T vs_6_5 %s | FileCheck %s -check-prefix=CHK_TEST2
+
+// Regression test for github issue# #3226
+
+
+#ifdef TEST1
+struct S
+{
+  float1x1 mat;    
+};
+#else
+struct S
+{
+  float2x2 mat;
+  float f1; 
+  int i1; 
+  struct S1 {
+      float1x1 mat1; // nested struct
+  } s1;
+};
+#endif
+
+RWByteAddressBuffer buf;
+
+void main(S a : IN0)
+{
+    S t = a;
+    
+    // CHK_TEST1: dx.op.rawBufferStore.f32
+    
+    // CHK_TEST2: dx.op.rawBufferStore.f32
+    // CHK_TEST2: dx.op.rawBufferStore.f32
+    // CHK_TEST2: dx.op.rawBufferStore.i32
+    // CHK_TEST2: dx.op.rawBufferStore.f32
+    buf.Store(0, t);
+}