2
0
Эх сурвалжийг харах

Matrix lowering for functions with UDT params preserved.

- Keep track of patch constant functions for later identification
- functions that require input/output signature processing identified
  with IsEntryThatUsesSignatures
- update lib_rt.hlsl intrinsics and naming
Tex Riddell 7 жил өмнө
parent
commit
15cd5f16e6

+ 12 - 0
include/dxc/HLSL/DxilModule.h

@@ -26,6 +26,7 @@
 #include <string>
 #include <vector>
 #include <unordered_map>
+#include <unordered_set>
 
 namespace llvm {
 class LLVMContext;
@@ -132,6 +133,14 @@ public:
   DxilFunctionProps &GetDxilFunctionProps(llvm::Function *F);
   // Move DxilFunctionProps of F to NewF.
   void ReplaceDxilFunctionProps(llvm::Function *F, llvm::Function *NewF);
+  void SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, llvm::Function *patchConstantFunc);
+  bool IsGraphicsShader(llvm::Function *F); // vs,hs,ds,gs,ps
+  bool IsPatchConstantShader(llvm::Function *F);
+  bool IsComputeShader(llvm::Function *F);
+
+  // Is an entry function that uses input/output signature conventions?
+  // Includes: vs/hs/ds/gs/ps/cs as well as the patch constant function.
+  bool IsEntryThatUsesSignatures(llvm::Function *F);
 
   // Remove Root Signature from module metadata
   void StripRootSignatureFromMetadata();
@@ -436,6 +445,9 @@ private:
   std::unordered_map<llvm::Function *, std::unique_ptr<DxilEntrySignature>>
       m_DxilEntrySignatureMap;
 
+  // Keeps track of patch constant functions used by hull shaders
+  std::unordered_set<llvm::Function *>  m_PatchConstantFunctions;
+
   // ViewId state.
   std::unique_ptr<DxilViewIdState> m_pViewIdState;
 

+ 10 - 0
include/dxc/HLSL/HLModule.h

@@ -24,6 +24,7 @@
 #include <string>
 #include <vector>
 #include <unordered_map>
+#include <unordered_set>
 
 namespace llvm {
 class LLVMContext;
@@ -127,6 +128,14 @@ public:
   bool HasDxilFunctionProps(llvm::Function *F);
   DxilFunctionProps &GetDxilFunctionProps(llvm::Function *F);
   void AddDxilFunctionProps(llvm::Function *F, std::unique_ptr<DxilFunctionProps> &info);
+  void SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, llvm::Function *patchConstantFunc);
+  bool IsGraphicsShader(llvm::Function *F); // vs,hs,ds,gs,ps
+  bool IsPatchConstantShader(llvm::Function *F);
+  bool IsComputeShader(llvm::Function *F);
+
+  // Is an entry function that uses input/output signature conventions?
+  // Includes: vs/hs/ds/gs/ps/cs as well as the patch constant function.
+  bool IsEntryThatUsesSignatures(llvm::Function *F);
 
   DxilFunctionAnnotation *GetFunctionAnnotation(llvm::Function *F);
   DxilFunctionAnnotation *AddFunctionAnnotation(llvm::Function *F);
@@ -238,6 +247,7 @@ private:
 
   // High level function info.
   std::unordered_map<llvm::Function *, std::unique_ptr<DxilFunctionProps>>  m_DxilFunctionPropsMap;
+  std::unordered_set<llvm::Function *>  m_PatchConstantFunctions;
 
   // Resource type annotation.
   std::unordered_map<llvm::Type *, std::pair<DXIL::ResourceClass, DXIL::ResourceKind>> m_ResTypeAnnotation;

+ 1 - 1
lib/HLSL/DxilGenerationPass.cpp

