Browse Source

Add RayQuery object, TraceRayInline method + template arg annotations

Tex Riddell 6 years ago
parent
commit
2209844cda

+ 3 - 1
docs/DXIL.rst

@@ -2245,7 +2245,7 @@ ID  Name                          Description
 154 RayTCurrent                   float representing the current parametric ending point for the ray
 155 IgnoreHit                     Used in an any hit shader to reject an intersection and terminate the shader
 156 AcceptHitAndEndSearch         Used in an any hit shader to abort the ray query and the intersection shader (if any). The current hit is committed and execution passes to the closest hit shader with the closest hit recorded so far
-157 TraceRay                      returns the view index
+157 TraceRay                      initiates raytrace
 158 ReportHit                     returns true if hit was accepted
 159 CallShader                    Call a shader in the callable shader table supplied through the DispatchRays() API
 160 CreateHandleForLib            create resource handle from resource struct for library
@@ -2256,6 +2256,8 @@ ID  Name                          Description
 165 WaveMatch                     returns the bitmask of active lanes that have the same value
 166 WaveMultiPrefixOp             returns the result of the operation on groups of lanes identified by a bitmask
 167 WaveMultiPrefixBitCount       returns the count of bits set to 1 on groups of lanes identified by a bitmask
+168 AllocateRayQuery              allocate space for RayQuery and return handle
+169 TraceRayInline                initialize RayQuery for raytrace
 === ============================= =======================================================================================================================================================================================================================
 
 

+ 20 - 5
include/dxc/DXIL/DxilConstants.h

