Explorar el Código

Add support to convert DXR HLSL to SPV_NV_ray_tracing (#1920)

* Add support to convert DXR HLSL to SPV_NV_ray_tracing

* Fix multiple typos and cleanup using clang-format.

* Update tests to verify transpose and custom matrix type
Update tests to add multple entry functions of same shader stage

* Replace ExecutionModel class with ShaderModel::Kind

- This change removes ExecutionModel class and relies on ShaderModel::Kind to track current entry point shader stage
- Also instead of declaring it in SpirvEmitter, DeclResultIdMapper & GlPerVertex, we declare it only once in common object SpirvContext

* Dont create a stageVar for raytracing interface variables.

* Don't perform 'new' memory allocation for FunctionInfo object

This change also -
- removes invalid "SpirvEmitter::" from function declarations in SpirvEmitter class.
- fix build errors by adding a default constructor in FunctionInfo struct to allow functionInfoMap allocate an empty object for no search results.

* Fix some more typos and fomatting errors.

* Update RST with raytracing stage info

* In SpirvContext.h, replace unsigned by uint32_t

* Test add ascii art and fixup grammar mistakes.

* Use placement new to allocate FunctionInfo objects.

Also bundle the insertion into functionInfoMap and workQueue together.

* Remove outdated comment.

* Update RST with intrinsic mapping and typo fixes.

* Some more wording fixes to RST for raytracing.

* Accidently broke table in RST due to missing '-'

* Add tick marks for supported stages in RST

* Final clang-format formatting fixes.

* Add missing labels to flowchart and spacing.
alelenv hace 6 años
padre
commit
e58ecd4c2f
Se han modificado 33 ficheros con 1941 adiciones y 143 borrados
  1. 282 0
      docs/SPIR-V.rst
  2. 6 0
      tools/clang/include/clang/SPIRV/AstTypeProbe.h
  3. 1 0
      tools/clang/include/clang/SPIRV/FeatureManager.h
  4. 6 0
      tools/clang/include/clang/SPIRV/SpirvBuilder.h
  5. 37 0
      tools/clang/include/clang/SPIRV/SpirvContext.h
  6. 24 0
      tools/clang/include/clang/SPIRV/SpirvInstruction.h
  7. 12 0
      tools/clang/include/clang/SPIRV/SpirvType.h
  8. 1 0
      tools/clang/include/clang/SPIRV/SpirvVisitor.h
  9. 80 0
      tools/clang/lib/SPIRV/AstTypeProbe.cpp
  10. 10 0
      tools/clang/lib/SPIRV/CapabilityVisitor.cpp
  11. 84 39
      tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
  12. 24 16
      tools/clang/lib/SPIRV/DeclResultIdMapper.h
  13. 20 0
      tools/clang/lib/SPIRV/EmitVisitor.cpp
  14. 1 0
      tools/clang/lib/SPIRV/EmitVisitor.h
  15. 4 2
      tools/clang/lib/SPIRV/FeatureManager.cpp
  16. 9 9
      tools/clang/lib/SPIRV/GlPerVertex.cpp
  17. 2 4
      tools/clang/lib/SPIRV/GlPerVertex.h
  18. 4 0
      tools/clang/lib/SPIRV/LowerTypeVisitor.cpp
  19. 11 0
      tools/clang/lib/SPIRV/SpirvBuilder.cpp
  20. 4 1
      tools/clang/lib/SPIRV/SpirvContext.cpp
  21. 647 57
      tools/clang/lib/SPIRV/SpirvEmitter.cpp
  22. 65 11
      tools/clang/lib/SPIRV/SpirvEmitter.h
  23. 6 0
      tools/clang/lib/SPIRV/SpirvInstruction.cpp
  24. 2 1
      tools/clang/lib/SPIRV/SpirvType.cpp
  25. 78 0
      tools/clang/test/CodeGenSPIRV/raytracing.nv.anyhit.hlsl
  26. 20 0
      tools/clang/test/CodeGenSPIRV/raytracing.nv.callable.hlsl
  27. 79 0
      tools/clang/test/CodeGenSPIRV/raytracing.nv.closesthit.hlsl
  28. 62 0
      tools/clang/test/CodeGenSPIRV/raytracing.nv.intersection.hlsl
  29. 260 0
      tools/clang/test/CodeGenSPIRV/raytracing.nv.library.hlsl
  30. 32 0
      tools/clang/test/CodeGenSPIRV/raytracing.nv.miss.hlsl
  31. 38 0
      tools/clang/test/CodeGenSPIRV/raytracing.nv.raygen.hlsl
  32. 23 0
      tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp
  33. 7 3
      tools/clang/unittests/SPIRV/FileTestUtils.cpp

+ 282 - 0
docs/SPIR-V.rst

@@ -2821,6 +2821,288 @@ behind ``T`` will be flushed before SPIR-V ``OpEmitVertex`` instruction is
 generated. ``.RestartStrip()`` method calls will be translated into the SPIR-V
 ``OpEndPrimitive`` instruction.
 
+Raytracing Shader Stages
+------------------------
+
+DirectX Raytracing adds six new shader stages for raytracing namely ray generation, intersection, closest-hit,
+any-hit, miss and callable.
+
+| Refer to following pages for details:
+| https://docs.microsoft.com/en-us/windows/desktop/direct3d12/direct3d-12-raytracing
+| https://docs.microsoft.com/en-us/windows/desktop/direct3d12/direct3d-12-raytracing-hlsl-reference
+
+
+Flow chart for various stages in a raytracing pipeline is as follows:
+::
+
+          +---------------------+
+          |   Ray generation    |
+          +---------------------+
+                     |
+          TraceRay() |                      +--------------+
+                     |      _ _ _ _ _ _ _ _ |   Any Hit    |
+                     |     |                +--------------+
+                     V     V                       ^
+          +---------------------+                  |
+          |    Acceleration     |           +--------------+
+          |     Structure       |           | Intersection |
+          |     Traversal       |           +--------------+
+          +---------------------+                  ^
+                    |        |                     |
+                    |        |_ _ _ _ _ _ _ _ _ _ _|
+                    |
+                    |
+                    V
+          +--------------------+            +-------------+
+          |      Is Hit ?      |            |  Callable   |
+          +--------------------+            +-------------+
+              |            |
+          Yes |            | No
+              V            V
+         +---------+    +------+
+         | Closest |    | Miss |
+         |   Hit   |    |      |
+         +---------+    +------+
+
+
+| *Note : DXC does not add special shader profiles for raytracing under -T option.*
+| *All raytracing shaders must be compiled as library using lib_6_3/lib_6_4 profile option.*
+
+Ray Generation Stage
+~~~~~~~~~~~~~~~~~~~~
+
+| Ray generation shaders start ray tracing work and work on a compute-like 3D grid of threads.
+| Entry functions of this stage type are annotated with **[shader("raygeneration")]** in HLSL source.
+| Such entry functions must return void and do not accept any arguments.
+
+| For example:
+
+.. code:: hlsl
+
+  RaytracingAccelerationStructure rs;
+  struct Payload
+  {
+  float4 color;
+  };
+  [shader("raygeneration")]
+  void main() {
+    Payload myPayload = { float4(0.0f,0.0f,0.0f,0.0f) };
+    RayDesc rayDesc;
+    rayDesc.Origin = float3(0.0f, 0.0f, 0.0f);
+    rayDesc.Direction = float3(0.0f, 0.0f, -1.0f);
+    rayDesc.TMin = 0.0f;
+    rayDesc.TMax = 1000.0f;
+    TraceRay(rs, 0x0, 0xff, 0, 1, 0, rayDesc, myPayload);
+  }
+
+Intersection Stage
+~~~~~~~~~~~~~~~~~~
+
+| Intersection shader stage is used to implement arbitrary ray-primitive intersections such spheres or axis-aligned bounding boxes (AABB). Triangle primitives do not require a custom intersection shader.
+| Entry functions of this stage are annotated with **[shader("intersection")]** in HLSL source.
+| Such entry functions must return void and do not accept any arguments.
+
+| For example:
+
+.. code:: hlsl
+
+  struct Attribute
+  {
+    float2 bary;
+  };
+
+  [shader("intersection")]
+  void main() {
+  Attribute myHitAttribute = { float2(0.0f,0.0f) };
+  ReportHit(0.0f, 0U, myHitAttribute);
+  }
+
+
+Closest-Hit Stage
+~~~~~~~~~~~~~~~~~
+
+| Hit shaders are invoked when a ray primitive intersection is found. A closest-hit shader
+| is invoked for the closest intersection point along a ray and can be used to compute interactions
+| at intersection point or spawn secondary rays.
+| Entry functions of this stage are annotated with **[shader("closesthit")]** in HLSL source.
+| Such entry functions must return void and accept exactly two arguments. First argument must be an inout
+| variable of user defined structure type and second argument must be a in variable of user defined structure type.
+
+| For example:
+
+.. code:: hlsl
+
+  struct Attribute
+  {
+    float2 bary;
+  };
+  struct Payload {
+    float4 color;
+  };
+  [shader("closesthit")]
+  void main(inout Payload a, in Attribute b) {
+    a.color = float4(0.0f,1.0f,0.0f,0.0f);
+  }
+
+Any-Hit Stage
+~~~~~~~~~~~~~~~~~
+
+| Hit shaders are invoked when a ray primitive intersection is found. An any-hit shader
+| is invoked for all intersections along a ray with a primitive.
+| Entry functions of this stage are annotated with **[shader("anyhit")]** in HLSL source.
+| Such entry functions must return void and accept exactly two arguments. First argument must be an inout
+| variable of user defined structure type and second argument must be an in variable of user defined structure type.
+
+| For example:
+
+.. code:: hlsl
+
+  struct Attribute
+  {
+    float2 bary;
+  };
+  struct Payload {
+    float4 color;
+  };
+  [shader("anyhit")]
+  void main(inout Payload a, in Attribute b) {
+    a.color = float4(0.0f,1.0f,0.0f,0.0f);
+  }
+
+Miss Stage
+~~~~~~~~~~
+
+| Miss shaders are invoked when no intersection is found.
+| Entry functions of this stage are annotated with **[shader("miss")]** in HLSL source.
+| Such entry functions return void and accept exactly one argument. First argument must be an inout variable of user defined structure type.
+
+| For example:
+
+.. code:: hlsl
+
+  struct Payload {
+    float4 color;
+  };
+  [shader("miss")]
+  void main(inout Payload a) {
+    a.color = float4(0.0f,1.0f,0.0f,0.0f);
+  }
+
+Callable Stage
+~~~~~~~~~~~~~~
+
+| Callables are generic function calls which can be invoked from either raygeneration, closest-hit, 
+| miss or callable shader stages.
+| Entry functions of this stage are annotated with **[shader("callable")]** in HLSL source.
+| Such entry functions must return void and accept exactly one argument. First argument must be an inout
+| variable of user defined structure type.
+
+| For example:
+
+.. code:: hlsl
+
+  struct CallData {
+    float4 data;
+  };
+  [shader("callable")]
+  void main(inout CallData a) {
+    a.color = float4(0.0f,1.0f,0.0f,0.0f);
+  }
+
+
+Raytracing in Vulkan and SPIRV
+==============================
+
+| SPIR-V codegen is currently supported for NVIDIA platforms via SPV_NV_ray_tracing extension
+| SPIR-V specification for reference:
+| https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/NV/SPV_NV_ray_tracing.asciidoc
+
+| Vulkan ray tracing samples:
+| https://developer.nvidia.com/rtx/raytracing/vkray
+
+
+Raytracing Mapping to SPIR-V
+----------------------------
+
+Intrinsics
+~~~~~~~~~~
+
+
+| Following table provides mapping for system value intrinsics along with supported shader stages.
+
+============================    ============================ ====== ============ =========== ======= ======== ========
+        HLSL                               SPIR-V                             HLSL Shader Stage
+----------------------------    ---------------------------- ---------------------------------------------------------
+  System Value Intrinsic               Builtin               Raygen Intersection Closest Hit Any Hit   Miss   Callable
+============================    ============================ ====== ============ =========== ======= ======== ========
+``DispatchRaysIndex()``         ``LaunchIdNV``                 ✓         ✓            ✓        ✓      ✓        ✓
+``DispatchRaysDimensions()``    ``LaunchSizeNV``               ✓         ✓            ✓        ✓      ✓        ✓
+``WorldRayOrigin()``            ``WorldRayOriginNV``                     ✓            ✓        ✓      ✓
+``WorldRayDirection()``         ``WorldRayDirectionNV``                  ✓            ✓        ✓      ✓
+``RayTMin()``                   ``RayTminNV``                            ✓            ✓        ✓      ✓
+``RayTCurrent()``               ``HitTNV``                               ✓            ✓        ✓      ✓
+``RayFlags()``                  ``IncomingRayFlagsNV``                   ✓            ✓        ✓      ✓
+``InstanceIndex()``             ``InstanceId``                           ✓            ✓        ✓
+``InstanceID()``                ``InstanceCustomIndexNV``                ✓            ✓        ✓
+``PrimitiveIndex()``            ``PrimitiveId``                          ✓            ✓        ✓
+``ObjectRayOrigin()``           ``ObjectRayOriginNV``                    ✓            ✓        ✓
+``ObjectRayDirection()``        ``ObjectRayDirectionNV``                 ✓            ✓        ✓
+``ObjectToWorld3x4()``          ``ObjectToWorldNV``                      ✓            ✓        ✓
+``ObjectToWorld4x3()``          ``ObjectToWorldNV``                      ✓            ✓        ✓
+``WorldToObject3x4()``          ``WorldToObjectNV``                      ✓            ✓        ✓
+``WorldToObject4x3()``          ``WorldToObjectNV``                      ✓            ✓        ✓
+``HitKind()``                   ``HitKindNV``                            ✓            ✓        ✓
+============================    ============================ ====== ============ =========== ======= ======== ========
+
+| *There is no separate builtin for transposed matrices ObjectToWorld3x4 and WorldToObject3x4 in SPIR-V hence we internally transpose during translation*
+
+
+| Following table provides mapping for other intrinsics along with supported shader stages.
+
+
+===========================     ============================ ====== ============ =========== ======= ===== ========
+        HLSL                               SPIR-V                             HLSL Shader Stage
+---------------------------     ---------------------------- ------------------------------------------------------
+   Intrinsic                              Opcode             Raygen Intersection Closest Hit Any Hit  Miss Callable
+===========================     ============================ ====== ============ =========== ======= ===== ========
+``TraceRay``                    ``OpTraceNV``                  ✓                     ✓                ✓
+``ReportHit``                   ``OpReportIntersectionNV``     ✓         ✓
+``IgnoreHit``                   ``OpIgnoreIntersectionNV``     ✓                                ✓
+``AcceptHitAndEndSearch``       ``OpTerminateRayNV``           ✓                                ✓
+``CallShader``                  ``OpExecuteCallable``          ✓                     ✓                ✓      ✓
+===========================     ============================ ====== ============ =========== ======= ===== ========
+
+
+Resource Types
+~~~~~~~~~~~~~~
+
+| Following table provides mapping for new resource types supported in all raytracing shaders.
+
+
+===================================     =================================
+        HLSL Type                               SPIR-V Opcode
+-----------------------------------     ---------------------------------
+``RaytracingAccelerationStructure``     ``OpTypeAccelerationStructureNV``
+===================================     =================================
+
+Interface Variables
+~~~~~~~~~~~~~~~~~~~
+
+| Interface variables are created for various ray tracing storage classes based on intrinsic/shader stage
+| Following table gives high level overview of the mapping.
+
+
+===========================     ===========================================================
+   SPIR-V Storage Class                Created For
+---------------------------     -----------------------------------------------------------
+``RayPayloadNV``                Last argument to TraceRay
+``IncomingRayPayloadNV``        First argument of entry for AnyHit/ClosestHit & Miss stage
+``HitAttributeNV``              Last argument to ReportHit
+``CallableDataNV``              Last argument to CallShader
+``IncomingCallableDataNV``      First argument of entry for Callable stage
+===========================     ===========================================================
+
+
 Shader Model 6.0 Wave Intrinsics
 ================================
 

+ 6 - 0
tools/clang/include/clang/SPIRV/AstTypeProbe.h

@@ -13,6 +13,7 @@
 #include "dxc/Support/SPIRVOptions.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/Type.h"
+#include "clang/Sema/Sema.h"
 
 namespace clang {
 namespace spirv {
@@ -268,6 +269,11 @@ bool isOrContainsNonFpColMajorMatrix(const ASTContext &,
 /// matrix type.
 QualType getComponentVectorType(const ASTContext &, QualType matrixType);
 
+/// \brief Returns a QualType corresponding to HLSL matrix of given element type
+/// and rows/columns.
+QualType getHLSLMatrixType(ASTContext &, Sema &, ClassTemplateDecl *,
+                           QualType elemType, int rows, int columns);
+
 } // namespace spirv
 } // namespace clang
 

+ 1 - 0
tools/clang/include/clang/SPIRV/FeatureManager.h

@@ -41,6 +41,7 @@ enum class Extension {
   AMD_gpu_shader_half_float,
   AMD_shader_explicit_vertex_parameter,
   GOOGLE_hlsl_functionality1,
+  NV_ray_tracing,
   Unknown,
 };
 

+ 6 - 0
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -420,6 +420,12 @@ public:
 
   void createLineInfo(SpirvString *file, uint32_t line, uint32_t column);
 
+  /// \brief Creates SPIR-V instructions for NV raytracing ops.
+  SpirvInstruction *
+  createRayTracingOpsNV(spv::Op opcode, QualType resultType,
+                        llvm::ArrayRef<SpirvInstruction *> operands,
+                        SourceLocation loc);
+
   // === SPIR-V Module Structure ===
 
   inline void requireCapability(spv::Capability, SourceLocation loc = {});

+ 37 - 0
tools/clang/include/clang/SPIRV/SpirvContext.h

@@ -12,6 +12,7 @@
 #include <array>
 #include <unordered_map>
 
+#include "dxc/DXIL/DxilShaderModel.h"
 #include "clang/Frontend/FrontendAction.h"
 #include "clang/SPIRV/SpirvInstruction.h"
 #include "clang/SPIRV/SpirvType.h"
@@ -119,6 +120,7 @@ struct FunctionTypeMapInfo {
 /// the SPIR-V entities allocated in memory.
 class SpirvContext {
 public:
+  using ShaderModelKind = hlsl::ShaderModel::Kind;
   SpirvContext();
   ~SpirvContext() = default;
 
@@ -180,6 +182,10 @@ public:
   const StructType *getByteAddressBufferType(bool isWritable);
   const StructType *getACSBufferCounterType();
 
+  const AccelerationStructureTypeNV *getAccelerationStructureTypeNV() const {
+    return accelerationStructureTypeNV;
+  }
+
   /// --- Hybrid type getter functions ---
   ///
   /// Concrete SpirvType objects represent a SPIR-V type completely. Hybrid
@@ -200,6 +206,30 @@ public:
   HybridFunctionType *getFunctionType(QualType ret,
                                       llvm::ArrayRef<QualType> param);
 
+  /// Functions to get/set current entry point ShaderModelKind.
+  ShaderModelKind getCurrentShaderModelKind() { return curShaderModelKind; }
+  void setCurrentShaderModelKind(ShaderModelKind smk) {
+    curShaderModelKind = smk;
+  }
+  /// Functions to get/set hlsl profile version.
+  uint32_t getMajorVersion() const { return majorVersion; }
+  void setMajorVersion(uint32_t major) { majorVersion = major; }
+  uint32_t getMinorVersion() const { return minorVersion; }
+  void setMinorVersion(uint32_t minor) { minorVersion = minor; }
+
+  /// Functions to query current entry point ShaderModelKind.
+  bool isPS() const { return curShaderModelKind == ShaderModelKind::Pixel; }
+  bool isVS() const { return curShaderModelKind == ShaderModelKind::Vertex; }
+  bool isGS() const { return curShaderModelKind == ShaderModelKind::Geometry; }
+  bool isHS() const { return curShaderModelKind == ShaderModelKind::Hull; }
+  bool isDS() const { return curShaderModelKind == ShaderModelKind::Domain; }
+  bool isCS() const { return curShaderModelKind == ShaderModelKind::Compute; }
+  bool isLib() const { return curShaderModelKind == ShaderModelKind::Library; }
+  bool isRay() const {
+    return curShaderModelKind >= ShaderModelKind::RayGeneration &&
+           curShaderModelKind <= ShaderModelKind::Callable;
+  }
+
 private:
   /// \brief The allocator used to create SPIR-V entity objects.
   ///
@@ -242,6 +272,13 @@ private:
   llvm::SmallVector<const StructType *, 8> structTypes;
   llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
   llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
+  const AccelerationStructureTypeNV *accelerationStructureTypeNV;
+
+  // Current ShaderModelKind for entry point.
+  ShaderModelKind curShaderModelKind;
+  // Major/Minor hlsl profile version.
+  uint32_t majorVersion;
+  uint32_t minorVersion;
 };
 
 } // end namespace spirv

+ 24 - 0
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -116,6 +116,7 @@ public:
     IK_UnaryOp,                   // Unary operations
     IK_VectorShuffle,             // OpVectorShuffle
     IK_ArrayLength,               // OpArrayLength
+    IK_RayTracingOpNV,            // NV raytracing ops
   };
 
   virtual ~SpirvInstruction() = default;
@@ -1757,6 +1758,29 @@ private:
   uint32_t column;
 };
 