@@ -238,7 +238,7 @@ public:
       for (auto It = M.begin(); It != M.end();) {
         Function &F = *(It++);
         // Lower signature for each entry function.
-        if (m_pHLModule->HasDxilFunctionProps(&F)) {
+        if (m_pHLModule->IsEntryThatUsesSignatures(&F)) {
           DxilFunctionProps &props = m_pHLModule->GetDxilFunctionProps(&F);
           std::unique_ptr<DxilEntrySignature> pSig =
               llvm::make_unique<DxilEntrySignature>(props.shaderKind, m_pHLModule->GetHLOptions().bUseMinPrecision);

+ 2 - 1
lib/HLSL/DxilLinker.cpp

@@ -607,7 +607,8 @@ DxilLinkJob::Link(std::pair<DxilFunctionLinkInfo *, DxilLib *> &entryLinkPair,
     Function *patchConstantFunc = props.ShaderProps.HS.patchConstantFunc;
     Function *newPatchConstantFunc =
         m_newFunctions[patchConstantFunc->getName()];
-    props.ShaderProps.HS.patchConstantFunc = newPatchConstantFunc;
+    DM.SetPatchConstantFunctionForHS(entryFunc, nullptr);
+    DM.SetPatchConstantFunctionForHS(NewEntryFunc, newPatchConstantFunc);
 
     if (newPatchConstantFunc->hasFnAttribute(llvm::Attribute::AlwaysInline))
       newPatchConstantFunc->removeFnAttr(llvm::Attribute::AlwaysInline);

+ 34 - 0
lib/HLSL/DxilModule.cpp

@@ -1102,6 +1102,35 @@ void DxilModule::ReplaceDxilFunctionProps(llvm::Function *F,
   m_DxilFunctionPropsMap.erase(F);
   m_DxilFunctionPropsMap[NewF] = std::move(props);
 }
+void DxilModule::SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, llvm::Function *patchConstantFunc) {
+  auto propIter = m_DxilFunctionPropsMap.find(hullShaderFunc);
+  DXASSERT(propIter != m_DxilFunctionPropsMap.end(), "Hull shader must already have function props!");
+  DxilFunctionProps &props = *(propIter->second);
+  DXASSERT(props.IsHS(), "else hullShaderFunc is not a Hull Shader");
+  if (props.ShaderProps.HS.patchConstantFunc)
+    m_PatchConstantFunctions.erase(props.ShaderProps.HS.patchConstantFunc);
+  props.ShaderProps.HS.patchConstantFunc = patchConstantFunc;
+  if (patchConstantFunc)
+    m_PatchConstantFunctions.insert(patchConstantFunc);
+}
+bool DxilModule::IsGraphicsShader(llvm::Function *F) {
+  return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsGraphics();
+}
+bool DxilModule::IsPatchConstantShader(llvm::Function *F) {
+  return m_PatchConstantFunctions.count(F) != 0;
+}
+bool DxilModule::IsComputeShader(llvm::Function *F) {
+  return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsCS();
+}
+bool DxilModule::IsEntryThatUsesSignatures(llvm::Function *F) {
+  auto propIter = m_DxilFunctionPropsMap.find(F);
+  if (propIter != m_DxilFunctionPropsMap.end()) {
+    DxilFunctionProps &props = *(propIter->second);
+    return props.IsGraphics() || props.IsCS();
+  }
+  // Otherwise, return true if patch constant function
+  return IsPatchConstantShader(F);
+}
 
 void DxilModule::StripRootSignatureFromMetadata() {
   NamedMDNode *pRootSignatureNamedMD = GetModule()->getNamedMetadata(DxilMDHelper::kDxilRootSignatureMDName);
@@ -1319,6 +1348,11 @@ void DxilModule::LoadDxilMetadata() {
 
       Function *F = m_pMDHelper->LoadDxilFunctionProps(pProps, props.get());
 
+      if (props->IsHS() && props->ShaderProps.HS.patchConstantFunc) {
+        // Add patch constant function to m_PatchConstantFunctions
+        m_PatchConstantFunctions.insert(props->ShaderProps.HS.patchConstantFunc);
+      }
+
       m_DxilFunctionPropsMap[F] = std::move(props);
     }
 

+ 2 - 2
lib/HLSL/DxilPreparePasses.cpp

@@ -374,7 +374,7 @@ private:
     } else {
       std::vector<Function *> entries;
       for (iplist<Function>::iterator F : M.getFunctionList()) {
-        if (DM.HasDxilFunctionProps(F)) {
+        if (DM.IsEntryThatUsesSignatures(F)) {
           entries.emplace_back(F);
         }
       }
@@ -384,7 +384,7 @@ private:
           // Strip patch constant function first.
           Function *patchConstFunc = StripFunctionParameter(
               props.ShaderProps.HS.patchConstantFunc, DM, FunctionDIs);
-          props.ShaderProps.HS.patchConstantFunc = patchConstFunc;
+          DM.SetPatchConstantFunctionForHS(entry, patchConstFunc);
         }
         StripFunctionParameter(entry, DM, FunctionDIs);
       }

+ 108 - 24
lib/HLSL/HLMatrixLowerPass.cpp

@@ -272,6 +272,9 @@ private:
   // Get new matrix value corresponding to vecVal
   Value *GetMatrixForVec(Value *vecVal, Type *matTy);
 
+  // Translate library function input/output to preserve function signatures
+  void TranslateLibraryArgs(Function &F);
+
   // Replace matVal with vecVal on matUseInst.
   void TrivialMatReplace(Value *matVal, Value *vecVal,
                         CallInst *matUseInst);
@@ -1269,6 +1272,16 @@ void HLMatrixLowerPass::TrivialMatReplace(Value *matVal,
     }
 }
 
+static Instruction *CreateTransposeShuffle(IRBuilder<> &Builder, Value *vecVal, unsigned row, unsigned col) {
+  SmallVector<int, 16> castMask(col * row);
+  unsigned idx = 0;
+  for (unsigned c = 0; c < col; c++)
+    for (unsigned r = 0; r < row; r++)
+      castMask[idx++] = r * col + c;
+  return cast<Instruction>(
+    Builder.CreateShuffleVector(vecVal, vecVal, castMask));
+}
+
 void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal,
                                               Value *vecVal,
                                               CallInst *castInst,
@@ -1291,25 +1304,9 @@ void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal,
 
   IRBuilder<> Builder(castInst);
 
-  // shuf to change major.
-  SmallVector<int, 16> castMask(col * row);
-  unsigned idx = 0;
-  if (bRowToCol) {
-    for (unsigned c = 0; c < col; c++)
-      for (unsigned r = 0; r < row; r++) {
-        unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
-        castMask[idx++] = matIdx;
-      }
-  } else {
-    for (unsigned r = 0; r < row; r++)
-      for (unsigned c = 0; c < col; c++) {
-        unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
-        castMask[idx++] = matIdx;
-      }
-  }
-
-  Instruction *vecCast = cast<Instruction>(
-      Builder.CreateShuffleVector(vecVal, vecVal, castMask));
+  if (bRowToCol)
+    std::swap(row, col);
+  Instruction *vecCast = CreateTransposeShuffle(Builder, vecVal, row, col);
 
   // Replace vec cast function call with vecCast.
   DXASSERT(matToVecMap.count(castInst), "must has vec version");
@@ -2109,12 +2106,10 @@ Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) {
 
 void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
                                           Value *vecVal) {
+  Type *matTy = matVal->getType();
   for (Value::user_iterator user = matVal->user_begin();
        user != matVal->user_end();) {
     Instruction *useInst = cast<Instruction>(*(user++));
-    // Skip return here.
-    if (isa<ReturnInst>(useInst))
-      continue;
     // User must be function call.
     if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
       hlsl::HLOpcodeGroup group =
@@ -2183,7 +2178,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
         for (unsigned i = 0; i < useCall->getNumArgOperands(); i++) {
           if (useCall->getArgOperand(i) == matVal) {
             // update the user call with the correct matrix value in new code sequence
-            Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
+            Value *newMatVal = GetMatrixForVec(vecVal, matTy);
             if (matVal != newMatVal)
               useCall->setArgOperand(i, newMatVal);
           }
@@ -2194,8 +2189,10 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
       // Just replace the src with vec version.
       useInst->setOperand(0, vecVal);
     } else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
-      Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
+      Value *newMatVal = GetMatrixForVec(vecVal, matTy);
       RI->setOperand(0, newMatVal);
+    } else if (isa<StoreInst>(useInst)) {
+      DXASSERT(vecToMatMap.count(vecVal) && vecToMatMap[vecVal] == matVal, "matrix store should only be used with preserved matrix values");
     } else {
       // Must be GEP on mat array alloca.
       GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
@@ -2467,6 +2464,85 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
   }
 }
 
+void HLMatrixLowerPass::TranslateLibraryArgs(Function &F) {
+  // Replace HLCast with BitCastValueOrPtr (+ transpose for colMatToVec)
+  // Replace HLMatLoadStore with bitcast + load/store + shuffle if col major
+  for (auto &arg : F.args()) {
+    SmallVector<CallInst *, 4> Candidates;
+    for (User *U : arg.users()) {
+      if (CallInst *CI = dyn_cast<CallInst>(U)) {
+        HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
+        switch (group) {
+        case HLOpcodeGroup::HLCast:
+        case HLOpcodeGroup::HLMatLoadStore:
+          Candidates.push_back(CI);
+          break;
+        }
+      }
+    }
+    for (CallInst *CI : Candidates) {
+      IRBuilder<> Builder(CI);
+      HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
+      switch (group) {
+      case HLOpcodeGroup::HLCast: {
+        HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
+        if (opcode == HLCastOpcode::RowMatrixToVecCast ||
+            opcode == HLCastOpcode::ColMatrixToVecCast) {
+          Value *matVal = CI->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
+          Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(),
+                                            /*bOrigAllocaTy*/false,
+                                            matVal->getName());
+          if (opcode == HLCastOpcode::ColMatrixToVecCast) {
+            unsigned row, col;
+            HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
+            vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
+          }
+          CI->replaceAllUsesWith(vecVal);
+          CI->eraseFromParent();
+        }
+      } break;
+      case HLOpcodeGroup::HLMatLoadStore: {
+        HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
+        bool bTranspose = false;
+        switch (opcode) {
+        case HLMatLoadStoreOpcode::ColMatStore:
+          bTranspose = true;
+        case HLMatLoadStoreOpcode::RowMatStore: {
+          // shuffle if transposed, bitcast, and store
+          Value *vecVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
+          Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
+          if (bTranspose) {
+            unsigned row, col;
+            HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
+            vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
+          }
+          Value *castPtr = Builder.CreateBitCast(matPtr, vecVal->getType()->getPointerTo());
+          Builder.CreateStore(vecVal, castPtr);
+          CI->eraseFromParent();
+        } break;
+        case HLMatLoadStoreOpcode::ColMatLoad:
+          bTranspose = true;
+        case HLMatLoadStoreOpcode::RowMatLoad: {
+          // bitcast, load, and shuffle if transposed
+          Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
+          Value *castPtr = Builder.CreateBitCast(matPtr, CI->getType()->getPointerTo());
+          Value *vecVal = Builder.CreateLoad(castPtr);
+          if (bTranspose) {
+            unsigned row, col;
+            HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
+            // row/col swapped for col major source
+            vecVal = CreateTransposeShuffle(Builder, vecVal, col, row);
+          }
+          CI->replaceAllUsesWith(vecVal);
+          CI->eraseFromParent();
+        } break;
+        }
+      } break;
+      }
+    }
+  }
+}
+
 void HLMatrixLowerPass::runOnFunction(Function &F) {
   // Create vector version of matrix instructions first.
   // The matrix operands will be undefval for these instructions.
@@ -2531,4 +2607,12 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
   DeleteDeadInsts();
   
   matToVecMap.clear();
+  vecToMatMap.clear();
+
+  // If this is a library function, now fix input/output matrix params
+  // TODO: What about Patch Constant Shaders?
+  if (!m_pHLModule->IsEntryThatUsesSignatures(&F)) {
+    TranslateLibraryArgs(F);
+  }
+  return;
 }