@@ -299,6 +299,9 @@ namespace DXIL {
   // OPCODE-ENUM:BEGIN
   // Enumeration for operations specified by DXIL
   enum class OpCode : unsigned {
+    // 
+    AllocateRayQuery = 168, // allocate space for RayQuery and return handle
+  
     // AnyHit Terminals
     AcceptHitAndEndSearch = 156, // Used in an any hit shader to abort the ray query and the intersection shader (if any). The current hit is committed and execution passes to the closest hit shader with the closest hit recorded so far
     IgnoreHit = 155, // Used in an any hit shader to reject an intersection and terminate the shader
@@ -383,7 +386,10 @@ namespace DXIL {
     // Indirect Shader Invocation
     CallShader = 159, // Call a shader in the callable shader table supplied through the DispatchRays() API
     ReportHit = 158, // returns true if hit was accepted
-    TraceRay = 157, // returns the view index
+    TraceRay = 157, // initiates raytrace
+  
+    // Inline Ray Query
+    TraceRayInline = 169, // initialize RayQuery for raytrace
   
     // Legacy floating-point
     LegacyF16ToF32 = 131, // legacy fuction to convert half (f16) to float (f32) (this is not related to min-precision)
@@ -562,9 +568,9 @@ namespace DXIL {
     NumOpCodes_Dxil_1_2 = 141,
     NumOpCodes_Dxil_1_3 = 162,
     NumOpCodes_Dxil_1_4 = 165,
-    NumOpCodes_Dxil_1_5 = 168,
+    NumOpCodes_Dxil_1_5 = 170,
   
-    NumOpCodes = 168 // exclusive last value of enumeration
+    NumOpCodes = 170 // exclusive last value of enumeration
   };
   // OPCODE-ENUM:END
 
@@ -572,6 +578,9 @@ namespace DXIL {
   // OPCODECLASS-ENUM:BEGIN
   // Groups for DXIL operations with equivalent function templates
   enum class OpCodeClass : unsigned {
+    // 
+    AllocateRayQuery,
+  
     // AnyHit Terminals
     AcceptHitAndEndSearch,
     IgnoreHit,
@@ -643,6 +652,9 @@ namespace DXIL {
     ReportHit,
     TraceRay,
   
+    // Inline Ray Query
+    TraceRayInline,
+  
     // LLVM Instructions
     LlvmInst,
   
@@ -778,9 +790,9 @@ namespace DXIL {
     NumOpClasses_Dxil_1_2 = 97,
     NumOpClasses_Dxil_1_3 = 118,
     NumOpClasses_Dxil_1_4 = 120,
-    NumOpClasses_Dxil_1_5 = 123,
+    NumOpClasses_Dxil_1_5 = 125,
   
-    NumOpClasses = 123 // exclusive last value of enumeration
+    NumOpClasses = 125 // exclusive last value of enumeration
   };
   // OPCODECLASS-ENUM:END
 
@@ -910,6 +922,9 @@ namespace DXIL {
     const unsigned kTraceRayPayloadOpIdx = 15;
     const unsigned kTraceRayNumOp = 16;
 
+    // TraceRayInline
+    const unsigned kTraceRayInlineRayDescOpIdx = 5;
+    const unsigned kTraceRayInlineNumOp = 13;
 
     // Emit/Cut
     const unsigned kStreamEmitCutIDOpIdx = 1;

+ 86 - 1
include/dxc/DXIL/DxilInstructions.h

@@ -5173,7 +5173,7 @@ struct DxilInst_AcceptHitAndEndSearch {
   bool requiresUniformInputs() const { return false; }
 };
 
-/// This instruction returns the view index
+/// This instruction initiates raytrace
 struct DxilInst_TraceRay {
   llvm::Instruction *Instr;
   // Construction and identification
@@ -5549,5 +5549,90 @@ struct DxilInst_WaveMultiPrefixBitCount {
   llvm::Value *get_mask3() const { return Instr->getOperand(5); }
   void set_mask3(llvm::Value *val) { Instr->setOperand(5, val); }
 };
+
+/// This instruction allocate space for RayQuery and return handle
+struct DxilInst_AllocateRayQuery {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_AllocateRayQuery(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::AllocateRayQuery);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_constRayFlags = 1,
+  };
+  // Accessors
+  llvm::Value *get_constRayFlags() const { return Instr->getOperand(1); }
+  void set_constRayFlags(llvm::Value *val) { Instr->setOperand(1, val); }
+  uint32_t get_constRayFlags_val() const { return (uint32_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(1))->getZExtValue()); }
+  void set_constRayFlags_val(uint32_t val) { Instr->setOperand(1, llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), 32), llvm::APInt(32, (uint64_t)val))); }
+};
+
+/// This instruction initialize RayQuery for raytrace
+struct DxilInst_TraceRayInline {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_TraceRayInline(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::TraceRayInline);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (13 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_accelerationStructure = 2,
+    arg_rayFlags = 3,
+    arg_instanceInclusionMask = 4,
+    arg_origin_X = 5,
+    arg_origin_Y = 6,
+    arg_origin_Z = 7,
+    arg_tMin = 8,
+    arg_direction_X = 9,
+    arg_direction_Y = 10,
+    arg_direction_Z = 11,
+    arg_tMax = 12,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_accelerationStructure() const { return Instr->getOperand(2); }
+  void set_accelerationStructure(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_rayFlags() const { return Instr->getOperand(3); }
+  void set_rayFlags(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_instanceInclusionMask() const { return Instr->getOperand(4); }
+  void set_instanceInclusionMask(llvm::Value *val) { Instr->setOperand(4, val); }
+  llvm::Value *get_origin_X() const { return Instr->getOperand(5); }
+  void set_origin_X(llvm::Value *val) { Instr->setOperand(5, val); }
+  llvm::Value *get_origin_Y() const { return Instr->getOperand(6); }
+  void set_origin_Y(llvm::Value *val) { Instr->setOperand(6, val); }
+  llvm::Value *get_origin_Z() const { return Instr->getOperand(7); }
+  void set_origin_Z(llvm::Value *val) { Instr->setOperand(7, val); }
+  llvm::Value *get_tMin() const { return Instr->getOperand(8); }
+  void set_tMin(llvm::Value *val) { Instr->setOperand(8, val); }
+  llvm::Value *get_direction_X() const { return Instr->getOperand(9); }
+  void set_direction_X(llvm::Value *val) { Instr->setOperand(9, val); }
+  llvm::Value *get_direction_Y() const { return Instr->getOperand(10); }
+  void set_direction_Y(llvm::Value *val) { Instr->setOperand(10, val); }
+  llvm::Value *get_direction_Z() const { return Instr->getOperand(11); }
+  void set_direction_Z(llvm::Value *val) { Instr->setOperand(11, val); }
+  llvm::Value *get_tMax() const { return Instr->getOperand(12); }
+  void set_tMax(llvm::Value *val) { Instr->setOperand(12, val); }
+};
 // INSTR-HELPER:END
 } // namespace hlsl

+ 13 - 0
include/dxc/DXIL/DxilMetadataHelper.h

@@ -46,6 +46,7 @@ class DxilSampler;
 class DxilTypeSystem;
 class DxilStructAnnotation;
 class DxilFieldAnnotation;
+class DxilTemplateArgAnnotation;
 class DxilFunctionAnnotation;
 class DxilParameterAnnotation;
 class RootSignatureHandle;
@@ -191,6 +192,16 @@ public:
   static const unsigned kDxilFieldAnnotationCompTypeTag           = 7;
   static const unsigned kDxilFieldAnnotationPreciseTag            = 8;
 
+  // StructAnnotation extended property tags (DXIL 1.5+ only, appended)
+  static const unsigned kDxilTemplateArgumentsTag                 = 0;  // Name for name-value list of extended struct properties
+  // TemplateArgument tags
+  static const unsigned kDxilTemplateArgTypeTag                   = 0;  // Type template argument, followed by undef of type
+  static const unsigned kDxilTemplateArgIntegralTag               = 1;  // Integral template argument, followed by i64 value
+  // TemplateArgType
+  static const unsigned kDxilTemplateArgType                      = 1;  // Position of type for template arg that is type
+  static const unsigned kDxilTemplateArgIntegral                  = 1;  // Position of i64 for template arg that is integral
+
+
   // Control flow hint.
   static const char kDxilControlFlowHintMDName[];
 
@@ -351,6 +362,8 @@ public:
   void LoadDxilParamAnnotation(const llvm::MDOperand &MDO, DxilParameterAnnotation &PA);
   llvm::Metadata *EmitDxilParamAnnotations(const DxilFunctionAnnotation &FA);
   void LoadDxilParamAnnotations(const llvm::MDOperand &MDO, DxilFunctionAnnotation &FA);
+  llvm::Metadata *EmitDxilTemplateArgAnnotation(const DxilTemplateArgAnnotation &annotation);
+  void LoadDxilTemplateArgAnnotation(const llvm::MDOperand &MDO, DxilTemplateArgAnnotation &annotation);
 
   // Function props.
   llvm::MDTuple *EmitDxilFunctionProps(const hlsl::DxilFunctionProps *props,

+ 25 - 1
include/dxc/DXIL/DxilTypeSystem.h

@@ -90,6 +90,22 @@ private:
   std::string m_FieldName;
 };
 
+class DxilTemplateArgAnnotation : DxilFieldAnnotation {
+public:
+  DxilTemplateArgAnnotation();
+
+  bool IsType() const;
+  const llvm::Type *GetType() const;
+  void SetType(const llvm::Type *pType);
+
+  bool IsIntegral() const;
+  int64_t GetIntegral() const;
+  void SetIntegral(int64_t i64);
+
+private:
+  const llvm::Type *m_Type;
+  int64_t m_Integral;
+};
 
 /// Use this class to represent LLVM structure annotation.
 class DxilStructAnnotation {
@@ -104,10 +120,18 @@ public:
   void SetCBufferSize(unsigned size);
   void MarkEmptyStruct();
   bool IsEmptyStruct();
+
+  // For template args, GetNumTemplateArgs() will return 0 if not a template
+  unsigned GetNumTemplateArgs() const;
+  void SetNumTemplateArgs(unsigned count);
+  DxilTemplateArgAnnotation &GetTemplateArgAnnotation(unsigned argIdx);
+  const DxilTemplateArgAnnotation &GetTemplateArgAnnotation(unsigned argIdx) const;
+
 private:
   const llvm::StructType *m_pStructType;
   std::vector<DxilFieldAnnotation> m_FieldAnnotations;
   unsigned m_CBufferSize;  // The size of struct if inside constant buffer.
+  std::vector<DxilTemplateArgAnnotation> m_TemplateAnnotations;
 };
 
 
@@ -163,7 +187,7 @@ public:
 
   DxilTypeSystem(llvm::Module *pModule);
 
-  DxilStructAnnotation *AddStructAnnotation(const llvm::StructType *pStructType);
+  DxilStructAnnotation *AddStructAnnotation(const llvm::StructType *pStructType, unsigned numTemplateArgs = 0);
   DxilStructAnnotation *GetStructAnnotation(const llvm::StructType *pStructType);
   const DxilStructAnnotation *GetStructAnnotation(const llvm::StructType *pStructType) const;
   void EraseStructAnnotation(const llvm::StructType *pStructType);

+ 1 - 0
include/dxc/DXIL/DxilUtil.h

@@ -109,6 +109,7 @@ namespace dxilutil {
   bool ContainsHLSLObjectType(llvm::Type *Ty);
   bool IsHLSLResourceType(llvm::Type *Ty);
   bool IsHLSLObjectType(llvm::Type *Ty);
+  bool IsHLSLRayQueryType(llvm::Type *Ty);
   bool IsSplat(llvm::ConstantDataVector *cdv);
 }
 

+ 3 - 0
include/dxc/HLSL/HLOperations.h

@@ -333,6 +333,9 @@ const unsigned kCreateHandleIndexOpIdx = 2; // Only for array of cbuffer.
 const unsigned kTraceRayRayDescOpIdx = 7;
 const unsigned kTraceRayPayLoadOpIdx = 8;
 
+// TraceRayInline.
+const unsigned kTraceRayInlineRayDescOpIdx = 5;
+
 // ReportIntersection.
 const unsigned kReportIntersectionAttributeOpIdx = 3;
 

+ 1 - 0
include/dxc/HlslIntrinsicOp.h

@@ -260,6 +260,7 @@ import hctdb_instrhelp
   MOP_DecrementCounter,
   MOP_IncrementCounter,
   MOP_Consume,
+  MOP_TraceRayInline,
 #ifdef ENABLE_SPIRV_CODEGEN
   MOP_SubpassLoad,
 #endif // ENABLE_SPIRV_CODEGEN

+ 71 - 2
lib/DXIL/DxilMetadataHelper.cpp

@@ -774,13 +774,62 @@ void DxilMDHelper::LoadDxilTypeSystem(DxilTypeSystem &TypeSystem) {
   }
 }
 
+Metadata *DxilMDHelper::EmitDxilTemplateArgAnnotation(const DxilTemplateArgAnnotation &annotation) {
+  SmallVector<Metadata *, 2> MDVals;
+  if (annotation.IsType()) {
+    MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilTemplateArgTypeTag));
+    MDVals.emplace_back(ValueAsMetadata::get(UndefValue::get(const_cast<Type*>(annotation.GetType()))));
+  } else if (annotation.IsIntegral()) {
+    MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilTemplateArgIntegralTag));
+    MDVals.emplace_back(Uint64ToConstMD((uint64_t)annotation.GetIntegral()));
+  }
+  return MDNode::get(m_Ctx, MDVals);
+}
+void DxilMDHelper::LoadDxilTemplateArgAnnotation(const llvm::MDOperand &MDO, DxilTemplateArgAnnotation &annotation) {
+  IFTBOOL(MDO.get() != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  const MDTuple *pTupleMD = dyn_cast<MDTuple>(MDO.get());
+  IFTBOOL(pTupleMD != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  IFTBOOL(pTupleMD->getNumOperands() >= 1, DXC_E_INCORRECT_DXIL_METADATA);
+  unsigned Tag = ConstMDToUint32(pTupleMD->getOperand(0));
+  switch (Tag) {
+  case kDxilTemplateArgTypeTag:
+    IFTBOOL(pTupleMD->getNumOperands() == 2, DXC_E_INCORRECT_DXIL_METADATA);
+    annotation.SetType(MetadataAsValue::get(m_Ctx,
+      pTupleMD->getOperand(kDxilTemplateArgType))->getType());
+    break;
+  case kDxilTemplateArgIntegralTag:
+    IFTBOOL(pTupleMD->getNumOperands() == 2, DXC_E_INCORRECT_DXIL_METADATA);
+    annotation.SetIntegral((int64_t)ConstMDToUint64(pTupleMD->getOperand(kDxilTemplateArgType)));
+    break;
+  }
+}
+
 Metadata *DxilMDHelper::EmitDxilStructAnnotation(const DxilStructAnnotation &SA) {
-  vector<Metadata *> MDVals(SA.GetNumFields() + 1);
+  unsigned valMajor, valMinor;
+  m_pSM->GetMinValidatorVersion(valMajor, valMinor);
+  bool bSupportExtended = !(valMajor == 1 && valMinor < 5);
+
+  vector<Metadata *> MDVals;
+  MDVals.reserve(SA.GetNumFields() + 2);  // In case of extended 1.5 property list
+  MDVals.resize(SA.GetNumFields() + 1);
+
   MDVals[0] = Uint32ToConstMD(SA.GetCBufferSize());
   for (unsigned i = 0; i < SA.GetNumFields(); i++) {
     MDVals[i+1] = EmitDxilFieldAnnotation(SA.GetFieldAnnotation(i));
   }
 
+  // Only add template args if shader target requires validator version that supports them.
+  if (bSupportExtended && SA.GetNumTemplateArgs()) {
+    vector<Metadata *> MDTemplateArgs(SA.GetNumTemplateArgs());
+    for (unsigned i = 0; i < SA.GetNumTemplateArgs(); ++i) {
+      MDTemplateArgs[i] = EmitDxilTemplateArgAnnotation(SA.GetTemplateArgAnnotation(i));
+    }
+    SmallVector<Metadata *, 2> MDExtraVals;
+    MDExtraVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilTemplateArgumentsTag));
+    MDExtraVals.emplace_back(MDNode::get(m_Ctx, MDTemplateArgs));
+    MDVals.emplace_back(MDNode::get(m_Ctx, MDExtraVals));
+  }
+
   return MDNode::get(m_Ctx, MDVals);
 }
 