+/// \brief Base class for all NV raytracing instructions.
+/// These include following SPIR-V opcodes:
+/// OpTraceNV, OpReportIntersectionNV, OpIgnoreIntersectionNV, OpTerminateRayNV,
+/// OpExecuteCallableNV
+class SpirvRayTracingOpNV : public SpirvInstruction {
+public:
+  SpirvRayTracingOpNV(QualType resultType, spv::Op opcode,
+                      llvm::ArrayRef<SpirvInstruction *> vecOperands,
+                      SourceLocation loc);
+
+  // For LLVM-style RTTI
+  static bool classof(const SpirvInstruction *inst) {
+    return inst->getKind() == IK_RayTracingOpNV;
+  }
+
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvRayTracingOpNV)
+
+  llvm::ArrayRef<SpirvInstruction *> getOperands() const { return operands; }
+
+private:
+  llvm::SmallVector<SpirvInstruction *, 4> operands;
+};
+
 #undef DECLARE_INVOKE_VISITOR_FOR_CLASS
 
 } // namespace spirv

+ 12 - 0
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -47,6 +47,7 @@ public:
     TK_Struct,
     TK_Pointer,
     TK_Function,
+    TK_AccelerationStructureNV,
     // Order matters: all the following are hybrid types
     TK_HybridStruct,
     TK_HybridPointer,
@@ -380,6 +381,17 @@ private:
   llvm::SmallVector<const SpirvType *, 8> paramTypes;
 };
 
+/// Represents accleration structure type as defined in SPV_NV_ray_tracing.
+class AccelerationStructureTypeNV : public SpirvType {
+public:
+  AccelerationStructureTypeNV()
+      : SpirvType(TK_AccelerationStructureNV, "accelerationStructureNV") {}
+
+  static bool classof(const SpirvType *t) {
+    return t->getKind() == TK_AccelerationStructureNV;
+  }
+};
+
 class HybridType : public SpirvType {
 public:
   static bool classof(const SpirvType *t) {

+ 1 - 0
tools/clang/include/clang/SPIRV/SpirvVisitor.h

@@ -113,6 +113,7 @@ public:
   DEFINE_VISIT_METHOD(SpirvUnaryOp)
   DEFINE_VISIT_METHOD(SpirvVectorShuffle)
   DEFINE_VISIT_METHOD(SpirvArrayLength)
+  DEFINE_VISIT_METHOD(SpirvRayTracingOpNV)
 
 #undef DEFINE_VISIT_METHOD
 

+ 80 - 0
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -11,6 +11,7 @@
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Attr.h"
 #include "clang/AST/Decl.h"
+#include "clang/AST/DeclTemplate.h"
 #include "clang/AST/HlslTypes.h"
 
 namespace {
@@ -850,6 +851,9 @@ bool isOpaqueType(QualType type) {
 
     if (name == "SamplerState" || name == "SamplerComparisonState")
       return true;
+
+    if (name == "RaytracingAccelerationStructure")
+      return true;
   }
   return false;
 }
@@ -1012,5 +1016,81 @@ QualType getComponentVectorType(const ASTContext &astContext,
   return astContext.getExtVectorType(elemType, colCount);
 }
 
+QualType getHLSLMatrixType(ASTContext &astContext, Sema &S,
+                           ClassTemplateDecl *templateDecl, QualType elemType,
+                           int rows, int columns) {
+  const SourceLocation noLoc;
+  TemplateArgument templateArgs[3] = {
+      TemplateArgument(elemType),
+      TemplateArgument(
+          astContext,
+          llvm::APSInt(
+              llvm::APInt(astContext.getIntWidth(astContext.IntTy), rows),
+              false),
+          astContext.IntTy),
+      TemplateArgument(
+          astContext,
+          llvm::APSInt(
+              llvm::APInt(astContext.getIntWidth(astContext.IntTy), columns),
+              false),
+          astContext.IntTy)};
+
+  SmallVector<TemplateArgument, 4> args;
+  args.push_back(templateArgs[0]);
+  args.push_back(templateArgs[1]);
+  args.push_back(templateArgs[2]);
+
+  DeclContext *currentDeclContext = astContext.getTranslationUnitDecl();
+  SmallVector<TemplateArgument, 3> templateArgsForDecl;
+
+  for (const TemplateArgument &Arg : templateArgs) {
+    if (Arg.getKind() == TemplateArgument::Type) {
+      // the class template need to use CanonicalType
+      templateArgsForDecl.emplace_back(
+          TemplateArgument(Arg.getAsType().getCanonicalType()));
+    } else
+      templateArgsForDecl.emplace_back(Arg);
+  }
+
+  // First, try looking up existing specialization
+  void *insertPos = nullptr;
+  ClassTemplateSpecializationDecl *specializationDecl =
+      templateDecl->findSpecialization(templateArgsForDecl, insertPos);
+
+  if (specializationDecl) {
+    // Instantiate the class template if not done yet.
+    if (specializationDecl->getInstantiatedFrom().isNull()) {
+      S.InstantiateClassTemplateSpecialization(
+          noLoc, specializationDecl,
+          TemplateSpecializationKind::TSK_ImplicitInstantiation, true);
+    }
+    return astContext.getTemplateSpecializationType(
+        TemplateName(templateDecl), args.data(), args.size(),
+        astContext.getTypeDeclType(specializationDecl));
+  }
+
+  specializationDecl = ClassTemplateSpecializationDecl::Create(
+      astContext, TagDecl::TagKind::TTK_Class, currentDeclContext, noLoc, noLoc,
+      templateDecl, templateArgsForDecl.data(), templateArgsForDecl.size(),
+      nullptr);
+  S.InstantiateClassTemplateSpecialization(
+      noLoc, specializationDecl,
+      TemplateSpecializationKind::TSK_ImplicitInstantiation, true);
+  templateDecl->AddSpecialization(specializationDecl, insertPos);
+  specializationDecl->setImplicit(true);
+
+  QualType canonType = astContext.getTypeDeclType(specializationDecl);
+  TemplateArgumentListInfo templateArgumentList(noLoc, noLoc);
+  TemplateArgumentLocInfo noTemplateArgumentLocInfo;
+
+  for (unsigned i = 0; i < args.size(); i++) {
+    templateArgumentList.addArgument(
+        TemplateArgumentLoc(args[i], noTemplateArgumentLocInfo));
+  }
+
+  return astContext.getTemplateSpecializationType(
+      TemplateName(templateDecl), templateArgumentList, canonType);
+}
+
 } // namespace spirv
 } // namespace clang

+ 10 - 0
tools/clang/lib/SPIRV/CapabilityVisitor.cpp

@@ -459,6 +459,16 @@ bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
   case spv::ExecutionModel::TessellationEvaluation:
     spvBuilder.requireCapability(spv::Capability::Tessellation);
     break;
+  case spv::ExecutionModel::RayGenerationNV:
+  case spv::ExecutionModel::IntersectionNV:
+  case spv::ExecutionModel::ClosestHitNV:
+  case spv::ExecutionModel::AnyHitNV:
+  case spv::ExecutionModel::MissNV:
+  case spv::ExecutionModel::CallableNV:
+    spvBuilder.requireCapability(spv::Capability::RayTracingNV);
+    spvBuilder.addExtension(Extension::NV_ray_tracing, "SPV_NV_ray_tracing",
+                            {});
+    break;
   default:
     llvm_unreachable("found unknown shader model");
     break;

+ 84 - 39
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -431,8 +431,8 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
   if (hlsl::IsHLSLStreamOutputType(type))
     type = hlsl::GetHLSLResourceResultType(type);
 
-  const auto *sigPoint =
-      deduceSigPoint(decl, /*asInput=*/false, shaderModel.GetKind(), forPCF);
+  const auto *sigPoint = deduceSigPoint(
+      decl, /*asInput=*/false, spvContext.getCurrentShaderModelKind(), forPCF);
 
   // HS output variables are created using the other overload. For the rest,
   // none of them should be created as arrays.
@@ -446,7 +446,7 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
   // Write back of stage output variables in GS is manually controlled by
   // .Append() intrinsic method, implemented in writeBackOutputStream(). So
   // ignoreValue should be set to true for GS.
-  const bool noWriteBack = storedValue == nullptr || shaderModel.IsGS();
+  const bool noWriteBack = storedValue == nullptr || spvContext.isGS();
 
   return createStageVars(sigPoint, decl, /*asInput=*/false, type,
                          /*arraySize=*/0, "out.var", llvm::None, &storedValue,
@@ -457,7 +457,7 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
                                               uint32_t arraySize,
                                               SpirvInstruction *invocationId,
                                               SpirvInstruction *storedValue) {
-  assert(shaderModel.IsHS());
+  assert(spvContext.isHS());
 
   QualType type = getTypeOrFnRetType(decl);
 
@@ -492,8 +492,9 @@ bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
     type = typeDecl->getElementType();
   }
 
-  const auto *sigPoint = deduceSigPoint(paramDecl, /*asInput=*/true,
-                                        shaderModel.GetKind(), forPCF);
+  const auto *sigPoint =
+      deduceSigPoint(paramDecl, /*asInput=*/true,
+                     spvContext.getCurrentShaderModelKind(), forPCF);
 
   SemanticInfo inheritSemantic = {};
 