+ 34 - 0
lib/HLSL/HLModule.cpp

@@ -350,6 +350,35 @@ void HLModule::AddDxilFunctionProps(llvm::Function *F, std::unique_ptr<DxilFunct
   DXASSERT_NOMSG(info->shaderKind != DXIL::ShaderKind::Invalid);
   m_DxilFunctionPropsMap[F] = std::move(info);
 }
+void HLModule::SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, llvm::Function *patchConstantFunc) {
+  auto propIter = m_DxilFunctionPropsMap.find(hullShaderFunc);
+  DXASSERT(propIter != m_DxilFunctionPropsMap.end(), "else Hull Shader missing function props");
+  DxilFunctionProps &props = *(propIter->second);
+  DXASSERT(props.IsHS(), "else hullShaderFunc is not a Hull Shader");
+  if (props.ShaderProps.HS.patchConstantFunc)
+    m_PatchConstantFunctions.erase(props.ShaderProps.HS.patchConstantFunc);
+  props.ShaderProps.HS.patchConstantFunc = patchConstantFunc;
+  if (patchConstantFunc)
+    m_PatchConstantFunctions.insert(patchConstantFunc);
+}
+bool HLModule::IsGraphicsShader(llvm::Function *F) {
+  return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsGraphics();
+}
+bool HLModule::IsPatchConstantShader(llvm::Function *F) {
+  return m_PatchConstantFunctions.count(F) != 0;
+}
+bool HLModule::IsComputeShader(llvm::Function *F) {
+  return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsCS();
+}
+bool HLModule::IsEntryThatUsesSignatures(llvm::Function *F) {
+  auto propIter = m_DxilFunctionPropsMap.find(F);
+  if (propIter != m_DxilFunctionPropsMap.end()) {
+    DxilFunctionProps &props = *(propIter->second);
+    return props.IsGraphics() || props.IsCS();
+  }
+  // Otherwise, return true if patch constant function
+  return IsPatchConstantShader(F);
+}
 
 DxilFunctionAnnotation *HLModule::GetFunctionAnnotation(llvm::Function *F) {
   return m_pTypeSystem->GetFunctionAnnotation(F);
@@ -475,6 +504,11 @@ void HLModule::LoadHLMetadata() {
 
       Function *F = m_pMDHelper->LoadDxilFunctionProps(pProps, props.get());
 
+      if (props->IsHS() && props->ShaderProps.HS.patchConstantFunc) {
+        // Add patch constant function to m_PatchConstantFunctions
+        m_PatchConstantFunctions.insert(props->ShaderProps.HS.patchConstantFunc);
+      }
+
       m_DxilFunctionPropsMap[F] = std::move(props);
     }
 

+ 8 - 17
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -5173,6 +5173,9 @@ void SROA_Parameter_HLSL::flattenArgument(
     Type *Ty = V->getType();
     if (Ty->isPointerTy())
       Ty = Ty->getPointerElementType();
+
+    // Stop doing this when preserving resource types and using new
+    // createHandleFrom??? whatever it's going to be called...
     V = castResourceArgIfRequired(V, Ty, bOut, inputQual, Builder);
 
     // Cannot SROA, save it to final parameter list.
@@ -5829,20 +5832,8 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
     IRBuilder<> RetBuilder(TmpBlockForFuncDecl.get());
     RetBuilder.CreateRetVoid();
   } else {
-    Function *Entry = m_pHLModule->GetEntryFunction();
-    hasShaderInputOutput = F == Entry;
-    if (m_pHLModule->HasDxilFunctionProps(F)) {
-      DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(F);
-      if (!funcProps.IsRay())
-        hasShaderInputOutput = true;
-    }
-    if (m_pHLModule->HasDxilFunctionProps(Entry)) {
-      DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(Entry);
-      if (funcProps.shaderKind == DXIL::ShaderKind::Hull) {
-        Function *patchConstantFunc = funcProps.ShaderProps.HS.patchConstantFunc;
-        hasShaderInputOutput |= F == patchConstantFunc;
-      }
-    }
+    hasShaderInputOutput = F == m_pHLModule->GetEntryFunction() ||
+                           m_pHLModule->IsEntryThatUsesSignatures(F);
   }
 
   std::vector<Value *> FlatParamList;
@@ -6361,9 +6352,9 @@ void SROA_Parameter_HLSL::replaceCall(Function *F, Function *flatF) {
     if (funcProps.shaderKind == DXIL::ShaderKind::Hull) {
       Function *oldPatchConstantFunc =
           funcProps.ShaderProps.HS.patchConstantFunc;
-      if (funcMap.count(oldPatchConstantFunc))
-        funcProps.ShaderProps.HS.patchConstantFunc =
-            funcMap[oldPatchConstantFunc];
+      if (funcMap.count(oldPatchConstantFunc)) {
+        m_pHLModule->SetPatchConstantFunctionForHS(flatF, funcMap[oldPatchConstantFunc]);
+      }
     }
   }
   // TODO: flatten vector argument and lower resource argument when flatten

+ 3 - 3
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -4317,11 +4317,11 @@ void CGMSHLSLRuntime::SetPatchConstantFunctionWithAttr(
   }
 
   Function *patchConstFunc = Entry->second.Func;
-  DxilFunctionProps *HSProps = &m_pHLModule->GetDxilFunctionProps(EntryFunc.Func);
-  DXASSERT(HSProps != nullptr,
+  DXASSERT(m_pHLModule->HasDxilFunctionProps(EntryFunc.Func),
     " else AddHLSLFunctionInfo did not save the dxil function props for the "
     "HS entry.");
-  HSProps->ShaderProps.HS.patchConstantFunc = patchConstFunc;
+  DxilFunctionProps *HSProps = &m_pHLModule->GetDxilFunctionProps(EntryFunc.Func);
+  m_pHLModule->SetPatchConstantFunctionForHS(EntryFunc.Func, patchConstFunc);
   DXASSERT_NOMSG(patchConstantFunctionPropsMap.count(patchConstFunc));
   // Check no inout parameter for patch constant function.
   DxilFunctionAnnotation *patchConstFuncAnnotation =

+ 38 - 30
tools/clang/test/CodeGenHLSL/quick-test/lib_rt.hlsl

@@ -2,19 +2,19 @@
 
 ////////////////////////////////////////////////////////////////////////////
 // Prototype header contents to be removed on implementation of features:
-#define HIT_KIND_TRIANGLE_FRONT_FACE    0xFE
-#define HIT_KIND_TRIANGLE_BACK_FACE     0xFF
+#define HIT_KIND_TRIANGLE_FRONT_FACE              0xFE
+#define HIT_KIND_TRIANGLE_BACK_FACE               0xFF
 
 typedef uint RAY_FLAG;
-#define RAY_FLAG_NONE                         0x00
-#define RAY_FLAG_FORCE_OPAQUE                 0x01
-#define RAY_FLAG_FORCE_NON_OPAQUE             0x02
-#define RAY_FLAG_TERMINATE_ON_FIRST_HIT       0x04
-#define RAY_FLAG_SKIP_CLOSEST_HIT_SHADER      0x08
-#define RAY_FLAG_CULL_BACK_FACING_TRIANGLES   0x10
-#define RAY_FLAG_CULL_FRONT_FACING_TRIANGLES  0x20
-#define RAY_FLAG_CULL_OPAQUE                  0x40
-#define RAY_FLAG_CULL_NON_OPAQUE              0x80
+#define RAY_FLAG_NONE                             0x00
+#define RAY_FLAG_FORCE_OPAQUE                     0x01
+#define RAY_FLAG_FORCE_NON_OPAQUE                 0x02
+#define RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH  0x04
+#define RAY_FLAG_SKIP_CLOSEST_HIT_SHADER          0x08
+#define RAY_FLAG_CULL_BACK_FACING_TRIANGLES       0x10
+#define RAY_FLAG_CULL_FRONT_FACING_TRIANGLES      0x20
+#define RAY_FLAG_CULL_OPAQUE                      0x40
+#define RAY_FLAG_CULL_NON_OPAQUE                  0x80
 
 struct RayDesc
 {
@@ -29,38 +29,46 @@ struct BuiltInTriangleIntersectionAttributes
     float2 barycentrics;
 };
 
-typedef ByteAddressBuffer RayTracingAccelerationStructure;
+typedef ByteAddressBuffer RaytracingAccelerationStructure;
 
+// group: Indirect Shader Invocation
 // Declare TraceRay overload for given payload structure
 #define Declare_TraceRay(payload_t) \
-    void TraceRay(RayTracingAccelerationStructure, uint RayFlags, uint InstanceCullMask, uint RayContributionToHitGroupIndex, uint MultiplierForGeometryContributionToHitGroupIndex, uint MissShaderIndex, RayDesc, inout payload_t);
+    void TraceRay(RaytracingAccelerationStructure, uint RayFlags, uint InstanceInclusionMask, uint RayContributionToHitGroupIndex, uint MultiplierForGeometryContributionToHitGroupIndex, uint MissShaderIndex, RayDesc, inout payload_t);
 
-// Declare ReportIntersection overload for given attribute structure
-#define Declare_ReportIntersection(attr_t) \
-    bool ReportIntersection(float HitT, uint HitKind, attr_t);
+// Declare ReportHit overload for given attribute structure
+#define Declare_ReportHit(attr_t) \
+    bool ReportHit(float HitT, uint HitKind, attr_t);
 
 // Declare CallShader overload for given param structure
 #define Declare_CallShader(param_t) \
     void CallShader(uint ShaderIndex, inout param_t);
 
-void IgnoreIntersection();
-void TerminateRay();
+// group: AnyHit Terminals
+void IgnoreHit();
+void AcceptHitAndEndSearch();
 
 // System Value retrieval functions
+// group: Ray Dispatch Arguments
 uint2 RayDispatchIndex();
 uint2 RayDispatchDimension();
+// group: Ray Vectors
 float3 WorldRayOrigin();
 float3 WorldRayDirection();
+float3 ObjectRayOrigin();
+float3 ObjectRayDirection();
+// group: RayT
 float RayTMin();
 float CurrentRayT();
-uint PrimitiveID();
+// group: Raytracing uint System Values
+uint PrimitiveID(); // watch for existing
 uint InstanceID();
 uint InstanceIndex();
-float3 ObjectRayOrigin();
-float3 ObjectRayDirection();
+uint HitKind();
+uint RayFlag();
+// group: Ray Transforms
 row_major float3x4 ObjectToWorld();
 row_major float3x4 WorldToObject();
-uint HitKind();
 ////////////////////////////////////////////////////////////////////////////
 
 struct MyPayload {
@@ -79,7 +87,7 @@ struct MyParam {
 };
 
 Declare_TraceRay(MyPayload);
-Declare_ReportIntersection(MyAttributes);
+Declare_ReportHit(MyAttributes);
 Declare_CallShader(MyParam);
 
 // CHECK: ; S                                 sampler      NA          NA      S0             s1     1
@@ -90,7 +98,7 @@ Declare_CallShader(MyParam);
 // CHECK: @T_rangeID = external constant i32
 // CHECK: @S_rangeID = external constant i32
 
-RayTracingAccelerationStructure RTAS : register(t5);
+RaytracingAccelerationStructure RTAS : register(t5);
 
 // CHECK: define void [[raygen1:@"\\01\?raygen1@[^\"]+"]]() {
 // CHECK:   [[RAWBUF_ID:[^ ]+]] = load i32, i32* @RTAS_rangeID
@@ -114,7 +122,7 @@ void raygen1()
 // CHECK: define void [[intersection1:@"\\01\?intersection1@[^\"]+"]]() {
 // CHECK:   call void {{.*}}CurrentRayT{{.*}}(float* nonnull [[pCurrentRayT:%[^)]+]])
 // CHECK:   [[CurrentRayT:%[^ ]+]] = load float, float* [[pCurrentRayT]], align 4
-// CHECK:   call void {{.*}}ReportIntersection{{.*}}(float [[CurrentRayT]], i32 0, float 0.000000e+00, float 0.000000e+00, i32 0, i1* nonnull {{.*}})
+// CHECK:   call void {{.*}}ReportHit{{.*}}(float [[CurrentRayT]], i32 0, float 0.000000e+00, float 0.000000e+00, i32 0, i1* nonnull {{.*}})
 // CHECK:   ret void
 
 [shader("intersection")]
@@ -122,15 +130,15 @@ void intersection1()
 {
   float hitT = CurrentRayT();
   MyAttributes attr = (MyAttributes)0;
-  bool bReported = ReportIntersection(hitT, 0, attr);
+  bool bReported = ReportHit(hitT, 0, attr);
 }
 
 // CHECK: define void [[anyhit1:@"\\01\?anyhit1@[^\"]+"]](float* noalias nocapture, float* noalias nocapture, float* noalias nocapture, float* noalias nocapture, i32* noalias nocapture, i32* noalias nocapture, float, float, i32)
 // CHECK:   call void {{.*}}ObjectRayOrigin{{.*}}(float* nonnull {{.*}}, float* nonnull {{.*}}, float* nonnull {{.*}})
 // CHECK:   call void {{.*}}ObjectRayDirection{{.*}}(float* nonnull {{.*}}, float* nonnull {{.*}}, float* nonnull {{.*}})
 // CHECK:   call void {{.*}}CurrentRayT{{.*}}(float* nonnull {{.*}})
-// CHECK:   call void {{.*}}TerminateRay{{.*}}()
-// CHECK:   call void {{.*}}IgnoreIntersection{{.*}}()
+// CHECK:   call void {{.*}}AcceptHitAndEndSearch{{.*}}()
+// CHECK:   call void {{.*}}IgnoreHit{{.*}}()
 // CHECK:   store float {{.*}}, float* %0, align 4
 // CHECK:   store float {{.*}}, float* %1, align 4
 // CHECK:   store float {{.*}}, float* %2, align 4
@@ -145,9 +153,9 @@ void anyhit1( inout MyPayload payload : SV_RayPayload,
 {
   float3 hitLocation = ObjectRayOrigin() + ObjectRayDirection() * CurrentRayT();
   if (hitLocation.z < attr.bary.x)
-    TerminateRay();         // aborts function
+    AcceptHitAndEndSearch();         // aborts function
   if (hitLocation.z < attr.bary.y)
-    IgnoreIntersection();   // aborts function
+    IgnoreHit();   // aborts function
   payload.color += float4(0.125, 0.25, 0.5, 1.0);
 }