Przeglądaj źródła

Add RayQuery object, TraceRayInline method + template arg annotations

Tex Riddell 6 lat temu
rodzic
commit
2209844cda

+ 3 - 1
docs/DXIL.rst

@@ -2245,7 +2245,7 @@ ID  Name                          Description
 154 RayTCurrent                   float representing the current parametric ending point for the ray
 154 RayTCurrent                   float representing the current parametric ending point for the ray
 155 IgnoreHit                     Used in an any hit shader to reject an intersection and terminate the shader
 155 IgnoreHit                     Used in an any hit shader to reject an intersection and terminate the shader
 156 AcceptHitAndEndSearch         Used in an any hit shader to abort the ray query and the intersection shader (if any). The current hit is committed and execution passes to the closest hit shader with the closest hit recorded so far
 156 AcceptHitAndEndSearch         Used in an any hit shader to abort the ray query and the intersection shader (if any). The current hit is committed and execution passes to the closest hit shader with the closest hit recorded so far
-157 TraceRay                      returns the view index
+157 TraceRay                      initiates raytrace
 158 ReportHit                     returns true if hit was accepted
 158 ReportHit                     returns true if hit was accepted
 159 CallShader                    Call a shader in the callable shader table supplied through the DispatchRays() API
 159 CallShader                    Call a shader in the callable shader table supplied through the DispatchRays() API
 160 CreateHandleForLib            create resource handle from resource struct for library
 160 CreateHandleForLib            create resource handle from resource struct for library
@@ -2256,6 +2256,8 @@ ID  Name                          Description
 165 WaveMatch                     returns the bitmask of active lanes that have the same value
 165 WaveMatch                     returns the bitmask of active lanes that have the same value
 166 WaveMultiPrefixOp             returns the result of the operation on groups of lanes identified by a bitmask
 166 WaveMultiPrefixOp             returns the result of the operation on groups of lanes identified by a bitmask
 167 WaveMultiPrefixBitCount       returns the count of bits set to 1 on groups of lanes identified by a bitmask
 167 WaveMultiPrefixBitCount       returns the count of bits set to 1 on groups of lanes identified by a bitmask
+168 AllocateRayQuery              allocate space for RayQuery and return handle
+169 TraceRayInline                initialize RayQuery for raytrace
 === ============================= =======================================================================================================================================================================================================================
 === ============================= =======================================================================================================================================================================================================================
 
 
 
 

+ 20 - 5
include/dxc/DXIL/DxilConstants.h