@@ -1245,7 +1246,7 @@ bool DeclResultIdMapper::finalizeStageIOLocations(bool forInput) {
   // likely. In order to avoid location mismatches between HS and DS, use
   // alphabetical ordering.
   if (spirvOptions.stageIoOrder == "alpha" ||
-      (!forInput && shaderModel.IsHS()) || (forInput && shaderModel.IsDS())) {
+      (!forInput && spvContext.isHS()) || (forInput && spvContext.isDS())) {
     // Sort stage input/output variables alphabetically
     std::sort(vars.begin(), vars.end(),
               [](const StageVar *a, const StageVar *b) {
@@ -1508,7 +1509,7 @@ bool DeclResultIdMapper::createStageVars(
   assert(value);
   // invocationId should only be used for handling HS per-vertex output.
   if (invocationId.hasValue()) {
-    assert(shaderModel.IsHS() && arraySize != 0 && !asInput);
+    assert(spvContext.isHS() && arraySize != 0 && !asInput);
   }
 
   assert(inheritSemantic);
@@ -1562,12 +1563,14 @@ bool DeclResultIdMapper::createStageVars(
 
     // Error out when the given semantic is invalid in this shader model
     if (hlsl::SigPoint::GetInterpretation(semanticKind, sigPoint->GetKind(),
-                                          shaderModel.GetMajor(),
-                                          shaderModel.GetMinor()) ==
+                                          spvContext.getMajorVersion(),
+                                          spvContext.getMinorVersion()) ==
         hlsl::DXIL::SemanticInterpretationKind::NA) {
       emitError("invalid usage of semantic '%0' in shader profile %1",
                 decl->getLocation())
-          << semanticToUse->str << shaderModel.GetName();
+          << semanticToUse->str
+          << hlsl::ShaderModel::GetKindName(
+                 spvContext.getCurrentShaderModelKind());
       return false;
     }
 
@@ -1682,7 +1685,7 @@ bool DeclResultIdMapper::createStageVars(
       spvBuilder.decoratePatch(varInstr);
 
     // Decorate with interpolation modes for pixel shader input variables
-    if (shaderModel.IsPS() && sigPoint->IsInput() &&
+    if (spvContext.isPS() && sigPoint->IsInput() &&
         // BaryCoord*AMD buitins already encode the interpolation mode.
         semanticKind != hlsl::Semantic::Kind::Barycentrics)
       decoratePSInterpolationMode(decl, type, varInstr);
@@ -2037,7 +2040,7 @@ bool DeclResultIdMapper::createStageVars(
 bool DeclResultIdMapper::writeBackOutputStream(const NamedDecl *decl,
                                                QualType type,
                                                SpirvInstruction *value) {
-  assert(shaderModel.IsGS()); // Only for GS use
+  assert(spvContext.isGS()); // Only for GS use
 
   if (hlsl::IsHLSLStreamOutputType(type))
     type = hlsl::GetHLSLResourceResultType(type);
@@ -2135,7 +2138,7 @@ DeclResultIdMapper::invertYIfRequested(SpirvInstruction *position) {
 SpirvInstruction *
 DeclResultIdMapper::invertWIfRequested(SpirvInstruction *position) {
   // Reciprocate SV_Position.w if requested
-  if (spirvOptions.invertW && shaderModel.IsPS()) {
+  if (spirvOptions.invertW && spvContext.isPS()) {
     const auto oldW =
         spvBuilder.createCompositeExtract(astContext.FloatTy, position, {3});
     const auto newW = spvBuilder.createBinaryOp(
@@ -2183,52 +2186,57 @@ void DeclResultIdMapper::decoratePSInterpolationMode(const NamedDecl *decl,
 }
 
 SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
+                                                 QualType type,
                                                  SourceLocation loc) {
   // Guarantee uniqueness
+  uint32_t spvBuiltinId = static_cast<uint32_t>(builtIn);
+  const auto builtInVar = builtinToVarMap.find(spvBuiltinId);
+  if (builtInVar != builtinToVarMap.end()) {
+    return builtInVar->second;
+  }
   switch (builtIn) {
   case spv::BuiltIn::SubgroupSize:
-    if (laneCountBuiltinVar)
-      return laneCountBuiltinVar;
-    break;
   case spv::BuiltIn::SubgroupLocalInvocationId:
-    if (laneIndexBuiltinVar)
-      return laneIndexBuiltinVar;
+  case spv::BuiltIn::HitTNV:
+  case spv::BuiltIn::RayTminNV:
+  case spv::BuiltIn::HitKindNV:
+  case spv::BuiltIn::IncomingRayFlagsNV:
+  case spv::BuiltIn::InstanceCustomIndexNV:
+  case spv::BuiltIn::PrimitiveId:
+  case spv::BuiltIn::InstanceId:
+  case spv::BuiltIn::WorldRayDirectionNV:
+  case spv::BuiltIn::WorldRayOriginNV:
+  case spv::BuiltIn::ObjectRayDirectionNV:
+  case spv::BuiltIn::ObjectRayOriginNV:
+  case spv::BuiltIn::ObjectToWorldNV:
+  case spv::BuiltIn::WorldToObjectNV:
+  case spv::BuiltIn::LaunchIdNV:
+  case spv::BuiltIn::LaunchSizeNV:
+    // Valid builtins supported
     break;
   default:
-    // Only allow the two cases we know about
-    assert(false && "unsupported builtin case");
+    assert(false && "unsupported SPIR-V builtin");
     return nullptr;
   }
 
   // Create a dummy StageVar for this builtin variable
-  auto var = spvBuilder.addStageBuiltinVar(
-      astContext.UnsignedIntTy, spv::StorageClass::Input, builtIn, loc);
+  auto var = spvBuilder.addStageBuiltinVar(type, spv::StorageClass::Input,
+                                           builtIn, loc);
 
   const hlsl::SigPoint *sigPoint =
       hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
-          hlsl::DxilParamInputQual::In, shaderModel.GetKind(),
+          hlsl::DxilParamInputQual::In, spvContext.getCurrentShaderModelKind(),
           /*isPatchConstant=*/false));
 
-  StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr,
-                    astContext.UnsignedIntTy,
+  StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
                     /*locCount=*/0);
 
   stageVar.setIsSpirvBuiltin();
   stageVar.setSpirvInstr(var);
   stageVars.push_back(stageVar);
 
-  switch (builtIn) {
-  case spv::BuiltIn::SubgroupSize:
-    laneCountBuiltinVar = var;
-    break;
-  case spv::BuiltIn::SubgroupLocalInvocationId:
-    laneIndexBuiltinVar = var;
-    break;
-  default:
-    // Only relevant to subgroup builtins.
-    break;
-  }
-
+  // Store in map for re-use
+  builtinToVarMap[spvBuiltinId] = var;
   return var;
 }
 
@@ -2601,7 +2609,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
 bool DeclResultIdMapper::validateVKAttributes(const NamedDecl *decl) {
   bool success = true;
   if (const auto *idxAttr = decl->getAttr<VKIndexAttr>()) {
-    if (!shaderModel.IsPS()) {
+    if (!spvContext.isPS()) {
       emitError("vk::index only allowed in pixel shader",
                 idxAttr->getLocation());
       success = false;
@@ -2799,5 +2807,42 @@ QualType DeclResultIdMapper::getTypeAndCreateCounterForPotentialAliasVar(
   return type;
 }
 
+SpirvVariable *
+DeclResultIdMapper::createRayTracingNVStageVar(spv::StorageClass sc,
+                                               const VarDecl *decl) {
+  QualType type = decl->getType();
+  SpirvVariable *retVal = nullptr;
+
+  // Raytracing interface variables are special since they do not participate
+  // in any interface matching and hence do not create StageVar and
+  // track them under StageVars vector
+
+  const auto name = decl->getName();
+
+  switch (sc) {
+  case spv::StorageClass::IncomingRayPayloadNV:
+  case spv::StorageClass::IncomingCallableDataNV:
+  case spv::StorageClass::HitAttributeNV:
+  case spv::StorageClass::RayPayloadNV:
+  case spv::StorageClass::CallableDataNV:
+    retVal = spvBuilder.addModuleVar(type, sc, name.str());
+    break;
+
+  default:
+    assert(false && "Unsupported SPIR-V storage class for raytracing");
+  }
+
+  return retVal;
+}
+
+void DeclResultIdMapper::createRayTracingNVImplicitVar(const VarDecl *varDecl) {
+  APValue *val = varDecl->evaluateValue();
+  assert(val);
+  SpirvInstruction *constVal =
+      spvBuilder.getConstantInt(astContext.UnsignedIntTy, val->getInt());
+  constVal->setRValue(true);
+  astDecls[varDecl].instr = constVal;
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 24 - 16
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -14,7 +14,6 @@
 #include <vector>
 
 #include "dxc/DXIL/DxilSemantic.h"
-#include "dxc/DXIL/DxilShaderModel.h"
 #include "dxc/DXIL/DxilSigPoint.h"
 #include "dxc/Support/SPIRVOptions.h"
 #include "spirv/unified1/spirv.hpp11"
@@ -259,14 +258,14 @@ private:
 /// stage variables per Vulkan's requirements.
 class DeclResultIdMapper {
 public:
-  inline DeclResultIdMapper(const hlsl::ShaderModel &stage, ASTContext &context,
-                            SpirvContext &spirvContext,
+  inline DeclResultIdMapper(ASTContext &context, SpirvContext &spirvContext,
                             SpirvBuilder &spirvBuilder, SpirvEmitter &emitter,
                             FeatureManager &features,
                             const SpirvCodeGenOptions &spirvOptions);
 
   /// \brief Returns the SPIR-V builtin variable.
-  SpirvVariable *getBuiltinVar(spv::BuiltIn builtIn, SourceLocation);
+  SpirvVariable *getBuiltinVar(spv::BuiltIn builtIn, QualType type,
+                               SourceLocation);
 
   /// \brief Creates the stage output variables by parsing the semantics
   /// attached to the given function's parameter or return value and returns
@@ -292,6 +291,10 @@ public:
   bool createStageInputVar(const ParmVarDecl *paramDecl,
                            SpirvInstruction **loadedValue, bool forPCF);
 
+  /// \brief Creates stage variables for raytracing.
+  SpirvVariable *createRayTracingNVStageVar(spv::StorageClass sc,
+                                            const VarDecl *decl);
+
   /// \brief Creates a function-scope paramter in the current function and
   /// returns its instruction.
   SpirvFunctionParameter *createFnParam(const ParmVarDecl *param);
@@ -359,6 +362,11 @@ public:
   /// \brief Sets the entry function.
   void setEntryFunction(SpirvFunction *fn) { entryFunction = fn; }
 
+  /// Raytracing specific functions
+  /// \brief Handle specific implicit declarations present only in raytracing
+  /// stages.
+  void createRayTracingNVImplicitVar(const VarDecl *varDecl);
+
 private:
   /// The struct containing SPIR-V information of a AST Decl.
   struct DeclSpirvInfo {
@@ -630,7 +638,6 @@ private:
   inline bool isInputStorageClass(const StageVar &v);
 
 private:
-  const hlsl::ShaderModel &shaderModel;
   SpirvBuilder &spvBuilder;
   SpirvEmitter &theEmitter;
   const SpirvCodeGenOptions &spirvOptions;
@@ -665,14 +672,13 @@ private:
   /// to the SPIR-V type.
   llvm::DenseMap<const DeclContext *, const SpirvType *> ctBufferPCTypes;
 
-  /// The SPIR-V builtin variables accessed by WaveGetLaneCount() and
-  /// WaveGetLaneIndex().
+  /// The SPIR-V builtin variables accessed by WaveGetLaneCount(),
+  /// WaveGetLaneIndex() and ray tracing builtins.
   ///
-  /// These are the only two cases that SPIR-V builtin variables are accessed
+  /// These are the only few cases where SPIR-V builtin variables are accessed
   /// using HLSL intrinsic function calls. All other builtin variables are
   /// accessed using stage IO variables.
-  SpirvVariable *laneCountBuiltinVar;
-  SpirvVariable *laneIndexBuiltinVar;
+  llvm::DenseMap<uint32_t, SpirvVariable *> builtinToVarMap;
 
   /// Whether the translated SPIR-V binary needs legalization.
   ///
@@ -745,21 +751,23 @@ void CounterIdAliasPair::assign(const CounterIdAliasPair &srcPair,
   builder.createStore(counterVar, srcPair.get(builder, context));
 }
 
-DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
-                                       ASTContext &context,
+DeclResultIdMapper::DeclResultIdMapper(ASTContext &context,
                                        SpirvContext &spirvContext,
                                        SpirvBuilder &spirvBuilder,
                                        SpirvEmitter &emitter,
                                        FeatureManager &features,
                                        const SpirvCodeGenOptions &options)
-    : shaderModel(model), spvBuilder(spirvBuilder), theEmitter(emitter),
-      spirvOptions(options), astContext(context), spvContext(spirvContext),
+    : spvBuilder(spirvBuilder), theEmitter(emitter), spirvOptions(options),
+      astContext(context), spvContext(spirvContext),
       diags(context.getDiagnostics()), entryFunction(nullptr),
-      laneCountBuiltinVar(nullptr), laneIndexBuiltinVar(nullptr),
       needsLegalization(false),
-      glPerVertex(model, context, spirvContext, spirvBuilder) {}
+      glPerVertex(context, spirvContext, spirvBuilder) {}
 
 bool DeclResultIdMapper::decorateStageIOLocations() {
+  if (spvContext.isRay()) {
+    // No location assignment for any raytracing stage variables
+    return true;
+  }
   // Try both input and output even if input location assignment failed
   return finalizeStageIOLocations(true) & finalizeStageIOLocations(false);
 }

+ 20 - 0
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -968,6 +968,20 @@ bool EmitVisitor::visit(SpirvArrayLength *inst) {
   return true;
 }
 
+bool EmitVisitor::visit(SpirvRayTracingOpNV *inst) {
+  initInstruction(inst);
+  if (inst->hasResultType()) {
+    curInst.push_back(inst->getResultTypeId());
+    curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
+  }
+  for (const auto operand : inst->getOperands())
+    curInst.push_back(getOrAssignResultId<SpirvInstruction>(operand));
+  finalizeInstruction();
+  emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
+                              inst->getDebugName());
+  return true;
+}
+
 // EmitTypeHandler ------
 
 void EmitTypeHandler::initTypeInstruction(spv::Op op) {
@@ -1471,6 +1485,12 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
       curTypeInst.push_back(paramTypeId);
     finalizeTypeInstruction();
   }
+  // Acceleration Structure NV type
+  else if (const auto *accType = dyn_cast<AccelerationStructureTypeNV>(type)) {
+    initTypeInstruction(spv::Op::OpTypeAccelerationStructureNV);
+    curTypeInst.push_back(id);
+    finalizeTypeInstruction();
+  }
   // Hybrid Types
   // Note: The type lowering pass should lower all types to SpirvTypes.
   // Therefore, if we find a hybrid type when going through the emitting pass,

+ 1 - 0
tools/clang/lib/SPIRV/EmitVisitor.h

@@ -252,6 +252,7 @@ public:
   bool visit(SpirvUnaryOp *);
   bool visit(SpirvVectorShuffle *);
   bool visit(SpirvArrayLength *);
+  bool visit(SpirvRayTracingOpNV *);
 
   // Returns the assembled binary built up in this visitor.
   std::vector<uint32_t> takeBinary();

+ 4 - 2
tools/clang/lib/SPIRV/FeatureManager.cpp

@@ -113,8 +113,8 @@ Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
             Extension::AMD_shader_explicit_vertex_parameter)
       .Case("SPV_GOOGLE_hlsl_functionality1",
             Extension::GOOGLE_hlsl_functionality1)
-      .Case("SPV_KHR_post_depth_coverage",
-            Extension::KHR_post_depth_coverage)
+      .Case("SPV_KHR_post_depth_coverage", Extension::KHR_post_depth_coverage)
+      .Case("SPV_NV_ray_tracing", Extension::NV_ray_tracing)
       .Default(Extension::Unknown);
 }
 
@@ -146,6 +146,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) {
     return "SPV_AMD_shader_explicit_vertex_parameter";
   case Extension::GOOGLE_hlsl_functionality1:
     return "SPV_GOOGLE_hlsl_functionality1";
+  case Extension::NV_ray_tracing:
+    return "SPV_NV_ray_tracing";
   default:
     break;
   }

+ 9 - 9
tools/clang/lib/SPIRV/GlPerVertex.cpp

@@ -60,13 +60,13 @@ inline bool hasGSPrimitiveTypeQualifier(const DeclaratorDecl *decl) {
 }
 } // anonymous namespace
 
-GlPerVertex::GlPerVertex(const hlsl::ShaderModel &sm, ASTContext &context,
-                         SpirvContext &spirvContext, SpirvBuilder &spirvBuilder)
-    : shaderModel(sm), astContext(context), spvContext(spirvContext),
-      spvBuilder(spirvBuilder), inClipVar(nullptr), inCullVar(nullptr),
-      outClipVar(nullptr), outCullVar(nullptr), inArraySize(0), outArraySize(0),
-      inClipArraySize(1), outClipArraySize(1), inCullArraySize(1),
-      outCullArraySize(1), inSemanticStrs(2, ""), outSemanticStrs(2, "") {}
+GlPerVertex::GlPerVertex(ASTContext &context, SpirvContext &spirvContext,
+                         SpirvBuilder &spirvBuilder)
+    : astContext(context), spvContext(spirvContext), spvBuilder(spirvBuilder),
+      inClipVar(nullptr), inCullVar(nullptr), outClipVar(nullptr),
+      outCullVar(nullptr), inArraySize(0), outArraySize(0), inClipArraySize(1),
+      outClipArraySize(1), inCullArraySize(1), outCullArraySize(1),
+      inSemanticStrs(2, ""), outSemanticStrs(2, "") {}
 
 void GlPerVertex::generateVars(uint32_t inArrayLen, uint32_t outArrayLen) {
   inArraySize = inArrayLen;
@@ -220,7 +220,7 @@ bool GlPerVertex::doGlPerVertexFacts(const DeclaratorDecl *decl,
   }
 
   if (baseType->isConstantArrayType()) {
-    if (shaderModel.IsHS() || shaderModel.IsDS() || shaderModel.IsGS()) {
+    if (spvContext.isHS() || spvContext.isDS() || spvContext.isGS()) {
       // Ignore the outermost arrayness and check the inner type to be
       // (vector of) floats
 
@@ -555,7 +555,7 @@ void GlPerVertex::writeClipCullArrayFromType(
   }
 
   // Writing to an array only happens in HSCPOut.
-  assert(shaderModel.IsHS());
+  assert(spvContext.isHS());
   // And we are only writing to the array element with InvocationId as index.
   assert(invocationId.hasValue());
 

+ 2 - 4
tools/clang/lib/SPIRV/GlPerVertex.h

@@ -11,7 +11,6 @@
 #define LLVM_CLANG_LIB_SPIRV_GLPERVERTEX_H
 
 #include "dxc/DXIL/DxilSemantic.h"
-#include "dxc/DXIL/DxilShaderModel.h"
 #include "dxc/DXIL/DxilSigPoint.h"
 #include "clang/SPIRV/SpirvBuilder.h"
 #include "llvm/ADT/DenseMap.h"
@@ -42,8 +41,8 @@ namespace spirv {
 /// array for ClipDistance builtin.
 class GlPerVertex {
 public:
-  GlPerVertex(const hlsl::ShaderModel &sm, ASTContext &context,
-              SpirvContext &spvContext, SpirvBuilder &spvBuilder);
+  GlPerVertex(ASTContext &context, SpirvContext &spvContext,
+              SpirvBuilder &spvBuilder);
 
   /// Records a declaration of SV_ClipDistance/SV_CullDistance so later
   /// we can caculate the ClipDistance/CullDistance array layout.
@@ -128,7 +127,6 @@ private:
   using SemanticIndexToTypeMap = llvm::DenseMap<uint32_t, QualType>;
   using SemanticIndexToArrayOffsetMap = llvm::DenseMap<uint32_t, uint32_t>;
 
-  const hlsl::ShaderModel &shaderModel;
   ASTContext &astContext;
   SpirvContext &spvContext;
   SpirvBuilder &spvBuilder;

+ 4 - 0
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -491,6 +491,10 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
     return spvContext.getSamplerType();
   }
 
+  if (name == "RaytracingAccelerationStructure") {
+    return spvContext.getAccelerationStructureTypeNV();
+  }
+
   if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
       name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer") {
     // StructureBuffer<S> will be translated into an OpTypeStruct with one

+ 11 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -803,6 +803,17 @@ void SpirvBuilder::createLineInfo(SpirvString *file, uint32_t line,
   insertPoint->addInstruction(inst);
 }
 
+SpirvInstruction *
+SpirvBuilder::createRayTracingOpsNV(spv::Op opcode, QualType resultType,
+                                    ArrayRef<SpirvInstruction *> operands,
+                                    SourceLocation loc) {
+  assert(insertPoint && "null insert point");
+  auto *inst =
+      new (context) SpirvRayTracingOpNV(resultType, opcode, operands, loc);
+  insertPoint->addInstruction(inst);
+  return inst;
+}
+
 void SpirvBuilder::addExtension(Extension ext, llvm::StringRef target,
                                 SourceLocation loc) {
   // TODO: The extension management should be removed from here and added as a

+ 4 - 1
tools/clang/lib/SPIRV/SpirvContext.cpp

@@ -17,10 +17,13 @@ namespace spirv {
 
 SpirvContext::SpirvContext()
     : allocator(), voidType(nullptr), boolType(nullptr), sintTypes({}),
-      uintTypes({}), floatTypes({}), samplerType(nullptr) {
+      uintTypes({}), floatTypes({}), samplerType(nullptr),
+      curShaderModelKind(ShaderModelKind::Invalid), majorVersion(0),
+      minorVersion(0) {
   voidType = new (this) VoidType;
   boolType = new (this) BoolType;
   samplerType = new (this) SamplerType;
+  accelerationStructureTypeNV = new (this) AccelerationStructureTypeNV;
 }
 
 inline uint32_t log2ForBitwidth(uint32_t bitwidth) {

+ 647 - 57
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -17,6 +17,7 @@
 #include "dxc/HlslIntrinsicOp.h"
 #include "spirv-tools/optimizer.hpp"
 #include "clang/SPIRV/AstTypeProbe.h"
+#include "clang/Sema/Sema.h"
 #include "llvm/ADT/StringExtras.h"
 
 #include "InitListHandler.h"
@@ -478,27 +479,34 @@ SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
     : theCompilerInstance(ci), astContext(ci.getASTContext()),
       diags(ci.getDiagnostics()),
       spirvOptions(ci.getCodeGenOpts().SpirvOptions),
-      entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
-      shaderModel(*hlsl::ShaderModel::GetByName(
-          ci.getCodeGenOpts().HLSLProfile.c_str())),
-      spvContext(), featureManager(diags, spirvOptions),
+      entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction), spvContext(),
+      featureManager(diags, spirvOptions),
       spvBuilder(astContext, spvContext, &featureManager, spirvOptions),
-      declIdMapper(shaderModel, astContext, spvContext, spvBuilder, *this,
-                   featureManager, spirvOptions),
+      declIdMapper(astContext, spvContext, spvBuilder, *this, featureManager,
+                   spirvOptions),
       entryFunction(nullptr), curFunction(nullptr), curThis(nullptr),
       seenPushConstantAt(), isSpecConstantMode(false), needsLegalization(false),
       mainSourceFile(nullptr) {
-  if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
-    emitError("unknown shader module: %0", {}) << shaderModel.GetName();
 
-  if (spirvOptions.invertY && !shaderModel.IsVS() && !shaderModel.IsDS() &&
-      !shaderModel.IsGS())
+  // Get ShaderModel from command line hlsl profile option.
+  const hlsl::ShaderModel *shaderModel =
+      hlsl::ShaderModel::GetByName(ci.getCodeGenOpts().HLSLProfile.c_str());
+  if (shaderModel->GetKind() == hlsl::ShaderModel::Kind::Invalid)
+    emitError("unknown shader module: %0", {}) << shaderModel->GetName();
+
+  if (spirvOptions.invertY && !shaderModel->IsVS() && !shaderModel->IsDS() &&
+      !shaderModel->IsGS())
     emitError("-fvk-invert-y can only be used in VS/DS/GS", {});
 
   if (spirvOptions.useGlLayout && spirvOptions.useDxLayout)
     emitError("cannot specify both -fvk-use-dx-layout and -fvk-use-gl-layout",
               {});
 
+  // Set shader model kind and hlsl major/minor version.
+  spvContext.setCurrentShaderModelKind(shaderModel->GetKind());
+  spvContext.setMajorVersion(shaderModel->GetMajor());
+  spvContext.setMinorVersion(shaderModel->GetMinor());
+
   if (spirvOptions.useDxLayout) {
     spirvOptions.cBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
     spirvOptions.tBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
@@ -533,8 +541,9 @@ SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
         sm.getBuffer(sm.getMainFileID(), SourceLocation());
     source = StringRef(mainFile->getBufferStart(), mainFile->getBufferSize());
   }
-  mainSourceFile = spvBuilder.setDebugSource(
-      shaderModel.GetMajor(), shaderModel.GetMinor(), fileName, source);
+  mainSourceFile =
+      spvBuilder.setDebugSource(spvContext.getMajorVersion(),
+                                spvContext.getMinorVersion(), fileName, source);
 
   if (spirvOptions.debugInfoTool && spirvOptions.targetEnv == "vulkan1.1") {
     // Emit OpModuleProcessed to indicate the commit information.
@@ -558,12 +567,25 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
     return;
 
   TranslationUnitDecl *tu = context.getTranslationUnitDecl();
+  uint32_t numEntryPoints = 0;
 
   // The entry function is the seed of the queue.
   for (auto *decl : tu->decls()) {
     if (auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
-      if (funcDecl->getName() == entryFunctionName) {
-        workQueue.insert(funcDecl);
+      if (spvContext.isLib()) {
+        if (const auto *shaderAttr = funcDecl->getAttr<HLSLShaderAttr>()) {
+          // If we are compiling as a library then add everything that has a
+          // ShaderAttr.
+          addFunctionToWorkQueue(getShaderModelKind(shaderAttr->getStage()),
+                                 funcDecl, /*isEntryFunction*/ true);
+          numEntryPoints++;
+        }
+      } else {
+        if (funcDecl->getName() == entryFunctionName) {
+          addFunctionToWorkQueue(spvContext.getCurrentShaderModelKind(),
+                                 funcDecl, /*isEntryFunction*/ true);
+          numEntryPoints++;
+        }
       }
     } else {
       doDecl(decl);
@@ -574,7 +596,9 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
   // The queue can grow in the meanwhile; so need to keep evaluating
   // workQueue.size().
   for (uint32_t i = 0; i < workQueue.size(); ++i) {
-    doDecl(workQueue[i]);
+    const FunctionInfo *curEntryOrCallee = workQueue[i];
+    spvContext.setCurrentShaderModelKind(curEntryOrCallee->shaderModelKind);
+    doDecl(curEntryOrCallee->funcDecl);
   }
 
   if (context.getDiagnostics().hasErrorOccurred())
@@ -586,8 +610,20 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
   spvBuilder.setMemoryModel(spv::AddressingModel::Logical,
                             spv::MemoryModel::GLSL450);
 
-  spvBuilder.addEntryPoint(getSpirvShaderStage(shaderModel), entryFunction,
-                           entryFunctionName, declIdMapper.collectStageVars());
+  // Even though the 'workQueue' grows due to the above loop, the first
+  // 'numEntryPoints' entries in the 'workQueue' are the ones with the HLSL
+  // 'shader' attribute, and must therefore be entry functions.
+  assert(numEntryPoints <= workQueue.size());
+
+  for (uint32_t i = 0; i < numEntryPoints; ++i) {
+    // TODO: assign specific StageVars w.r.t. to entry point
+    const FunctionInfo *entryInfo = workQueue[i];
+    assert(entryInfo->isEntryFunction);
+    spvBuilder.addEntryPoint(getSpirvShaderStage(entryInfo->shaderModelKind),
+                             entryInfo->entryFunction,
+                             entryInfo->funcDecl->getName(),
+                             declIdMapper.collectStageVars());
+  }
 
   // Add Location decorations to stage input/output variables.
   if (!declIdMapper.decorateStageIOLocations())
@@ -651,8 +687,13 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
 }
 
 void SpirvEmitter::doDecl(const Decl *decl) {
-  if (decl->isImplicit() || isa<EmptyDecl>(decl) || isa<TypedefDecl>(decl))
+  if (isa<EmptyDecl>(decl) || isa<TypedefDecl>(decl))
+    return;
+
+  if (decl->isImplicit()) {
+    doImplicitDecl(decl);
     return;
+  }
 
   if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
     // We can have VarDecls inside cbuffer/tbuffer. For those VarDecls, we need
@@ -953,11 +994,15 @@ void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) {
 
   SpirvFunction *func = declIdMapper.getOrRegisterFn(decl);
 
-  if (funcName == entryFunctionName) {
-    funcName = "src." + funcName;
-    // Create wrapper for the entry function
-    if (!emitEntryFunctionWrapper(decl, func))
-      return;
+  const auto iter = functionInfoMap.find(decl);
+  if (iter != functionInfoMap.end()) {
+    const auto &entryInfo = iter->second;
+    if (entryInfo->isEntryFunction) {
+      funcName = "src." + funcName;
+      // Create wrapper for the entry function
+      if (!emitEntryFunctionWrapper(decl, func))
+        return;
+    }
   }
 
   const QualType retType =
@@ -1043,7 +1088,7 @@ bool SpirvEmitter::validateVKAttributes(const NamedDecl *decl) {
   }
 
   if (decl->getAttr<VKInputAttachmentIndexAttr>()) {
-    if (!shaderModel.IsPS()) {
+    if (!spvContext.isPS()) {
       emitError("SubpassInput(MS) only allowed in pixel shader",
                 decl->getLocation());
       success = false;
@@ -1135,6 +1180,19 @@ void SpirvEmitter::doHLSLBufferDecl(const HLSLBufferDecl *bufferDecl) {
   (void)declIdMapper.createCTBuffer(bufferDecl);
 }
 
+void SpirvEmitter::doImplicitDecl(const Decl *decl) {
+  // We only handle specific implicit declaration for raytracing
+  // which are RayFlag/HitKind constant unsigned integers
+  // Ignore others
+  if (spvContext.isLib() || spvContext.isRay()) {
+    const VarDecl *implDecl = dyn_cast<VarDecl>(decl);
+    if (implDecl && (implDecl->getName().startswith(StringRef("RAY_FLAG")) ||
+                     implDecl->getName().startswith(StringRef("HIT_KIND")))) {
+      (void)declIdMapper.createRayTracingNVImplicitVar(implDecl);
+    }
+  }
+}
+
 void SpirvEmitter::doRecordDecl(const RecordDecl *recordDecl) {
   // Ignore implict records
   // Somehow we'll have implicit records with:
@@ -2000,9 +2058,8 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
   assert(vars.size() == args.size());
 
   // Push the callee into the work queue if it is not there.
-  if (!workQueue.count(callee)) {
-    workQueue.insert(callee);
-  }
+  addFunctionToWorkQueue(spvContext.getCurrentShaderModelKind(), callee,
+                         /*isEntryFunction*/ false);
 
   const QualType retType =
       declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(callee);
@@ -3781,7 +3838,7 @@ SpirvInstruction *SpirvEmitter::createImageSample(
   const bool isExplicit = lod || (grad.first && grad.second);
 
   // Implicit-lod instructions are only allowed in pixel shader.
-  if (!shaderModel.IsPS() && !isExplicit)
+  if (!spvContext.isPS() && !isExplicit)
     needsLegalization = true;
 
   auto *retVal = spvBuilder.createImageSample(
@@ -6360,16 +6417,18 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
     featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "WaveGetLaneCount",
                                     callExpr->getExprLoc());
     const QualType retType = callExpr->getCallReturnType(astContext);
-    auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupSize,
+    auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupSize, retType,
                                            callExpr->getExprLoc());
+
     retVal = spvBuilder.createLoad(retType, var);
   } break;
   case hlsl::IntrinsicOp::IOP_WaveGetLaneIndex: {
     featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "WaveGetLaneIndex",
                                     callExpr->getExprLoc());
     const QualType retType = callExpr->getCallReturnType(astContext);
-    auto *var = declIdMapper.getBuiltinVar(
-        spv::BuiltIn::SubgroupLocalInvocationId, callExpr->getExprLoc());
+    auto *var =
+        declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupLocalInvocationId,
+                                   retType, callExpr->getExprLoc());
     retVal = spvBuilder.createLoad(retType, var);
   } break;
   case hlsl::IntrinsicOp::IOP_WaveIsFirstLane:
@@ -6446,6 +6505,49 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
       retVal = processNonFpMatrixTranspose(matType, doExpr(mat));
 
     break;
+  }
+  // DXR raytracing intrinsics
+  case hlsl::IntrinsicOp::IOP_DispatchRaysDimensions:
+  case hlsl::IntrinsicOp::IOP_DispatchRaysIndex:
+  case hlsl::IntrinsicOp::IOP_HitKind:
+  case hlsl::IntrinsicOp::IOP_InstanceIndex:
+  case hlsl::IntrinsicOp::IOP_InstanceID:
+  case hlsl::IntrinsicOp::IOP_ObjectRayDirection:
+  case hlsl::IntrinsicOp::IOP_ObjectRayOrigin:
+  case hlsl::IntrinsicOp::IOP_ObjectToWorld3x4:
+  case hlsl::IntrinsicOp::IOP_ObjectToWorld4x3:
+  case hlsl::IntrinsicOp::IOP_PrimitiveIndex:
+  case hlsl::IntrinsicOp::IOP_RayFlags:
+  case hlsl::IntrinsicOp::IOP_RayTCurrent:
+  case hlsl::IntrinsicOp::IOP_RayTMin:
+  case hlsl::IntrinsicOp::IOP_WorldRayDirection:
+  case hlsl::IntrinsicOp::IOP_WorldRayOrigin:
+  case hlsl::IntrinsicOp::IOP_WorldToObject3x4:
+  case hlsl::IntrinsicOp::IOP_WorldToObject4x3: {
+    retVal = processRayBuiltins(callExpr, hlslOpcode);
+    break;
+  }
+  case hlsl::IntrinsicOp::IOP_AcceptHitAndEndSearch: {
+    spvBuilder.createRayTracingOpsNV(spv::Op::OpTerminateRayNV, QualType(), {},
+                                     callExpr->getExprLoc());
+    break;
+  }
+  case hlsl::IntrinsicOp::IOP_IgnoreHit: {
+    spvBuilder.createRayTracingOpsNV(spv::Op::OpIgnoreIntersectionNV,
+                                     QualType(), {}, callExpr->getExprLoc());
+    break;
+  }
+  case hlsl::IntrinsicOp::IOP_ReportHit: {
+    retVal = processReportHit(callExpr);
+    break;
+  }
+  case hlsl::IntrinsicOp::IOP_TraceRay: {
+    processTraceRay(callExpr);
+    break;
+  }
+  case hlsl::IntrinsicOp::IOP_CallShader: {
+    processCallShader(callExpr);
+    break;
   }
     INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
     INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
@@ -8416,7 +8518,7 @@ SpirvEmitter::processIntrinsicF32ToF16(const CallExpr *callExpr) {
 SpirvInstruction *SpirvEmitter::processIntrinsicUsingSpirvInst(
     const CallExpr *callExpr, spv::Op opcode, bool actPerRowForMatrices) {
   // Certain opcodes are only allowed in pixel shader
-  if (!shaderModel.IsPS())
+  if (!spvContext.isPS())
     switch (opcode) {
     case spv::Op::OpDPdx:
     case spv::Op::OpDPdy:
@@ -8564,6 +8666,309 @@ SpirvEmitter::processIntrinsicLog10(const CallExpr *callExpr) {
   return spvBuilder.createBinaryOp(scaleOp, returnType, log2, scale);
 }
 
+SpirvInstruction *SpirvEmitter::processRayBuiltins(const CallExpr *callExpr,
+                                                   hlsl::IntrinsicOp op) {
+  spv::BuiltIn builtin = spv::BuiltIn::Max;
+  bool transposeMatrix = false;
+  switch (op) {
+  case hlsl::IntrinsicOp::IOP_DispatchRaysDimensions:
+    builtin = spv::BuiltIn::LaunchSizeNV;
+    break;
+  case hlsl::IntrinsicOp::IOP_DispatchRaysIndex:
+    builtin = spv::BuiltIn::LaunchIdNV;
+    break;
+  case hlsl::IntrinsicOp::IOP_RayTCurrent:
+    builtin = spv::BuiltIn::HitTNV;
+    break;
+  case hlsl::IntrinsicOp::IOP_RayTMin:
+    builtin = spv::BuiltIn::RayTminNV;
+    break;
+  case hlsl::IntrinsicOp::IOP_HitKind:
+    builtin = spv::BuiltIn::HitKindNV;
+    break;
+  case hlsl::IntrinsicOp::IOP_WorldRayDirection:
+    builtin = spv::BuiltIn::WorldRayDirectionNV;
+    break;
+  case hlsl::IntrinsicOp::IOP_WorldRayOrigin:
+    builtin = spv::BuiltIn::WorldRayOriginNV;
+    break;
+  case hlsl::IntrinsicOp::IOP_ObjectRayDirection:
+    builtin = spv::BuiltIn::ObjectRayDirectionNV;
+    break;
+  case hlsl::IntrinsicOp::IOP_ObjectRayOrigin:
+    builtin = spv::BuiltIn::ObjectRayOriginNV;
+    break;
+  case hlsl::IntrinsicOp::IOP_InstanceIndex:
+    builtin = spv::BuiltIn::InstanceId;
+    break;
+  case hlsl::IntrinsicOp::IOP_PrimitiveIndex:
+    builtin = spv::BuiltIn::PrimitiveId;
+    break;
+  case hlsl::IntrinsicOp::IOP_InstanceID:
+    builtin = spv::BuiltIn::InstanceCustomIndexNV;
+    break;
+  case hlsl::IntrinsicOp::IOP_RayFlags:
+    builtin = spv::BuiltIn::IncomingRayFlagsNV;
+    break;
+  case hlsl::IntrinsicOp::IOP_ObjectToWorld3x4:
+    transposeMatrix = true;
+  case hlsl::IntrinsicOp::IOP_ObjectToWorld4x3:
+    builtin = spv::BuiltIn::ObjectToWorldNV;
+    break;
+  case hlsl::IntrinsicOp::IOP_WorldToObject3x4:
+    transposeMatrix = true;
+  case hlsl::IntrinsicOp::IOP_WorldToObject4x3:
+    builtin = spv::BuiltIn::WorldToObjectNV;
+    break;
+  default:
+    emitError("ray intrinsic function unimplemented", callExpr->getExprLoc());
+    return nullptr;
+  }
+
+  QualType builtinType = callExpr->getType();
+  if (transposeMatrix) {
+    // DXR defines ObjectToWorld3x4, WorldToObject3x4 as transposed matrices.
+    // SPIR-V has only non tranposed variant defined as a builtin
+    // So perform read of original non transposed builtin and perform transpose.
+    assert(hlsl::IsHLSLMatType(builtinType) && "Builtin should be matrix");
+    const clang::Type *type = builtinType.getCanonicalType().getTypePtr();
+    const RecordType *RT = cast<RecordType>(type);
+    const ClassTemplateSpecializationDecl *templateSpecDecl =
+        cast<ClassTemplateSpecializationDecl>(RT->getDecl());
+    ClassTemplateDecl *templateDecl =
+        templateSpecDecl->getSpecializedTemplate();
+    builtinType = getHLSLMatrixType(astContext, theCompilerInstance.getSema(),
+                                    templateDecl, astContext.FloatTy, 4, 3);
+  }
+  SpirvInstruction *retVal =
+      declIdMapper.getBuiltinVar(builtin, builtinType, callExpr->getExprLoc());
+  retVal = spvBuilder.createLoad(builtinType, retVal);
+  if (transposeMatrix)
+    retVal = spvBuilder.createUnaryOp(spv::Op::OpTranspose, callExpr->getType(),
+                                      retVal);
+  return retVal;
+}
+
+SpirvInstruction *SpirvEmitter::processReportHit(const CallExpr *callExpr) {
+  SpirvInstruction *hitAttributeStageVar = nullptr;
+  const VarDecl *hitAttributeArg = nullptr;
+  QualType hitAttributeType;
+  const auto args = callExpr->getArgs();
+
+  if (callExpr->getNumArgs() != 3) {
+    emitError("invalid number of arguments to ReportHit",
+              callExpr->getExprLoc());
+  }
+
+  // HLSL Function :
+  // template<typename hitAttr>
+  // ReportHit(in float, in uint, in hitAttr)
+  if (const auto *implCastExpr = dyn_cast<CastExpr>(callExpr->getArg(2))) {
+    if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
+      if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
+        hitAttributeType = varDecl->getType();
+        hitAttributeArg = varDecl;
+        // Check if same type of hit attribute stage variable was already
+        // created, if so re-use
+        const auto iter = hitAttributeMap.find(hitAttributeType);
+        if (iter == hitAttributeMap.end()) {
+          hitAttributeStageVar = declIdMapper.createRayTracingNVStageVar(
+              spv::StorageClass::HitAttributeNV, varDecl);
+          hitAttributeMap[hitAttributeType] = hitAttributeStageVar;
+        } else {
+          hitAttributeStageVar = iter->second;
+        }
+      }
+    }
+  }
+
+  assert(hitAttributeStageVar && hitAttributeArg);
+
+  // Copy argument to stage variable
+  const auto hitAttributeArgInst =
+      declIdMapper.getDeclEvalInfo(hitAttributeArg);
+  auto tempLoad =
+      spvBuilder.createLoad(hitAttributeArg->getType(), hitAttributeArgInst);
+  spvBuilder.createStore(hitAttributeStageVar, tempLoad);
+
+  // SPIR-V Instruction :
+  // bool OpReportIntersection(<id> float Hit, <id> uint HitKind)
+  llvm::SmallVector<SpirvInstruction *, 4> reportHitArgs;
+  reportHitArgs.push_back(doExpr(args[0])); // Hit
+  reportHitArgs.push_back(doExpr(args[1])); // HitKind
+  return spvBuilder.createRayTracingOpsNV(spv::Op::OpReportIntersectionNV,
+                                          astContext.BoolTy, reportHitArgs,
+                                          callExpr->getExprLoc());
+}
+void SpirvEmitter::processCallShader(const CallExpr *callExpr) {
+  SpirvInstruction *callDataLocInst = nullptr;
+  SpirvInstruction *callDataStageVar = nullptr;
+  const VarDecl *callDataArg = nullptr;
+  QualType callDataType;
+  const auto args = callExpr->getArgs();
+
+  if (callExpr->getNumArgs() != 2) {
+    emitError("invalid number of arguments to CallShader",
+              callExpr->getExprLoc());
+  }
+
+  // HLSL Func :
+  // template<typename CallData>
+  // void CallShader(in int sbtIndex, inout CallData arg)
+  if (const auto *implCastExpr = dyn_cast<CastExpr>(args[1])) {
+    if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
+      if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
+        callDataType = varDecl->getType();
+        callDataArg = varDecl;
+        // Check if same type of callable data stage variable was already
+        // created, if so re-use
+        const auto callDataPair = callDataMap.find(callDataType);
+        if (callDataPair == callDataMap.end()) {
+          int numCallDataVars = callDataMap.size();
+          callDataStageVar = declIdMapper.createRayTracingNVStageVar(
+              spv::StorageClass::CallableDataNV, varDecl);
+          // Decorate unique location id for each created stage var
+          spvBuilder.decorateLocation(callDataStageVar, numCallDataVars);
+          callDataLocInst = spvBuilder.getConstantInt(
+              astContext.UnsignedIntTy, llvm::APInt(32, numCallDataVars));
+          callDataMap[callDataType] =
+              std::make_pair(callDataStageVar, callDataLocInst);
+        } else {
+          callDataStageVar = callDataPair->second.first;
+          callDataLocInst = callDataPair->second.second;
+        }
+      }
+    }
+  }
+
+  assert(callDataStageVar && callDataArg);
+
+  // Copy argument to stage variable
+  const auto callDataArgInst = declIdMapper.getDeclEvalInfo(callDataArg);
+  auto tempLoad =
+      spvBuilder.createLoad(callDataArg->getType(), callDataArgInst);
+  spvBuilder.createStore(callDataStageVar, tempLoad);
+
+  // SPIR-V Instruction
+  // void OpExecuteCallable(<id> int SBT Index, <id> uint Callable Data Location
+  // Id)
+  llvm::SmallVector<SpirvInstruction *, 2> callShaderArgs;
+  callShaderArgs.push_back(doExpr(args[0]));
+  callShaderArgs.push_back(callDataLocInst);
+
+  spvBuilder.createRayTracingOpsNV(spv::Op::OpExecuteCallableNV, QualType(),
+                                   callShaderArgs, callExpr->getExprLoc());
+
+  // Copy data back to argument
+  tempLoad = spvBuilder.createLoad(callDataArg->getType(), callDataStageVar);
+  spvBuilder.createStore(callDataArgInst, tempLoad);
+  return;
+}
+void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
+  SpirvInstruction *payloadLocInst = nullptr;
+  SpirvInstruction *payloadStageVar = nullptr;
+  const VarDecl *payloadArg = nullptr;
+  QualType payloadType;
+
+  const auto args = callExpr->getArgs();
+
+  if (callExpr->getNumArgs() != 8) {
+    emitError("invalid number of arguments to TraceRay",
+              callExpr->getExprLoc());
+  }
+
+  // HLSL Func
+  // template<typename Payload>
+  // void TraceRay(RaytracingAccelerationStructure rs,
+  //              uint rayflags,
+  //              uint InstanceInclusionMask
+  //              uint RayContributionToHitGroupIndex,
+  //              uint MultiplierForGeometryContributionToHitGroupIndex,
+  //              uint MissShaderIndex,
+  //              RayDesc ray,
+  //              inout Payload p)
+  // where RayDesc = {float3 origin, float tMin, float3 direction, float tMax}
+
+  if (const auto *implCastExpr = dyn_cast<CastExpr>(args[7])) {
+    if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
+      if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
+        payloadType = varDecl->getType();
+        payloadArg = varDecl;
+        const auto payloadPair = payloadMap.find(payloadType);
+        // Check if same type of payload stage variable was already
+        // created, if so re-use
+        if (payloadPair == payloadMap.end()) {
+          int numPayloadVars = payloadMap.size();
+          payloadStageVar = declIdMapper.createRayTracingNVStageVar(
+              spv::StorageClass::RayPayloadNV, varDecl);
+          // Decorate unique location id for each created stage var
+          spvBuilder.decorateLocation(payloadStageVar, numPayloadVars);
+          payloadLocInst = spvBuilder.getConstantInt(
+              astContext.UnsignedIntTy, llvm::APInt(32, numPayloadVars));
+          payloadMap[payloadType] =
+              std::make_pair(payloadStageVar, payloadLocInst);
+        } else {
+          payloadStageVar = payloadPair->second.first;
+          payloadLocInst = payloadPair->second.second;
+        }
+      }
+    }
+  }
+
+  assert(payloadStageVar && payloadArg);
+
+  const auto floatType = astContext.FloatTy;
+  const auto vecType = astContext.getExtVectorType(astContext.FloatTy, 3);
+
+  // Extract the ray description to match SPIR-V
+  SpirvInstruction *rayDescArg = doExpr(args[6]);
+  const auto origin =
+      spvBuilder.createCompositeExtract(vecType, rayDescArg, {0});
+  const auto tMin =
+      spvBuilder.createCompositeExtract(floatType, rayDescArg, {1});
+  const auto direction =
+      spvBuilder.createCompositeExtract(vecType, rayDescArg, {2});
+  const auto tMax =
+      spvBuilder.createCompositeExtract(floatType, rayDescArg, {3});
+
+  // Copy argument to stage variable
+  const auto payloadArgInst = declIdMapper.getDeclEvalInfo(payloadArg);
+  auto tempLoad = spvBuilder.createLoad(payloadArg->getType(), payloadArgInst);
+  spvBuilder.createStore(payloadStageVar, tempLoad);
+
+  // SPIR-V Instruction
+  // void OpTraceNV ( <id> AccelerationStructureNV acStruct,
+  //                 <id> uint Ray Flags,
+  //                 <id> uint Cull Mask,
+  //                 <id> uint SBT Offset,
+  //                 <id> uint SBT Stride,
+  //                 <id> uint Miss Index,
+  //                 <id> vec4 Ray Origin,
+  //                 <id> float Ray Tmin,
+  //                 <id> vec3 Ray Direction,
+  //                 <id> float Ray Tmax,
+  //                 <id> uint Payload number)
+
+  llvm::SmallVector<SpirvInstruction *, 8> traceArgs;
+  for (int ii = 0; ii < 6; ii++) {
+    traceArgs.push_back(doExpr(args[ii]));
+  }
+
+  traceArgs.push_back(origin);
+  traceArgs.push_back(tMin);
+  traceArgs.push_back(direction);
+  traceArgs.push_back(tMax);
+  traceArgs.push_back(payloadLocInst);
+
+  spvBuilder.createRayTracingOpsNV(spv::Op::OpTraceNV, QualType(), traceArgs,
+                                   callExpr->getExprLoc());
+
+  // Copy arguments back to stage variable
+  tempLoad = spvBuilder.createLoad(payloadArg->getType(), payloadStageVar);
+  spvBuilder.createStore(payloadArgInst, tempLoad);
+  return;
+}
+
 SpirvConstant *SpirvEmitter::getValueZero(QualType type) {
   {
     QualType scalarType = {};
@@ -8817,17 +9222,65 @@ SpirvConstant *SpirvEmitter::tryToEvaluateAsConst(const Expr *expr) {
   return nullptr;
 }
 
+hlsl::ShaderModel::Kind SpirvEmitter::getShaderModelKind(StringRef stageName) {
+  hlsl::ShaderModel::Kind smk;
+  switch (stageName[0]) {
+  case 'c':
+    switch (stageName[1]) {
+    case 'o':
+      smk = hlsl::ShaderModel::Kind::Compute;
+      break;
+    case 'l':
+      smk = hlsl::ShaderModel::Kind::ClosestHit;
+      break;
+    case 'a':
+      smk = hlsl::ShaderModel::Kind::Callable;
+      break;
+    default:
+      smk = hlsl::ShaderModel::Kind::Invalid;
+      break;
+    }
+    break;
+  case 'v':
+    smk = hlsl::ShaderModel::Kind::Vertex;
+    break;
+  case 'h':
+    smk = hlsl::ShaderModel::Kind::Hull;
+    break;
+  case 'd':
+    smk = hlsl::ShaderModel::Kind::Domain;
+    break;
+  case 'g':
+    smk = hlsl::ShaderModel::Kind::Geometry;
+    break;
+  case 'p':
+    smk = hlsl::ShaderModel::Kind::Pixel;
+    break;
+  case 'r':
+    smk = hlsl::ShaderModel::Kind::RayGeneration;
+    break;
+  case 'i':
+    smk = hlsl::ShaderModel::Kind::Intersection;
+    break;
+  case 'a':
+    smk = hlsl::ShaderModel::Kind::AnyHit;
+    break;
+  case 'm':
+    smk = hlsl::ShaderModel::Kind::Miss;
+    break;
+  default:
+    smk = hlsl::ShaderModel::Kind::Invalid;
+    break;
+  }
+  if (smk == hlsl::ShaderModel::Kind::Invalid) {
+    llvm_unreachable("unknown stage name");
+  }
+  return smk;
+}
+
 spv::ExecutionModel
-SpirvEmitter::getSpirvShaderStage(const hlsl::ShaderModel &model) {
-  // DXIL Models are:
-  // Profile (DXIL Model) : HLSL Shader Kind : SPIR-V Shader Stage
-  // vs_<version>         : Vertex Shader    : Vertex Shader
-  // hs_<version>         : Hull Shader      : Tassellation Control Shader
-  // ds_<version>         : Domain Shader    : Tessellation Evaluation Shader
-  // gs_<version>         : Geometry Shader  : Geometry Shader
-  // ps_<version>         : Pixel Shader     : Fragment Shader
-  // cs_<version>         : Compute Shader   : Compute Shader
-  switch (model.GetKind()) {
+SpirvEmitter::getSpirvShaderStage(hlsl::ShaderModel::Kind smk) {
+  switch (smk) {
   case hlsl::ShaderModel::Kind::Vertex:
     return spv::ExecutionModel::Vertex;
   case hlsl::ShaderModel::Kind::Hull:
@@ -8840,16 +9293,28 @@ SpirvEmitter::getSpirvShaderStage(const hlsl::ShaderModel &model) {
     return spv::ExecutionModel::Fragment;
   case hlsl::ShaderModel::Kind::Compute:
     return spv::ExecutionModel::GLCompute;
+  case hlsl::ShaderModel::Kind::RayGeneration:
+    return spv::ExecutionModel::RayGenerationNV;
+  case hlsl::ShaderModel::Kind::Intersection:
+    return spv::ExecutionModel::IntersectionNV;
+  case hlsl::ShaderModel::Kind::AnyHit:
+    return spv::ExecutionModel::AnyHitNV;
+  case hlsl::ShaderModel::Kind::ClosestHit:
+    return spv::ExecutionModel::ClosestHitNV;
+  case hlsl::ShaderModel::Kind::Miss:
+    return spv::ExecutionModel::MissNV;
+  case hlsl::ShaderModel::Kind::Callable:
+    return spv::ExecutionModel::CallableNV;
   default:
+    llvm_unreachable("invalid shader model kind");
     break;
   }
-  llvm_unreachable("unknown shader model");
 }
 
 bool SpirvEmitter::processGeometryShaderAttributes(const FunctionDecl *decl,
                                                    uint32_t *arraySize) {
   bool success = true;
-  assert(shaderModel.IsGS());
+  assert(spvContext.isGS());
   if (auto *vcAttr = decl->getAttr<HLSLMaxVertexCountAttr>()) {
     spvBuilder.addExecutionMode(
         entryFunction, spv::ExecutionMode::OutputVertices,
@@ -8975,7 +9440,7 @@ void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
 
 bool SpirvEmitter::processTessellationShaderAttributes(
     const FunctionDecl *decl, uint32_t *numOutputControlPoints) {
-  assert(shaderModel.IsHS() || shaderModel.IsDS());
+  assert(spvContext.isHS() || spvContext.isDS());
   using namespace spv;
 
   if (auto *domain = decl->getAttr<HLSLDomainAttr>()) {
@@ -8997,7 +9462,7 @@ bool SpirvEmitter::processTessellationShaderAttributes(
 
   // Early return for domain shaders as domain shaders only takes the 'domain'
   // attribute.
-  if (shaderModel.IsDS())
+  if (spvContext.isDS())
     return true;
 
   if (auto *partitioning = decl->getAttr<HLSLPartitioningAttr>()) {
@@ -9061,6 +9526,105 @@ bool SpirvEmitter::processTessellationShaderAttributes(
   return true;
 }
 
+bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
+    const FunctionDecl *decl, SpirvFunction *entryFuncInstr) {
+  // The entry basic block.
+  auto *entryLabel = spvBuilder.createBasicBlock();
+  spvBuilder.setInsertPoint(entryLabel);
+
+  // Initialize all global variables at the beginning of the wrapper
+  for (const VarDecl *varDecl : toInitGloalVars) {
+    const auto varInfo = declIdMapper.getDeclEvalInfo(varDecl);
+    if (const auto *init = varDecl->getInit()) {
+      storeValue(varInfo, doExpr(init), varDecl->getType());
+
+      // Update counter variable associated with global variables
+      tryToAssignCounterVar(varDecl, init);
+    }
+    // If not explicitly initialized, initialize with their zero values if not
+    // resource objects
+    else if (!hlsl::IsHLSLResourceType(varDecl->getType())) {
+      auto *nullValue = spvBuilder.getConstantNull(varDecl->getType());
+      spvBuilder.createStore(varInfo, nullValue);
+    }
+  }
+
+  // Create temporary variables for holding function call arguments
+  llvm::SmallVector<SpirvInstruction *, 4> params;
+  llvm::SmallVector<QualType, 4> paramTypes;
+  llvm::SmallVector<SpirvInstruction *, 4> stageVars;
+  hlsl::ShaderModel::Kind sKind = spvContext.getCurrentShaderModelKind();
+  for (uint32_t i = 0; i < decl->getNumParams(); i++) {
+    const auto param = decl->getParamDecl(i);
+    const auto paramType = param->getType();
+    std::string tempVarName = "param.var." + param->getNameAsString();
+    auto *tempVar =
+        spvBuilder.addFnVar(paramType, param->getLocation(), tempVarName);
+
+    SpirvVariable *curStageVar = nullptr;
+
+    params.push_back(tempVar);
+    paramTypes.push_back(paramType);
+
+    // Order of arguments is fixed
+    // Any-Hit/Closest-Hit : Arg 0 = payload(inout), Arg1 = attribute(in)
+    // Miss : Arg 0 = payload(inout)
+    // Callable : Arg 0 = callable data(inout)
+    // Raygeneration/Intersection : No Args allowed
+    if (sKind == hlsl::ShaderModel::Kind::RayGeneration) {
+      assert("Raygeneration shaders have no arguments of entry function");
+    } else if (sKind == hlsl::ShaderModel::Kind::Intersection) {
+      assert("Intersection shaders have no arguments of entry function");
+    } else if (sKind == hlsl::ShaderModel::Kind::ClosestHit ||
+               sKind == hlsl::ShaderModel::Kind::AnyHit) {
+      // Generate rayPayloadInNV and hitAttributeNV stage variables
+      if (i == 0) {
+        // First argument is always payload
+        curStageVar = declIdMapper.createRayTracingNVStageVar(
+            spv::StorageClass::IncomingRayPayloadNV, param);
+      } else {
+        // Second argument is always attribute
+        curStageVar = declIdMapper.createRayTracingNVStageVar(
+            spv::StorageClass::HitAttributeNV, param);
+      }
+    } else if (sKind == hlsl::ShaderModel::Kind::Miss) {
+      // Generate rayPayloadInNV stage variable
+      // First and only argument is payload
+      curStageVar = declIdMapper.createRayTracingNVStageVar(
+          spv::StorageClass::IncomingRayPayloadNV, param);
+    } else if (sKind == hlsl::ShaderModel::Kind::Callable) {
+      curStageVar = declIdMapper.createRayTracingNVStageVar(
+          spv::StorageClass::IncomingCallableDataNV, param);
+    }
+
+    if (curStageVar != nullptr) {
+      stageVars.push_back(curStageVar);
+      // Copy data to temporary
+      auto *tempLoadInst = spvBuilder.createLoad(paramType, curStageVar);
+      spvBuilder.createStore(tempVar, tempLoadInst);
+    }
+  }
+
+  // Call the original entry function
+  const QualType retType = decl->getReturnType();
+  spvBuilder.createFunctionCall(retType, entryFuncInstr, params);
+
+  // Write certain output variables back
+  if (sKind == hlsl::ShaderModel::Kind::ClosestHit ||
+      sKind == hlsl::ShaderModel::Kind::AnyHit ||
+      sKind == hlsl::ShaderModel::Kind::Miss ||
+      sKind == hlsl::ShaderModel::Kind::Callable) {
+    // Write back results to IncomingRayPayloadNV/IncomingCallableDataNV
+    auto *tempLoad = spvBuilder.createLoad(paramTypes[0], params[0]);
+    spvBuilder.createStore(stageVars[0], tempLoad);
+  }
+
+  spvBuilder.createReturn();
+  spvBuilder.endFunction();
+
+  return true;
+}
+
 bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
                                             SpirvFunction *entryFuncInstr) {
   // HS specific attributes
@@ -9089,12 +9653,22 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   // Note this should happen before using declIdMapper for other tasks.
   declIdMapper.setEntryFunction(entryFunction);
 
+  // Set entryFunction for current entry point.
+  auto iter = functionInfoMap.find(decl);
+  assert(iter != functionInfoMap.end());
+  auto &entryInfo = iter->second;
+  assert(entryInfo->isEntryFunction);
+  entryInfo->entryFunction = entryFunction;
+
+  if (spvContext.isRay()) {
+    return emitEntryFunctionWrapperForRayTracing(decl, entryFuncInstr);
+  }
   // Handle attributes specific to each shader stage
-  if (shaderModel.IsPS()) {
+  if (spvContext.isPS()) {
     processPixelShaderAttributes(decl);
-  } else if (shaderModel.IsCS()) {
+  } else if (spvContext.isCS()) {
     processComputeShaderAttributes(decl);
-  } else if (shaderModel.IsHS()) {
+  } else if (spvContext.isHS()) {
     if (!processTessellationShaderAttributes(decl, &numOutputControlPoints))
       return false;
 
@@ -9106,7 +9680,7 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
       }
 
     outputArraySize = numOutputControlPoints;
-  } else if (shaderModel.IsDS()) {
+  } else if (spvContext.isDS()) {
     if (!processTessellationShaderAttributes(decl, &numOutputControlPoints))
       return false;
 
@@ -9117,7 +9691,7 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
         break;
       }
     // The per-vertex output of DS is not an array.
-  } else if (shaderModel.IsGS()) {
+  } else if (spvContext.isGS()) {
     if (!processGeometryShaderAttributes(decl, &inputArraySize))
       return false;
     // The per-vertex output of GS is not an array.
@@ -9149,7 +9723,7 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   // offset of SV_ClipDistance/SV_CullDistance variables within the array.
   declIdMapper.glPerVertex.calculateClipCullDistanceArraySize();
 
-  if (!shaderModel.IsCS()) {
+  if (!spvContext.isCS()) {
     // Generate stand-alone builtins of Position, ClipDistance, and
     // CullDistance, which belongs to gl_PerVertex.
     declIdMapper.glPerVertex.generateVars(inputArraySize, outputArraySize);
@@ -9191,7 +9765,7 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
     // Also do not create input variables for output stream objects of geometry
     // shaders (e.g. TriangleStream) which are required to be marked as 'inout'.
     if (canActAsInParmVar(param)) {
-      if (shaderModel.IsHS() && hlsl::IsHLSLInputPatchType(paramType)) {
+      if (spvContext.isHS() && hlsl::IsHLSLInputPatchType(paramType)) {
         // Record the temporary variable holding InputPatch. It may be used
         // later in the patch constant function.
         hullMainInputPatchParam = tempVar;
@@ -9229,7 +9803,7 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   //    to the proper offset in the array.
   // 2- The patch constant function must be called *once* after all invocations
   //    of the main entry point function is done.
-  if (shaderModel.IsHS()) {
+  if (spvContext.isHS()) {
     // Create stage output variables out of the return type.
     if (!declIdMapper.createStageOutputVar(decl, numOutputControlPoints,
                                            outputControlPointIdVal, retVal))
@@ -9258,7 +9832,7 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
       // Write back of stage output variables in GS is manually controlled by
       // .Append() intrinsic method. No need to load the parameter since we
       // won't need to write back here.
-      if (param->isUsed() && !shaderModel.IsGS())
+      if (param->isUsed() && !spvContext.isGS())
         loadedParam = spvBuilder.createLoad(param->getType(), params[i]);
 
       if (!declIdMapper.createStageOutputVar(param, loadedParam, false))
@@ -9271,7 +9845,7 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
 
   // For Hull shaders, there is no explicit call to the PCF in the HLSL source.
   // We should invoke a translation of the PCF manually.
-  if (shaderModel.IsHS())
+  if (spvContext.isHS())
     doDecl(patchConstFunc);
 
   return true;
@@ -9283,7 +9857,7 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
     SpirvInstruction *outputControlPointId, SpirvInstruction *primitiveId,
     SpirvInstruction *viewId, SpirvInstruction *hullMainInputPatch) {
   // This method may only be called for Hull shaders.
-  assert(shaderModel.IsHS());
+  assert(spvContext.isHS());
 
   // For Hull shaders, the real output is an array of size
   // numOutputControlPoints. The results of the main should be written to the
@@ -9687,5 +10261,21 @@ void SpirvEmitter::emitDebugLine(SourceLocation loc) {
   }
 }
 
+void SpirvEmitter::addFunctionToWorkQueue(hlsl::DXIL::ShaderKind shaderKind,
+                                          const clang::FunctionDecl *fnDecl,
+                                          bool isEntryFunction) {
+  // Only update the workQueue and the function info map if the given
+  // FunctionDecl hasn't been added already.
+  if (functionInfoMap.find(fnDecl) == functionInfoMap.end()) {
+    // Note: The function is just discovered and is being added to the
+    // workQueue, therefore it does not have the entryFunction SPIR-V
+    // instruction yet (use nullptr).
+    auto *fnInfo = new (spvContext) FunctionInfo(
+        shaderKind, fnDecl, /*entryFunction*/ nullptr, isEntryFunction);
+    functionInfoMap[fnDecl] = fnInfo;
+    workQueue.push_back(fnInfo);
+  }
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 65 - 11
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -25,11 +25,12 @@
 #include "clang/AST/AST.h"
 #include "clang/AST/ASTConsumer.h"
 #include "clang/AST/ASTContext.h"
+#include "clang/AST/TypeOrdering.h"
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/SPIRV/FeatureManager.h"
-#include "clang/SPIRV/SpirvContext.h"
 #include "clang/SPIRV/SpirvBuilder.h"
+#include "clang/SPIRV/SpirvContext.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 
@@ -77,6 +78,7 @@ private:
   void doVarDecl(const VarDecl *decl);
   void doRecordDecl(const RecordDecl *decl);
   void doHLSLBufferDecl(const HLSLBufferDecl *decl);
+  void doImplicitDecl(const Decl *decl);
 
   void doBreakStmt(const BreakStmt *stmt);
   void doDiscardStmt(const DiscardStmt *stmt);
@@ -503,6 +505,14 @@ private:
   /// Processes the NonUniformResourceIndex intrinsic function.
   SpirvInstruction *processIntrinsicNonUniformResourceIndex(const CallExpr *);
 
+  /// Process builtins specific to raytracing.
+  SpirvInstruction *processRayBuiltins(const CallExpr *, hlsl::IntrinsicOp op);
+
+  /// Process raytracing intrinsics.
+  SpirvInstruction *processReportHit(const CallExpr *);
+  void processCallShader(const CallExpr *callExpr);
+  void processTraceRay(const CallExpr *callExpr);
+
 private:
   /// Returns the <result-id> for constant value 0 of the given type.
   SpirvConstant *getValueZero(QualType type);
@@ -582,8 +592,8 @@ private:
   /// Emits an error if the given attribute is not a loop attribute.
   spv::LoopControlMask translateLoopAttribute(const Stmt *, const Attr &);
 
-  static spv::ExecutionModel
-  getSpirvShaderStage(const hlsl::ShaderModel &model);
+  static hlsl::ShaderModel::Kind getShaderModelKind(StringRef stageName);
+  static spv::ExecutionModel getSpirvShaderStage(hlsl::ShaderModel::Kind smk);
 
   /// \brief Adds necessary execution modes for the hull/domain shaders based on
   /// the HLSL attributes of the entry point function.
@@ -621,6 +631,15 @@ private:
   bool emitEntryFunctionWrapper(const FunctionDecl *entryFunction,
                                 SpirvFunction *entryFuncId);
 
+  /// \brief Emits a wrapper function for the entry functions for raytracing
+  /// stages and returns true on success.
+  ///
+  /// Wrapper is specific to raytracing stages since for specific stages we
+  /// create specific module scoped stage variables and perform copies to them.
+  /// The wrapper function is also responsible for initializing global static
+  /// variables for some cases.
+  bool emitEntryFunctionWrapperForRayTracing(const FunctionDecl *entryFunction,
+                                             SpirvFunction *entryFuncId);
   /// \brief Performs the following operations for the Hull shader:
   /// * Creates an output variable which is an Array containing results for all
   /// control points.
@@ -913,6 +932,14 @@ private:
   /// \brief Emit an OpLine instruction for the given source location.
   void emitDebugLine(SourceLocation);
 
+private:
+  /// \brief If the given FunctionDecl is not already in the workQueue, creates
+  /// a FunctionInfo object for it, and inserts it into the workQueue. It also
+  /// updates the functionInfoMap with the proper mapping.
+  void addFunctionToWorkQueue(hlsl::DXIL::ShaderKind,
+                              const clang::FunctionDecl *,
+                              bool isEntryFunction);
+
 private:
   /// \brief Wrapper method to create a fatal error message and report it
   /// in the diagnostic engine associated with this consumer.
@@ -958,21 +985,37 @@ private:
 
   SpirvCodeGenOptions &spirvOptions;
 
-  /// Entry function name and shader stage. Both of them are derived from the
-  /// command line and should be const.
+  /// \brief Entry function name, derived from the command line
+  /// and should be const.
   const llvm::StringRef entryFunctionName;
-  const hlsl::ShaderModel &shaderModel;
+
+  /// \brief Structure to maintain record of all entry functions and any
+  /// reachable functions.
+  struct FunctionInfo {
+  public:
+    hlsl::ShaderModel::Kind shaderModelKind;
+    const DeclaratorDecl *funcDecl;
+    SpirvFunction *entryFunction;
+    bool isEntryFunction;
+
+    FunctionInfo() = default;
+    FunctionInfo(hlsl::ShaderModel::Kind smk, const DeclaratorDecl *fDecl,
+                 SpirvFunction *entryFunc, bool isEntryFunc)
+        : shaderModelKind(smk), funcDecl(fDecl), entryFunction(entryFunc),
+          isEntryFunction(isEntryFunc) {}
+  };
 
   SpirvContext spvContext;
   FeatureManager featureManager;
   SpirvBuilder spvBuilder;
   DeclResultIdMapper declIdMapper;
 
-  /// A queue of decls reachable from the entry function. Decls inserted into
-  /// this queue will persist to avoid duplicated translations. And we'd like
-  /// a deterministic order of iterating the queue for finding the next decl
-  /// to translate. So we need SetVector here.
-  llvm::SetVector<const DeclaratorDecl *> workQueue;
+  /// \brief A map of funcDecl to its FunctionInfo. Consists of all entry
+  /// functions followed by all reachable functions from the entry functions.
+  llvm::DenseMap<const DeclaratorDecl *, FunctionInfo *> functionInfoMap;
+
+  /// A queue of FunctionInfo reachable from all the entry functions.
+  std::vector<const FunctionInfo *> workQueue;
 
   /// <result-id> for the entry function. Initially it is zero and will be reset
   /// when starting to translate the entry function.
@@ -1046,6 +1089,17 @@ private:
   /// Maps a given statement to the basic block that is associated with it.
   llvm::DenseMap<const Stmt *, SpirvBasicBlock *> stmtBasicBlock;
 
+  /// Maintains mapping from a type to SPIR-V variable along with SPIR-V
+  /// instruction for id of location decoration Used for raytracing stage
+  /// variables of storage class RayPayloadNV, CallableDataNV and
+  /// HitAttributeNV.
+  llvm::SmallDenseMap<QualType,
+                      std::pair<SpirvInstruction *, SpirvInstruction *>, 4>
+      payloadMap;
+  llvm::SmallDenseMap<QualType, SpirvInstruction *, 4> hitAttributeMap;
+  llvm::SmallDenseMap<QualType,
+                      std::pair<SpirvInstruction *, SpirvInstruction *>, 4>
+      callDataMap;
   /// This is the Patch Constant Function. This function is not explicitly
   /// called from the entry point function.
   FunctionDecl *patchConstFunc;

+ 6 - 0
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -81,6 +81,7 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvStore)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvUnaryOp)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvVectorShuffle)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvArrayLength)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvRayTracingOpNV)
 
 #undef DEFINE_INVOKE_VISITOR_FOR_CLASS
 
@@ -719,5 +720,10 @@ SpirvArrayLength::SpirvArrayLength(QualType resultType, SourceLocation loc,
     : SpirvInstruction(IK_ArrayLength, spv::Op::OpArrayLength, resultType, loc),
       structure(structure_), arrayMember(memberLiteral) {}
 
+SpirvRayTracingOpNV::SpirvRayTracingOpNV(
+    QualType resultType, spv::Op opcode,
+    llvm::ArrayRef<SpirvInstruction *> vecOperands, SourceLocation loc)
+    : SpirvInstruction(IK_RayTracingOpNV, opcode, resultType, loc),
+      operands(vecOperands.begin(), vecOperands.end()) {}
 } // namespace spirv
 } // namespace clang

+ 2 - 1
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -88,7 +88,8 @@ bool SpirvType::isSubpassInputMS(const SpirvType *type) {
 }
 
 bool SpirvType::isResourceType(const SpirvType *type) {
-  if (isa<ImageType>(type) || isa<SamplerType>(type))
+  if (isa<ImageType>(type) || isa<SamplerType>(type) ||
+      isa<AccelerationStructureTypeNV>(type))
     return true;
 
   if (const auto *structType = dyn_cast<StructType>(type))

+ 78 - 0
tools/clang/test/CodeGenSPIRV/raytracing.nv.anyhit.hlsl

@@ -0,0 +1,78 @@
+// Run: %dxc -T lib_6_3
+// CHECK:  OpCapability RayTracingNV
+// CHECK:  OpExtension "SPV_NV_ray_tracing"
+// CHECK:  OpDecorate [[a:%\d+]] BuiltIn LaunchIdNV
+// CHECK:  OpDecorate [[b:%\d+]] BuiltIn LaunchSizeNV
+// CHECK:  OpDecorate [[c:%\d+]] BuiltIn WorldRayOriginNV
+// CHECK:  OpDecorate [[d:%\d+]] BuiltIn WorldRayDirectionNV
+// CHECK:  OpDecorate [[e:%\d+]] BuiltIn RayTminNV
+// CHECK:  OpDecorate [[f:%\d+]] BuiltIn IncomingRayFlagsNV
+// CHECK:  OpDecorate %gl_InstanceID BuiltIn InstanceId
+// CHECK:  OpDecorate [[g:%\d+]] BuiltIn InstanceCustomIndexNV
+// CHECK:  OpDecorate %gl_PrimitiveID BuiltIn PrimitiveId
+// CHECK:  OpDecorate [[h:%\d+]] BuiltIn ObjectRayOriginNV
+// CHECK:  OpDecorate [[i:%\d+]] BuiltIn ObjectRayDirectionNV
+// CHECK:  OpDecorate [[j:%\d+]] BuiltIn ObjectToWorldNV
+// CHECK:  OpDecorate [[k:%\d+]] BuiltIn WorldToObjectNV
+// CHECK:  OpDecorate [[l:%\d+]] BuiltIn HitKindNV
+
+// CHECK:  OpTypePointer IncomingRayPayloadNV %Payload
+struct Payload
+{
+  float4 color;
+};
+// CHECK:  OpTypePointer HitAttributeNV %Attribute
+struct Attribute
+{
+  float2 bary;
+};
+
+// CHECK-COUNT-1: [[rstype:%\d+]] = OpTypeAccelerationStructureNV
+RaytracingAccelerationStructure rs;
+
+[shader("anyhit")]
+void main(inout Payload MyPayload, in Attribute MyAttr) {
+
+// CHECK:  OpLoad %v3uint [[a]]
+  uint3 _1 = DispatchRaysIndex();
+// CHECK:  OpLoad %v3uint [[b]]
+  uint3 _2 = DispatchRaysDimensions();
+// CHECK:  OpLoad %v3float [[c]]
+  float3 _3 = WorldRayOrigin();
+// CHECK:  OpLoad %v3float [[d]]
+  float3 _4 = WorldRayDirection();
+// CHECK:  OpLoad %float [[e]]
+  float _5 = RayTMin();
+// CHECK:  OpLoad %uint [[f]]
+  uint _6 = RayFlags();
+// CHECK:  OpLoad %uint %gl_InstanceID
+  uint _7 = InstanceIndex();
+// CHECK:  OpLoad %uint [[g]]
+  uint _8 = InstanceID();
+// CHECK:  OpLoad %uint %gl_PrimitiveID
+  uint _9 = PrimitiveIndex();
+// CHECK:  OpLoad %v3float [[h]]
+  float3 _10 = ObjectRayOrigin();
+// CHECK:  OpLoad %v3float [[i]]
+  float3 _11 = ObjectRayDirection();
+// CHECK: [[matotw:%\d+]] = OpLoad %mat4v3float [[j]]
+// CHECK-NEXT: OpTranspose %mat3v4float [[matotw]]
+  float3x4 _12 = ObjectToWorld3x4();
+// CHECK:  OpLoad %mat4v3float [[j]]
+  float4x3 _13 = ObjectToWorld4x3();
+// CHECK: [[matwto:%\d+]] = OpLoad %mat4v3float [[k]]
+// CHECK-NEXT: OpTranspose %mat3v4float [[matwto]]
+  float3x4 _14 = WorldToObject3x4();
+// CHECK:  OpLoad %mat4v3float [[k]]
+  float4x3 _15 = WorldToObject4x3();
+// CHECK:  OpLoad %uint [[l]]
+  uint _16 = HitKind();
+
+  if (_16 == 1U) {
+// CHECK:  OpIgnoreIntersectionNV
+    IgnoreHit();
+  } else {
+// CHECK:  OpTerminateRayNV
+    AcceptHitAndEndSearch();
+  }
+}

+ 20 - 0
tools/clang/test/CodeGenSPIRV/raytracing.nv.callable.hlsl

@@ -0,0 +1,20 @@
+// Run: %dxc -T lib_6_3
+// CHECK:  OpCapability RayTracingNV
+// CHECK:  OpExtension "SPV_NV_ray_tracing"
+// CHECK:  OpDecorate [[a:%\d+]] BuiltIn LaunchIdNV
+// CHECK:  OpDecorate [[b:%\d+]] BuiltIn LaunchSizeNV
+
+// CHECK:  OpTypePointer IncomingCallableDataNV %CallData
+struct CallData
+{
+  float4 data;
+};
+
+[shader("callable")]
+void main(inout CallData myCallData) {
+
+// CHECK:  OpLoad %v3uint [[a]]
+  uint3 a = DispatchRaysIndex();
+// CHECK:  OpLoad %v3uint [[b]]
+  uint3 b = DispatchRaysDimensions();
+}

+ 79 - 0
tools/clang/test/CodeGenSPIRV/raytracing.nv.closesthit.hlsl

@@ -0,0 +1,79 @@
+// Run: %dxc -T lib_6_3
+// CHECK:  OpCapability RayTracingNV
+// CHECK:  OpExtension "SPV_NV_ray_tracing"
+// CHECK:  OpDecorate [[a:%\d+]] BuiltIn LaunchIdNV
+// CHECK:  OpDecorate [[b:%\d+]] BuiltIn LaunchSizeNV
+// CHECK:  OpDecorate [[c:%\d+]] BuiltIn WorldRayOriginNV
+// CHECK:  OpDecorate [[d:%\d+]] BuiltIn WorldRayDirectionNV
+// CHECK:  OpDecorate [[e:%\d+]] BuiltIn RayTminNV
+// CHECK:  OpDecorate [[f:%\d+]] BuiltIn IncomingRayFlagsNV
+// CHECK:  OpDecorate %gl_InstanceID BuiltIn InstanceId
+// CHECK:  OpDecorate [[g:%\d+]] BuiltIn InstanceCustomIndexNV
+// CHECK:  OpDecorate %gl_PrimitiveID BuiltIn PrimitiveId
+// CHECK:  OpDecorate [[h:%\d+]] BuiltIn ObjectRayOriginNV
+// CHECK:  OpDecorate [[i:%\d+]] BuiltIn ObjectRayDirectionNV
+// CHECK:  OpDecorate [[j:%\d+]] BuiltIn ObjectToWorldNV
+// CHECK:  OpDecorate [[k:%\d+]] BuiltIn WorldToObjectNV
+// CHECK:  OpDecorate [[l:%\d+]] BuiltIn HitKindNV
+
+// CHECK:  OpTypePointer IncomingRayPayloadNV %Payload
+struct Payload
+{
+  float4 color;
+};
+// CHECK:  OpTypePointer HitAttributeNV %Attribute
+struct Attribute
+{
+  float2 bary;
+};
+
+// CHECK-COUNT-1: [[rstype:%\d+]] = OpTypeAccelerationStructureNV
+RaytracingAccelerationStructure rs;
+
+[shader("closesthit")]
+void main(inout Payload MyPayload, in Attribute MyAttr) {
+
+// CHECK:  OpLoad %v3uint [[a]]
+  uint3 _1 = DispatchRaysIndex();
+// CHECK:  OpLoad %v3uint [[b]]
+  uint3 _2 = DispatchRaysDimensions();
+// CHECK:  OpLoad %v3float [[c]]
+  float3 _3 = WorldRayOrigin();
+// CHECK:  OpLoad %v3float [[d]]
+  float3 _4 = WorldRayDirection();
+// CHECK:  OpLoad %float [[e]]
+  float _5 = RayTMin();
+// CHECK:  OpLoad %uint [[f]]
+  uint _6 = RayFlags();
+// CHECK:  OpLoad %uint %gl_InstanceID
+  uint _7 = InstanceIndex();
+// CHECK:  OpLoad %uint [[g]]
+  uint _8 = InstanceID();
+// CHECK:  OpLoad %uint %gl_PrimitiveID
+  uint _9 = PrimitiveIndex();
+// CHECK:  OpLoad %v3float [[h]]
+  float3 _10 = ObjectRayOrigin();
+// CHECK:  OpLoad %v3float [[i]]
+  float3 _11 = ObjectRayDirection();
+// CHECK: [[matotw:%\d+]] = OpLoad %mat4v3float [[j]]
+// CHECK-NEXT: OpTranspose %mat3v4float [[matotw]]
+  float3x4 _12 = ObjectToWorld3x4();
+// CHECK:  OpLoad %mat4v3float [[j]]
+  float4x3 _13 = ObjectToWorld4x3();
+// CHECK: [[matwto:%\d+]] = OpLoad %mat4v3float [[k]]
+// CHECK-NEXT: OpTranspose %mat3v4float [[matwto]]
+  float3x4 _14 = WorldToObject3x4();
+// CHECK:  OpLoad %mat4v3float [[k]]
+  float4x3 _15 = WorldToObject4x3();
+// CHECK:  OpLoad %uint [[l]]
+  uint _16 = HitKind();
+
+  Payload myPayload = { float4(0.0f,0.0f,0.0f,0.0f) };
+  RayDesc rayDesc;
+  rayDesc.Origin = float3(0.0f, 0.0f, 0.0f);
+  rayDesc.Direction = float3(0.0f, 0.0f, -1.0f);
+  rayDesc.TMin = 0.0f;
+  rayDesc.TMax = 1000.0f;
+// CHECK: OpTraceNV {{%\d+}} %uint_0 %uint_255 %uint_0 %uint_1 %uint_0 {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %uint_0
+  TraceRay(rs, 0x0, 0xff, 0, 1, 0, rayDesc, myPayload);
+}

+ 62 - 0
tools/clang/test/CodeGenSPIRV/raytracing.nv.intersection.hlsl

@@ -0,0 +1,62 @@
+// Run: %dxc -T lib_6_3
+// CHECK:  OpCapability RayTracingNV
+// CHECK:  OpExtension "SPV_NV_ray_tracing"
+// CHECK:  OpDecorate [[a:%\d+]] BuiltIn LaunchIdNV
+// CHECK:  OpDecorate [[b:%\d+]] BuiltIn LaunchSizeNV
+// CHECK:  OpDecorate [[c:%\d+]] BuiltIn WorldRayOriginNV
+// CHECK:  OpDecorate [[d:%\d+]] BuiltIn WorldRayDirectionNV
+// CHECK:  OpDecorate [[e:%\d+]] BuiltIn RayTminNV
+// CHECK:  OpDecorate [[f:%\d+]] BuiltIn IncomingRayFlagsNV
+// CHECK:  OpDecorate %gl_InstanceID BuiltIn InstanceId
+// CHECK:  OpDecorate [[g:%\d+]] BuiltIn InstanceCustomIndexNV
+// CHECK:  OpDecorate %gl_PrimitiveID BuiltIn PrimitiveId
+// CHECK:  OpDecorate [[h:%\d+]] BuiltIn ObjectRayOriginNV
+// CHECK:  OpDecorate [[i:%\d+]] BuiltIn ObjectRayDirectionNV
+// CHECK:  OpDecorate [[j:%\d+]] BuiltIn ObjectToWorldNV
+// CHECK:  OpDecorate [[k:%\d+]] BuiltIn WorldToObjectNV
+
+struct Attribute
+{
+  float2 bary;
+};
+
+[shader("intersection")]
+void main() {
+
+// CHECK:  OpLoad %v3uint [[a]]
+  uint3 _1 = DispatchRaysIndex();
+// CHECK:  OpLoad %v3uint [[b]]
+  uint3 _2 = DispatchRaysDimensions();
+// CHECK:  OpLoad %v3float [[c]]
+  float3 _3 = WorldRayOrigin();
+// CHECK:  OpLoad %v3float [[d]]
+  float3 _4 = WorldRayDirection();
+// CHECK:  OpLoad %float [[e]]
+  float _5 = RayTMin();
+// CHECK:  OpLoad %uint [[f]]
+  uint _6 = RayFlags();
+// CHECK:  OpLoad %uint %gl_InstanceID
+  uint _7 = InstanceIndex();
+// CHECK:  OpLoad %uint [[g]]
+  uint _8 = InstanceID();
+// CHECK:  OpLoad %uint %gl_PrimitiveID
+  uint _9 = PrimitiveIndex();
+// CHECK:  OpLoad %v3float [[h]]
+  float3 _10 = ObjectRayOrigin();
+// CHECK:  OpLoad %v3float [[i]]
+  float3 _11 = ObjectRayDirection();
+// CHECK: [[matotw:%\d+]] = OpLoad %mat4v3float [[j]]
+// CHECK-NEXT: OpTranspose %mat3v4float [[matotw]]
+  float3x4 _12 = ObjectToWorld3x4();
+// CHECK:  OpLoad %mat4v3float [[j]]
+  float4x3 _13 = ObjectToWorld4x3();
+// CHECK: [[matwto:%\d+]] = OpLoad %mat4v3float [[k]]
+// CHECK-NEXT: OpTranspose %mat3v4float [[matwto]]
+  float3x4 _14 = WorldToObject3x4();
+// CHECK:  OpLoad %mat4v3float [[k]]
+  float4x3 _15 = WorldToObject4x3();
+
+  Attribute myHitAttribute = { float2(0.0f,0.0f) };
+// CHECK: OpReportIntersectionNV %bool %float_0 %uint_0
+  ReportHit(0.0f, 0U, myHitAttribute);
+}

+ 260 - 0
tools/clang/test/CodeGenSPIRV/raytracing.nv.library.hlsl

@@ -0,0 +1,260 @@
+// Run: %dxc -T lib_6_3
+// CHECK:  OpCapability RayTracingNV
+// CHECK:  OpExtension "SPV_NV_ray_tracing"
+// CHECK:  OpEntryPoint RayGenerationNV %MyRayGenMain "MyRayGenMain" {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %gl_InstanceID {{%\d+}} %gl_PrimitiveID {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}}
+// CHECK:  OpEntryPoint RayGenerationNV %MyRayGenMain2 "MyRayGenMain2" {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %gl_InstanceID {{%\d+}} %gl_PrimitiveID {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}}
+// CHECK:  OpEntryPoint MissNV %MyMissMain "MyMissMain" {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %gl_InstanceID {{%\d+}} %gl_PrimitiveID {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}}
+// CHECK:  OpEntryPoint MissNV %MyMissMain2 "MyMissMain2" {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %gl_InstanceID {{%\d+}} %gl_PrimitiveID {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}}
+// CHECK:  OpEntryPoint IntersectionNV %MyISecMain "MyISecMain" {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %gl_InstanceID {{%\d+}} %gl_PrimitiveID {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}}
+// CHECK:  OpEntryPoint IntersectionNV %MyISecMain2 "MyISecMain2" {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %gl_InstanceID {{%\d+}} %gl_PrimitiveID {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}}
+// CHECK:  OpEntryPoint AnyHitNV %MyAHitMain "MyAHitMain" {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %gl_InstanceID {{%\d+}} %gl_PrimitiveID {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}}
+// CHECK:  OpEntryPoint AnyHitNV %MyAHitMain2 "MyAHitMain2" {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %gl_InstanceID {{%\d+}} %gl_PrimitiveID {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}}
+// CHECK:  OpEntryPoint ClosestHitNV %MyCHitMain "MyCHitMain" {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %gl_InstanceID {{%\d+}} %gl_PrimitiveID {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}}
+// CHECK:  OpEntryPoint ClosestHitNV %MyCHitMain2 "MyCHitMain2" {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %gl_InstanceID {{%\d+}} %gl_PrimitiveID {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}}
+// CHECK:  OpEntryPoint CallableNV %MyCallMain "MyCallMain" {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %gl_InstanceID {{%\d+}} %gl_PrimitiveID {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}}
+// CHECK:  OpEntryPoint CallableNV %MyCallMain2 "MyCallMain2" {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %gl_InstanceID {{%\d+}} %gl_PrimitiveID {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}}
+// CHECK:  OpDecorate [[a:%\d+]] BuiltIn LaunchIdNV
+// CHECK:  OpDecorate [[b:%\d+]] BuiltIn LaunchSizeNV
+// CHECK:  OpDecorate [[c:%\d+]] BuiltIn WorldRayOriginNV
+// CHECK:  OpDecorate [[d:%\d+]] BuiltIn WorldRayDirectionNV
+// CHECK:  OpDecorate [[e:%\d+]] BuiltIn RayTminNV
+// CHECK:  OpDecorate [[f:%\d+]] BuiltIn IncomingRayFlagsNV
+// CHECK:  OpDecorate %gl_InstanceID BuiltIn InstanceId
+// CHECK:  OpDecorate [[g:%\d+]] BuiltIn InstanceCustomIndexNV
+// CHECK:  OpDecorate %gl_PrimitiveID BuiltIn PrimitiveId
+// CHECK:  OpDecorate [[h:%\d+]] BuiltIn ObjectRayOriginNV
+// CHECK:  OpDecorate [[i:%\d+]] BuiltIn ObjectRayDirectionNV
+// CHECK:  OpDecorate [[j:%\d+]] BuiltIn ObjectToWorldNV
+// CHECK:  OpDecorate [[k:%\d+]] BuiltIn WorldToObjectNV
+// CHECK:  OpDecorate [[l:%\d+]] BuiltIn HitKindNV
+
+
+// CHECK: OpTypePointer CallableDataNV %CallData
+struct CallData
+{
+  float4 data;
+};
+// CHECK:  OpTypePointer IncomingRayPayloadNV %Payload
+struct Payload
+{
+  float4 color;
+};
+// CHECK:  OpTypePointer HitAttributeNV %Attribute
+struct Attribute
+{
+  float2 bary;
+};
+// CHECK-COUNT-1: [[rstype:%\d+]] = OpTypeAccelerationStructureNV
+RaytracingAccelerationStructure rs;
+
+
+[shader("raygeneration")]
+void MyRayGenMain() {
+
+// CHECK:  OpLoad %v3uint [[a]]
+  uint3 a = DispatchRaysIndex();
+// CHECK:  OpLoad %v3uint [[b]]
+  uint3 b = DispatchRaysDimensions();
+
+  Payload myPayload = { float4(0.0f,0.0f,0.0f,0.0f) };
+  CallData myCallData = { float4(0.0f,0.0f,0.0f,0.0f) };
+  RayDesc rayDesc;
+  rayDesc.Origin = float3(0.0f, 0.0f, 0.0f);
+  rayDesc.Direction = float3(0.0f, 0.0f, -1.0f);
+  rayDesc.TMin = 0.0f;
+  rayDesc.TMax = 1000.0f;
+// CHECK: OpTraceNV {{%\d+}} %uint_0 %uint_255 %uint_0 %uint_1 %uint_0 {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %uint_0
+  TraceRay(rs, 0x0, 0xff, 0, 1, 0, rayDesc, myPayload);
+// CHECK: OpExecuteCallableNV %uint_0 %uint_0
+  CallShader(0, myCallData);
+}
+
+[shader("raygeneration")]
+void MyRayGenMain2() {
+    CallData myCallData = { float4(0.0f,0.0f,0.0f,0.0f) };
+    CallShader(0, myCallData);
+}
+
+[shader("miss")]
+void MyMissMain(inout Payload MyPayload) {
+
+// CHECK:  OpLoad %v3uint [[a]]
+  uint3 _1 = DispatchRaysIndex();
+// CHECK:  OpLoad %v3uint [[b]]
+  uint3 _2 = DispatchRaysDimensions();
+// CHECK:  OpLoad %v3float [[c]]
+  float3 _3 = WorldRayOrigin();
+// CHECK:  OpLoad %v3float [[d]]
+  float3 _4 = WorldRayDirection();
+// CHECK:  OpLoad %float [[e]]
+  float _5 = RayTMin();
+// CHECK:  OpLoad %uint [[f]]
+  uint _6 = RayFlags();
+}
+
+[shader("miss")]
+void MyMissMain2(inout Payload MyPayload) {
+    MyPayload.color = float4(0.0f,1.0f,0.0f,1.0f);
+}
+
+[shader("intersection")]
+void MyISecMain() {
+
+// CHECK:  OpLoad %v3uint [[a]]
+  uint3 _1 = DispatchRaysIndex();
+// CHECK:  OpLoad %v3uint [[b]]
+  uint3 _2 = DispatchRaysDimensions();
+// CHECK:  OpLoad %v3float [[c]]
+  float3 _3 = WorldRayOrigin();
+// CHECK:  OpLoad %v3float [[d]]
+  float3 _4 = WorldRayDirection();
+// CHECK:  OpLoad %float [[e]]
+  float _5 = RayTMin();
+// CHECK:  OpLoad %uint [[f]]
+  uint _6 = RayFlags();
+// CHECK:  OpLoad %uint %gl_InstanceID
+  uint _7 = InstanceIndex();
+// CHECK:  OpLoad %uint [[g]]
+  uint _8 = InstanceID();
+// CHECK:  OpLoad %uint %gl_PrimitiveID
+  uint _9 = PrimitiveIndex();
+// CHECK:  OpLoad %v3float [[h]]
+  float3 _10 = ObjectRayOrigin();
+// CHECK:  OpLoad %v3float [[i]]
+  float3 _11 = ObjectRayDirection();
+// CHECK:  OpLoad %mat4v3float [[j]]
+  float3x4 _12 = ObjectToWorld3x4();
+// CHECK:  OpLoad %mat4v3float [[j]]
+  float4x3 _13 = ObjectToWorld4x3();
+// CHECK:  OpLoad %mat4v3float [[k]]
+  float3x4 _14 = WorldToObject3x4();
+// CHECK:  OpLoad %mat4v3float [[k]]
+  float4x3 _15 = WorldToObject4x3();
+
+  Attribute myHitAttribute = { float2(0.0f,0.0f) };
+// CHECK: OpReportIntersectionNV %bool %float_0 %uint_0
+  ReportHit(0.0f, 0U, myHitAttribute);
+}
+
+[shader("intersection")]
+void MyISecMain2() {
+  Attribute myHitAttribute = { float2(0.0f,1.0f) };
+// CHECK: OpReportIntersectionNV %bool %float_0 %uint_0
+  ReportHit(0.0f, 0U, myHitAttribute);
+}
+
+[shader("anyhit")]
+void MyAHitMain(inout Payload MyPayload, in Attribute MyAttr) {
+
+// CHECK:  OpLoad %v3uint [[a]]
+  uint3 _1 = DispatchRaysIndex();
+// CHECK:  OpLoad %v3uint [[b]]
+  uint3 _2 = DispatchRaysDimensions();
+// CHECK:  OpLoad %v3float [[c]]
+  float3 _3 = WorldRayOrigin();
+// CHECK:  OpLoad %v3float [[d]]
+  float3 _4 = WorldRayDirection();
+// CHECK:  OpLoad %float [[e]]
+  float _5 = RayTMin();
+// CHECK:  OpLoad %uint [[f]]
+  uint _6 = RayFlags();
+// CHECK:  OpLoad %uint %gl_InstanceID
+  uint _7 = InstanceIndex();
+// CHECK:  OpLoad %uint [[g]]
+  uint _8 = InstanceID();
+// CHECK:  OpLoad %uint %gl_PrimitiveID
+  uint _9 = PrimitiveIndex();
+// CHECK:  OpLoad %v3float [[h]]
+  float3 _10 = ObjectRayOrigin();
+// CHECK:  OpLoad %v3float [[i]]
+  float3 _11 = ObjectRayDirection();
+// CHECK:  OpLoad %mat4v3float [[j]]
+  float3x4 _12 = ObjectToWorld3x4();
+// CHECK:  OpLoad %mat4v3float [[j]]
+  float4x3 _13 = ObjectToWorld4x3();
+// CHECK:  OpLoad %mat4v3float [[k]]
+  float3x4 _14 = WorldToObject3x4();
+// CHECK:  OpLoad %mat4v3float [[k]]
+  float4x3 _15 = WorldToObject4x3();
+// CHECK:  OpLoad %uint [[l]]
+  uint _16 = HitKind();
+
+  if (_16 == 1U) {
+// CHECK:  OpIgnoreIntersectionNV
+    IgnoreHit();
+  } else {
+// CHECK:  OpTerminateRayNV
+    AcceptHitAndEndSearch();
+  }
+}
+
+[shader("anyhit")]
+void MyAHitMain2(inout Payload MyPayload, in Attribute MyAttr) {
+// CHECK:  OpTerminateRayNV
+    AcceptHitAndEndSearch();
+}
+
+[shader("closesthit")]
+void MyCHitMain(inout Payload MyPayload, in Attribute MyAttr) {
+
+// CHECK:  OpLoad %v3uint [[a]]
+  uint3 _1 = DispatchRaysIndex();
+// CHECK:  OpLoad %v3uint [[b]]
+  uint3 _2 = DispatchRaysDimensions();
+// CHECK:  OpLoad %v3float [[c]]
+  float3 _3 = WorldRayOrigin();
+// CHECK:  OpLoad %v3float [[d]]
+  float3 _4 = WorldRayDirection();
+// CHECK:  OpLoad %float [[e]]
+  float _5 = RayTMin();
+// CHECK:  OpLoad %uint [[f]]
+  uint _6 = RayFlags();
+// CHECK:  OpLoad %uint %gl_InstanceID
+  uint _7 = InstanceIndex();
+// CHECK:  OpLoad %uint [[g]]
+  uint _8 = InstanceID();
+// CHECK:  OpLoad %uint %gl_PrimitiveID
+  uint _9 = PrimitiveIndex();
+// CHECK:  OpLoad %v3float [[h]]
+  float3 _10 = ObjectRayOrigin();
+// CHECK:  OpLoad %v3float [[i]]
+  float3 _11 = ObjectRayDirection();
+// CHECK:  OpLoad %mat4v3float [[j]]
+  float3x4 _12 = ObjectToWorld3x4();
+// CHECK:  OpLoad %mat4v3float [[j]]
+  float4x3 _13 = ObjectToWorld4x3();
+// CHECK:  OpLoad %mat4v3float [[k]]
+  float3x4 _14 = WorldToObject3x4();
+// CHECK:  OpLoad %mat4v3float [[k]]
+  float4x3 _15 = WorldToObject4x3();
+// CHECK:  OpLoad %uint [[l]]
+  uint _16 = HitKind();
+
+  Payload myPayload = { float4(0.0f,0.0f,0.0f,0.0f) };
+  RayDesc rayDesc;
+  rayDesc.Origin = float3(0.0f, 0.0f, 0.0f);
+  rayDesc.Direction = float3(0.0f, 0.0f, -1.0f);
+  rayDesc.TMin = 0.0f;
+  rayDesc.TMax = 1000.0f;
+// CHECK: OpTraceNV {{%\d+}} %uint_0 %uint_255 %uint_0 %uint_1 %uint_0 {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %uint_0
+  TraceRay(rs, 0x0, 0xff, 0, 1, 0, rayDesc, myPayload);
+}
+
+[shader("closesthit")]
+void MyCHitMain2(inout Payload MyPayload, in Attribute MyAttr) {
+    MyPayload.color = float4(0.0f,1.0f,0.0f,1.0f);
+}
+
+[shader("callable")]
+void MyCallMain(inout CallData myCallData) {
+
+// CHECK:  OpLoad %v3uint [[a]]
+  uint3 a = DispatchRaysIndex();
+// CHECK:  OpLoad %v3uint [[b]]
+  uint3 b = DispatchRaysDimensions();
+}
+
+[shader("callable")]
+void MyCallMain2(inout CallData myCallData) {
+    myCallData.data = float4(0.0f,1.0f,0.0f,1.0f);
+}

+ 32 - 0
tools/clang/test/CodeGenSPIRV/raytracing.nv.miss.hlsl

@@ -0,0 +1,32 @@
+// Run: %dxc -T lib_6_3
+// CHECK:  OpCapability RayTracingNV
+// CHECK:  OpExtension "SPV_NV_ray_tracing"
+// CHECK:  OpDecorate [[a:%\d+]] BuiltIn LaunchIdNV
+// CHECK:  OpDecorate [[b:%\d+]] BuiltIn LaunchSizeNV
+// CHECK:  OpDecorate [[c:%\d+]] BuiltIn WorldRayOriginNV
+// CHECK:  OpDecorate [[d:%\d+]] BuiltIn WorldRayDirectionNV
+// CHECK:  OpDecorate [[e:%\d+]] BuiltIn RayTminNV
+// CHECK:  OpDecorate [[f:%\d+]] BuiltIn IncomingRayFlagsNV
+
+// CHECK:  OpTypePointer IncomingRayPayloadNV %Payload
+struct Payload
+{
+  float4 color;
+};
+
+[shader("miss")]
+void main(inout Payload MyPayload) {
+
+// CHECK:  OpLoad %v3uint [[a]]
+  uint3 _1 = DispatchRaysIndex();
+// CHECK:  OpLoad %v3uint [[b]]
+  uint3 _2 = DispatchRaysDimensions();
+// CHECK:  OpLoad %v3float [[c]]
+  float3 _3 = WorldRayOrigin();
+// CHECK:  OpLoad %v3float [[d]]
+  float3 _4 = WorldRayDirection();
+// CHECK:  OpLoad %float [[e]]
+  float _5 = RayTMin();
+// CHECK:  OpLoad %uint [[f]]
+  uint _6 = RayFlags();
+}

+ 38 - 0
tools/clang/test/CodeGenSPIRV/raytracing.nv.raygen.hlsl

@@ -0,0 +1,38 @@
+// Run: %dxc -T lib_6_3
+// CHECK:  OpCapability RayTracingNV
+// CHECK:  OpExtension "SPV_NV_ray_tracing"
+// CHECK:  OpDecorate [[a:%\d+]] BuiltIn LaunchIdNV
+// CHECK:  OpDecorate [[b:%\d+]] BuiltIn LaunchSizeNV
+
+// CHECK-COUNT-1: [[rstype:%\d+]] = OpTypeAccelerationStructureNV
+RaytracingAccelerationStructure rs;
+
+struct Payload
+{
+  float4 color;
+};
+struct CallData
+{
+  float4 data;
+};
+
+[shader("raygeneration")]
+void main() {
+
+// CHECK:  OpLoad %v3uint [[a]]
+  uint3 a = DispatchRaysIndex();
+// CHECK:  OpLoad %v3uint [[b]]
+  uint3 b = DispatchRaysDimensions();
+
+  Payload myPayload = { float4(0.0f,0.0f,0.0f,0.0f) };
+  CallData myCallData = { float4(0.0f,0.0f,0.0f,0.0f) };
+  RayDesc rayDesc;
+  rayDesc.Origin = float3(0.0f, 0.0f, 0.0f);
+  rayDesc.Direction = float3(0.0f, 0.0f, -1.0f);
+  rayDesc.TMin = 0.0f;
+  rayDesc.TMax = 1000.0f;
+  // CHECK: OpTraceNV {{%\d+}} %uint_0 %uint_255 %uint_0 %uint_1 %uint_0 {{%\d+}} {{%\d+}} {{%\d+}} {{%\d+}} %uint_0
+  TraceRay(rs, 0x0, 0xff, 0, 1, 0, rayDesc, myPayload);
+  // CHECK: OpExecuteCallableNV %uint_0 %uint_0
+  CallShader(0, myCallData);
+}

+ 23 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -1789,4 +1789,27 @@ TEST_F(FileTest, PreprocessorError) {
   runFileTest("preprocess.error.hlsl", Expect::Failure);
 }
 
+// === Raytracing NV examples ===
+TEST_F(FileTest, RayTracingNVRaygen) {
+  runFileTest("raytracing.nv.raygen.hlsl");
+}
+TEST_F(FileTest, RayTracingNVIntersection) {
+  runFileTest("raytracing.nv.intersection.hlsl");
+}
+TEST_F(FileTest, RayTracingNVAnyHit) {
+  runFileTest("raytracing.nv.anyhit.hlsl");
+}
+TEST_F(FileTest, RayTracingNVClosestHit) {
+  runFileTest("raytracing.nv.closesthit.hlsl");
+}
+TEST_F(FileTest, RayTracingNVMiss) {
+  runFileTest("raytracing.nv.miss.hlsl");
+}
+TEST_F(FileTest, RayTracingNVCallable) {
+  runFileTest("raytracing.nv.callable.hlsl");
+}
+TEST_F(FileTest, RayTracingNVLibrary) {
+  runFileTest("raytracing.nv.library.hlsl");
+}
+
 } // namespace

+ 7 - 3
tools/clang/unittests/SPIRV/FileTestUtils.cpp

@@ -85,7 +85,8 @@ bool processRunCommandArgs(const llvm::StringRef runCommandLine,
     fprintf(stderr, "Error: Missing target profile argument (-T).\n");
     return false;
   }
-  if (entryPoint->empty()) {
+  // lib_6_* profile doesn't need an entryPoint
+  if (targetProfile->c_str()[0] != 'l' && entryPoint->empty()) {
     fprintf(stderr, "Error: Missing entry point argument (-E).\n");
     return false;
   }
@@ -159,8 +160,11 @@ bool runCompilerWithSpirvGeneration(const llvm::StringRef inputFilePath,
         requires_opt = true;
 
     std::vector<LPCWSTR> flags;
-    flags.push_back(L"-E");
-    flags.push_back(entry.c_str());
+    // lib_6_* profile doesn't need an entryPoint
+    if (profile.c_str()[0] != 'l') {
+      flags.push_back(L"-E");
+      flags.push_back(entry.c_str());
+    }
     flags.push_back(L"-T");
     flags.push_back(profile.c_str());
     flags.push_back(L"-spirv");