@@ -791,7 +840,27 @@ void DxilMDHelper::LoadDxilStructAnnotation(const MDOperand &MDO, DxilStructAnno
   if (pTupleMD->getNumOperands() == 1) {
     SA.MarkEmptyStruct();
   }
-  IFTBOOL(pTupleMD->getNumOperands() == SA.GetNumFields()+1, DXC_E_INCORRECT_DXIL_METADATA);
+  unsigned valMajor, valMinor;
+  m_pSM->GetMinValidatorVersion(valMajor, valMinor);
+  if (!(valMajor == 1 && valMinor < 5) &&
+      (pTupleMD->getNumOperands() == SA.GetNumFields()+2)) {
+    // Load template args from extended operand
+    const MDOperand &MDOExtra = pTupleMD->getOperand(SA.GetNumFields()+1);
+    const MDTuple *pTupleMDExtra = dyn_cast_or_null<MDTuple>(MDOExtra.get());
+    if(pTupleMDExtra) {
+      IFTBOOL(pTupleMDExtra->getNumOperands() % 2 == 0, DXC_E_INCORRECT_DXIL_METADATA);
+      unsigned Tag = ConstMDToUint32(pTupleMDExtra->getOperand(0));
+      IFTBOOL(Tag == kDxilTemplateArgumentsTag, DXC_E_INCORRECT_DXIL_METADATA); // Only one allowed at this point
+      const MDTuple *pTupleTemplateArgs = dyn_cast_or_null<MDTuple>(pTupleMDExtra->getOperand(1).get());
+      IFTBOOL(pTupleTemplateArgs, DXC_E_INCORRECT_DXIL_METADATA);
+      SA.SetNumTemplateArgs(pTupleTemplateArgs->getNumOperands());
+      for (unsigned i = 0; i < pTupleTemplateArgs->getNumOperands(); ++i) {
+        LoadDxilTemplateArgAnnotation(pTupleTemplateArgs->getOperand(i), SA.GetTemplateArgAnnotation(i));
+      }
+    }
+  } else {
+    IFTBOOL(pTupleMD->getNumOperands() == SA.GetNumFields()+1, DXC_E_INCORRECT_DXIL_METADATA);
+  }
 
   SA.SetCBufferSize(ConstMDToUint32(pTupleMD->getOperand(0)));
   for (unsigned i = 0; i < SA.GetNumFields(); i++) {

+ 16 - 2
lib/DXIL/DxilOperations.cpp

@@ -321,6 +321,12 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = {
   {  OC::WaveMatch,               "WaveMatch",                OCC::WaveMatch,                "waveMatch",                 { false,  true,  true,  true, false,  true,  true,  true,  true, false, false}, Attribute::None,     },
   {  OC::WaveMultiPrefixOp,       "WaveMultiPrefixOp",        OCC::WaveMultiPrefixOp,        "waveMultiPrefixOp",         { false,  true,  true,  true, false,  true,  true,  true,  true, false, false}, Attribute::None,     },
   {  OC::WaveMultiPrefixBitCount, "WaveMultiPrefixBitCount",  OCC::WaveMultiPrefixBitCount,  "waveMultiPrefixBitCount",   {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+
+  //                                                                                                                         void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  {  OC::AllocateRayQuery,        "AllocateRayQuery",         OCC::AllocateRayQuery,         "allocateRayQuery",          {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+
+  // Inline Ray Query                                                                                                        void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  {  OC::TraceRayInline,          "TraceRayInline",           OCC::TraceRayInline,           "traceRayInline",            {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
 };
 // OPCODE-OLOADS:END
 
@@ -657,8 +663,8 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
     return;
   }
   // Instructions: WaveMatch=165, WaveMultiPrefixOp=166,
-  // WaveMultiPrefixBitCount=167
-  if ((165 <= op && op <= 167)) {
+  // WaveMultiPrefixBitCount=167, TraceRayInline=169
+  if ((165 <= op && op <= 167) || op == 169) {
     major = 6;  minor = 5;
     return;
   }
@@ -1062,6 +1068,12 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
   case OpCode::WaveMatch:              A(pI4S);     A(pI32); A(pETy); break;
   case OpCode::WaveMultiPrefixOp:      A(pETy);     A(pI32); A(pETy); A(pI32); A(pI32); A(pI32); A(pI32); A(pI8);  A(pI8);  break;
   case OpCode::WaveMultiPrefixBitCount:A(pI32);     A(pI32); A(pI1);  A(pI32); A(pI32); A(pI32); A(pI32); break;
+
+    // 
+  case OpCode::AllocateRayQuery:       A(pI32);     A(pI32); A(pI32); break;
+
+    // Inline Ray Query
+  case OpCode::TraceRayInline:         A(pV);       A(pI32); A(pI32); A(pRes); A(pI32); A(pI32); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); break;
   // OPCODE-OLOAD-FUNCS:END
   default: DXASSERT(false, "otherwise unhandled case"); break;
   }
@@ -1215,6 +1227,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   case OpCode::IgnoreHit:
   case OpCode::AcceptHitAndEndSearch:
   case OpCode::WaveMultiPrefixBitCount:
+  case OpCode::AllocateRayQuery:
+  case OpCode::TraceRayInline:
     return Type::getVoidTy(m_Ctx);
   case OpCode::CheckAccessFullyMapped:
   case OpCode::AtomicBinOp:

+ 34 - 1
lib/DXIL/DxilTypeSystem.cpp

@@ -78,6 +78,22 @@ const std::string &DxilFieldAnnotation::GetFieldName() const { return m_FieldNam
 void DxilFieldAnnotation::SetFieldName(const std::string &FieldName) { m_FieldName = FieldName; }
 
 
+//------------------------------------------------------------------------------
+//
+// DxilStructAnnotation class methods.
+//
+DxilTemplateArgAnnotation::DxilTemplateArgAnnotation()
+    : DxilFieldAnnotation(), m_Type(nullptr), m_Integral(0)
+{}
+
+bool DxilTemplateArgAnnotation::IsType() const { return m_Type != nullptr; }
+const llvm::Type *DxilTemplateArgAnnotation::GetType() const { return m_Type; }
+void DxilTemplateArgAnnotation::SetType(const llvm::Type *pType) { m_Type = pType; }
+
+bool DxilTemplateArgAnnotation::IsIntegral() const { return m_Type == nullptr; }
+int64_t DxilTemplateArgAnnotation::GetIntegral() const { return m_Integral; }
+void DxilTemplateArgAnnotation::SetIntegral(int64_t i64) { m_Type = nullptr; m_Integral = i64; }
+
 //------------------------------------------------------------------------------
 //
 // DxilStructAnnotation class methods.
@@ -103,6 +119,22 @@ void DxilStructAnnotation::SetCBufferSize(unsigned size) { m_CBufferSize = size;
 void DxilStructAnnotation::MarkEmptyStruct() { m_FieldAnnotations.clear(); }
 bool DxilStructAnnotation::IsEmptyStruct() { return m_FieldAnnotations.empty(); }
 
+// For template args, GetNumTemplateArgs() will return 0 if not a template
+unsigned DxilStructAnnotation::GetNumTemplateArgs() const {
+  return (unsigned)m_TemplateAnnotations.size();
+}
+void DxilStructAnnotation::SetNumTemplateArgs(unsigned count) {
+  DXASSERT(m_TemplateAnnotations.empty(), "template args already initialized");
+  m_TemplateAnnotations.resize(count);
+}
+DxilTemplateArgAnnotation &DxilStructAnnotation::GetTemplateArgAnnotation(unsigned argIdx) {
+  return m_TemplateAnnotations[argIdx];
+}
+const DxilTemplateArgAnnotation &DxilStructAnnotation::GetTemplateArgAnnotation(unsigned argIdx) const {
+  return m_TemplateAnnotations[argIdx];
+}
+
+
 //------------------------------------------------------------------------------
 //
 // DxilParameterAnnotation class methods.
@@ -166,12 +198,13 @@ DxilTypeSystem::DxilTypeSystem(Module *pModule)
     : m_pModule(pModule),
       m_LowPrecisionMode(DXIL::LowPrecisionMode::Undefined) {}
 
-DxilStructAnnotation *DxilTypeSystem::AddStructAnnotation(const StructType *pStructType) {
+DxilStructAnnotation *DxilTypeSystem::AddStructAnnotation(const StructType *pStructType, unsigned numTemplateArgs) {
   DXASSERT_NOMSG(m_StructAnnotations.find(pStructType) == m_StructAnnotations.end());
   DxilStructAnnotation *pA = new DxilStructAnnotation();
   m_StructAnnotations[pStructType] = unique_ptr<DxilStructAnnotation>(pA);
   pA->m_pStructType = pStructType;
   pA->m_FieldAnnotations.resize(pStructType->getNumElements());
+  pA->SetNumTemplateArgs(numTemplateArgs);
   return pA;
 }
 

+ 14 - 0
lib/DXIL/DxilUtil.cpp

@@ -546,6 +546,20 @@ bool IsHLSLObjectType(llvm::Type *Ty) {
       return true;
     if (name.startswith("LineStream<"))
       return true;
+
+    if (name.startswith("RayQuery<"))
+      return true;
+  }
+  return false;
+}
+
+bool IsHLSLRayQueryType(llvm::Type *Ty) {
+  if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
+    StringRef name = ST->getName();
+    // TODO: don't check names.
+    name = name.ltrim("class.");
+    if (name.startswith("RayQuery<"))
+      return true;
   }
   return false;
 }

+ 1 - 1
lib/DxilPIXPasses/DxilShaderAccessTracking.cpp

@@ -453,7 +453,7 @@ bool DxilShaderAccessTracking::runOnModule(Module &M)
 
 
     // todo: should "GetDimensions" mean a resource access?
-    static_assert(DXIL::OpCode::NumOpCodes == static_cast<DXIL::OpCode>(168), "Please update PIX passes if any resource access opcodes are added");
+    static_assert(DXIL::OpCode::NumOpCodes == static_cast<DXIL::OpCode>(170), "Please update PIX passes if any resource access opcodes are added");
     ResourceAccessFunction raFunctions[] = {
       { DXIL::OpCode::CBufferLoadLegacy     , ShaderAccessFlags::Read   , false, f32i32f64 },
       { DXIL::OpCode::CBufferLoad           , ShaderAccessFlags::Read   , false, f16f32f64i16i32i64 },

+ 2 - 2
lib/HLSL/DxilValidation.cpp

@@ -856,8 +856,8 @@ static bool ValidateOpcodeInProfile(DXIL::OpCode opcode,
   if ((162 <= op && op <= 164))
     return (major > 6 || (major == 6 && minor >= 4));
   // Instructions: WaveMatch=165, WaveMultiPrefixOp=166,
-  // WaveMultiPrefixBitCount=167
-  if ((165 <= op && op <= 167))
+  // WaveMultiPrefixBitCount=167, TraceRayInline=169
+  if ((165 <= op && op <= 167) || op == 169)
     return (major > 6 || (major == 6 && minor >= 5));
   return true;
   // VALOPCODESM-TEXT:END

+ 111 - 0
lib/HLSL/HLOperationLower.cpp

@@ -4637,6 +4637,108 @@ Value *TranslateTraceRay(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
   return Builder.CreateCall(F, Args);
 }
 
+void AllocateRayQueryObjects(llvm::Module *M, HLOperationLowerHelper &helper) {
+  // Iterate functions and insert AllocateRayQuery intrinsic to initialize
+  // handle value for every alloca of ray query type
+  hlsl::OP &hlslOP = helper.hlslOP;
+  Constant *i32Zero = hlslOP.GetI32Const(0);
+  DXIL::OpCode opcode = DXIL::OpCode::AllocateRayQuery;
+  llvm::Value *opcodeVal = hlslOP.GetU32Const(static_cast<unsigned>(opcode));
+  for (Function &f : M->functions()) {
+    if (f.isDeclaration() || f.isIntrinsic() ||
+      GetHLOpcodeGroup(&f) != HLOpcodeGroup::NotHL)
+      continue;
+    // Iterate allocas
+    BasicBlock &BB = f.getEntryBlock();
+    IRBuilder<> Builder(dxilutil::FirstNonAllocaInsertionPt(&BB));
+    for (BasicBlock::iterator BI = BB.begin(), BE = BB.end(); BI != BE;) {
+      // Avoid invalidating the iterator.
+      Instruction *I = BI++;
+      if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) {
+        llvm::Type *allocaTy = AI->getAllocatedType();
+        llvm::Type *elementTy = allocaTy;
+        while (elementTy->isArrayTy())
+          elementTy = elementTy->getArrayElementType();
+        if (dxilutil::IsHLSLRayQueryType(elementTy)) {
+          DxilStructAnnotation *SA = helper.dxilTypeSys.GetStructAnnotation(cast<StructType>(elementTy));
+          DXASSERT(SA, "otherwise, could not find type annoation for RayQuery specialization");
+          DXASSERT(SA->GetNumTemplateArgs() == 1 && SA->GetTemplateArgAnnotation(0).IsIntegral(),
+                   "otherwise, RayQuery has changed, or lacks template args");
+          Builder.SetInsertPoint(AI->getNextNode());
+          DXASSERT(!allocaTy->isArrayTy(), "Array not handled yet");
+          llvm::Function *AllocFn = hlslOP.GetOpFunc(DXIL::OpCode::AllocateRayQuery, Builder.getVoidTy());
+          llvm::Value *rayFlags = ConstantInt::get(helper.i32Ty,
+            APInt(32, SA->GetTemplateArgAnnotation(0).GetIntegral()));
+          llvm::CallInst *CI = Builder.CreateCall(AllocFn, {opcodeVal, rayFlags}, "hRayQuery");
+          llvm::Value *GEP = Builder.CreateGEP(AI, {i32Zero, i32Zero});
+          Builder.CreateStore(CI, GEP);
+        }
+      }
+    }
+  }
+}
+
+Value *TranslateTraceRayInline(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
+                         HLOperationLowerHelper &helper,
+                         HLObjectOperationLowerHelper *pObjHelper,
+                         bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP;
+
+  Value *rayDesc = CI->getArgOperand(HLOperandIndex::kTraceRayInlineRayDescOpIdx);
+
+  Value *opArg = hlslOP->GetU32Const(static_cast<unsigned>(opcode));
+
+  Value *Args[DXIL::OperandIndex::kTraceRayInlineNumOp];
+  Args[0] = opArg;
+
+  // Translate RayQuery `this` pointer to i32 handle by-value
+  IRBuilder<> Builder(CI);
+  Value *thisArg = CI->getArgOperand(1);
+  Constant *i32Zero = hlslOP->GetI32Const(0);
+  Value *handleGEP = Builder.CreateGEP(thisArg, {i32Zero, i32Zero});
+  Value *handleValue = Builder.CreateLoad(handleGEP);
+  Args[1] = handleValue;
+
+  for (unsigned i = 2; i < HLOperandIndex::kTraceRayInlineRayDescOpIdx; i++) {
+    Args[i] = CI->getArgOperand(i);
+  }
+  // struct RayDesc
+  //{
+  //    float3 Origin;
+  //    float  TMin;
+  //    float3 Direction;
+  //    float  TMax;
+  //};
+  Value *zeroIdx = hlslOP->GetU32Const(0);
+  Value *origin = Builder.CreateGEP(rayDesc, {zeroIdx, zeroIdx});
+  origin = Builder.CreateLoad(origin);
+  unsigned index = DXIL::OperandIndex::kTraceRayInlineRayDescOpIdx;
+  Args[index++] = Builder.CreateExtractElement(origin, (uint64_t)0);
+  Args[index++] = Builder.CreateExtractElement(origin, 1);
+  Args[index++] = Builder.CreateExtractElement(origin, 2);
+
+  Value *tmin = Builder.CreateGEP(rayDesc, {zeroIdx, hlslOP->GetU32Const(1)});
+  tmin = Builder.CreateLoad(tmin);
+  Args[index++] = tmin;
+
+  Value *direction = Builder.CreateGEP(rayDesc, {zeroIdx, hlslOP->GetU32Const(2)});
+  direction = Builder.CreateLoad(direction);
+
+  Args[index++] = Builder.CreateExtractElement(direction, (uint64_t)0);
+  Args[index++] = Builder.CreateExtractElement(direction, 1);
+  Args[index++] = Builder.CreateExtractElement(direction, 2);
+
+  Value *tmax = Builder.CreateGEP(rayDesc, {zeroIdx, hlslOP->GetU32Const(3)});
+  tmax = Builder.CreateLoad(tmax);
+  Args[index++] = tmax;
+
+  DXASSERT_NOMSG(index == DXIL::OperandIndex::kTraceRayInlineNumOp);
+
+  Function *F = hlslOP->GetOpFunc(opcode, Builder.getVoidTy());
+
+  return Builder.CreateCall(F, Args);
+}
+
 Value *TranslateNoArgVectorOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
                          HLOperationLowerHelper &helper,
                          HLObjectOperationLowerHelper *pObjHelper,
@@ -5029,6 +5131,8 @@ IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] =
     {IntrinsicOp::MOP_IncrementCounter, GenerateUpdateCounter, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::MOP_Consume, EmptyLower, DXIL::OpCode::NumOpCodes},
 
+    {IntrinsicOp::MOP_TraceRayInline, TranslateTraceRayInline, DXIL::OpCode::TraceRayInline},
+
     // SPIRV change starts
 #ifdef ENABLE_SPIRV_CODEGEN
     {IntrinsicOp::MOP_SubpassLoad, UnsupportedVulkanIntrinsic, DXIL::OpCode::NumOpCodes},
@@ -5769,6 +5873,11 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
       }
 
       CI->eraseFromParent();
+    } else if (group == HLOpcodeGroup::HLIntrinsic) {
+      // FIXME: This case is hit when using built-in structures in constant
+      //        buffers passed directly to an intrinsic, such as:
+      //        RayDesc from cbuffer passed to TraceRay.
+      DXASSERT(0, "not implemented yet");
     } else {
       DXASSERT(0, "not implemented yet");
     }
@@ -7318,6 +7427,8 @@ void TranslateBuiltinOperations(
 
   Module *M = HLM.GetModule();
 
+  AllocateRayQueryObjects(M, helper);
+
   SmallVector<Function *, 4> NonUniformResourceIndexIntrinsics;
 
   // generate dxil operation

+ 17 - 4
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -1586,9 +1586,9 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
       SimpleCopy(Dest, Src, idxList, Builder);
       return;
     }
+    // Built-in structs have no type annotation
     DxilStructAnnotation *STA = typeSys.GetStructAnnotation(ST);
-    DXASSERT(STA, "require annotation here");
-    if (STA->IsEmptyStruct())
+    if (STA && STA->IsEmptyStruct())
       return;
     for (uint32_t i = 0; i < ST->getNumElements(); i++) {
       llvm::Type *ET = ST->getElementType(i);
@@ -1598,8 +1598,8 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
       if (bEltMemCpy && IsMemCpyTy(ET, typeSys)) {
         EltMemCpy(ET, Dest, Src, idxList, Builder, DL);
       } else {
-        DxilFieldAnnotation &EltAnnotation = STA->GetFieldAnnotation(i);
-        SplitCpy(ET, Dest, Src, idxList, Builder, DL, typeSys, &EltAnnotation,
+        DxilFieldAnnotation *EltAnnotation = STA ? &STA->GetFieldAnnotation(i) : nullptr;
+        SplitCpy(ET, Dest, Src, idxList, Builder, DL, typeSys, EltAnnotation,
                  bEltMemCpy);
       }
 
@@ -2412,6 +2412,12 @@ void SROA_Helper::RewriteMemIntrin(MemIntrinsic *MI, Value *OldV) {
 }
 
 void SROA_Helper::RewriteBitCast(BitCastInst *BCI) {
+  // Unused bitcast may be leftover from temporary memcpy
+  if (BCI->use_empty()) {
+    BCI->eraseFromParent();
+    return;
+  }
+
   Type *DstTy = BCI->getType();
   Value *Val = BCI->getOperand(0);
   Type *SrcTy = Val->getType();
@@ -2565,6 +2571,13 @@ void SROA_Helper::RewriteCall(CallInst *CI) {
         RewriteCallArg(CI, HLOperandIndex::kBinaryOpSrc1Idx,
                        /*bIn*/ true, /*bOut*/ true);
       } break;
+      case IntrinsicOp::MOP_TraceRayInline: {
+        if (OldVal ==
+            CI->getArgOperand(HLOperandIndex::kTraceRayInlineRayDescOpIdx)) {
+          RewriteCallArg(CI, HLOperandIndex::kTraceRayInlineRayDescOpIdx,
+                         /*bIn*/ true, /*bOut*/ false);
+        }
+      } break;
       default:
         DXASSERT(0, "cannot flatten hlsl intrinsic.");
       }

+ 5 - 0
tools/clang/include/clang/AST/HlslTypes.h

@@ -326,6 +326,11 @@ void AddTemplateTypeWithHandle(
             uint8_t templateArgCount,
   _In_opt_  clang::TypeSourceInfo* defaultTypeArgValue);
 
+void AddRayQueryTemplate(
+           clang::ASTContext& context,
+  _Outptr_ clang::ClassTemplateDecl** typeDecl,
+  _Outptr_ clang::CXXRecordDecl** recordDecl);
+
 /// <summary>Create a function template declaration for the specified method.</summary>
 /// <param name="context">AST context in which to work.</param>
 /// <param name="recordDecl">Class in which the function template is declared.</param>

+ 102 - 0
tools/clang/lib/AST/ASTContextHLSL.cpp

@@ -38,6 +38,7 @@ static const bool DelayTypeCreationTrue = true;   // delay type creation for a d
 static const SourceLocation NoLoc;                // no source location attribution available
 static const bool InlineFalse = false;            // namespace is not an inline namespace
 static const bool InlineSpecifiedFalse = false;   // function was not specified as inline
+static const bool ExplicitFalse = false;          // constructor was not specified as explicit
 static const bool IsConstexprFalse = false;       // function is not constexpr
 static const bool VirtualFalse = false;           // whether the base class is declares 'virtual'
 static const bool BaseClassFalse = false;         // whether the base class is declared as 'class' (vs. 'struct')
@@ -894,6 +895,28 @@ void AssociateParametersToFunctionPrototype(
   }
 }
 
+static void CreateConstructorDeclaration(
+  ASTContext &context, _In_ CXXRecordDecl *recordDecl, QualType resultType,
+  ArrayRef<QualType> args, DeclarationName declarationName, bool isConst,
+  _Out_ CXXConstructorDecl **constructorDecl, _Out_ TypeSourceInfo **tinfo) {
+  DXASSERT_NOMSG(recordDecl != nullptr);
+  DXASSERT_NOMSG(constructorDecl != nullptr);
+
+  FunctionProtoType::ExtProtoInfo functionExtInfo;
+  functionExtInfo.TypeQuals = isConst ? Qualifiers::Const : 0;
+  QualType functionQT = context.getFunctionType(
+    resultType, args, functionExtInfo, ArrayRef<ParameterModifier>());
+  DeclarationNameInfo declNameInfo(declarationName, NoLoc);
+  *tinfo = context.getTrivialTypeSourceInfo(functionQT, NoLoc);
+  DXASSERT_NOMSG(*tinfo != nullptr);
+  *constructorDecl = CXXConstructorDecl::Create(
+    context, recordDecl, NoLoc, declNameInfo, functionQT, *tinfo,
+    StorageClass::SC_None, ExplicitFalse, InlineSpecifiedFalse, IsConstexprFalse);
+  DXASSERT_NOMSG(*constructorDecl != nullptr);
+  (*constructorDecl)->setLexicalDeclContext(recordDecl);
+  (*constructorDecl)->setAccess(AccessSpecifier::AS_public);
+}
+
 static void CreateObjectFunctionDeclaration(
     ASTContext &context, _In_ CXXRecordDecl *recordDecl, QualType resultType,
     ArrayRef<QualType> args, DeclarationName declarationName, bool isConst,
@@ -959,6 +982,85 @@ CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams(
   return functionDecl;
 }
 
+void hlsl::AddRayQueryTemplate(
+  ASTContext& context,
+  _Outptr_ ClassTemplateDecl** typeDecl,
+  _Outptr_ CXXRecordDecl** recordDecl
+)
+{
+  DXASSERT_NOMSG(typeDecl != nullptr);
+  DXASSERT_NOMSG(recordDecl != nullptr);
+
+  DeclContext* currentDeclContext = context.getTranslationUnitDecl();
+
+  // Create a RayQuery template declaration in translation unit scope.
+  // template<uint flags> RayQuery { ... }
+  QualType uintType = context.UnsignedIntTy;
+
+  NonTypeTemplateParmDecl* flagsTemplateParamDecl = nullptr;
+  IdentifierInfo& countParamId = context.Idents.get(StringRef("flags"), tok::TokenKind::identifier);
+  flagsTemplateParamDecl = NonTypeTemplateParmDecl::Create(
+    context, currentDeclContext, NoLoc, NoLoc,
+    FirstTemplateDepth, FirstParamPosition, &countParamId, uintType, ParameterPackFalse, nullptr);
+
+  // Should flags default to zero?
+  Expr *literalIntZero = IntegerLiteral::Create(
+    context, llvm::APInt(context.getIntWidth(uintType), 0), uintType, NoLoc);
+  flagsTemplateParamDecl->setDefaultArgument(literalIntZero);
+
+  NamedDecl* templateParameters[] =
+  {
+    flagsTemplateParamDecl
+  };
+  TemplateParameterList* templateParameterList = TemplateParameterList::Create(
+    context, NoLoc, NoLoc, templateParameters, 1, NoLoc);
+
+  IdentifierInfo& typeId = context.Idents.get(StringRef("RayQuery"), tok::TokenKind::identifier);
+  CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
+    context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &typeId,
+    nullptr, DelayTypeCreationTrue);
+  ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
+    context, currentDeclContext, NoLoc, DeclarationName(&typeId),
+    templateParameterList, templateRecordDecl, nullptr);
+  templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
+  templateRecordDecl->addAttr(FinalAttr::CreateImplicit(context, FinalAttr::Keyword_final));
+
+  // Requesting the class name specialization will fault in required types.
+  QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
+  T = context.getInjectedClassNameType(templateRecordDecl, T);
+  assert(T->isDependentType() && "Class template type is not dependent?");
+  classTemplateDecl->setLexicalDeclContext(currentDeclContext);
+  templateRecordDecl->setLexicalDeclContext(currentDeclContext);
+  templateRecordDecl->startDefinition();
+
+  // TODO: Add constructor that will be lowered to the intrinsic that produces
+  // the RayQuery handle for this object.
+  CanQualType canQualType = templateRecordDecl->getTypeForDecl()->getCanonicalTypeUnqualified();
+  CXXConstructorDecl *pConstructorDecl = nullptr;
+  TypeSourceInfo *pTypeSourceInfo = nullptr;
+  CreateConstructorDeclaration(context, templateRecordDecl, context.VoidTy, {}, context.DeclarationNames.getCXXConstructorName(canQualType), false, &pConstructorDecl, &pTypeSourceInfo);
+  templateRecordDecl->addDecl(pConstructorDecl);
+
+  // Add an 'h' field to hold the handle.
+  AddHLSLHandleField(context, templateRecordDecl, uintType);
+
+  templateRecordDecl->completeDefinition();
+
+  // Both declarations need to be present for correct handling.
+  currentDeclContext->addDecl(classTemplateDecl);
+  currentDeclContext->addDecl(templateRecordDecl);
+
+#ifdef DBG
+  // Verify that we can read the field member from the template record.
+  DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
+    DeclarationName(&context.Idents.get(StringRef("h"))));
+  DXASSERT(!lookupResult.empty(), "otherwise template object handle cannot be looked up");
+#endif
+
+  *typeDecl = classTemplateDecl;
+  *recordDecl = templateRecordDecl;
+}
+
 bool hlsl::IsIntrinsicOp(const clang::FunctionDecl *FD) {
   return FD != nullptr && FD->hasAttr<HLSLIntrinsicAttr>();
 }