@@ -299,6 +299,9 @@ namespace DXIL {
   // OPCODE-ENUM:BEGIN
   // OPCODE-ENUM:BEGIN
   // Enumeration for operations specified by DXIL
   // Enumeration for operations specified by DXIL
   enum class OpCode : unsigned {
   enum class OpCode : unsigned {
+    // 
+    AllocateRayQuery = 168, // allocate space for RayQuery and return handle
+  
     // AnyHit Terminals
     // AnyHit Terminals
     AcceptHitAndEndSearch = 156, // Used in an any hit shader to abort the ray query and the intersection shader (if any). The current hit is committed and execution passes to the closest hit shader with the closest hit recorded so far
     AcceptHitAndEndSearch = 156, // Used in an any hit shader to abort the ray query and the intersection shader (if any). The current hit is committed and execution passes to the closest hit shader with the closest hit recorded so far
     IgnoreHit = 155, // Used in an any hit shader to reject an intersection and terminate the shader
     IgnoreHit = 155, // Used in an any hit shader to reject an intersection and terminate the shader
@@ -383,7 +386,10 @@ namespace DXIL {
     // Indirect Shader Invocation
     // Indirect Shader Invocation
     CallShader = 159, // Call a shader in the callable shader table supplied through the DispatchRays() API
     CallShader = 159, // Call a shader in the callable shader table supplied through the DispatchRays() API
     ReportHit = 158, // returns true if hit was accepted
     ReportHit = 158, // returns true if hit was accepted
-    TraceRay = 157, // returns the view index
+    TraceRay = 157, // initiates raytrace
+  
+    // Inline Ray Query
+    TraceRayInline = 169, // initialize RayQuery for raytrace
   
   
     // Legacy floating-point
     // Legacy floating-point
     LegacyF16ToF32 = 131, // legacy fuction to convert half (f16) to float (f32) (this is not related to min-precision)
     LegacyF16ToF32 = 131, // legacy fuction to convert half (f16) to float (f32) (this is not related to min-precision)
@@ -562,9 +568,9 @@ namespace DXIL {
     NumOpCodes_Dxil_1_2 = 141,
     NumOpCodes_Dxil_1_2 = 141,
     NumOpCodes_Dxil_1_3 = 162,
     NumOpCodes_Dxil_1_3 = 162,
     NumOpCodes_Dxil_1_4 = 165,
     NumOpCodes_Dxil_1_4 = 165,
-    NumOpCodes_Dxil_1_5 = 168,
+    NumOpCodes_Dxil_1_5 = 170,
   
   
-    NumOpCodes = 168 // exclusive last value of enumeration
+    NumOpCodes = 170 // exclusive last value of enumeration
   };
   };
   // OPCODE-ENUM:END
   // OPCODE-ENUM:END
 
 
@@ -572,6 +578,9 @@ namespace DXIL {
   // OPCODECLASS-ENUM:BEGIN
   // OPCODECLASS-ENUM:BEGIN
   // Groups for DXIL operations with equivalent function templates
   // Groups for DXIL operations with equivalent function templates
   enum class OpCodeClass : unsigned {
   enum class OpCodeClass : unsigned {
+    // 
+    AllocateRayQuery,
+  
     // AnyHit Terminals
     // AnyHit Terminals
     AcceptHitAndEndSearch,
     AcceptHitAndEndSearch,
     IgnoreHit,
     IgnoreHit,
@@ -643,6 +652,9 @@ namespace DXIL {
     ReportHit,
     ReportHit,
     TraceRay,
     TraceRay,
   
   
+    // Inline Ray Query
+    TraceRayInline,
+  
     // LLVM Instructions
     // LLVM Instructions
     LlvmInst,
     LlvmInst,
   
   
@@ -778,9 +790,9 @@ namespace DXIL {
     NumOpClasses_Dxil_1_2 = 97,
     NumOpClasses_Dxil_1_2 = 97,
     NumOpClasses_Dxil_1_3 = 118,
     NumOpClasses_Dxil_1_3 = 118,
     NumOpClasses_Dxil_1_4 = 120,
     NumOpClasses_Dxil_1_4 = 120,
-    NumOpClasses_Dxil_1_5 = 123,
+    NumOpClasses_Dxil_1_5 = 125,
   
   
-    NumOpClasses = 123 // exclusive last value of enumeration
+    NumOpClasses = 125 // exclusive last value of enumeration
   };
   };
   // OPCODECLASS-ENUM:END
   // OPCODECLASS-ENUM:END
 
 
@@ -910,6 +922,9 @@ namespace DXIL {
     const unsigned kTraceRayPayloadOpIdx = 15;
     const unsigned kTraceRayPayloadOpIdx = 15;
     const unsigned kTraceRayNumOp = 16;
     const unsigned kTraceRayNumOp = 16;
 
 
+    // TraceRayInline
+    const unsigned kTraceRayInlineRayDescOpIdx = 5;
+    const unsigned kTraceRayInlineNumOp = 13;
 
 
     // Emit/Cut
     // Emit/Cut
     const unsigned kStreamEmitCutIDOpIdx = 1;
     const unsigned kStreamEmitCutIDOpIdx = 1;

+ 86 - 1
include/dxc/DXIL/DxilInstructions.h

@@ -5173,7 +5173,7 @@ struct DxilInst_AcceptHitAndEndSearch {
   bool requiresUniformInputs() const { return false; }
   bool requiresUniformInputs() const { return false; }
 };
 };
 
 
-/// This instruction returns the view index
+/// This instruction initiates raytrace
 struct DxilInst_TraceRay {
 struct DxilInst_TraceRay {
   llvm::Instruction *Instr;
   llvm::Instruction *Instr;
   // Construction and identification
   // Construction and identification
@@ -5549,5 +5549,90 @@ struct DxilInst_WaveMultiPrefixBitCount {
   llvm::Value *get_mask3() const { return Instr->getOperand(5); }
   llvm::Value *get_mask3() const { return Instr->getOperand(5); }
   void set_mask3(llvm::Value *val) { Instr->setOperand(5, val); }
   void set_mask3(llvm::Value *val) { Instr->setOperand(5, val); }
 };
 };
+
+/// This instruction allocate space for RayQuery and return handle
+struct DxilInst_AllocateRayQuery {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_AllocateRayQuery(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::AllocateRayQuery);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_constRayFlags = 1,
+  };
+  // Accessors
+  llvm::Value *get_constRayFlags() const { return Instr->getOperand(1); }
+  void set_constRayFlags(llvm::Value *val) { Instr->setOperand(1, val); }
+  uint32_t get_constRayFlags_val() const { return (uint32_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(1))->getZExtValue()); }
+  void set_constRayFlags_val(uint32_t val) { Instr->setOperand(1, llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), 32), llvm::APInt(32, (uint64_t)val))); }
+};
+
+/// This instruction initialize RayQuery for raytrace
+struct DxilInst_TraceRayInline {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_TraceRayInline(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::TraceRayInline);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (13 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_rayQueryHandle = 1,
+    arg_accelerationStructure = 2,
+    arg_rayFlags = 3,
+    arg_instanceInclusionMask = 4,
+    arg_origin_X = 5,
+    arg_origin_Y = 6,
+    arg_origin_Z = 7,
+    arg_tMin = 8,
+    arg_direction_X = 9,
+    arg_direction_Y = 10,
+    arg_direction_Z = 11,
+    arg_tMax = 12,
+  };
+  // Accessors
+  llvm::Value *get_rayQueryHandle() const { return Instr->getOperand(1); }
+  void set_rayQueryHandle(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_accelerationStructure() const { return Instr->getOperand(2); }
+  void set_accelerationStructure(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_rayFlags() const { return Instr->getOperand(3); }
+  void set_rayFlags(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_instanceInclusionMask() const { return Instr->getOperand(4); }
+  void set_instanceInclusionMask(llvm::Value *val) { Instr->setOperand(4, val); }
+  llvm::Value *get_origin_X() const { return Instr->getOperand(5); }
+  void set_origin_X(llvm::Value *val) { Instr->setOperand(5, val); }
+  llvm::Value *get_origin_Y() const { return Instr->getOperand(6); }
+  void set_origin_Y(llvm::Value *val) { Instr->setOperand(6, val); }
+  llvm::Value *get_origin_Z() const { return Instr->getOperand(7); }
+  void set_origin_Z(llvm::Value *val) { Instr->setOperand(7, val); }
+  llvm::Value *get_tMin() const { return Instr->getOperand(8); }
+  void set_tMin(llvm::Value *val) { Instr->setOperand(8, val); }
+  llvm::Value *get_direction_X() const { return Instr->getOperand(9); }
+  void set_direction_X(llvm::Value *val) { Instr->setOperand(9, val); }
+  llvm::Value *get_direction_Y() const { return Instr->getOperand(10); }
+  void set_direction_Y(llvm::Value *val) { Instr->setOperand(10, val); }
+  llvm::Value *get_direction_Z() const { return Instr->getOperand(11); }
+  void set_direction_Z(llvm::Value *val) { Instr->setOperand(11, val); }
+  llvm::Value *get_tMax() const { return Instr->getOperand(12); }
+  void set_tMax(llvm::Value *val) { Instr->setOperand(12, val); }
+};
 // INSTR-HELPER:END
 // INSTR-HELPER:END
 } // namespace hlsl
 } // namespace hlsl

+ 13 - 0
include/dxc/DXIL/DxilMetadataHelper.h

@@ -46,6 +46,7 @@ class DxilSampler;
 class DxilTypeSystem;
 class DxilTypeSystem;
 class DxilStructAnnotation;
 class DxilStructAnnotation;
 class DxilFieldAnnotation;
 class DxilFieldAnnotation;
+class DxilTemplateArgAnnotation;
 class DxilFunctionAnnotation;
 class DxilFunctionAnnotation;
 class DxilParameterAnnotation;
 class DxilParameterAnnotation;
 class RootSignatureHandle;
 class RootSignatureHandle;
@@ -191,6 +192,16 @@ public:
   static const unsigned kDxilFieldAnnotationCompTypeTag           = 7;
   static const unsigned kDxilFieldAnnotationCompTypeTag           = 7;
   static const unsigned kDxilFieldAnnotationPreciseTag            = 8;
   static const unsigned kDxilFieldAnnotationPreciseTag            = 8;
 
 
+  // StructAnnotation extended property tags (DXIL 1.5+ only, appended)
+  static const unsigned kDxilTemplateArgumentsTag                 = 0;  // Name for name-value list of extended struct properties
+  // TemplateArgument tags
+  static const unsigned kDxilTemplateArgTypeTag                   = 0;  // Type template argument, followed by undef of type
+  static const unsigned kDxilTemplateArgIntegralTag               = 1;  // Integral template argument, followed by i64 value
+  // TemplateArgType
+  static const unsigned kDxilTemplateArgType                      = 1;  // Position of type for template arg that is type
+  static const unsigned kDxilTemplateArgIntegral                  = 1;  // Position of i64 for template arg that is integral
+
+
   // Control flow hint.
   // Control flow hint.
   static const char kDxilControlFlowHintMDName[];
   static const char kDxilControlFlowHintMDName[];
 
 
@@ -351,6 +362,8 @@ public:
   void LoadDxilParamAnnotation(const llvm::MDOperand &MDO, DxilParameterAnnotation &PA);
   void LoadDxilParamAnnotation(const llvm::MDOperand &MDO, DxilParameterAnnotation &PA);
   llvm::Metadata *EmitDxilParamAnnotations(const DxilFunctionAnnotation &FA);
   llvm::Metadata *EmitDxilParamAnnotations(const DxilFunctionAnnotation &FA);
   void LoadDxilParamAnnotations(const llvm::MDOperand &MDO, DxilFunctionAnnotation &FA);
   void LoadDxilParamAnnotations(const llvm::MDOperand &MDO, DxilFunctionAnnotation &FA);
+  llvm::Metadata *EmitDxilTemplateArgAnnotation(const DxilTemplateArgAnnotation &annotation);
+  void LoadDxilTemplateArgAnnotation(const llvm::MDOperand &MDO, DxilTemplateArgAnnotation &annotation);
 
 
   // Function props.
   // Function props.
   llvm::MDTuple *EmitDxilFunctionProps(const hlsl::DxilFunctionProps *props,
   llvm::MDTuple *EmitDxilFunctionProps(const hlsl::DxilFunctionProps *props,

+ 25 - 1
include/dxc/DXIL/DxilTypeSystem.h

@@ -90,6 +90,22 @@ private:
   std::string m_FieldName;
   std::string m_FieldName;
 };
 };
 
 
+class DxilTemplateArgAnnotation : DxilFieldAnnotation {
+public:
+  DxilTemplateArgAnnotation();
+
+  bool IsType() const;
+  const llvm::Type *GetType() const;
+  void SetType(const llvm::Type *pType);
+
+  bool IsIntegral() const;
+  int64_t GetIntegral() const;
+  void SetIntegral(int64_t i64);
+
+private:
+  const llvm::Type *m_Type;
+  int64_t m_Integral;
+};
 
 
 /// Use this class to represent LLVM structure annotation.
 /// Use this class to represent LLVM structure annotation.
 class DxilStructAnnotation {
 class DxilStructAnnotation {
@@ -104,10 +120,18 @@ public:
   void SetCBufferSize(unsigned size);
   void SetCBufferSize(unsigned size);
   void MarkEmptyStruct();
   void MarkEmptyStruct();
   bool IsEmptyStruct();
   bool IsEmptyStruct();
+
+  // For template args, GetNumTemplateArgs() will return 0 if not a template
+  unsigned GetNumTemplateArgs() const;
+  void SetNumTemplateArgs(unsigned count);
+  DxilTemplateArgAnnotation &GetTemplateArgAnnotation(unsigned argIdx);
+  const DxilTemplateArgAnnotation &GetTemplateArgAnnotation(unsigned argIdx) const;
+
 private:
 private:
   const llvm::StructType *m_pStructType;
   const llvm::StructType *m_pStructType;
   std::vector<DxilFieldAnnotation> m_FieldAnnotations;
   std::vector<DxilFieldAnnotation> m_FieldAnnotations;
   unsigned m_CBufferSize;  // The size of struct if inside constant buffer.
   unsigned m_CBufferSize;  // The size of struct if inside constant buffer.
+  std::vector<DxilTemplateArgAnnotation> m_TemplateAnnotations;
 };
 };
 
 
 
 
@@ -163,7 +187,7 @@ public:
 
 
   DxilTypeSystem(llvm::Module *pModule);
   DxilTypeSystem(llvm::Module *pModule);
 
 
-  DxilStructAnnotation *AddStructAnnotation(const llvm::StructType *pStructType);
+  DxilStructAnnotation *AddStructAnnotation(const llvm::StructType *pStructType, unsigned numTemplateArgs = 0);
   DxilStructAnnotation *GetStructAnnotation(const llvm::StructType *pStructType);
   DxilStructAnnotation *GetStructAnnotation(const llvm::StructType *pStructType);
   const DxilStructAnnotation *GetStructAnnotation(const llvm::StructType *pStructType) const;
   const DxilStructAnnotation *GetStructAnnotation(const llvm::StructType *pStructType) const;
   void EraseStructAnnotation(const llvm::StructType *pStructType);
   void EraseStructAnnotation(const llvm::StructType *pStructType);

+ 1 - 0
include/dxc/DXIL/DxilUtil.h

@@ -109,6 +109,7 @@ namespace dxilutil {
   bool ContainsHLSLObjectType(llvm::Type *Ty);
   bool ContainsHLSLObjectType(llvm::Type *Ty);
   bool IsHLSLResourceType(llvm::Type *Ty);
   bool IsHLSLResourceType(llvm::Type *Ty);
   bool IsHLSLObjectType(llvm::Type *Ty);
   bool IsHLSLObjectType(llvm::Type *Ty);
+  bool IsHLSLRayQueryType(llvm::Type *Ty);
   bool IsSplat(llvm::ConstantDataVector *cdv);
   bool IsSplat(llvm::ConstantDataVector *cdv);
 }
 }
 
 

+ 3 - 0
include/dxc/HLSL/HLOperations.h

@@ -333,6 +333,9 @@ const unsigned kCreateHandleIndexOpIdx = 2; // Only for array of cbuffer.
 const unsigned kTraceRayRayDescOpIdx = 7;
 const unsigned kTraceRayRayDescOpIdx = 7;
 const unsigned kTraceRayPayLoadOpIdx = 8;
 const unsigned kTraceRayPayLoadOpIdx = 8;
 
 
+// TraceRayInline.
+const unsigned kTraceRayInlineRayDescOpIdx = 5;
+
 // ReportIntersection.
 // ReportIntersection.
 const unsigned kReportIntersectionAttributeOpIdx = 3;
 const unsigned kReportIntersectionAttributeOpIdx = 3;
 
 

+ 1 - 0
include/dxc/HlslIntrinsicOp.h

@@ -260,6 +260,7 @@ import hctdb_instrhelp
   MOP_DecrementCounter,
   MOP_DecrementCounter,
   MOP_IncrementCounter,
   MOP_IncrementCounter,
   MOP_Consume,
   MOP_Consume,
+  MOP_TraceRayInline,
 #ifdef ENABLE_SPIRV_CODEGEN
 #ifdef ENABLE_SPIRV_CODEGEN
   MOP_SubpassLoad,
   MOP_SubpassLoad,
 #endif // ENABLE_SPIRV_CODEGEN
 #endif // ENABLE_SPIRV_CODEGEN

+ 71 - 2
lib/DXIL/DxilMetadataHelper.cpp

@@ -774,13 +774,62 @@ void DxilMDHelper::LoadDxilTypeSystem(DxilTypeSystem &TypeSystem) {
   }
   }
 }
 }
 
 
+Metadata *DxilMDHelper::EmitDxilTemplateArgAnnotation(const DxilTemplateArgAnnotation &annotation) {
+  SmallVector<Metadata *, 2> MDVals;
+  if (annotation.IsType()) {
+    MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilTemplateArgTypeTag));
+    MDVals.emplace_back(ValueAsMetadata::get(UndefValue::get(const_cast<Type*>(annotation.GetType()))));
+  } else if (annotation.IsIntegral()) {
+    MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilTemplateArgIntegralTag));
+    MDVals.emplace_back(Uint64ToConstMD((uint64_t)annotation.GetIntegral()));
+  }
+  return MDNode::get(m_Ctx, MDVals);
+}
+void DxilMDHelper::LoadDxilTemplateArgAnnotation(const llvm::MDOperand &MDO, DxilTemplateArgAnnotation &annotation) {
+  IFTBOOL(MDO.get() != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  const MDTuple *pTupleMD = dyn_cast<MDTuple>(MDO.get());
+  IFTBOOL(pTupleMD != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  IFTBOOL(pTupleMD->getNumOperands() >= 1, DXC_E_INCORRECT_DXIL_METADATA);
+  unsigned Tag = ConstMDToUint32(pTupleMD->getOperand(0));
+  switch (Tag) {
+  case kDxilTemplateArgTypeTag:
+    IFTBOOL(pTupleMD->getNumOperands() == 2, DXC_E_INCORRECT_DXIL_METADATA);
+    annotation.SetType(MetadataAsValue::get(m_Ctx,
+      pTupleMD->getOperand(kDxilTemplateArgType))->getType());
+    break;
+  case kDxilTemplateArgIntegralTag:
+    IFTBOOL(pTupleMD->getNumOperands() == 2, DXC_E_INCORRECT_DXIL_METADATA);
+    annotation.SetIntegral((int64_t)ConstMDToUint64(pTupleMD->getOperand(kDxilTemplateArgType)));
+    break;
+  }
+}
+
 Metadata *DxilMDHelper::EmitDxilStructAnnotation(const DxilStructAnnotation &SA) {
 Metadata *DxilMDHelper::EmitDxilStructAnnotation(const DxilStructAnnotation &SA) {
-  vector<Metadata *> MDVals(SA.GetNumFields() + 1);
+  unsigned valMajor, valMinor;
+  m_pSM->GetMinValidatorVersion(valMajor, valMinor);
+  bool bSupportExtended = !(valMajor == 1 && valMinor < 5);
+
+  vector<Metadata *> MDVals;
+  MDVals.reserve(SA.GetNumFields() + 2);  // In case of extended 1.5 property list
+  MDVals.resize(SA.GetNumFields() + 1);
+
   MDVals[0] = Uint32ToConstMD(SA.GetCBufferSize());
   MDVals[0] = Uint32ToConstMD(SA.GetCBufferSize());
   for (unsigned i = 0; i < SA.GetNumFields(); i++) {
   for (unsigned i = 0; i < SA.GetNumFields(); i++) {
     MDVals[i+1] = EmitDxilFieldAnnotation(SA.GetFieldAnnotation(i));
     MDVals[i+1] = EmitDxilFieldAnnotation(SA.GetFieldAnnotation(i));
   }
   }
 
 
+  // Only add template args if shader target requires validator version that supports them.
+  if (bSupportExtended && SA.GetNumTemplateArgs()) {
+    vector<Metadata *> MDTemplateArgs(SA.GetNumTemplateArgs());
+    for (unsigned i = 0; i < SA.GetNumTemplateArgs(); ++i) {
+      MDTemplateArgs[i] = EmitDxilTemplateArgAnnotation(SA.GetTemplateArgAnnotation(i));
+    }
+    SmallVector<Metadata *, 2> MDExtraVals;
+    MDExtraVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilTemplateArgumentsTag));
+    MDExtraVals.emplace_back(MDNode::get(m_Ctx, MDTemplateArgs));
+    MDVals.emplace_back(MDNode::get(m_Ctx, MDExtraVals));
+  }
+
   return MDNode::get(m_Ctx, MDVals);
   return MDNode::get(m_Ctx, MDVals);
 }
 }
 
 
@@ -791,7 +840,27 @@ void DxilMDHelper::LoadDxilStructAnnotation(const MDOperand &MDO, DxilStructAnno
   if (pTupleMD->getNumOperands() == 1) {
   if (pTupleMD->getNumOperands() == 1) {
     SA.MarkEmptyStruct();
     SA.MarkEmptyStruct();
   }
   }
-  IFTBOOL(pTupleMD->getNumOperands() == SA.GetNumFields()+1, DXC_E_INCORRECT_DXIL_METADATA);
+  unsigned valMajor, valMinor;
+  m_pSM->GetMinValidatorVersion(valMajor, valMinor);
+  if (!(valMajor == 1 && valMinor < 5) &&
+      (pTupleMD->getNumOperands() == SA.GetNumFields()+2)) {
+    // Load template args from extended operand
+    const MDOperand &MDOExtra = pTupleMD->getOperand(SA.GetNumFields()+1);
+    const MDTuple *pTupleMDExtra = dyn_cast_or_null<MDTuple>(MDOExtra.get());
+    if(pTupleMDExtra) {
+      IFTBOOL(pTupleMDExtra->getNumOperands() % 2 == 0, DXC_E_INCORRECT_DXIL_METADATA);
+      unsigned Tag = ConstMDToUint32(pTupleMDExtra->getOperand(0));
+      IFTBOOL(Tag == kDxilTemplateArgumentsTag, DXC_E_INCORRECT_DXIL_METADATA); // Only one allowed at this point
+      const MDTuple *pTupleTemplateArgs = dyn_cast_or_null<MDTuple>(pTupleMDExtra->getOperand(1).get());
+      IFTBOOL(pTupleTemplateArgs, DXC_E_INCORRECT_DXIL_METADATA);
+      SA.SetNumTemplateArgs(pTupleTemplateArgs->getNumOperands());
+      for (unsigned i = 0; i < pTupleTemplateArgs->getNumOperands(); ++i) {
+        LoadDxilTemplateArgAnnotation(pTupleTemplateArgs->getOperand(i), SA.GetTemplateArgAnnotation(i));
+      }
+    }
+  } else {
+    IFTBOOL(pTupleMD->getNumOperands() == SA.GetNumFields()+1, DXC_E_INCORRECT_DXIL_METADATA);
+  }
 
 
   SA.SetCBufferSize(ConstMDToUint32(pTupleMD->getOperand(0)));
   SA.SetCBufferSize(ConstMDToUint32(pTupleMD->getOperand(0)));
   for (unsigned i = 0; i < SA.GetNumFields(); i++) {
   for (unsigned i = 0; i < SA.GetNumFields(); i++) {

+ 16 - 2
lib/DXIL/DxilOperations.cpp

@@ -321,6 +321,12 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = {
   {  OC::WaveMatch,               "WaveMatch",                OCC::WaveMatch,                "waveMatch",                 { false,  true,  true,  true, false,  true,  true,  true,  true, false, false}, Attribute::None,     },
   {  OC::WaveMatch,               "WaveMatch",                OCC::WaveMatch,                "waveMatch",                 { false,  true,  true,  true, false,  true,  true,  true,  true, false, false}, Attribute::None,     },
   {  OC::WaveMultiPrefixOp,       "WaveMultiPrefixOp",        OCC::WaveMultiPrefixOp,        "waveMultiPrefixOp",         { false,  true,  true,  true, false,  true,  true,  true,  true, false, false}, Attribute::None,     },
   {  OC::WaveMultiPrefixOp,       "WaveMultiPrefixOp",        OCC::WaveMultiPrefixOp,        "waveMultiPrefixOp",         { false,  true,  true,  true, false,  true,  true,  true,  true, false, false}, Attribute::None,     },
   {  OC::WaveMultiPrefixBitCount, "WaveMultiPrefixBitCount",  OCC::WaveMultiPrefixBitCount,  "waveMultiPrefixBitCount",   {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
   {  OC::WaveMultiPrefixBitCount, "WaveMultiPrefixBitCount",  OCC::WaveMultiPrefixBitCount,  "waveMultiPrefixBitCount",   {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+
+  //                                                                                                                         void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  {  OC::AllocateRayQuery,        "AllocateRayQuery",         OCC::AllocateRayQuery,         "allocateRayQuery",          {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
+
+  // Inline Ray Query                                                                                                        void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  {  OC::TraceRayInline,          "TraceRayInline",           OCC::TraceRayInline,           "traceRayInline",            {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
 };
 };
 // OPCODE-OLOADS:END
 // OPCODE-OLOADS:END
 
 
@@ -657,8 +663,8 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
     return;
     return;
   }
   }
   // Instructions: WaveMatch=165, WaveMultiPrefixOp=166,
   // Instructions: WaveMatch=165, WaveMultiPrefixOp=166,
-  // WaveMultiPrefixBitCount=167
-  if ((165 <= op && op <= 167)) {
+  // WaveMultiPrefixBitCount=167, TraceRayInline=169
+  if ((165 <= op && op <= 167) || op == 169) {
     major = 6;  minor = 5;
     major = 6;  minor = 5;
     return;
     return;
   }
   }
@@ -1062,6 +1068,12 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
   case OpCode::WaveMatch:              A(pI4S);     A(pI32); A(pETy); break;
   case OpCode::WaveMatch:              A(pI4S);     A(pI32); A(pETy); break;
   case OpCode::WaveMultiPrefixOp:      A(pETy);     A(pI32); A(pETy); A(pI32); A(pI32); A(pI32); A(pI32); A(pI8);  A(pI8);  break;
   case OpCode::WaveMultiPrefixOp:      A(pETy);     A(pI32); A(pETy); A(pI32); A(pI32); A(pI32); A(pI32); A(pI8);  A(pI8);  break;
   case OpCode::WaveMultiPrefixBitCount:A(pI32);     A(pI32); A(pI1);  A(pI32); A(pI32); A(pI32); A(pI32); break;
   case OpCode::WaveMultiPrefixBitCount:A(pI32);     A(pI32); A(pI1);  A(pI32); A(pI32); A(pI32); A(pI32); break;
+
+    // 
+  case OpCode::AllocateRayQuery:       A(pI32);     A(pI32); A(pI32); break;
+
+    // Inline Ray Query
+  case OpCode::TraceRayInline:         A(pV);       A(pI32); A(pI32); A(pRes); A(pI32); A(pI32); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); A(pF32); break;
   // OPCODE-OLOAD-FUNCS:END
   // OPCODE-OLOAD-FUNCS:END
   default: DXASSERT(false, "otherwise unhandled case"); break;
   default: DXASSERT(false, "otherwise unhandled case"); break;
   }
   }
@@ -1215,6 +1227,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   case OpCode::IgnoreHit:
   case OpCode::IgnoreHit:
   case OpCode::AcceptHitAndEndSearch:
   case OpCode::AcceptHitAndEndSearch:
   case OpCode::WaveMultiPrefixBitCount:
   case OpCode::WaveMultiPrefixBitCount:
+  case OpCode::AllocateRayQuery:
+  case OpCode::TraceRayInline:
     return Type::getVoidTy(m_Ctx);
     return Type::getVoidTy(m_Ctx);
   case OpCode::CheckAccessFullyMapped:
   case OpCode::CheckAccessFullyMapped:
   case OpCode::AtomicBinOp:
   case OpCode::AtomicBinOp:

+ 34 - 1
lib/DXIL/DxilTypeSystem.cpp

@@ -78,6 +78,22 @@ const std::string &DxilFieldAnnotation::GetFieldName() const { return m_FieldNam
 void DxilFieldAnnotation::SetFieldName(const std::string &FieldName) { m_FieldName = FieldName; }
 void DxilFieldAnnotation::SetFieldName(const std::string &FieldName) { m_FieldName = FieldName; }
 
 
 
 
+//------------------------------------------------------------------------------
+//
+// DxilStructAnnotation class methods.
+//
+DxilTemplateArgAnnotation::DxilTemplateArgAnnotation()
+    : DxilFieldAnnotation(), m_Type(nullptr), m_Integral(0)
+{}
+
+bool DxilTemplateArgAnnotation::IsType() const { return m_Type != nullptr; }
+const llvm::Type *DxilTemplateArgAnnotation::GetType() const { return m_Type; }
+void DxilTemplateArgAnnotation::SetType(const llvm::Type *pType) { m_Type = pType; }
+
+bool DxilTemplateArgAnnotation::IsIntegral() const { return m_Type == nullptr; }
+int64_t DxilTemplateArgAnnotation::GetIntegral() const { return m_Integral; }
+void DxilTemplateArgAnnotation::SetIntegral(int64_t i64) { m_Type = nullptr; m_Integral = i64; }
+
 //------------------------------------------------------------------------------
 //------------------------------------------------------------------------------
 //
 //
 // DxilStructAnnotation class methods.
 // DxilStructAnnotation class methods.
@@ -103,6 +119,22 @@ void DxilStructAnnotation::SetCBufferSize(unsigned size) { m_CBufferSize = size;
 void DxilStructAnnotation::MarkEmptyStruct() { m_FieldAnnotations.clear(); }
 void DxilStructAnnotation::MarkEmptyStruct() { m_FieldAnnotations.clear(); }
 bool DxilStructAnnotation::IsEmptyStruct() { return m_FieldAnnotations.empty(); }
 bool DxilStructAnnotation::IsEmptyStruct() { return m_FieldAnnotations.empty(); }
 
 
+// For template args, GetNumTemplateArgs() will return 0 if not a template
+unsigned DxilStructAnnotation::GetNumTemplateArgs() const {
+  return (unsigned)m_TemplateAnnotations.size();
+}
+void DxilStructAnnotation::SetNumTemplateArgs(unsigned count) {
+  DXASSERT(m_TemplateAnnotations.empty(), "template args already initialized");
+  m_TemplateAnnotations.resize(count);
+}
+DxilTemplateArgAnnotation &DxilStructAnnotation::GetTemplateArgAnnotation(unsigned argIdx) {
+  return m_TemplateAnnotations[argIdx];
+}
+const DxilTemplateArgAnnotation &DxilStructAnnotation::GetTemplateArgAnnotation(unsigned argIdx) const {
+  return m_TemplateAnnotations[argIdx];
+}
+
+
 //------------------------------------------------------------------------------
 //------------------------------------------------------------------------------
 //
 //
 // DxilParameterAnnotation class methods.
 // DxilParameterAnnotation class methods.
@@ -166,12 +198,13 @@ DxilTypeSystem::DxilTypeSystem(Module *pModule)
     : m_pModule(pModule),
     : m_pModule(pModule),
       m_LowPrecisionMode(DXIL::LowPrecisionMode::Undefined) {}
       m_LowPrecisionMode(DXIL::LowPrecisionMode::Undefined) {}
 
 
-DxilStructAnnotation *DxilTypeSystem::AddStructAnnotation(const StructType *pStructType) {
+DxilStructAnnotation *DxilTypeSystem::AddStructAnnotation(const StructType *pStructType, unsigned numTemplateArgs) {
   DXASSERT_NOMSG(m_StructAnnotations.find(pStructType) == m_StructAnnotations.end());
   DXASSERT_NOMSG(m_StructAnnotations.find(pStructType) == m_StructAnnotations.end());
   DxilStructAnnotation *pA = new DxilStructAnnotation();
   DxilStructAnnotation *pA = new DxilStructAnnotation();
   m_StructAnnotations[pStructType] = unique_ptr<DxilStructAnnotation>(pA);
   m_StructAnnotations[pStructType] = unique_ptr<DxilStructAnnotation>(pA);
   pA->m_pStructType = pStructType;
   pA->m_pStructType = pStructType;
   pA->m_FieldAnnotations.resize(pStructType->getNumElements());
   pA->m_FieldAnnotations.resize(pStructType->getNumElements());
+  pA->SetNumTemplateArgs(numTemplateArgs);
   return pA;
   return pA;
 }
 }
 
 

+ 14 - 0
lib/DXIL/DxilUtil.cpp

@@ -546,6 +546,20 @@ bool IsHLSLObjectType(llvm::Type *Ty) {
       return true;
       return true;
     if (name.startswith("LineStream<"))
     if (name.startswith("LineStream<"))
       return true;
       return true;
+
+    if (name.startswith("RayQuery<"))
+      return true;
+  }
+  return false;
+}
+
+bool IsHLSLRayQueryType(llvm::Type *Ty) {
+  if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
+    StringRef name = ST->getName();
+    // TODO: don't check names.
+    name = name.ltrim("class.");
+    if (name.startswith("RayQuery<"))
+      return true;
   }
   }
   return false;
   return false;
 }
 }