+ 43 - 12
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -238,6 +238,9 @@ private:
 
   std::unordered_map<Constant*, DxilFieldAnnotation> m_ConstVarAnnotationMap;
 
+  // Insert AllocateRayQuery to initialize each RayQuery alloca
+  void AllocateRayQueryObjects();
+
 public:
   CGMSHLSLRuntime(CodeGenModule &CGM);
 
@@ -857,6 +860,27 @@ unsigned CGMSHLSLRuntime::ConstructStructAnnotation(DxilStructAnnotation *annota
   unsigned offset = 0;
   bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
   if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+
+    // If template, save template args
+    if (const ClassTemplateSpecializationDecl *templateSpecializationDecl =
+          dyn_cast<ClassTemplateSpecializationDecl>(CXXRD)) {
+      const clang::TemplateArgumentList &args = templateSpecializationDecl->getTemplateInstantiationArgs();
+      for (unsigned i = 0; i < args.size(); ++i) {
+        DxilTemplateArgAnnotation &argAnnotation = annotation->GetTemplateArgAnnotation(i);
+        const clang::TemplateArgument &arg = args[i];
+        switch (arg.getKind()) {
+        case clang::TemplateArgument::ArgKind::Type:
+          argAnnotation.SetType(CGM.getTypes().ConvertType(arg.getAsType()));
+        break;
+        case clang::TemplateArgument::ArgKind::Integral:
+          argAnnotation.SetIntegral(arg.getAsIntegral().getExtValue());
+          break;
+        default:
+          break;
+        }
+      }
+    }
+
     if (CXXRD->getNumBases()) {
       // Add base as field.
       for (const auto &I : CXXRD->bases()) {
@@ -963,6 +987,17 @@ static bool IsElementInputOutputType(QualType Ty) {
   return Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty) || Ty->isEnumeralType();
 }
 
+static unsigned GetNumTemplateArgsForRecordDecl(const RecordDecl *RD) {
+  if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+    if (const ClassTemplateSpecializationDecl *templateSpecializationDecl =
+          dyn_cast<ClassTemplateSpecializationDecl>(CXXRD)) {
+      const clang::TemplateArgumentList &args = templateSpecializationDecl->getTemplateInstantiationArgs();
+      return args.size();
+    }
+  }
+  return 0;
+}
+
 // Return the size for constant buffer of each decl.
 unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
                                             DxilTypeSystem &dxilTypeSys,
@@ -1001,7 +1036,8 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
       unsigned structSize = annotation->GetCBufferSize();
       return structSize;
     }