+ 1 - 1
lib/DxilPIXPasses/DxilShaderAccessTracking.cpp

@@ -453,7 +453,7 @@ bool DxilShaderAccessTracking::runOnModule(Module &M)
 
 
 
 
     // todo: should "GetDimensions" mean a resource access?
     // todo: should "GetDimensions" mean a resource access?
-    static_assert(DXIL::OpCode::NumOpCodes == static_cast<DXIL::OpCode>(168), "Please update PIX passes if any resource access opcodes are added");
+    static_assert(DXIL::OpCode::NumOpCodes == static_cast<DXIL::OpCode>(170), "Please update PIX passes if any resource access opcodes are added");
     ResourceAccessFunction raFunctions[] = {
     ResourceAccessFunction raFunctions[] = {
       { DXIL::OpCode::CBufferLoadLegacy     , ShaderAccessFlags::Read   , false, f32i32f64 },
       { DXIL::OpCode::CBufferLoadLegacy     , ShaderAccessFlags::Read   , false, f32i32f64 },
       { DXIL::OpCode::CBufferLoad           , ShaderAccessFlags::Read   , false, f16f32f64i16i32i64 },
       { DXIL::OpCode::CBufferLoad           , ShaderAccessFlags::Read   , false, f16f32f64i16i32i64 },

+ 2 - 2
lib/HLSL/DxilValidation.cpp

@@ -856,8 +856,8 @@ static bool ValidateOpcodeInProfile(DXIL::OpCode opcode,
   if ((162 <= op && op <= 164))
   if ((162 <= op && op <= 164))
     return (major > 6 || (major == 6 && minor >= 4));
     return (major > 6 || (major == 6 && minor >= 4));
   // Instructions: WaveMatch=165, WaveMultiPrefixOp=166,
   // Instructions: WaveMatch=165, WaveMultiPrefixOp=166,
-  // WaveMultiPrefixBitCount=167
-  if ((165 <= op && op <= 167))
+  // WaveMultiPrefixBitCount=167, TraceRayInline=169
+  if ((165 <= op && op <= 167) || op == 169)
     return (major > 6 || (major == 6 && minor >= 5));
     return (major > 6 || (major == 6 && minor >= 5));
   return true;
   return true;
   // VALOPCODESM-TEXT:END
   // VALOPCODESM-TEXT:END

+ 111 - 0
lib/HLSL/HLOperationLower.cpp

@@ -4637,6 +4637,108 @@ Value *TranslateTraceRay(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
   return Builder.CreateCall(F, Args);
   return Builder.CreateCall(F, Args);
 }
 }
 
 
+void AllocateRayQueryObjects(llvm::Module *M, HLOperationLowerHelper &helper) {
+  // Iterate functions and insert AllocateRayQuery intrinsic to initialize
+  // handle value for every alloca of ray query type
+  hlsl::OP &hlslOP = helper.hlslOP;
+  Constant *i32Zero = hlslOP.GetI32Const(0);
+  DXIL::OpCode opcode = DXIL::OpCode::AllocateRayQuery;
+  llvm::Value *opcodeVal = hlslOP.GetU32Const(static_cast<unsigned>(opcode));
+  for (Function &f : M->functions()) {
+    if (f.isDeclaration() || f.isIntrinsic() ||
+      GetHLOpcodeGroup(&f) != HLOpcodeGroup::NotHL)
+      continue;
+    // Iterate allocas
+    BasicBlock &BB = f.getEntryBlock();
+    IRBuilder<> Builder(dxilutil::FirstNonAllocaInsertionPt(&BB));
+    for (BasicBlock::iterator BI = BB.begin(), BE = BB.end(); BI != BE;) {
+      // Avoid invalidating the iterator.
+      Instruction *I = BI++;
+      if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) {
+        llvm::Type *allocaTy = AI->getAllocatedType();
+        llvm::Type *elementTy = allocaTy;
+        while (elementTy->isArrayTy())
+          elementTy = elementTy->getArrayElementType();
+        if (dxilutil::IsHLSLRayQueryType(elementTy)) {
+          DxilStructAnnotation *SA = helper.dxilTypeSys.GetStructAnnotation(cast<StructType>(elementTy));
+          DXASSERT(SA, "otherwise, could not find type annoation for RayQuery specialization");
+          DXASSERT(SA->GetNumTemplateArgs() == 1 && SA->GetTemplateArgAnnotation(0).IsIntegral(),
+                   "otherwise, RayQuery has changed, or lacks template args");
+          Builder.SetInsertPoint(AI->getNextNode());
+          DXASSERT(!allocaTy->isArrayTy(), "Array not handled yet");
+          llvm::Function *AllocFn = hlslOP.GetOpFunc(DXIL::OpCode::AllocateRayQuery, Builder.getVoidTy());
+          llvm::Value *rayFlags = ConstantInt::get(helper.i32Ty,
+            APInt(32, SA->GetTemplateArgAnnotation(0).GetIntegral()));
+          llvm::CallInst *CI = Builder.CreateCall(AllocFn, {opcodeVal, rayFlags}, "hRayQuery");
+          llvm::Value *GEP = Builder.CreateGEP(AI, {i32Zero, i32Zero});
+          Builder.CreateStore(CI, GEP);
+        }
+      }
+    }
+  }
+}
+
+Value *TranslateTraceRayInline(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
+                         HLOperationLowerHelper &helper,
+                         HLObjectOperationLowerHelper *pObjHelper,
+                         bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP;
+
+  Value *rayDesc = CI->getArgOperand(HLOperandIndex::kTraceRayInlineRayDescOpIdx);
+
+  Value *opArg = hlslOP->GetU32Const(static_cast<unsigned>(opcode));
+
+  Value *Args[DXIL::OperandIndex::kTraceRayInlineNumOp];
+  Args[0] = opArg;
+
+  // Translate RayQuery `this` pointer to i32 handle by-value
+  IRBuilder<> Builder(CI);
+  Value *thisArg = CI->getArgOperand(1);
+  Constant *i32Zero = hlslOP->GetI32Const(0);
+  Value *handleGEP = Builder.CreateGEP(thisArg, {i32Zero, i32Zero});
+  Value *handleValue = Builder.CreateLoad(handleGEP);
+  Args[1] = handleValue;
+
+  for (unsigned i = 2; i < HLOperandIndex::kTraceRayInlineRayDescOpIdx; i++) {
+    Args[i] = CI->getArgOperand(i);
+  }
+  // struct RayDesc
+  //{
+  //    float3 Origin;
+  //    float  TMin;
+  //    float3 Direction;
+  //    float  TMax;
+  //};
+  Value *zeroIdx = hlslOP->GetU32Const(0);
+  Value *origin = Builder.CreateGEP(rayDesc, {zeroIdx, zeroIdx});
+  origin = Builder.CreateLoad(origin);
+  unsigned index = DXIL::OperandIndex::kTraceRayInlineRayDescOpIdx;
+  Args[index++] = Builder.CreateExtractElement(origin, (uint64_t)0);
+  Args[index++] = Builder.CreateExtractElement(origin, 1);
+  Args[index++] = Builder.CreateExtractElement(origin, 2);
+
+  Value *tmin = Builder.CreateGEP(rayDesc, {zeroIdx, hlslOP->GetU32Const(1)});
+  tmin = Builder.CreateLoad(tmin);
+  Args[index++] = tmin;
+
+  Value *direction = Builder.CreateGEP(rayDesc, {zeroIdx, hlslOP->GetU32Const(2)});
+  direction = Builder.CreateLoad(direction);
+
+  Args[index++] = Builder.CreateExtractElement(direction, (uint64_t)0);
+  Args[index++] = Builder.CreateExtractElement(direction, 1);
+  Args[index++] = Builder.CreateExtractElement(direction, 2);
+
+  Value *tmax = Builder.CreateGEP(rayDesc, {zeroIdx, hlslOP->GetU32Const(3)});
+  tmax = Builder.CreateLoad(tmax);
+  Args[index++] = tmax;
+
+  DXASSERT_NOMSG(index == DXIL::OperandIndex::kTraceRayInlineNumOp);
+
+  Function *F = hlslOP->GetOpFunc(opcode, Builder.getVoidTy());
+
+  return Builder.CreateCall(F, Args);
+}
+
 Value *TranslateNoArgVectorOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
 Value *TranslateNoArgVectorOperation(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
                          HLOperationLowerHelper &helper,
                          HLOperationLowerHelper &helper,
                          HLObjectOperationLowerHelper *pObjHelper,
                          HLObjectOperationLowerHelper *pObjHelper,
@@ -5029,6 +5131,8 @@ IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] =
     {IntrinsicOp::MOP_IncrementCounter, GenerateUpdateCounter, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::MOP_IncrementCounter, GenerateUpdateCounter, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::MOP_Consume, EmptyLower, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::MOP_Consume, EmptyLower, DXIL::OpCode::NumOpCodes},
 
 
+    {IntrinsicOp::MOP_TraceRayInline, TranslateTraceRayInline, DXIL::OpCode::TraceRayInline},
+
     // SPIRV change starts
     // SPIRV change starts
 #ifdef ENABLE_SPIRV_CODEGEN
 #ifdef ENABLE_SPIRV_CODEGEN
     {IntrinsicOp::MOP_SubpassLoad, UnsupportedVulkanIntrinsic, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::MOP_SubpassLoad, UnsupportedVulkanIntrinsic, DXIL::OpCode::NumOpCodes},
@@ -5769,6 +5873,11 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
       }
       }
 
 
       CI->eraseFromParent();
       CI->eraseFromParent();
+    } else if (group == HLOpcodeGroup::HLIntrinsic) {
+      // FIXME: This case is hit when using built-in structures in constant
+      //        buffers passed directly to an intrinsic, such as:
+      //        RayDesc from cbuffer passed to TraceRay.
+      DXASSERT(0, "not implemented yet");
     } else {
     } else {
       DXASSERT(0, "not implemented yet");
       DXASSERT(0, "not implemented yet");
     }
     }
@@ -7318,6 +7427,8 @@ void TranslateBuiltinOperations(
 
 
   Module *M = HLM.GetModule();
   Module *M = HLM.GetModule();
 
 
+  AllocateRayQueryObjects(M, helper);
+
   SmallVector<Function *, 4> NonUniformResourceIndexIntrinsics;
   SmallVector<Function *, 4> NonUniformResourceIndexIntrinsics;
 
 
   // generate dxil operation
   // generate dxil operation

+ 17 - 4
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -1586,9 +1586,9 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
       SimpleCopy(Dest, Src, idxList, Builder);
       SimpleCopy(Dest, Src, idxList, Builder);
       return;
       return;
     }
     }
+    // Built-in structs have no type annotation
     DxilStructAnnotation *STA = typeSys.GetStructAnnotation(ST);
     DxilStructAnnotation *STA = typeSys.GetStructAnnotation(ST);
-    DXASSERT(STA, "require annotation here");
-    if (STA->IsEmptyStruct())
+    if (STA && STA->IsEmptyStruct())
       return;
       return;
     for (uint32_t i = 0; i < ST->getNumElements(); i++) {
     for (uint32_t i = 0; i < ST->getNumElements(); i++) {
       llvm::Type *ET = ST->getElementType(i);
       llvm::Type *ET = ST->getElementType(i);
@@ -1598,8 +1598,8 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
       if (bEltMemCpy && IsMemCpyTy(ET, typeSys)) {
       if (bEltMemCpy && IsMemCpyTy(ET, typeSys)) {
         EltMemCpy(ET, Dest, Src, idxList, Builder, DL);
         EltMemCpy(ET, Dest, Src, idxList, Builder, DL);
       } else {
       } else {
-        DxilFieldAnnotation &EltAnnotation = STA->GetFieldAnnotation(i);
-        SplitCpy(ET, Dest, Src, idxList, Builder, DL, typeSys, &EltAnnotation,
+        DxilFieldAnnotation *EltAnnotation = STA ? &STA->GetFieldAnnotation(i) : nullptr;
+        SplitCpy(ET, Dest, Src, idxList, Builder, DL, typeSys, EltAnnotation,
                  bEltMemCpy);
                  bEltMemCpy);
       }
       }
 
 
@@ -2412,6 +2412,12 @@ void SROA_Helper::RewriteMemIntrin(MemIntrinsic *MI, Value *OldV) {
 }
 }
 
 
 void SROA_Helper::RewriteBitCast(BitCastInst *BCI) {
 void SROA_Helper::RewriteBitCast(BitCastInst *BCI) {
+  // Unused bitcast may be leftover from temporary memcpy
+  if (BCI->use_empty()) {
+    BCI->eraseFromParent();
+    return;
+  }
+
   Type *DstTy = BCI->getType();
   Type *DstTy = BCI->getType();
   Value *Val = BCI->getOperand(0);
   Value *Val = BCI->getOperand(0);
   Type *SrcTy = Val->getType();
   Type *SrcTy = Val->getType();
@@ -2565,6 +2571,13 @@ void SROA_Helper::RewriteCall(CallInst *CI) {
         RewriteCallArg(CI, HLOperandIndex::kBinaryOpSrc1Idx,
         RewriteCallArg(CI, HLOperandIndex::kBinaryOpSrc1Idx,
                        /*bIn*/ true, /*bOut*/ true);
                        /*bIn*/ true, /*bOut*/ true);
       } break;
       } break;
+      case IntrinsicOp::MOP_TraceRayInline: {
+        if (OldVal ==
+            CI->getArgOperand(HLOperandIndex::kTraceRayInlineRayDescOpIdx)) {
+          RewriteCallArg(CI, HLOperandIndex::kTraceRayInlineRayDescOpIdx,
+                         /*bIn*/ true, /*bOut*/ false);
+        }
+      } break;
       default:
       default:
         DXASSERT(0, "cannot flatten hlsl intrinsic.");
         DXASSERT(0, "cannot flatten hlsl intrinsic.");
       }
       }

+ 5 - 0
tools/clang/include/clang/AST/HlslTypes.h

@@ -326,6 +326,11 @@ void AddTemplateTypeWithHandle(
             uint8_t templateArgCount,
             uint8_t templateArgCount,
   _In_opt_  clang::TypeSourceInfo* defaultTypeArgValue);
   _In_opt_  clang::TypeSourceInfo* defaultTypeArgValue);
 
 
+void AddRayQueryTemplate(
+           clang::ASTContext& context,
+  _Outptr_ clang::ClassTemplateDecl** typeDecl,
+  _Outptr_ clang::CXXRecordDecl** recordDecl);
+
 /// <summary>Create a function template declaration for the specified method.</summary>
 /// <summary>Create a function template declaration for the specified method.</summary>
 /// <param name="context">AST context in which to work.</param>
 /// <param name="context">AST context in which to work.</param>
 /// <param name="recordDecl">Class in which the function template is declared.</param>
 /// <param name="recordDecl">Class in which the function template is declared.</param>

+ 102 - 0
tools/clang/lib/AST/ASTContextHLSL.cpp

@@ -38,6 +38,7 @@ static const bool DelayTypeCreationTrue = true;   // delay type creation for a d
 static const SourceLocation NoLoc;                // no source location attribution available
 static const SourceLocation NoLoc;                // no source location attribution available
 static const bool InlineFalse = false;            // namespace is not an inline namespace
 static const bool InlineFalse = false;            // namespace is not an inline namespace
 static const bool InlineSpecifiedFalse = false;   // function was not specified as inline
 static const bool InlineSpecifiedFalse = false;   // function was not specified as inline
+static const bool ExplicitFalse = false;          // constructor was not specified as explicit
 static const bool IsConstexprFalse = false;       // function is not constexpr
 static const bool IsConstexprFalse = false;       // function is not constexpr
 static const bool VirtualFalse = false;           // whether the base class is declares 'virtual'
 static const bool VirtualFalse = false;           // whether the base class is declares 'virtual'
 static const bool BaseClassFalse = false;         // whether the base class is declared as 'class' (vs. 'struct')
 static const bool BaseClassFalse = false;         // whether the base class is declared as 'class' (vs. 'struct')
@@ -894,6 +895,28 @@ void AssociateParametersToFunctionPrototype(
   }
   }
 }
 }
 
 