-    DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST);
+    DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST,
+      GetNumTemplateArgsForRecordDecl(RT->getDecl()));
 
     return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
   } else if (const RecordType *RT = dyn_cast<RecordType>(paramTy)) {
@@ -1013,7 +1049,8 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
       unsigned structSize = annotation->GetCBufferSize();
       return structSize;
     }
-    DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST);
+    DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST,
+      GetNumTemplateArgsForRecordDecl(RT->getDecl()));
 
     return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
   } else if (IsHLSLResourceType(Ty)) {
@@ -3612,10 +3649,8 @@ static void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F,
     llvm::Type *Ty = paramTyList[i];
     if (Ty->isPointerTy()) {
       Ty = Ty->getPointerElementType();
-      if (dxilutil::IsHLSLObjectType(Ty) &&
-          // StreamOutput don't need handle.
-          !HLModule::IsStreamOutputType(Ty)) {
-        // Use handle type for object type.
+      if (dxilutil::IsHLSLResourceType(Ty)) {
+        // Use handle type for resource type.
         // This will make sure temp object variable only used by createHandle.
         paramTyList[i] = HandleTy;
       }
@@ -3673,7 +3708,7 @@ static void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F,
     gep_type_iterator GEPIt = gep_type_begin(objGEP), E = gep_type_end(objGEP);
     llvm::Type *resTy = nullptr;
     while (GEPIt != E) {
-      if (dxilutil::IsHLSLObjectType(*GEPIt)) {
+      if (dxilutil::IsHLSLResourceType(*GEPIt)) {
         resTy = *GEPIt;
         break;
       }
@@ -3756,9 +3791,7 @@ static void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F,
       llvm::Type *Ty = arg->getType();
       if (Ty->isPointerTy()) {
         Ty = Ty->getPointerElementType();
-        if (dxilutil::IsHLSLObjectType(Ty) &&
-          // StreamOutput don't need handle.
-          !HLModule::IsStreamOutputType(Ty)) {
+        if (dxilutil::IsHLSLResourceType(Ty)) {
           // Use object type directly, not by pointer.
           // This will make sure temp object variable only used by ld/st.
           if (GEPOperator *argGEP = dyn_cast<GEPOperator>(arg)) {
@@ -4770,8 +4803,6 @@ static void CreateWriteEnabledStaticGlobals(llvm::Module *M,
   }
 }
 
-
-
 void CGMSHLSLRuntime::FinishCodeGen() {
   // Library don't have entry.
   if (!m_bIsLib) {

+ 42 - 4
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -199,6 +199,9 @@ enum ArBasicKind {
   AR_OBJECT_TRIANGLE_HIT_GROUP,
   AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
 
+  // RayQuery
+  AR_OBJECT_RAY_QUERY,
+
   AR_BASIC_MAXIMUM_COUNT
 };
 
@@ -476,6 +479,8 @@ const UINT g_uBasicKindProps[] =
   0,      //AR_OBJECT_TRIANGLE_HIT_GROUP,
   0,      //AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
 
+  0,      //AR_OBJECT_RAY_QUERY,
+
   // AR_BASIC_MAXIMUM_COUNT
 };
 
@@ -1286,7 +1291,9 @@ const ArBasicKind g_ArBasicKindsAsTypes[] =
   AR_OBJECT_RAYTRACING_SHADER_CONFIG,
   AR_OBJECT_RAYTRACING_PIPELINE_CONFIG,
   AR_OBJECT_TRIANGLE_HIT_GROUP,
-  AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP
+  AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
+
+  AR_OBJECT_RAY_QUERY
 };
 
 // Count of template arguments for basic kind of objects that look like templates (one or more type arguments).
@@ -1366,6 +1373,8 @@ const uint8_t g_ArBasicKindsTemplateCount[] =
   0, // AR_OBJECT_RAYTRACING_PIPELINE_CONFIG,
   0, // AR_OBJECT_TRIANGLE_HIT_GROUP,
   0, // AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
+
+  1, // AR_OBJECT_RAY_QUERY,
 };
 
 C_ASSERT(_countof(g_ArBasicKindsAsTypes) == _countof(g_ArBasicKindsTemplateCount));
@@ -1456,6 +1465,7 @@ const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] =
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_TRIANGLE_HIT_GROUP,
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
 
+  { 0, MipsFalse, SampleFalse },  // AR_OBJECT_RAY_QUERY,
 };
 
 C_ASSERT(_countof(g_ArBasicKindsAsTypes) == _countof(g_ArBasicKindsSubscripts));
@@ -1568,7 +1578,9 @@ const char* g_ArBasicTypeNames[] =
   "RaytracingShaderConfig",
   "RaytracingPipelineConfig",
   "TriangleHitGroup",
-  "ProceduralPrimitiveHitGroup"
+  "ProceduralPrimitiveHitGroup",
+
+  "RayQuery"
 };
 
 C_ASSERT(_countof(g_ArBasicTypeNames) == AR_BASIC_MAXIMUM_COUNT);
@@ -2120,7 +2132,11 @@ void GetIntrinsicMethods(ArBasicKind kind, _Outptr_result_buffer_(*intrinsicCoun
     *intrinsics = g_ConsumeStructuredBufferMethods;
     *intrinsicCount = _countof(g_ConsumeStructuredBufferMethods);
     break;
-  // SPIRV change starts
+  case AR_OBJECT_RAY_QUERY:
+    *intrinsics = g_RayQueryMethods;
+    *intrinsicCount = _countof(g_RayQueryMethods);
+    break;
+    // SPIRV change starts
 #ifdef ENABLE_SPIRV_CODEGEN
   case AR_OBJECT_VK_SUBPASS_INPUT:
     *intrinsics = g_VkSubpassInputMethods;
@@ -3219,6 +3235,12 @@ private:
           recordDecl = CreateSubobjectProceduralPrimitiveHitGroup(*m_context);
           break;
         }
+      } else if (kind == AR_OBJECT_RAY_QUERY) {
+        ClassTemplateDecl* typeDecl = nullptr;
+        AddRayQueryTemplate(*m_context, &typeDecl, &recordDecl);
+        DXASSERT(typeDecl != nullptr, "AddRayQueryTemplate failed to return the object declaration");
+        typeDecl->setImplicit(true);
+        recordDecl->setImplicit(true);
       }
       else if (templateArgCount == 0)
       {
@@ -3419,6 +3441,13 @@ public:
     return IsSubobjectBasicKind(GetTypeElementKind(type));
   }
 
+  bool IsRayQueryBasicKind(ArBasicKind kind) {
+    return kind == AR_OBJECT_RAY_QUERY;
+  }
+  bool IsRayQueryType(QualType type) {
+    return IsRayQueryBasicKind(GetTypeElementKind(type));
+  }
+
   void WarnMinPrecision(HLSLScalarType type, SourceLocation loc) {
     // TODO: enalbe this once we introduce precise master option
     bool UseMinPrecision = m_context->getLangOpts().UseMinPrecision;
@@ -4882,7 +4911,8 @@ QualType GetFirstElementTypeFromDecl(const Decl* decl)
   if (specialization) {
     const TemplateArgumentList& list = specialization->getTemplateArgs();
     if (list.size()) {
-      return list[0].getAsType();
+      if (list[0].getKind() == TemplateArgument::ArgKind::Type)
+        return list[0].getAsType();
     }
   }
 
@@ -6744,6 +6774,14 @@ void HLSLExternalSource::InitializeInitSequenceForHLSL(
 
   // In HLSL there are no default initializers, eg float4x4 m();
   if (Kind.getKind() == InitializationKind::IK_Default) {
+    // Except for RayQuery.
+    if (GetTypeElementKind(Entity.getType()) == AR_OBJECT_RAY_QUERY) {
+      // RayQuery handle initialization
+      // TODO: Try generating an intrinsic method call for AllocateRayQuery
+      // - pass the flags from the intrinsic argument.
+      // Lower to intrinsic that takes flags and returns i32 value,
+      // which is then stored in the RayQuery alloca as the handle.
+    }
     return;
   }
 

+ 19 - 0
tools/clang/lib/Sema/gen_intrin_main_tables_15.h

@@ -5787,6 +5787,24 @@ static const HLSL_INTRINSIC g_ConsumeStructuredBufferMethods[] =
     {(UINT)hlsl::IntrinsicOp::MOP_GetDimensions, false, false, -1, 3, g_ConsumeStructuredBufferMethods_Args1},
 };
 
+//
+// Start of RayQueryMethods
+//
+
+static const HLSL_INTRINSIC_ARGUMENT g_RayQueryMethods_Args0[] =
+{
+    {"TraceRayInline", 0, 0, LITEMPLATE_VOID, 0, LICOMPTYPE_VOID, 0, 0},
+    {"AccelerationStructure", AR_QUAL_IN, 1, LITEMPLATE_OBJECT, 1, LICOMPTYPE_ACCELERATION_STRUCT, 1, 1},
+    {"RayFlags", AR_QUAL_IN, 2, LITEMPLATE_SCALAR, 2, LICOMPTYPE_UINT, 1, 1},
+    {"InstanceInclusionMask", AR_QUAL_IN, 3, LITEMPLATE_SCALAR, 3, LICOMPTYPE_UINT, 1, 1},
+    {"Ray", AR_QUAL_IN, 4, LITEMPLATE_OBJECT, 4, LICOMPTYPE_RAYDESC, 1, 1},
+};
+
+static const HLSL_INTRINSIC g_RayQueryMethods[] =
+{
+    {(UINT)hlsl::IntrinsicOp::MOP_TraceRayInline, false, false, -1, 5, g_RayQueryMethods_Args0},
+};
+
 //
 // Start of VkSubpassInputMethods
 //
@@ -5840,6 +5858,7 @@ static const UINT g_uRWTexture1DMethodsCount = 4;
 static const UINT g_uRWTexture2DArrayMethodsCount = 4;
 static const UINT g_uRWTexture2DMethodsCount = 4;
 static const UINT g_uRWTexture3DMethodsCount = 4;
+static const UINT g_uRayQueryMethodsCount = 1;
 static const UINT g_uStreamMethodsCount = 2;
 static const UINT g_uStructuredBufferMethodsCount = 3;
 static const UINT g_uTexture1DArrayMethodsCount = 31;

+ 19 - 0
tools/clang/test/CodeGenHLSL/batch/shader_stages/raytracing/rayquery/tracerayinline.hlsl

@@ -0,0 +1,19 @@
+// RUN: %dxc -T vs_6_5 -E main %s | FileCheck %s
+
+// CHECK: %[[RTAS:[^ ]+]] = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 0, i32 0, i32 0, i1 false)
+// CHECK: %[[RQ:[^ ]+]] = call i32 @dx.op.allocateRayQuery(i32 168, i32 1)
+// CHECK: call void @dx.op.traceRayInline(i32 169, i32 %[[RQ]], %dx.types.Handle %[[RTAS]], i32 0, i32 1,
+// CHECK: call void @dx.op.traceRayInline(i32 169, i32 %[[RQ]], %dx.types.Handle %[[RTAS]], i32 1, i32 2,
+
+RaytracingAccelerationStructure RTAS;
+
+void DoTrace(RayQuery<RAY_FLAG_FORCE_OPAQUE> rayQuery, RayDesc rayDesc) {
+  rayQuery.TraceRayInline(RTAS, 0, 1, rayDesc);
+}
+
+float main(RayDesc rayDesc : RAYDESC) : OUT {
+  RayQuery<RAY_FLAG_FORCE_OPAQUE> rayQuery;
+  DoTrace(rayQuery, rayDesc);
+  rayQuery.TraceRayInline(RTAS, 1, 2, rayDesc);
+  return 0;
+}

+ 3 - 1
tools/clang/tools/dxcompiler/dxcdisassembler.cpp

@@ -1182,7 +1182,9 @@ static const char *OpCodeSignatures[] = {
   "(acc,a,b)",  // Dot4AddU8Packed
   "(value)",  // WaveMatch
   "(value,mask0,mask1,mask2,mask3,op,sop)",  // WaveMultiPrefixOp
-  "(value,mask0,mask1,mask2,mask3)"  // WaveMultiPrefixBitCount
+  "(value,mask0,mask1,mask2,mask3)",  // WaveMultiPrefixBitCount
+  "(constRayFlags)",  // AllocateRayQuery
+  "(rayQueryHandle,accelerationStructure,rayFlags,instanceInclusionMask,origin_X,origin_Y,origin_Z,tMin,direction_X,direction_Y,direction_Z,tMax)"  // TraceRayInline
 };
 // OPCODE-SIGS:END
 

+ 6 - 0
utils/hct/gen_intrin_main.txt

@@ -834,6 +834,12 @@ $classT [[]] Consume() : structuredbuffer_consume;
 
 } namespace
 
+namespace RayQueryMethods {
+
+void [[]]  TraceRayInline(in acceleration_struct AccelerationStructure, in uint RayFlags, in uint InstanceInclusionMask, in ray_desc Ray);
+
+} namespace
+
 // SPIRV Change Starts
 
 namespace VkSubpassInputMethods {

+ 26 - 2
utils/hct/hctdb.py

@@ -380,6 +380,9 @@ class db_dxil(object):
         for i in "WaveMatch,WaveMultiPrefixOp,WaveMultiPrefixBitCount".split(","):
             self.name_idx[i].category = "Wave"
             self.name_idx[i].shader_model = 6,5
+        for i in "AllocateRayQuery,TraceRayInline".split(","):
+            self.name_idx[i].category = "Inline Ray Query"
+            self.name_idx[i].shader_model = 6,5
 
     def populate_llvm_instructions(self):
         # Add instructions that map to LLVM instructions.
@@ -1269,7 +1272,7 @@ class db_dxil(object):
             db_dxil_param(0, "v", "", "")])
         next_op_idx += 1
 
-        self.add_dxil_op("TraceRay", next_op_idx, "TraceRay", "returns the view index", "u", "", [
+        self.add_dxil_op("TraceRay", next_op_idx, "TraceRay", "initiates raytrace", "u", "", [
             db_dxil_param(0, "v", "", ""),
             db_dxil_param(2, "res", "AccelerationStructure", "Top-level acceleration structure to use"),
             db_dxil_param(3, "i32", "RayFlags", "Valid combination of Ray_flags"),
@@ -1373,9 +1376,30 @@ class db_dxil(object):
             db_dxil_param(6, "i32", "mask3", "mask 3")])
         next_op_idx += 1
 
+        self.add_dxil_op("AllocateRayQuery", next_op_idx, "AllocateRayQuery", "allocate space for RayQuery and return handle", "v", "", [
+            db_dxil_param(0, "i32", "", "handle to RayQuery state"),
+            db_dxil_param(2, "u32", "constRayFlags", "Valid combination of RAY_FLAGS", is_const=True)])
+        next_op_idx += 1
+
+        self.add_dxil_op("TraceRayInline", next_op_idx, "TraceRayInline", "initialize RayQuery for raytrace", "v", "", [
+            db_dxil_param(0, "v", "", ""),
+            db_dxil_param(2, "i32", "rayQueryHandle", "RayQuery handle"),
+            db_dxil_param(3, "res", "accelerationStructure", "Top-level acceleration structure to use"),
+            db_dxil_param(4, "i32", "rayFlags", "Valid combination of RAY_FLAGS, combined with constRayFlags provided to AllocateRayQuery"),
+            db_dxil_param(5, "i32", "instanceInclusionMask", "Bottom 8 bits of InstanceInclusionMask are used to include/rejectgeometry instances based on the InstanceMask in each instance: if(!((InstanceInclusionMask & InstanceMask) & 0xff)) { ignore intersection }"),
+            db_dxil_param(6, "f", "origin_X", "Origin x of the ray"),
+            db_dxil_param(7, "f", "origin_Y", "Origin y of the ray"),
+            db_dxil_param(8, "f", "origin_Z", "Origin z of the ray"),
+            db_dxil_param(9, "f", "tMin", "Tmin of the ray"),
+            db_dxil_param(10, "f", "direction_X", "Direction x of the ray"),
+            db_dxil_param(11, "f", "direction_Y", "Direction y of the ray"),
+            db_dxil_param(12, "f", "direction_Z", "Direction z of the ray"),
+            db_dxil_param(13, "f", "tMax", "Tmax of the ray")])
+        next_op_idx += 1
+
         # End of DXIL 1.5 opcodes.
         self.set_op_count_for_version(1, 5, next_op_idx)
-        assert next_op_idx == 168, "next operation index is %d rather than 168 and thus opcodes are broken" % next_op_idx
+        assert next_op_idx == 170, "next operation index is %d rather than 169 and thus opcodes are broken" % next_op_idx
 
         # Set interesting properties.
         self.build_indices()