+static void CreateConstructorDeclaration(
+  ASTContext &context, _In_ CXXRecordDecl *recordDecl, QualType resultType,
+  ArrayRef<QualType> args, DeclarationName declarationName, bool isConst,
+  _Out_ CXXConstructorDecl **constructorDecl, _Out_ TypeSourceInfo **tinfo) {
+  DXASSERT_NOMSG(recordDecl != nullptr);
+  DXASSERT_NOMSG(constructorDecl != nullptr);
+
+  FunctionProtoType::ExtProtoInfo functionExtInfo;
+  functionExtInfo.TypeQuals = isConst ? Qualifiers::Const : 0;
+  QualType functionQT = context.getFunctionType(
+    resultType, args, functionExtInfo, ArrayRef<ParameterModifier>());
+  DeclarationNameInfo declNameInfo(declarationName, NoLoc);
+  *tinfo = context.getTrivialTypeSourceInfo(functionQT, NoLoc);
+  DXASSERT_NOMSG(*tinfo != nullptr);
+  *constructorDecl = CXXConstructorDecl::Create(
+    context, recordDecl, NoLoc, declNameInfo, functionQT, *tinfo,
+    StorageClass::SC_None, ExplicitFalse, InlineSpecifiedFalse, IsConstexprFalse);
+  DXASSERT_NOMSG(*constructorDecl != nullptr);
+  (*constructorDecl)->setLexicalDeclContext(recordDecl);
+  (*constructorDecl)->setAccess(AccessSpecifier::AS_public);
+}
+
 static void CreateObjectFunctionDeclaration(
 static void CreateObjectFunctionDeclaration(
     ASTContext &context, _In_ CXXRecordDecl *recordDecl, QualType resultType,
     ASTContext &context, _In_ CXXRecordDecl *recordDecl, QualType resultType,
     ArrayRef<QualType> args, DeclarationName declarationName, bool isConst,
     ArrayRef<QualType> args, DeclarationName declarationName, bool isConst,
@@ -959,6 +982,85 @@ CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams(
   return functionDecl;
   return functionDecl;
 }
 }
 
 
+void hlsl::AddRayQueryTemplate(
+  ASTContext& context,
+  _Outptr_ ClassTemplateDecl** typeDecl,
+  _Outptr_ CXXRecordDecl** recordDecl
+)
+{
+  DXASSERT_NOMSG(typeDecl != nullptr);
+  DXASSERT_NOMSG(recordDecl != nullptr);
+
+  DeclContext* currentDeclContext = context.getTranslationUnitDecl();
+
+  // Create a RayQuery template declaration in translation unit scope.
+  // template<uint flags> RayQuery { ... }
+  QualType uintType = context.UnsignedIntTy;
+
+  NonTypeTemplateParmDecl* flagsTemplateParamDecl = nullptr;
+  IdentifierInfo& countParamId = context.Idents.get(StringRef("flags"), tok::TokenKind::identifier);
+  flagsTemplateParamDecl = NonTypeTemplateParmDecl::Create(
+    context, currentDeclContext, NoLoc, NoLoc,
+    FirstTemplateDepth, FirstParamPosition, &countParamId, uintType, ParameterPackFalse, nullptr);
+
+  // Should flags default to zero?
+  Expr *literalIntZero = IntegerLiteral::Create(
+    context, llvm::APInt(context.getIntWidth(uintType), 0), uintType, NoLoc);
+  flagsTemplateParamDecl->setDefaultArgument(literalIntZero);
+
+  NamedDecl* templateParameters[] =
+  {
+    flagsTemplateParamDecl
+  };
+  TemplateParameterList* templateParameterList = TemplateParameterList::Create(
+    context, NoLoc, NoLoc, templateParameters, 1, NoLoc);
+
+  IdentifierInfo& typeId = context.Idents.get(StringRef("RayQuery"), tok::TokenKind::identifier);
+  CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
+    context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &typeId,
+    nullptr, DelayTypeCreationTrue);
+  ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
+    context, currentDeclContext, NoLoc, DeclarationName(&typeId),
+    templateParameterList, templateRecordDecl, nullptr);
+  templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
+  templateRecordDecl->addAttr(FinalAttr::CreateImplicit(context, FinalAttr::Keyword_final));
+
+  // Requesting the class name specialization will fault in required types.
+  QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
+  T = context.getInjectedClassNameType(templateRecordDecl, T);
+  assert(T->isDependentType() && "Class template type is not dependent?");
+  classTemplateDecl->setLexicalDeclContext(currentDeclContext);
+  templateRecordDecl->setLexicalDeclContext(currentDeclContext);
+  templateRecordDecl->startDefinition();
+
+  // TODO: Add constructor that will be lowered to the intrinsic that produces
+  // the RayQuery handle for this object.
+  CanQualType canQualType = templateRecordDecl->getTypeForDecl()->getCanonicalTypeUnqualified();
+  CXXConstructorDecl *pConstructorDecl = nullptr;
+  TypeSourceInfo *pTypeSourceInfo = nullptr;
+  CreateConstructorDeclaration(context, templateRecordDecl, context.VoidTy, {}, context.DeclarationNames.getCXXConstructorName(canQualType), false, &pConstructorDecl, &pTypeSourceInfo);
+  templateRecordDecl->addDecl(pConstructorDecl);
+
+  // Add an 'h' field to hold the handle.
+  AddHLSLHandleField(context, templateRecordDecl, uintType);
+
+  templateRecordDecl->completeDefinition();
+
+  // Both declarations need to be present for correct handling.
+  currentDeclContext->addDecl(classTemplateDecl);
+  currentDeclContext->addDecl(templateRecordDecl);
+
+#ifdef DBG
+  // Verify that we can read the field member from the template record.
+  DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
+    DeclarationName(&context.Idents.get(StringRef("h"))));
+  DXASSERT(!lookupResult.empty(), "otherwise template object handle cannot be looked up");
+#endif
+
+  *typeDecl = classTemplateDecl;
+  *recordDecl = templateRecordDecl;
+}
+
 bool hlsl::IsIntrinsicOp(const clang::FunctionDecl *FD) {
 bool hlsl::IsIntrinsicOp(const clang::FunctionDecl *FD) {
   return FD != nullptr && FD->hasAttr<HLSLIntrinsicAttr>();
   return FD != nullptr && FD->hasAttr<HLSLIntrinsicAttr>();
 }
 }

+ 43 - 12
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -238,6 +238,9 @@ private:
 
 
   std::unordered_map<Constant*, DxilFieldAnnotation> m_ConstVarAnnotationMap;
   std::unordered_map<Constant*, DxilFieldAnnotation> m_ConstVarAnnotationMap;
 
 
+  // Insert AllocateRayQuery to initialize each RayQuery alloca
+  void AllocateRayQueryObjects();
+
 public:
 public:
   CGMSHLSLRuntime(CodeGenModule &CGM);
   CGMSHLSLRuntime(CodeGenModule &CGM);
 
 
@@ -857,6 +860,27 @@ unsigned CGMSHLSLRuntime::ConstructStructAnnotation(DxilStructAnnotation *annota
   unsigned offset = 0;
   unsigned offset = 0;
   bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
   bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
   if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
   if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+
+    // If template, save template args
+    if (const ClassTemplateSpecializationDecl *templateSpecializationDecl =
+          dyn_cast<ClassTemplateSpecializationDecl>(CXXRD)) {
+      const clang::TemplateArgumentList &args = templateSpecializationDecl->getTemplateInstantiationArgs();
+      for (unsigned i = 0; i < args.size(); ++i) {
+        DxilTemplateArgAnnotation &argAnnotation = annotation->GetTemplateArgAnnotation(i);
+        const clang::TemplateArgument &arg = args[i];
+        switch (arg.getKind()) {
+        case clang::TemplateArgument::ArgKind::Type:
+          argAnnotation.SetType(CGM.getTypes().ConvertType(arg.getAsType()));
+        break;
+        case clang::TemplateArgument::ArgKind::Integral:
+          argAnnotation.SetIntegral(arg.getAsIntegral().getExtValue());
+          break;
+        default:
+          break;
+        }
+      }
+    }
+
     if (CXXRD->getNumBases()) {
     if (CXXRD->getNumBases()) {
       // Add base as field.
       // Add base as field.
       for (const auto &I : CXXRD->bases()) {
       for (const auto &I : CXXRD->bases()) {
@@ -963,6 +987,17 @@ static bool IsElementInputOutputType(QualType Ty) {
   return Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty) || Ty->isEnumeralType();
   return Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty) || Ty->isEnumeralType();
 }
 }
 
 
+static unsigned GetNumTemplateArgsForRecordDecl(const RecordDecl *RD) {
+  if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+    if (const ClassTemplateSpecializationDecl *templateSpecializationDecl =
+          dyn_cast<ClassTemplateSpecializationDecl>(CXXRD)) {
+      const clang::TemplateArgumentList &args = templateSpecializationDecl->getTemplateInstantiationArgs();
+      return args.size();
+    }
+  }
+  return 0;
+}
+
 // Return the size for constant buffer of each decl.
 // Return the size for constant buffer of each decl.
 unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
 unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
                                             DxilTypeSystem &dxilTypeSys,
                                             DxilTypeSystem &dxilTypeSys,
@@ -1001,7 +1036,8 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
       unsigned structSize = annotation->GetCBufferSize();
       unsigned structSize = annotation->GetCBufferSize();
       return structSize;
       return structSize;
     }
     }
-    DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST);
+    DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST,
+      GetNumTemplateArgsForRecordDecl(RT->getDecl()));
 
 
     return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
     return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
   } else if (const RecordType *RT = dyn_cast<RecordType>(paramTy)) {
   } else if (const RecordType *RT = dyn_cast<RecordType>(paramTy)) {
@@ -1013,7 +1049,8 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
       unsigned structSize = annotation->GetCBufferSize();
       unsigned structSize = annotation->GetCBufferSize();
       return structSize;
       return structSize;
     }
     }
-    DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST);
+    DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST,
+      GetNumTemplateArgsForRecordDecl(RT->getDecl()));
 
 
     return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
     return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
   } else if (IsHLSLResourceType(Ty)) {
   } else if (IsHLSLResourceType(Ty)) {
@@ -3612,10 +3649,8 @@ static void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F,
     llvm::Type *Ty = paramTyList[i];
     llvm::Type *Ty = paramTyList[i];
     if (Ty->isPointerTy()) {
     if (Ty->isPointerTy()) {
       Ty = Ty->getPointerElementType();
       Ty = Ty->getPointerElementType();
-      if (dxilutil::IsHLSLObjectType(Ty) &&
-          // StreamOutput don't need handle.
-          !HLModule::IsStreamOutputType(Ty)) {
-        // Use handle type for object type.
+      if (dxilutil::IsHLSLResourceType(Ty)) {
+        // Use handle type for resource type.
         // This will make sure temp object variable only used by createHandle.
         // This will make sure temp object variable only used by createHandle.
         paramTyList[i] = HandleTy;
         paramTyList[i] = HandleTy;
       }
       }
@@ -3673,7 +3708,7 @@ static void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F,
     gep_type_iterator GEPIt = gep_type_begin(objGEP), E = gep_type_end(objGEP);
     gep_type_iterator GEPIt = gep_type_begin(objGEP), E = gep_type_end(objGEP);
     llvm::Type *resTy = nullptr;
     llvm::Type *resTy = nullptr;
     while (GEPIt != E) {
     while (GEPIt != E) {
-      if (dxilutil::IsHLSLObjectType(*GEPIt)) {
+      if (dxilutil::IsHLSLResourceType(*GEPIt)) {
         resTy = *GEPIt;
         resTy = *GEPIt;
         break;
         break;
       }
       }
@@ -3756,9 +3791,7 @@ static void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F,
       llvm::Type *Ty = arg->getType();
       llvm::Type *Ty = arg->getType();
       if (Ty->isPointerTy()) {
       if (Ty->isPointerTy()) {
         Ty = Ty->getPointerElementType();
         Ty = Ty->getPointerElementType();
-        if (dxilutil::IsHLSLObjectType(Ty) &&
-          // StreamOutput don't need handle.
-          !HLModule::IsStreamOutputType(Ty)) {
+        if (dxilutil::IsHLSLResourceType(Ty)) {
           // Use object type directly, not by pointer.
           // Use object type directly, not by pointer.
           // This will make sure temp object variable only used by ld/st.
           // This will make sure temp object variable only used by ld/st.
           if (GEPOperator *argGEP = dyn_cast<GEPOperator>(arg)) {
           if (GEPOperator *argGEP = dyn_cast<GEPOperator>(arg)) {
@@ -4770,8 +4803,6 @@ static void CreateWriteEnabledStaticGlobals(llvm::Module *M,
   }
   }
 }
 }
 
 
-
-
 void CGMSHLSLRuntime::FinishCodeGen() {
 void CGMSHLSLRuntime::FinishCodeGen() {
   // Library don't have entry.
   // Library don't have entry.
   if (!m_bIsLib) {
   if (!m_bIsLib) {

+ 42 - 4
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -199,6 +199,9 @@ enum ArBasicKind {
   AR_OBJECT_TRIANGLE_HIT_GROUP,
   AR_OBJECT_TRIANGLE_HIT_GROUP,
   AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
   AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
 
 
+  // RayQuery
+  AR_OBJECT_RAY_QUERY,
+
   AR_BASIC_MAXIMUM_COUNT
   AR_BASIC_MAXIMUM_COUNT
 };
 };
 
 
@@ -476,6 +479,8 @@ const UINT g_uBasicKindProps[] =
   0,      //AR_OBJECT_TRIANGLE_HIT_GROUP,
   0,      //AR_OBJECT_TRIANGLE_HIT_GROUP,
   0,      //AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
   0,      //AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
 
 
+  0,      //AR_OBJECT_RAY_QUERY,
+
   // AR_BASIC_MAXIMUM_COUNT
   // AR_BASIC_MAXIMUM_COUNT
 };
 };
 
 
@@ -1286,7 +1291,9 @@ const ArBasicKind g_ArBasicKindsAsTypes[] =
   AR_OBJECT_RAYTRACING_SHADER_CONFIG,
   AR_OBJECT_RAYTRACING_SHADER_CONFIG,
   AR_OBJECT_RAYTRACING_PIPELINE_CONFIG,
   AR_OBJECT_RAYTRACING_PIPELINE_CONFIG,
   AR_OBJECT_TRIANGLE_HIT_GROUP,
   AR_OBJECT_TRIANGLE_HIT_GROUP,
-  AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP
+  AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
+
+  AR_OBJECT_RAY_QUERY
 };
 };
 
 
 // Count of template arguments for basic kind of objects that look like templates (one or more type arguments).
 // Count of template arguments for basic kind of objects that look like templates (one or more type arguments).
@@ -1366,6 +1373,8 @@ const uint8_t g_ArBasicKindsTemplateCount[] =
   0, // AR_OBJECT_RAYTRACING_PIPELINE_CONFIG,
   0, // AR_OBJECT_RAYTRACING_PIPELINE_CONFIG,
   0, // AR_OBJECT_TRIANGLE_HIT_GROUP,
   0, // AR_OBJECT_TRIANGLE_HIT_GROUP,
   0, // AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
   0, // AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
+
+  1, // AR_OBJECT_RAY_QUERY,
 };
 };
 
 
 C_ASSERT(_countof(g_ArBasicKindsAsTypes) == _countof(g_ArBasicKindsTemplateCount));
 C_ASSERT(_countof(g_ArBasicKindsAsTypes) == _countof(g_ArBasicKindsTemplateCount));
@@ -1456,6 +1465,7 @@ const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] =
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_TRIANGLE_HIT_GROUP,
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_TRIANGLE_HIT_GROUP,
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
   { 0, MipsFalse, SampleFalse },  // AR_OBJECT_PROCEDURAL_PRIMITIVE_HIT_GROUP,
 
 
+  { 0, MipsFalse, SampleFalse },  // AR_OBJECT_RAY_QUERY,
 };
 };
 
 
 C_ASSERT(_countof(g_ArBasicKindsAsTypes) == _countof(g_ArBasicKindsSubscripts));
 C_ASSERT(_countof(g_ArBasicKindsAsTypes) == _countof(g_ArBasicKindsSubscripts));
@@ -1568,7 +1578,9 @@ const char* g_ArBasicTypeNames[] =
   "RaytracingShaderConfig",
   "RaytracingShaderConfig",
   "RaytracingPipelineConfig",
   "RaytracingPipelineConfig",
   "TriangleHitGroup",
   "TriangleHitGroup",
-  "ProceduralPrimitiveHitGroup"
+  "ProceduralPrimitiveHitGroup",
+
+  "RayQuery"
 };
 };
 
 
 C_ASSERT(_countof(g_ArBasicTypeNames) == AR_BASIC_MAXIMUM_COUNT);
 C_ASSERT(_countof(g_ArBasicTypeNames) == AR_BASIC_MAXIMUM_COUNT);
@@ -2120,7 +2132,11 @@ void GetIntrinsicMethods(ArBasicKind kind, _Outptr_result_buffer_(*intrinsicCoun
     *intrinsics = g_ConsumeStructuredBufferMethods;
     *intrinsics = g_ConsumeStructuredBufferMethods;
     *intrinsicCount = _countof(g_ConsumeStructuredBufferMethods);
     *intrinsicCount = _countof(g_ConsumeStructuredBufferMethods);
     break;
     break;
-  // SPIRV change starts
+  case AR_OBJECT_RAY_QUERY:
+    *intrinsics = g_RayQueryMethods;
+    *intrinsicCount = _countof(g_RayQueryMethods);
+    break;
+    // SPIRV change starts
 #ifdef ENABLE_SPIRV_CODEGEN
 #ifdef ENABLE_SPIRV_CODEGEN
   case AR_OBJECT_VK_SUBPASS_INPUT:
   case AR_OBJECT_VK_SUBPASS_INPUT:
     *intrinsics = g_VkSubpassInputMethods;
     *intrinsics = g_VkSubpassInputMethods;
@@ -3219,6 +3235,12 @@ private:
           recordDecl = CreateSubobjectProceduralPrimitiveHitGroup(*m_context);
           recordDecl = CreateSubobjectProceduralPrimitiveHitGroup(*m_context);
           break;
           break;
         }
         }
+      } else if (kind == AR_OBJECT_RAY_QUERY) {
+        ClassTemplateDecl* typeDecl = nullptr;
+        AddRayQueryTemplate(*m_context, &typeDecl, &recordDecl);
+        DXASSERT(typeDecl != nullptr, "AddRayQueryTemplate failed to return the object declaration");
+        typeDecl->setImplicit(true);
+        recordDecl->setImplicit(true);
       }
       }
       else if (templateArgCount == 0)
       else if (templateArgCount == 0)
       {
       {
@@ -3419,6 +3441,13 @@ public:
     return IsSubobjectBasicKind(GetTypeElementKind(type));
     return IsSubobjectBasicKind(GetTypeElementKind(type));
   }
   }
 
 
+  bool IsRayQueryBasicKind(ArBasicKind kind) {
+    return kind == AR_OBJECT_RAY_QUERY;
+  }
+  bool IsRayQueryType(QualType type) {
+    return IsRayQueryBasicKind(GetTypeElementKind(type));
+  }
+
   void WarnMinPrecision(HLSLScalarType type, SourceLocation loc) {
   void WarnMinPrecision(HLSLScalarType type, SourceLocation loc) {
     // TODO: enalbe this once we introduce precise master option
     // TODO: enalbe this once we introduce precise master option
     bool UseMinPrecision = m_context->getLangOpts().UseMinPrecision;
     bool UseMinPrecision = m_context->getLangOpts().UseMinPrecision;
@@ -4882,7 +4911,8 @@ QualType GetFirstElementTypeFromDecl(const Decl* decl)
   if (specialization) {
   if (specialization) {
     const TemplateArgumentList& list = specialization->getTemplateArgs();
     const TemplateArgumentList& list = specialization->getTemplateArgs();
     if (list.size()) {
     if (list.size()) {
-      return list[0].getAsType();
+      if (list[0].getKind() == TemplateArgument::ArgKind::Type)
+        return list[0].getAsType();
     }
     }
   }
   }
 
 
@@ -6744,6 +6774,14 @@ void HLSLExternalSource::InitializeInitSequenceForHLSL(
 
 
   // In HLSL there are no default initializers, eg float4x4 m();
   // In HLSL there are no default initializers, eg float4x4 m();
   if (Kind.getKind() == InitializationKind::IK_Default) {
   if (Kind.getKind() == InitializationKind::IK_Default) {
+    // Except for RayQuery.
+    if (GetTypeElementKind(Entity.getType()) == AR_OBJECT_RAY_QUERY) {
+      // RayQuery handle initialization
+      // TODO: Try generating an intrinsic method call for AllocateRayQuery
+      // - pass the flags from the intrinsic argument.
+      // Lower to intrinsic that takes flags and returns i32 value,
+      // which is then stored in the RayQuery alloca as the handle.
+    }
     return;
     return;
   }
   }
 
 

+ 19 - 0
tools/clang/lib/Sema/gen_intrin_main_tables_15.h

@@ -5787,6 +5787,24 @@ static const HLSL_INTRINSIC g_ConsumeStructuredBufferMethods[] =
     {(UINT)hlsl::IntrinsicOp::MOP_GetDimensions, false, false, -1, 3, g_ConsumeStructuredBufferMethods_Args1},
     {(UINT)hlsl::IntrinsicOp::MOP_GetDimensions, false, false, -1, 3, g_ConsumeStructuredBufferMethods_Args1},
 };
 };
 
 
+//
+// Start of RayQueryMethods
+//
+
+static const HLSL_INTRINSIC_ARGUMENT g_RayQueryMethods_Args0[] =
+{
+    {"TraceRayInline", 0, 0, LITEMPLATE_VOID, 0, LICOMPTYPE_VOID, 0, 0},
+    {"AccelerationStructure", AR_QUAL_IN, 1, LITEMPLATE_OBJECT, 1, LICOMPTYPE_ACCELERATION_STRUCT, 1, 1},
+    {"RayFlags", AR_QUAL_IN, 2, LITEMPLATE_SCALAR, 2, LICOMPTYPE_UINT, 1, 1},
+    {"InstanceInclusionMask", AR_QUAL_IN, 3, LITEMPLATE_SCALAR, 3, LICOMPTYPE_UINT, 1, 1},
+    {"Ray", AR_QUAL_IN, 4, LITEMPLATE_OBJECT, 4, LICOMPTYPE_RAYDESC, 1, 1},
+};
+
+static const HLSL_INTRINSIC g_RayQueryMethods[] =
+{
+    {(UINT)hlsl::IntrinsicOp::MOP_TraceRayInline, false, false, -1, 5, g_RayQueryMethods_Args0},
+};
+
 //
 //
 // Start of VkSubpassInputMethods
 // Start of VkSubpassInputMethods
 //
 //
@@ -5840,6 +5858,7 @@ static const UINT g_uRWTexture1DMethodsCount = 4;
 static const UINT g_uRWTexture2DArrayMethodsCount = 4;
 static const UINT g_uRWTexture2DArrayMethodsCount = 4;
 static const UINT g_uRWTexture2DMethodsCount = 4;
 static const UINT g_uRWTexture2DMethodsCount = 4;
 static const UINT g_uRWTexture3DMethodsCount = 4;
 static const UINT g_uRWTexture3DMethodsCount = 4;
+static const UINT g_uRayQueryMethodsCount = 1;
 static const UINT g_uStreamMethodsCount = 2;
 static const UINT g_uStreamMethodsCount = 2;
 static const UINT g_uStructuredBufferMethodsCount = 3;
 static const UINT g_uStructuredBufferMethodsCount = 3;
 static const UINT g_uTexture1DArrayMethodsCount = 31;
 static const UINT g_uTexture1DArrayMethodsCount = 31;

+ 19 - 0
tools/clang/test/CodeGenHLSL/batch/shader_stages/raytracing/rayquery/tracerayinline.hlsl

@@ -0,0 +1,19 @@
+// RUN: %dxc -T vs_6_5 -E main %s | FileCheck %s
+
+// CHECK: %[[RTAS:[^ ]+]] = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 0, i32 0, i32 0, i1 false)
+// CHECK: %[[RQ:[^ ]+]] = call i32 @dx.op.allocateRayQuery(i32 168, i32 1)
+// CHECK: call void @dx.op.traceRayInline(i32 169, i32 %[[RQ]], %dx.types.Handle %[[RTAS]], i32 0, i32 1,
+// CHECK: call void @dx.op.traceRayInline(i32 169, i32 %[[RQ]], %dx.types.Handle %[[RTAS]], i32 1, i32 2,
+
+RaytracingAccelerationStructure RTAS;
+
+void DoTrace(RayQuery<RAY_FLAG_FORCE_OPAQUE> rayQuery, RayDesc rayDesc) {
+  rayQuery.TraceRayInline(RTAS, 0, 1, rayDesc);
+}
+
+float main(RayDesc rayDesc : RAYDESC) : OUT {
+  RayQuery<RAY_FLAG_FORCE_OPAQUE> rayQuery;
+  DoTrace(rayQuery, rayDesc);
+  rayQuery.TraceRayInline(RTAS, 1, 2, rayDesc);
+  return 0;
+}

+ 3 - 1
tools/clang/tools/dxcompiler/dxcdisassembler.cpp

@@ -1182,7 +1182,9 @@ static const char *OpCodeSignatures[] = {
   "(acc,a,b)",  // Dot4AddU8Packed
   "(acc,a,b)",  // Dot4AddU8Packed
   "(value)",  // WaveMatch
   "(value)",  // WaveMatch
   "(value,mask0,mask1,mask2,mask3,op,sop)",  // WaveMultiPrefixOp
   "(value,mask0,mask1,mask2,mask3,op,sop)",  // WaveMultiPrefixOp
-  "(value,mask0,mask1,mask2,mask3)"  // WaveMultiPrefixBitCount
+  "(value,mask0,mask1,mask2,mask3)",  // WaveMultiPrefixBitCount
+  "(constRayFlags)",  // AllocateRayQuery
+  "(rayQueryHandle,accelerationStructure,rayFlags,instanceInclusionMask,origin_X,origin_Y,origin_Z,tMin,direction_X,direction_Y,direction_Z,tMax)"  // TraceRayInline
 };
 };
 // OPCODE-SIGS:END
 // OPCODE-SIGS:END
 
 

+ 6 - 0
utils/hct/gen_intrin_main.txt

@@ -834,6 +834,12 @@ $classT [[]] Consume() : structuredbuffer_consume;
 
 
 } namespace
 } namespace
 
 
+namespace RayQueryMethods {
+
+void [[]]  TraceRayInline(in acceleration_struct AccelerationStructure, in uint RayFlags, in uint InstanceInclusionMask, in ray_desc Ray);
+
+} namespace
+
 // SPIRV Change Starts
 // SPIRV Change Starts
 
 
 namespace VkSubpassInputMethods {
 namespace VkSubpassInputMethods {

+ 26 - 2
utils/hct/hctdb.py

@@ -380,6 +380,9 @@ class db_dxil(object):
         for i in "WaveMatch,WaveMultiPrefixOp,WaveMultiPrefixBitCount".split(","):
         for i in "WaveMatch,WaveMultiPrefixOp,WaveMultiPrefixBitCount".split(","):
             self.name_idx[i].category = "Wave"
             self.name_idx[i].category = "Wave"
             self.name_idx[i].shader_model = 6,5
             self.name_idx[i].shader_model = 6,5
+        for i in "AllocateRayQuery,TraceRayInline".split(","):
+            self.name_idx[i].category = "Inline Ray Query"
+            self.name_idx[i].shader_model = 6,5
 
 
     def populate_llvm_instructions(self):
     def populate_llvm_instructions(self):
         # Add instructions that map to LLVM instructions.
         # Add instructions that map to LLVM instructions.
@@ -1269,7 +1272,7 @@ class db_dxil(object):
             db_dxil_param(0, "v", "", "")])
             db_dxil_param(0, "v", "", "")])
         next_op_idx += 1
         next_op_idx += 1
 
 
-        self.add_dxil_op("TraceRay", next_op_idx, "TraceRay", "returns the view index", "u", "", [
+        self.add_dxil_op("TraceRay", next_op_idx, "TraceRay", "initiates raytrace", "u", "", [
             db_dxil_param(0, "v", "", ""),
             db_dxil_param(0, "v", "", ""),
             db_dxil_param(2, "res", "AccelerationStructure", "Top-level acceleration structure to use"),
             db_dxil_param(2, "res", "AccelerationStructure", "Top-level acceleration structure to use"),
             db_dxil_param(3, "i32", "RayFlags", "Valid combination of Ray_flags"),
             db_dxil_param(3, "i32", "RayFlags", "Valid combination of Ray_flags"),
@@ -1373,9 +1376,30 @@ class db_dxil(object):
             db_dxil_param(6, "i32", "mask3", "mask 3")])
             db_dxil_param(6, "i32", "mask3", "mask 3")])
         next_op_idx += 1
         next_op_idx += 1
 
 
+        self.add_dxil_op("AllocateRayQuery", next_op_idx, "AllocateRayQuery", "allocate space for RayQuery and return handle", "v", "", [
+            db_dxil_param(0, "i32", "", "handle to RayQuery state"),
+            db_dxil_param(2, "u32", "constRayFlags", "Valid combination of RAY_FLAGS", is_const=True)])
+        next_op_idx += 1
+
+        self.add_dxil_op("TraceRayInline", next_op_idx, "TraceRayInline", "initialize RayQuery for raytrace", "v", "", [
+            db_dxil_param(0, "v", "", ""),
+            db_dxil_param(2, "i32", "rayQueryHandle", "RayQuery handle"),
+            db_dxil_param(3, "res", "accelerationStructure", "Top-level acceleration structure to use"),
+            db_dxil_param(4, "i32", "rayFlags", "Valid combination of RAY_FLAGS, combined with constRayFlags provided to AllocateRayQuery"),
+            db_dxil_param(5, "i32", "instanceInclusionMask", "Bottom 8 bits of InstanceInclusionMask are used to include/rejectgeometry instances based on the InstanceMask in each instance: if(!((InstanceInclusionMask & InstanceMask) & 0xff)) { ignore intersection }"),
+            db_dxil_param(6, "f", "origin_X", "Origin x of the ray"),
+            db_dxil_param(7, "f", "origin_Y", "Origin y of the ray"),
+            db_dxil_param(8, "f", "origin_Z", "Origin z of the ray"),
+            db_dxil_param(9, "f", "tMin", "Tmin of the ray"),
+            db_dxil_param(10, "f", "direction_X", "Direction x of the ray"),
+            db_dxil_param(11, "f", "direction_Y", "Direction y of the ray"),
+            db_dxil_param(12, "f", "direction_Z", "Direction z of the ray"),
+            db_dxil_param(13, "f", "tMax", "Tmax of the ray")])
+        next_op_idx += 1
+
         # End of DXIL 1.5 opcodes.
         # End of DXIL 1.5 opcodes.
         self.set_op_count_for_version(1, 5, next_op_idx)
         self.set_op_count_for_version(1, 5, next_op_idx)
-        assert next_op_idx == 168, "next operation index is %d rather than 168 and thus opcodes are broken" % next_op_idx
+        assert next_op_idx == 170, "next operation index is %d rather than 169 and thus opcodes are broken" % next_op_idx
 
 
         # Set interesting properties.
         # Set interesting properties.
         self.build_indices()
         self.build_indices()