Explorar el Código

[spirv] Support HLSL 'precise' keyword (#2024)

* [spirv] Support 'precise' keyword.
* [spirv] Support 'precise' on struct members.
Ehsan hace 6 años
padre
commit
8275b8103c

+ 11 - 9
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -9,8 +9,8 @@
 #ifndef LLVM_CLANG_SPIRV_SPIRVBUILDER_H
 #define LLVM_CLANG_SPIRV_SPIRVBUILDER_H
 
-#include "clang/SPIRV/SpirvContext.h"
 #include "clang/SPIRV/SpirvBasicBlock.h"
+#include "clang/SPIRV/SpirvContext.h"
 #include "clang/SPIRV/SpirvFunction.h"
 #include "clang/SPIRV/SpirvInstruction.h"
 #include "clang/SPIRV/SpirvModule.h"
@@ -56,12 +56,13 @@ public:
   /// At any time, there can only exist at most one function under building.
   SpirvFunction *beginFunction(QualType returnType, SpirvType *functionType,
                                SourceLocation, llvm::StringRef name = "",
+                               bool isPrecise = false,
                                SpirvFunction *func = nullptr);
 
   /// \brief Creates and registers a function parameter of the given pointer
   /// type in the current function and returns its pointer.
-  SpirvFunctionParameter *addFnParam(QualType ptrType, SourceLocation,
-                                     llvm::StringRef name = "");
+  SpirvFunctionParameter *addFnParam(QualType ptrType, bool isPrecise,
+                                     SourceLocation, llvm::StringRef name = "");
 
   /// \brief Creates a local variable of the given type in the current
   /// function and returns it.
@@ -69,7 +70,7 @@ public:
   /// The corresponding pointer type of the given type will be constructed in
   /// this method for the variable itself.
   SpirvVariable *addFnVar(QualType valueType, SourceLocation,
-                          llvm::StringRef name = "",
+                          llvm::StringRef name = "", bool isPrecise = false,
                           SpirvInstruction *init = nullptr);
 
   /// \brief Ends building of the current function. All basic blocks constructed
@@ -460,7 +461,8 @@ public:
   /// Note: the corresponding pointer type of the given type will not be
   /// constructed in this method.
   SpirvVariable *addStageIOVar(QualType type, spv::StorageClass storageClass,
-                               std::string name, SourceLocation loc = {});
+                               std::string name, bool isPrecise,
+                               SourceLocation loc = {});
 
   /// \brief Adds a stage builtin variable whose value is of the given type.
   ///
@@ -468,7 +470,8 @@ public:
   /// constructed in this method.
   SpirvVariable *addStageBuiltinVar(QualType type,
                                     spv::StorageClass storageClass,
-                                    spv::BuiltIn, SourceLocation loc = {});
+                                    spv::BuiltIn, bool isPrecise,
+                                    SourceLocation loc = {});
 
   /// \brief Adds a module variable. This variable should not have the Function
   /// storage class.
@@ -477,12 +480,12 @@ public:
   /// constructed in this method.
   SpirvVariable *
   addModuleVar(QualType valueType, spv::StorageClass storageClass,
-               llvm::StringRef name = "",
+               bool isPrecise, llvm::StringRef name = "",
                llvm::Optional<SpirvInstruction *> init = llvm::None,
                SourceLocation loc = {});
   SpirvVariable *
   addModuleVar(const SpirvType *valueType, spv::StorageClass storageClass,
-               llvm::StringRef name = "",
+               bool isPrecise, llvm::StringRef name = "",
                llvm::Optional<SpirvInstruction *> init = llvm::None,
                SourceLocation loc = {});
 
@@ -558,7 +561,6 @@ public:
 public:
   std::vector<uint32_t> takeModule();
 
-
 protected:
   /// Only friend classes are allowed to add capability/extension to the module
   /// under construction.

+ 11 - 7
tools/clang/include/clang/SPIRV/SpirvFunction.h

@@ -24,9 +24,8 @@ class SpirvVisitor;
 /// The class representing a SPIR-V function in memory.
 class SpirvFunction {
 public:
-  SpirvFunction(QualType astReturnType, SpirvType *fnSpirvType,
-                spv::FunctionControlMask, SourceLocation,
-                llvm::StringRef name = "");
+  SpirvFunction(QualType astReturnType, SpirvType *fnSpirvType, SourceLocation,
+                llvm::StringRef name = "", bool precise = false);
   ~SpirvFunction() = default;
 
   // Forbid copy construction and assignment
@@ -59,7 +58,12 @@ public:
   // Store that the return type is at relaxed precision.
   void setRelaxedPrecision() { relaxedPrecision = true; }
   // Returns whether the return type has relaxed precision.
-  uint32_t isRelaxedPrecision() const { return relaxedPrecision; }
+  bool isRelaxedPrecision() const { return relaxedPrecision; }
+
+  // Store that the return value is precise.
+  void setPrecise(bool p = true) { precise = p; }
+  // Returns whether the return value is precise.
+  bool isPrecise() const { return precise; }
 
   void setSourceLocation(SourceLocation loc) { functionLoc = loc; }
   SourceLocation getSourceLocation() const { return functionLoc; }
@@ -89,6 +93,7 @@ private:
   SpirvType *fnType;      ///< The SPIR-V function type
 
   bool relaxedPrecision; ///< Whether the return type is at relaxed precision
+  bool precise;          ///< Whether the return value is 'precise'
 
   /// Legalization-specific code
   ///
@@ -99,9 +104,8 @@ private:
   bool containsAlias; ///< Whether function return type is aliased
   bool rvalue;        ///< Whether the return value is an rvalue
 
-  spv::FunctionControlMask functionControl; ///< SPIR-V function control
-  SourceLocation functionLoc;               ///< Location in source code
-  std::string functionName;                 ///< This function's name
+  SourceLocation functionLoc; ///< Location in source code
+  std::string functionName;   ///< This function's name
 
   /// Parameters to this function.
   llvm::SmallVector<SpirvFunctionParameter *, 8> parameters;

+ 9 - 2
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -147,6 +147,8 @@ public:
   void setDebugName(llvm::StringRef name) { debugName = name; }
   llvm::StringRef getDebugName() const { return debugName; }
 
+  bool isArithmeticInstruction() const;
+
   SpirvLayoutRule getLayoutRule() const { return layoutRule; }
   void setLayoutRule(SpirvLayoutRule rule) { layoutRule = rule; }
 
@@ -162,6 +164,9 @@ public:
   void setNonUniform(bool nu = true) { isNonUniform_ = nu; }
   bool isNonUniform() const { return isNonUniform_; }
 
+  void setPrecise(bool p = true) { isPrecise_ = p; }
+  bool isPrecise() const { return isPrecise_; }
+
   /// Legalization-specific code
   ///
   /// Note: the following two functions are currently needed in order to support
@@ -202,6 +207,7 @@ protected:
   bool isRValue_;
   bool isRelaxedPrecision_;
   bool isNonUniform_;
+  bool isPrecise_;
 };
 
 #define DECLARE_INVOKE_VISITOR_FOR_CLASS(cls)                                  \
@@ -450,7 +456,7 @@ public:
   };
 
   SpirvVariable(QualType resultType, SourceLocation loc, spv::StorageClass sc,
-                SpirvInstruction *initializerId = 0);
+                bool isPrecise, SpirvInstruction *initializerId = 0);
 
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
@@ -471,7 +477,8 @@ private:
 
 class SpirvFunctionParameter : public SpirvInstruction {
 public:
-  SpirvFunctionParameter(QualType resultType, SourceLocation loc);
+  SpirvFunctionParameter(QualType resultType, bool isPrecise,
+                         SourceLocation loc);
 
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {

+ 9 - 4
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -291,10 +291,10 @@ public:
               llvm::Optional<uint32_t> offset_ = llvm::None,
               llvm::Optional<uint32_t> matrixStride_ = llvm::None,
               llvm::Optional<bool> isRowMajor_ = llvm::None,
-              bool relaxedPrecision = false)
+              bool relaxedPrecision = false, bool precise = false)
         : type(type_), name(name_), offset(offset_),
           matrixStride(matrixStride_), isRowMajor(isRowMajor_),
-          isRelaxedPrecision(relaxedPrecision) {
+          isRelaxedPrecision(relaxedPrecision), isPrecise(precise) {
       // A StructType may not contain any hybrid types.
       assert(!isa<HybridType>(type_));
     }
@@ -313,6 +313,8 @@ public:
     llvm::Optional<bool> isRowMajor;
     // Whether this field is a RelaxedPrecision field.
     bool isRelaxedPrecision;
+    // Whether this field is marked as 'precise'.
+    bool isPrecise;
   };
 
   StructType(
@@ -418,9 +420,10 @@ public:
     FieldInfo(QualType astType_, llvm::StringRef name_ = "",
               clang::VKOffsetAttr *offset = nullptr,
               hlsl::ConstantPacking *packOffset = nullptr,
-              const hlsl::RegisterAssignment *regC = nullptr)
+              const hlsl::RegisterAssignment *regC = nullptr,
+              bool precise = false)
         : astType(astType_), name(name_), vkOffsetAttr(offset),
-          packOffsetAttr(packOffset), registerC(regC) {}
+          packOffsetAttr(packOffset), registerC(regC), isPrecise(precise) {}
 
     // The field's type.
     QualType astType;
@@ -432,6 +435,8 @@ public:
     hlsl::ConstantPacking *packOffsetAttr;
     // :register(c#) annotations associated with this field.
     const hlsl::RegisterAssignment *registerC;
+    // Whether this field is marked as 'precise'.
+    bool isPrecise;
   };
 
   HybridStructType(

+ 1 - 0
tools/clang/lib/SPIRV/CMakeLists.txt

@@ -15,6 +15,7 @@ add_clang_library(clangSPIRV
   InitListHandler.cpp
   LiteralTypeVisitor.cpp
   LowerTypeVisitor.cpp
+  PreciseVisitor.cpp
   RelaxedPrecisionVisitor.cpp
   SpirvBasicBlock.cpp
   SpirvBuilder.cpp

+ 68 - 49
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -543,8 +543,8 @@ SpirvFunctionParameter *
 DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
   const auto type = getTypeOrFnRetType(param);
   const auto loc = param->getLocation();
-  SpirvFunctionParameter *fnParamInstr =
-      spvBuilder.addFnParam(type, loc, param->getName());
+  SpirvFunctionParameter *fnParamInstr = spvBuilder.addFnParam(
+      type, param->hasAttr<HLSLPreciseAttr>(), loc, param->getName());
 
   bool isAlias = false;
   (void)getTypeAndCreateCounterForPotentialAliasVar(param, &isAlias);
@@ -574,8 +574,9 @@ DeclResultIdMapper::createFnVar(const VarDecl *var,
   const auto type = getTypeOrFnRetType(var);
   const auto loc = var->getLocation();
   const auto name = var->getName();
+  const bool isPrecise = var->hasAttr<HLSLPreciseAttr>();
   SpirvVariable *varInstr = spvBuilder.addFnVar(
-      type, loc, name, init.hasValue() ? init.getValue() : nullptr);
+      type, loc, name, isPrecise, init.hasValue() ? init.getValue() : nullptr);
 
   bool isAlias = false;
   (void)getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias);
@@ -593,7 +594,8 @@ DeclResultIdMapper::createFileVar(const VarDecl *var,
   const auto type = getTypeOrFnRetType(var);
   const auto loc = var->getLocation();
   SpirvVariable *varInstr = spvBuilder.addModuleVar(
-      type, spv::StorageClass::Private, var->getName(), init, loc);
+      type, spv::StorageClass::Private, var->hasAttr<HLSLPreciseAttr>(),
+      var->getName(), init, loc);
 
   bool isAlias = false;
   (void)getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias);
@@ -651,7 +653,8 @@ SpirvVariable *DeclResultIdMapper::createExternVar(const VarDecl *var) {
   const auto type = var->getType();
   const auto loc = var->getLocation();
   SpirvVariable *varInstr = spvBuilder.addModuleVar(
-      type, storageClass, var->getName(), llvm::None, loc);
+      type, storageClass, var->hasAttr<HLSLPreciseAttr>(), var->getName(),
+      llvm::None, loc);
   varInstr->setLayoutRule(rule);
   DeclSpirvInfo info(varInstr);
   astDecls[var] = info;
@@ -716,7 +719,8 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
     varType.removeLocalConst();
     HybridStructType::FieldInfo info(varType, declDecl->getName(),
                                      declDecl->getAttr<VKOffsetAttr>(),
-                                     getPackOffset(declDecl), registerC);
+                                     getPackOffset(declDecl), registerC,
+                                     declDecl->hasAttr<HLSLPreciseAttr>());
     fields.push_back(info);
   }
 
@@ -743,7 +747,10 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
       forPC ? spv::StorageClass::PushConstant : spv::StorageClass::Uniform;
 
   // Create the variable for the whole struct / struct array.
-  SpirvVariable *var = spvBuilder.addModuleVar(resultType, sc, varName);
+  // The fields may be 'precise', but the structure itself is not.
+  SpirvVariable *var =
+      spvBuilder.addModuleVar(resultType, sc, /*isPrecise*/ false, varName);
+
   const SpirvLayoutRule layoutRule =
       (forCBuffer || forGlobals)
           ? spirvOptions.cBufferLayoutRule
@@ -889,9 +896,10 @@ SpirvFunction *DeclResultIdMapper::getOrRegisterFn(const FunctionDecl *fn) {
   bool isAlias = false;
   (void)getTypeAndCreateCounterForPotentialAliasVar(fn, &isAlias);
 
-  SpirvFunction *spirvFunction = new (spvContext) SpirvFunction(
-      fn->getReturnType(), /*functionType*/ nullptr,
-      spv::FunctionControlMask::MaskNone, fn->getLocation(), fn->getName());
+  const bool isPrecise = fn->hasAttr<HLSLPreciseAttr>();
+  SpirvFunction *spirvFunction = new (spvContext)
+      SpirvFunction(fn->getReturnType(), /*functionType*/ nullptr,
+                    fn->getLocation(), fn->getName(), isPrecise);
 
   // No need to dereference to get the pointer. Function returns that are
   // stand-alone aliases are already pointers to values. All other cases should
@@ -965,8 +973,8 @@ void DeclResultIdMapper::createCounterVar(
         spvContext.getPointerType(counterType, spv::StorageClass::Uniform);
   }
 
-  SpirvVariable *counterInstr =
-      spvBuilder.addModuleVar(counterType, sc, counterName);
+  SpirvVariable *counterInstr = spvBuilder.addModuleVar(
+      counterType, sc, /*isPrecise*/ false, counterName);
 
   if (!isAlias) {
     // Non-alias counter variables should be put in to resourceVars so that
@@ -2218,7 +2226,7 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
 
   // Create a dummy StageVar for this builtin variable
   auto var = spvBuilder.addStageBuiltinVar(type, spv::StorageClass::Input,
-                                           builtIn, loc);
+                                           builtIn, /*isPrecise*/ false, loc);
 
   const hlsl::SigPoint *sigPoint =
       hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
@@ -2246,6 +2254,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   const auto semanticKind = stageVar->getSemanticInfo().getKind();
   const auto sigPointKind = sigPoint->GetKind();
   const auto type = stageVar->getAstType();
+  const auto isPrecise = decl->hasAttr<HLSLPreciseAttr>();
 
   spv::StorageClass sc = getStorageClassForSigPoint(sigPoint);
   if (sc == spv::StorageClass::Max)
@@ -2265,7 +2274,8 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
             .Default(BuiltIn::Max);
 
     assert(spvBuiltIn != BuiltIn::Max); // The frontend should guarantee this.
-    return spvBuilder.addStageBuiltinVar(type, sc, spvBuiltIn, srcLoc);
+    return spvBuilder.addStageBuiltinVar(type, sc, spvBuiltIn, isPrecise,
+                                         srcLoc);
   }
 
   // The following translation assumes that semantic validity in the current
@@ -2281,7 +2291,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     case hlsl::SigPoint::Kind::VSIn:
     case hlsl::SigPoint::Kind::PCOut:
     case hlsl::SigPoint::Kind::DSIn:
-      return spvBuilder.addStageIOVar(type, sc, name.str());
+      return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise);
     case hlsl::SigPoint::Kind::VSOut:
     case hlsl::SigPoint::Kind::HSCPIn:
     case hlsl::SigPoint::Kind::HSCPOut:
@@ -2290,11 +2300,12 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     case hlsl::SigPoint::Kind::GSVIn:
     case hlsl::SigPoint::Kind::GSOut:
       stageVar->setIsSpirvBuiltin();
-      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Position, srcLoc);
+      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Position,
+                                           isPrecise, srcLoc);
     case hlsl::SigPoint::Kind::PSIn:
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragCoord,
-                                           srcLoc);
+                                           isPrecise, srcLoc);
     default:
       llvm_unreachable("invalid usage of SV_Position sneaked in");
     }
@@ -2305,7 +2316,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   case hlsl::Semantic::Kind::VertexID: {
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::VertexIndex,
-                                         srcLoc);
+                                         isPrecise, srcLoc);
   }
   // According to DXIL spec, the InstanceID SV can be used by VSIn, VSOut,
   // HSCPIn, HSCPOut, DSCPIn, DSOut, GSVIn, GSOut, PSIn.
@@ -2316,7 +2327,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     case hlsl::SigPoint::Kind::VSIn:
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::InstanceIndex,
-                                           srcLoc);
+                                           isPrecise, srcLoc);
     case hlsl::SigPoint::Kind::VSOut:
     case hlsl::SigPoint::Kind::HSCPIn:
     case hlsl::SigPoint::Kind::HSCPOut:
@@ -2325,7 +2336,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     case hlsl::SigPoint::Kind::GSVIn:
     case hlsl::SigPoint::Kind::GSOut:
     case hlsl::SigPoint::Kind::PSIn:
-      return spvBuilder.addStageIOVar(type, sc, name.str());
+      return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise);
     default:
       llvm_unreachable("invalid usage of SV_InstanceID sneaked in");
     }
@@ -2346,7 +2357,8 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     else if (semanticKind == hlsl::Semantic::Kind::DepthLessEqual)
       spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::DepthLess,
                                   {});
-    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragDepth, srcLoc);
+    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragDepth,
+                                         isPrecise, srcLoc);
   }
   // According to DXIL spec, the ClipDistance/CullDistance SV can be used by all
   // SigPoints other than PCIn, HSIn, GSIn, PSOut, CSIn.
@@ -2358,7 +2370,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     case hlsl::SigPoint::Kind::VSIn:
     case hlsl::SigPoint::Kind::PCOut:
     case hlsl::SigPoint::Kind::DSIn:
-      return spvBuilder.addStageIOVar(type, sc, name.str());
+      return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise);
     case hlsl::SigPoint::Kind::VSOut:
     case hlsl::SigPoint::Kind::HSCPIn:
     case hlsl::SigPoint::Kind::HSCPOut:
@@ -2379,11 +2391,11 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   case hlsl::Semantic::Kind::IsFrontFace: {
     switch (sigPointKind) {
     case hlsl::SigPoint::Kind::GSOut:
-      return spvBuilder.addStageIOVar(type, sc, name.str());
+      return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise);
     case hlsl::SigPoint::Kind::PSIn:
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FrontFacing,
-                                           srcLoc);
+                                           isPrecise, srcLoc);
     default:
       llvm_unreachable("invalid usage of SV_IsFrontFace sneaked in");
     }
@@ -2395,7 +2407,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   // An arbitrary semantic is defined by users. Generate normal Vulkan stage
   // input/output variables.
   case hlsl::Semantic::Kind::Arbitrary: {
-    return spvBuilder.addStageIOVar(type, sc, name.str());
+    return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise);
     // TODO: patch constant function in hull shader
   }
   // According to DXIL spec, the DispatchThreadID SV can only be used by CSIn.
@@ -2403,29 +2415,29 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   case hlsl::Semantic::Kind::DispatchThreadID: {
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::GlobalInvocationId,
-                                         srcLoc);
+                                         isPrecise, srcLoc);
   }
   // According to DXIL spec, the GroupID SV can only be used by CSIn.
   // According to Vulkan spec, the WorkgroupId can only be used in CSIn.
   case hlsl::Semantic::Kind::GroupID: {
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::WorkgroupId,
-                                         srcLoc);
+                                         isPrecise, srcLoc);
   }
   // According to DXIL spec, the GroupThreadID SV can only be used by CSIn.
   // According to Vulkan spec, the LocalInvocationId can only be used in CSIn.
   case hlsl::Semantic::Kind::GroupThreadID: {
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::LocalInvocationId,
-                                         srcLoc);
+                                         isPrecise, srcLoc);
   }
   // According to DXIL spec, the GroupIndex SV can only be used by CSIn.
   // According to Vulkan spec, the LocalInvocationIndex can only be used in
   // CSIn.
   case hlsl::Semantic::Kind::GroupIndex: {
     stageVar->setIsSpirvBuiltin();
-    return spvBuilder.addStageBuiltinVar(type, sc,
-                                         BuiltIn::LocalInvocationIndex, srcLoc);
+    return spvBuilder.addStageBuiltinVar(
+        type, sc, BuiltIn::LocalInvocationIndex, isPrecise, srcLoc);
   }
   // According to DXIL spec, the OutputControlID SV can only be used by HSIn.
   // According to Vulkan spec, the InvocationId BuiltIn can only be used in
@@ -2433,7 +2445,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   case hlsl::Semantic::Kind::OutputControlPointID: {
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::InvocationId,
-                                         srcLoc);
+                                         isPrecise, srcLoc);
   }
   // According to DXIL spec, the PrimitiveID SV can only be used by PCIn, HSIn,
   // DSIn, GSIn, GSOut, and PSIn.
@@ -2443,7 +2455,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     // Translate to PrimitiveId BuiltIn for all valid SigPoints.
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::PrimitiveId,
-                                         srcLoc);
+                                         isPrecise, srcLoc);
   }
   // According to DXIL spec, the TessFactor SV can only be used by PCOut and
   // DSIn.
@@ -2452,7 +2464,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   case hlsl::Semantic::Kind::TessFactor: {
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessLevelOuter,
-                                         srcLoc);
+                                         isPrecise, srcLoc);
   }
   // According to DXIL spec, the InsideTessFactor SV can only be used by PCOut
   // and DSIn.
@@ -2461,13 +2473,14 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   case hlsl::Semantic::Kind::InsideTessFactor: {
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessLevelInner,
-                                         srcLoc);
+                                         isPrecise, srcLoc);
   }
   // According to DXIL spec, the DomainLocation SV can only be used by DSIn.
   // According to Vulkan spec, the TessCoord BuiltIn can only be used in DSIn.
   case hlsl::Semantic::Kind::DomainLocation: {
     stageVar->setIsSpirvBuiltin();
-    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessCoord, srcLoc);
+    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessCoord,
+                                         isPrecise, srcLoc);
   }
   // According to DXIL spec, the GSInstanceID SV can only be used by GSIn.
   // According to Vulkan spec, the InvocationId BuiltIn can only be used in
@@ -2475,19 +2488,20 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   case hlsl::Semantic::Kind::GSInstanceID: {
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::InvocationId,
-                                         srcLoc);
+                                         isPrecise, srcLoc);
   }
   // According to DXIL spec, the SampleIndex SV can only be used by PSIn.
   // According to Vulkan spec, the SampleId BuiltIn can only be used in PSIn.
   case hlsl::Semantic::Kind::SampleIndex: {
     stageVar->setIsSpirvBuiltin();
-    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleId, srcLoc);
+    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleId, isPrecise,
+                                         srcLoc);
   }
   // According to DXIL spec, the StencilRef SV can only be used by PSOut.
   case hlsl::Semantic::Kind::StencilRef: {
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragStencilRefEXT,
-                                         srcLoc);
+                                         isPrecise, srcLoc);
   }
   // According to DXIL spec, the Barycentrics SV can only be used by PSIn.
   case hlsl::Semantic::Kind::Barycentrics: {
@@ -2513,7 +2527,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
       }
     }
 
-    return spvBuilder.addStageBuiltinVar(type, sc, bi, srcLoc);
+    return spvBuilder.addStageBuiltinVar(type, sc, bi, isPrecise, srcLoc);
   }
   // According to DXIL spec, the RenderTargetArrayIndex SV can only be used by
   // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn.
@@ -2528,15 +2542,17 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     case hlsl::SigPoint::Kind::DSIn:
     case hlsl::SigPoint::Kind::DSCPIn:
     case hlsl::SigPoint::Kind::GSVIn:
-      return spvBuilder.addStageIOVar(type, sc, name.str());
+      return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise);
     case hlsl::SigPoint::Kind::VSOut:
     case hlsl::SigPoint::Kind::DSOut:
       stageVar->setIsSpirvBuiltin();
-      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer, srcLoc);
+      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer, isPrecise,
+                                           srcLoc);
     case hlsl::SigPoint::Kind::GSOut:
     case hlsl::SigPoint::Kind::PSIn:
       stageVar->setIsSpirvBuiltin();
-      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer, srcLoc);
+      return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer, isPrecise,
+                                           srcLoc);
     default:
       llvm_unreachable("invalid usage of SV_RenderTargetArrayIndex sneaked in");
     }
@@ -2554,17 +2570,17 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     case hlsl::SigPoint::Kind::DSIn:
     case hlsl::SigPoint::Kind::DSCPIn:
     case hlsl::SigPoint::Kind::GSVIn:
-      return spvBuilder.addStageIOVar(type, sc, name.str());
+      return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise);
     case hlsl::SigPoint::Kind::VSOut:
     case hlsl::SigPoint::Kind::DSOut:
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex,
-                                           srcLoc);
+                                           isPrecise, srcLoc);
     case hlsl::SigPoint::Kind::GSOut:
     case hlsl::SigPoint::Kind::PSIn:
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex,
-                                           srcLoc);
+                                           isPrecise, srcLoc);
     default:
       llvm_unreachable("invalid usage of SV_ViewportArrayIndex sneaked in");
     }
@@ -2574,7 +2590,8 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   // PSIn and PSOut.
   case hlsl::Semantic::Kind::Coverage: {
     stageVar->setIsSpirvBuiltin();
-    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleMask, srcLoc);
+    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleMask,
+                                         isPrecise, srcLoc);
   }
   // According to DXIL spec, the ViewID SV can only be used by VSIn, PCIn,
   // HSIn, DSIn, GSIn, PSIn.
@@ -2582,7 +2599,8 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   // VS/HS/DS/GS/PS input.
   case hlsl::Semantic::Kind::ViewID: {
     stageVar->setIsSpirvBuiltin();
-    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewIndex, srcLoc);
+    return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewIndex,
+                                         isPrecise, srcLoc);
   }
     // According to DXIL spec, the InnerCoverage SV can only be used as PSIn.
     // According to Vulkan spec, the FullyCoveredEXT BuiltIn can only be used as
@@ -2590,7 +2608,7 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   case hlsl::Semantic::Kind::InnerCoverage: {
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FullyCoveredEXT,
-                                         srcLoc);
+                                         isPrecise, srcLoc);
   }
   default:
     emitError("semantic %0 unimplemented", srcLoc)
@@ -2820,7 +2838,8 @@ DeclResultIdMapper::createRayTracingNVStageVar(spv::StorageClass sc,
   case spv::StorageClass::HitAttributeNV:
   case spv::StorageClass::RayPayloadNV:
   case spv::StorageClass::CallableDataNV:
-    retVal = spvBuilder.addModuleVar(type, sc, name.str());
+    retVal = spvBuilder.addModuleVar(type, sc, decl->hasAttr<HLSLPreciseAttr>(),
+                                     name.str());
     break;
 
   default:

+ 5 - 0
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -99,6 +99,11 @@ void EmitVisitor::initInstruction(SpirvInstruction *inst) {
     typeHandler.emitDecoration(getOrAssignResultId<SpirvInstruction>(inst),
                                spv::Decoration::RelaxedPrecision, {});
   }
+  // Emit NoContraction decoration (if any).
+  if (inst->isPrecise() && inst->isArithmeticInstruction()) {
+    typeHandler.emitDecoration(getOrAssignResultId<SpirvInstruction>(inst),
+                               spv::Decoration::NoContraction, {});
+  }
 
   // Initialize the current instruction for emitting.
   curInst.clear();

+ 26 - 9
tools/clang/lib/SPIRV/GlPerVertex.cpp

@@ -64,9 +64,11 @@ GlPerVertex::GlPerVertex(ASTContext &context, SpirvContext &spirvContext,
                          SpirvBuilder &spirvBuilder)
     : astContext(context), spvContext(spirvContext), spvBuilder(spirvBuilder),
       inClipVar(nullptr), inCullVar(nullptr), outClipVar(nullptr),
-      outCullVar(nullptr), inArraySize(0), outArraySize(0), inClipArraySize(1),
-      outClipArraySize(1), inCullArraySize(1), outCullArraySize(1),
-      inSemanticStrs(2, ""), outSemanticStrs(2, "") {}
+      outCullVar(nullptr), inClipPrecise(false), outClipPrecise(false),
+      inCullPrecise(false), outCullPrecise(false), inArraySize(0),
+      outArraySize(0), inClipArraySize(1), outClipArraySize(1),
+      inCullArraySize(1), outCullArraySize(1), inSemanticStrs(2, ""),
+      outSemanticStrs(2, "") {}
 
 void GlPerVertex::generateVars(uint32_t inArrayLen, uint32_t outArrayLen) {
   inArraySize = inArrayLen;
@@ -74,16 +76,16 @@ void GlPerVertex::generateVars(uint32_t inArrayLen, uint32_t outArrayLen) {
 
   if (!inClipType.empty())
     inClipVar = createClipCullDistanceVar(/*asInput=*/true, /*isClip=*/true,
-                                          inClipArraySize);
+                                          inClipArraySize, inClipPrecise);
   if (!inCullType.empty())
     inCullVar = createClipCullDistanceVar(/*asInput=*/true, /*isClip=*/false,
-                                          inCullArraySize);
+                                          inCullArraySize, inCullPrecise);
   if (!outClipType.empty())
     outClipVar = createClipCullDistanceVar(/*asInput=*/false, /*isClip=*/true,
-                                           outClipArraySize);
+                                           outClipArraySize, outClipPrecise);
   if (!outCullType.empty())
     outCullVar = createClipCullDistanceVar(/*asInput=*/false, /*isClip=*/false,
-                                           outCullArraySize);
+                                           outCullArraySize, outCullPrecise);
 }
 
 llvm::SmallVector<SpirvVariable *, 2> GlPerVertex::getStageInVars() const {
@@ -124,6 +126,7 @@ bool GlPerVertex::doGlPerVertexFacts(const DeclaratorDecl *decl,
   llvm::StringRef semanticStr;
   const hlsl::Semantic *semantic = {};
   uint32_t semanticIndex = {};
+  bool isPrecise = decl->hasAttr<HLSLPreciseAttr>();
 
   if (!getStageVarSemantic(decl, &semanticStr, &semantic, &semanticIndex)) {
     if (baseType->isStructureType()) {
@@ -190,6 +193,18 @@ bool GlPerVertex::doGlPerVertexFacts(const DeclaratorDecl *decl,
     break;
   }
 
+  if (isCull) {
+    if (asInput)
+      inCullPrecise = isPrecise;
+    else
+      outCullPrecise = isPrecise;
+  } else {
+    if (asInput)
+      inClipPrecise = isPrecise;
+    else
+      outClipPrecise = isPrecise;
+  }
+
   // Remember the semantic strings provided by the developer so that we can
   // emit OpDecorate* instructions properly for them
   if (index < kSemanticStrCount) {
@@ -307,7 +322,8 @@ void GlPerVertex::calculateClipCullDistanceArraySize() {
 }
 
 SpirvVariable *GlPerVertex::createClipCullDistanceVar(bool asInput, bool isClip,
-                                                      uint32_t arraySize) {
+                                                      uint32_t arraySize,
+                                                      bool isPrecise) {
   QualType type = astContext.getConstantArrayType(astContext.FloatTy,
                                                   llvm::APInt(32, arraySize),
                                                   clang::ArrayType::Normal, 0);
@@ -325,7 +341,8 @@ SpirvVariable *GlPerVertex::createClipCullDistanceVar(bool asInput, bool isClip,
 
   SpirvVariable *var = spvBuilder.addStageBuiltinVar(
       type, sc,
-      isClip ? spv::BuiltIn::ClipDistance : spv::BuiltIn::CullDistance);
+      isClip ? spv::BuiltIn::ClipDistance : spv::BuiltIn::CullDistance,
+      isPrecise);
 
   const auto index = isClip ? gClipDistanceIndex : gCullDistanceIndex;
   spvBuilder.decorateHlslSemantic(var, asInput ? inSemanticStrs[index]

+ 6 - 1
tools/clang/lib/SPIRV/GlPerVertex.h

@@ -95,7 +95,7 @@ private:
 
   /// Creates a stand-alone ClipDistance/CullDistance builtin variable.
   SpirvVariable *createClipCullDistanceVar(bool asInput, bool isClip,
-                                           uint32_t arraySize);
+                                           uint32_t arraySize, bool isPrecise);
 
   /// Creates SPIR-V instructions for reading the data starting from offset in
   /// the ClipDistance/CullDistance builtin. The data read will be transformed
@@ -135,6 +135,11 @@ private:
   SpirvVariable *inClipVar, *inCullVar;
   SpirvVariable *outClipVar, *outCullVar;
 
+  // We need to record whether the variables with 'SV_ClipDistance' or
+  // 'SV_CullDistance' have the HLSL 'precise' keyword.
+  bool inClipPrecise, outClipPrecise;
+  bool inCullPrecise, outCullPrecise;
+
   /// The array size for the input/output gl_PerVertex block member variables.
   /// HS input and output, DS input, GS input has an additional level of
   /// arrayness. The array size is stored in this variable. Zero means

+ 7 - 1
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -378,7 +378,8 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
           field->getType(), field->getName(),
           /*vkoffset*/ field->getAttr<VKOffsetAttr>(),
           /*packoffset*/ getPackOffset(field),
-          /*RegisterAssignment*/ nullptr));
+          /*RegisterAssignment*/ nullptr,
+          /*isPrecise*/ field->hasAttr<HLSLPreciseAttr>()));
     }
 
     auto loweredFields = populateLayoutInformation(fields, rule);
@@ -728,6 +729,11 @@ LowerTypeVisitor::populateLayoutInformation(
       loweredField.isRelaxedPrecision = true;
     }
 
+    // Set 'precise' information for the lowered field.
+    if (field.isPrecise) {
+      loweredField.isPrecise = true;
+    }
+
     // We only need layout information for strcutres with non-void layout rule.
     if (rule == SpirvLayoutRule::Void) {
       loweredFields.push_back(loweredField);

+ 255 - 0
tools/clang/lib/SPIRV/PreciseVisitor.cpp

@@ -0,0 +1,255 @@
+//===--- PreciseVisitor.cpp ------- Precise Visitor --------------*- C++ -*-==//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PreciseVisitor.h"
+#include "clang/SPIRV/AstTypeProbe.h"
+#include "clang/SPIRV/SpirvFunction.h"
+#include "clang/SPIRV/SpirvType.h"
+
+#include <stack>
+
+namespace {
+
+/// \brief Returns true if the given OpAccessChain instruction is accessing a
+/// precise variable, or accessing a precise member of a structure. Returns
+/// false otherwise.
+bool isAccessingPrecise(clang::spirv::SpirvAccessChain *inst) {
+  using namespace clang::spirv;
+
+  // If the access chain base is another access chain and so on, first flatten
+  // them (from the bottom to the top). For example:
+  // %x = OpAccessChain <type> %obj %int_1 %int_2
+  // %y = OpAccessChain <type> %x   %int_3 %int_4
+  // %z = OpAccessChain <type> %y   %int_5 %int_6
+  // Should be flattened to:
+  // %z = OpAccessChain <type> %obj %int_1 %int_2 %int_3 %int_4 %int_5 %int_6
+  std::stack<SpirvInstruction *> indexes;
+  SpirvInstruction *base = inst;
+  while (auto *accessChain = llvm::dyn_cast<SpirvAccessChain>(base)) {
+    for (auto iter = accessChain->getIndexes().rbegin();
+         iter != accessChain->getIndexes().rend(); ++iter) {
+      indexes.push(*iter);
+    }
+    base = accessChain->getBase();
+
+    // If we reach a 'precise' base at any level, return true.
+    if (base->isPrecise())
+      return true;
+  }
+
+  // Start from the lowest level base (%obj in the above example), and step
+  // forward using the 'indexes'. If a 'precise' structure field is discovered
+  // at any point, return true.
+  const SpirvType *baseType = base->getResultType();
+  while (baseType && !indexes.empty()) {
+    if (auto *vecType = llvm::dyn_cast<VectorType>(baseType)) {
+      indexes.pop();
+      baseType = vecType->getElementType();
+    } else if (auto *matType = llvm::dyn_cast<MatrixType>(baseType)) {
+      indexes.pop();
+      baseType = matType->getVecType();
+    } else if (auto *arrType = llvm::dyn_cast<ArrayType>(baseType)) {
+      indexes.pop();
+      baseType = arrType->getElementType();
+    } else if (auto *raType = llvm::dyn_cast<RuntimeArrayType>(baseType)) {
+      indexes.pop();
+      baseType = raType->getElementType();
+    } else if (auto *structType = llvm::dyn_cast<StructType>(baseType)) {
+      SpirvInstruction *index = indexes.top();
+      if (auto *constInt = llvm::dyn_cast<SpirvConstantInteger>(index)) {
+        uint32_t indexValue =
+            static_cast<uint32_t>(constInt->getValue().getZExtValue());
+        auto fields = structType->getFields();
+        assert(indexValue < fields.size());
+        auto &fieldInfo = fields[indexValue];
+        if (fieldInfo.isPrecise) {
+          return true;
+        } else {
+          baseType = fieldInfo.type;
+          indexes.pop();
+        }
+      } else {
+        // Trying to index into a structure using a variable? This shouldn't be
+        // happening.
+        assert(false && "indexing into a struct with variable value");
+        return false;
+      }
+    } else if (auto *ptrType = llvm::dyn_cast<SpirvPointerType>(baseType)) {
+      // Note: no need to pop the stack here.
+      baseType = ptrType->getPointeeType();
+    } else {
+      return false;
+    }
+  }
+
+  return false;
+}
+
+} // anonymous namespace
+
+namespace clang {
+namespace spirv {
+
+bool PreciseVisitor::visit(SpirvFunction *fn, Phase phase) {
+  // Before going through the function instructions
+  if (phase == Visitor::Phase::Init) {
+    curFnRetValPrecise = fn->isPrecise();
+  }
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvReturn *inst) {
+  if (inst->hasReturnValue()) {
+    inst->getReturnValue()->setPrecise(curFnRetValPrecise);
+  }
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvVariable *var) {
+  if (var->hasInitializer())
+    var->getInitializer()->setPrecise(var->isPrecise());
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvSelect *inst) {
+  inst->getTrueObject()->setPrecise(inst->isPrecise());
+  inst->getFalseObject()->setPrecise(inst->isPrecise());
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvVectorShuffle *inst) {
+  // If the result of a vector shuffle is 'precise', the vectors from which the
+  // elements are chosen should also be 'precise'.
+  if (inst->isPrecise()) {
+    auto *vec1 = inst->getVec1();
+    auto *vec2 = inst->getVec2();
+    const auto vec1Type = vec1->getAstResultType();
+    const auto vec2Type = vec2->getAstResultType();
+    uint32_t vec1Size;
+    uint32_t vec2Size;
+    (void)isVectorType(vec1Type, nullptr, &vec1Size);
+    (void)isVectorType(vec2Type, nullptr, &vec2Size);
+    bool vec1ElemUsed = false;
+    bool vec2ElemUsed = false;
+    for (auto component : inst->getComponents()) {
+      if (component < vec1Size)
+        vec1ElemUsed = true;
+      else
+        vec2ElemUsed = true;
+    }
+
+    if (vec1ElemUsed)
+      vec1->setPrecise();
+    if (vec2ElemUsed)
+      vec2->setPrecise();
+  }
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvBitFieldExtract *inst) {
+  inst->getBase()->setPrecise(inst->isPrecise());
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvBitFieldInsert *inst) {
+  inst->getBase()->setPrecise(inst->isPrecise());
+  inst->getInsert()->setPrecise(inst->isPrecise());
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvAtomic *inst) {
+  if (inst->isPrecise() && inst->hasValue())
+    inst->getValue()->setPrecise();
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvCompositeConstruct *inst) {
+  if (inst->isPrecise())
+    for (auto *consituent : inst->getConstituents())
+      consituent->setPrecise();
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvCompositeExtract *inst) {
+  inst->getComposite()->setPrecise(inst->isPrecise());
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvCompositeInsert *inst) {
+  inst->getComposite()->setPrecise(inst->isPrecise());
+  inst->getObject()->setPrecise(inst->isPrecise());
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvLoad *inst) {
+  // If the instruction result is precise, the pointer we're loading from should
+  // also be marked as precise.
+  if (inst->isPrecise())
+    inst->getPointer()->setPrecise();
+
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvStore *inst) {
+  // If the 'pointer' to which we are storing is marked as 'precise', the object
+  // we are storing should also be marked as 'precise'.
+  // Note that the 'pointer' may either be an 'OpVariable' or it might be the
+  // result of one or more access chains (in which case we should figure out if
+  // the 'base' of the access chain is 'precise').
+  auto *ptr = inst->getPointer();
+  auto *obj = inst->getObject();
+
+  // The simple case (target is a precise variable).
+  if (ptr->isPrecise()) {
+    obj->setPrecise();
+    return true;
+  }
+
+  if (auto *accessChain = llvm::dyn_cast<SpirvAccessChain>(ptr)) {
+    if (isAccessingPrecise(accessChain)) {
+      obj->setPrecise();
+      return true;
+    }
+  }
+
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvBinaryOp *inst) {
+  bool isPrecise = inst->isPrecise();
+  inst->getOperand1()->setPrecise(isPrecise);
+  inst->getOperand2()->setPrecise(isPrecise);
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvUnaryOp *inst) {
+  inst->getOperand()->setPrecise(inst->isPrecise());
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvNonUniformBinaryOp *inst) {
+  inst->getArg1()->setPrecise(inst->isPrecise());
+  inst->getArg2()->setPrecise(inst->isPrecise());
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvNonUniformUnaryOp *inst) {
+  inst->getArg()->setPrecise(inst->isPrecise());
+  return true;
+}
+
+bool PreciseVisitor::visit(SpirvExtInst *inst) {
+  if (inst->isPrecise())
+    for (auto *operand : inst->getOperands())
+      operand->setPrecise();
+  return true;
+}
+
+} // end namespace spirv
+} // end namespace clang

+ 55 - 0
tools/clang/lib/SPIRV/PreciseVisitor.h

@@ -0,0 +1,55 @@
+//===--- PreciseVisitor.h ---- Precise Visitor -------------------*- C++ -*-==//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_LIB_SPIRV_PRECISEVISITOR_H
+#define LLVM_CLANG_LIB_SPIRV_PRECISEVISITOR_H
+
+#include "clang/SPIRV/SpirvVisitor.h"
+
+namespace clang {
+namespace spirv {
+
+class PreciseVisitor : public Visitor {
+public:
+  PreciseVisitor(SpirvContext &spvCtx, const SpirvCodeGenOptions &opts)
+      : Visitor(opts, spvCtx) {}
+
+  bool visit(SpirvFunction *, Phase);
+
+  bool visit(SpirvVariable *);
+  bool visit(SpirvReturn *);
+  bool visit(SpirvSelect *);
+  bool visit(SpirvVectorShuffle *);
+  bool visit(SpirvBitFieldExtract *);
+  bool visit(SpirvBitFieldInsert *);
+  bool visit(SpirvAtomic *);
+  bool visit(SpirvCompositeConstruct *);
+  bool visit(SpirvCompositeExtract *);
+  bool visit(SpirvCompositeInsert *);
+  bool visit(SpirvLoad *);
+  bool visit(SpirvStore *);
+  bool visit(SpirvBinaryOp *);
+  bool visit(SpirvUnaryOp *);
+  bool visit(SpirvNonUniformBinaryOp *);
+  bool visit(SpirvNonUniformUnaryOp *);
+  bool visit(SpirvExtInst *);
+
+  // TODO: Support propagation of 'precise' through OpSpecConstantOp and image
+  // operations if necessary. Related instruction classes are:
+  // SpirvSpecConstantBinaryOp, SpirvSpecConstantUnaryOp
+  // SpirvImageOp, SpirvImageQuery, SpirvImageTexelPointer, SpirvSampledImage
+
+private:
+  bool curFnRetValPrecise; ///< Whether current function is 'precise'
+};
+
+} // end namespace spirv
+} // end namespace clang
+
+#endif // LLVM_CLANG_LIB_SPIRV_PRECISEVISITOR_H

+ 32 - 22
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -12,6 +12,7 @@
 #include "EmitVisitor.h"
 #include "LiteralTypeVisitor.h"
 #include "LowerTypeVisitor.h"
+#include "PreciseVisitor.h"
 #include "RelaxedPrecisionVisitor.h"
 #include "clang/SPIRV/AstTypeProbe.h"
 
@@ -25,11 +26,10 @@ SpirvBuilder::SpirvBuilder(ASTContext &ac, SpirvContext &ctx,
   module = new (context) SpirvModule;
 }
 
-SpirvFunction *SpirvBuilder::beginFunction(QualType returnType,
-                                           SpirvType *functionType,
-                                           SourceLocation loc,
-                                           llvm::StringRef funcName,
-                                           SpirvFunction *func) {
+SpirvFunction *
+SpirvBuilder::beginFunction(QualType returnType, SpirvType *functionType,
+                            SourceLocation loc, llvm::StringRef funcName,
+                            bool isPrecise, SpirvFunction *func) {
   assert(!function && "found nested function");
   if (func) {
     function = func;
@@ -37,20 +37,21 @@ SpirvFunction *SpirvBuilder::beginFunction(QualType returnType,
     function->setFunctionType(functionType);
     function->setSourceLocation(loc);
     function->setFunctionName(funcName);
+    function->setPrecise(isPrecise);
   } else {
     function = new (context)
-        SpirvFunction(returnType, functionType,
-                      spv::FunctionControlMask::MaskNone, loc, funcName);
+        SpirvFunction(returnType, functionType, loc, funcName, isPrecise);
   }
 
   return function;
 }
 
 SpirvFunctionParameter *SpirvBuilder::addFnParam(QualType ptrType,
+                                                 bool isPrecise,
                                                  SourceLocation loc,
                                                  llvm::StringRef name) {
   assert(function && "found detached parameter");
-  auto *param = new (context) SpirvFunctionParameter(ptrType, loc);
+  auto *param = new (context) SpirvFunctionParameter(ptrType, isPrecise, loc);
   param->setStorageClass(spv::StorageClass::Function);
   param->setDebugName(name);
   function->addParameter(param);
@@ -58,11 +59,11 @@ SpirvFunctionParameter *SpirvBuilder::addFnParam(QualType ptrType,
 }
 
 SpirvVariable *SpirvBuilder::addFnVar(QualType valueType, SourceLocation loc,
-                                      llvm::StringRef name,
+                                      llvm::StringRef name, bool isPrecise,
                                       SpirvInstruction *init) {
   assert(function && "found detached local variable");
-  auto *var = new (context)
-      SpirvVariable(valueType, loc, spv::StorageClass::Function, init);
+  auto *var = new (context) SpirvVariable(
+      valueType, loc, spv::StorageClass::Function, isPrecise, init);
   var->setDebugName(name);
   function->addVariable(var);
   return var;
@@ -830,10 +831,10 @@ SpirvExtInstImport *SpirvBuilder::getGLSLExtInstSet(SourceLocation loc) {
 
 SpirvVariable *SpirvBuilder::addStageIOVar(QualType type,
                                            spv::StorageClass storageClass,
-                                           std::string name,
+                                           std::string name, bool isPrecise,
                                            SourceLocation loc) {
   // Note: We store the underlying type in the variable, *not* the pointer type.
-  auto *var = new (context) SpirvVariable(type, loc, storageClass);
+  auto *var = new (context) SpirvVariable(type, loc, storageClass, isPrecise);
   var->setDebugName(name);
   module->addVariable(var);
   return var;
@@ -842,6 +843,7 @@ SpirvVariable *SpirvBuilder::addStageIOVar(QualType type,
 SpirvVariable *SpirvBuilder::addStageBuiltinVar(QualType type,
                                                 spv::StorageClass storageClass,
                                                 spv::BuiltIn builtin,
+                                                bool isPrecise,
                                                 SourceLocation loc) {
   // If the built-in variable has already been added (via a built-in alias),
   // return the existing variable.
@@ -855,7 +857,7 @@ SpirvVariable *SpirvBuilder::addStageBuiltinVar(QualType type,
   }
 
   // Note: We store the underlying type in the variable, *not* the pointer type.
-  auto *var = new (context) SpirvVariable(type, loc, storageClass);
+  auto *var = new (context) SpirvVariable(type, loc, storageClass, isPrecise);
   module->addVariable(var);
 
   // Decorate with the specified Builtin
@@ -869,25 +871,29 @@ SpirvVariable *SpirvBuilder::addStageBuiltinVar(QualType type,
   return var;
 }
 
-SpirvVariable *SpirvBuilder::addModuleVar(
-    QualType type, spv::StorageClass storageClass, llvm::StringRef name,
-    llvm::Optional<SpirvInstruction *> init, SourceLocation loc) {
+SpirvVariable *
+SpirvBuilder::addModuleVar(QualType type, spv::StorageClass storageClass,
+                           bool isPrecise, llvm::StringRef name,
+                           llvm::Optional<SpirvInstruction *> init,
+                           SourceLocation loc) {
   assert(storageClass != spv::StorageClass::Function);
   // Note: We store the underlying type in the variable, *not* the pointer type.
-  auto *var = new (context) SpirvVariable(
-      type, loc, storageClass, init.hasValue() ? init.getValue() : nullptr);
+  auto *var =
+      new (context) SpirvVariable(type, loc, storageClass, isPrecise,
+                                  init.hasValue() ? init.getValue() : nullptr);
   var->setDebugName(name);
   module->addVariable(var);
   return var;
 }
 
 SpirvVariable *SpirvBuilder::addModuleVar(
-    const SpirvType *type, spv::StorageClass storageClass, llvm::StringRef name,
-    llvm::Optional<SpirvInstruction *> init, SourceLocation loc) {
+    const SpirvType *type, spv::StorageClass storageClass, bool isPrecise,
+    llvm::StringRef name, llvm::Optional<SpirvInstruction *> init,
+    SourceLocation loc) {
   assert(storageClass != spv::StorageClass::Function);
   // Note: We store the underlying type in the variable, *not* the pointer type.
   auto *var =
-      new (context) SpirvVariable(/*QualType*/ {}, loc, storageClass,
+      new (context) SpirvVariable(/*QualType*/ {}, loc, storageClass, isPrecise,
                                   init.hasValue() ? init.getValue() : nullptr);
   var->setResultType(type);
   var->setDebugName(name);
@@ -1058,6 +1064,7 @@ std::vector<uint32_t> SpirvBuilder::takeModule() {
   LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions);
   CapabilityVisitor capabilityVisitor(astContext, context, spirvOptions, *this);
   RelaxedPrecisionVisitor relaxedPrecisionVisitor(context, spirvOptions);
+  PreciseVisitor preciseVisitor(context, spirvOptions);
   EmitVisitor emitVisitor(astContext, context, spirvOptions);
 
   module->invokeVisitor(&literalTypeVisitor, true);
@@ -1071,6 +1078,9 @@ std::vector<uint32_t> SpirvBuilder::takeModule() {
   // Propagate RelaxedPrecision decorations
   module->invokeVisitor(&relaxedPrecisionVisitor);
 
+  // Propagate NoContraction decorations
+  module->invokeVisitor(&preciseVisitor, true);
+
   // Emit SPIR-V
   module->invokeVisitor(&emitVisitor);
 

+ 22 - 13
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -1032,13 +1032,13 @@ void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) {
 
   auto *funcType = spvContext.getFunctionType(retType, paramTypes);
   spvBuilder.beginFunction(retType, funcType, decl->getLocation(), funcName,
-                           func);
+                           decl->hasAttr<HLSLPreciseAttr>(), func);
 
   if (isNonStaticMemberFn) {
-    // Remember the parameter for the this object so later we can handle
+    // Remember the parameter for the 'this' object so later we can handle
     // CXXThisExpr correctly.
-    curThis = spvBuilder.addFnParam(paramTypes[0], /*SourceLocation*/ {},
-                                    "param.this");
+    curThis = spvBuilder.addFnParam(paramTypes[0], /*isPrecise*/ false,
+                                    /*SourceLocation*/ {}, "param.this");
   }
 
   // Create all parameters.
@@ -2049,8 +2049,14 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
       const QualType varType =
           declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param);
       const std::string varName = "param.var." + param->getNameAsString();
-      auto *tempVar =
-          spvBuilder.addFnVar(varType, param->getLocation(), varName);
+      // Temporary "param.var.*" variables are used for OpFunctionCall purposes.
+      // 'precise' attribute on function parameters only affect computations
+      // inside the function, not the variables at the call sites. Therefore, we
+      // do not need to mark the "param.var.*" variables as precise.
+      const bool isPrecise = false;
+
+      auto *tempVar = spvBuilder.addFnVar(varType, param->getLocation(),
+                                          varName, isPrecise);
 
       vars.push_back(tempVar);
       isTempVar.push_back(true);
@@ -4804,7 +4810,7 @@ void SpirvEmitter::storeValue(SpirvInstruction *lhsPtr,
                               SpirvInstruction *rhsVal, QualType lhsValType) {
   // Defend against nullptr source or destination so errors can bubble up to the
   // user.
-  if(!lhsPtr || !rhsVal)
+  if (!lhsPtr || !rhsVal)
     return;
 
   if (const auto *refType = lhsValType->getAs<ReferenceType>())
@@ -5165,9 +5171,9 @@ void SpirvEmitter::initOnce(QualType varType, std::string varName,
   varName = "init.done." + varName;
 
   // Create a file/module visible variable to hold the initialization state.
-  SpirvVariable *initDoneVar =
-      spvBuilder.addModuleVar(astContext.BoolTy, spv::StorageClass::Private,
-                              varName, spvBuilder.getConstantBool(false));
+  SpirvVariable *initDoneVar = spvBuilder.addModuleVar(
+      astContext.BoolTy, spv::StorageClass::Private, /*isPrecise*/ false,
+      varName, spvBuilder.getConstantBool(false));
 
   auto *condition = spvBuilder.createLoad(astContext.BoolTy, initDoneVar);
 
@@ -9605,7 +9611,8 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
     const auto paramType = param->getType();
     std::string tempVarName = "param.var." + param->getNameAsString();
     auto *tempVar =
-        spvBuilder.addFnVar(paramType, param->getLocation(), tempVarName);
+        spvBuilder.addFnVar(paramType, param->getLocation(), tempVarName,
+                            param->hasAttr<HLSLPreciseAttr>());
 
     SpirvVariable *curStageVar = nullptr;
 
@@ -9802,7 +9809,8 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
     const auto paramType = param->getType();
     std::string tempVarName = "param.var." + param->getNameAsString();
     auto *tempVar =
-        spvBuilder.addFnVar(paramType, param->getLocation(), tempVarName);
+        spvBuilder.addFnVar(paramType, param->getLocation(), tempVarName,
+                            param->hasAttr<HLSLPreciseAttr>());
 
     params.push_back(tempVar);
 
@@ -9982,7 +9990,8 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
         const QualType type = param->getType();
         std::string tempVarName = "param.var." + param->getNameAsString();
         auto *tempVar =
-            spvBuilder.addFnVar(type, param->getLocation(), tempVarName);
+            spvBuilder.addFnVar(type, param->getLocation(), tempVarName,
+                                param->hasAttr<HLSLPreciseAttr>());
         SpirvInstruction *loadedValue = nullptr;
         declIdMapper.createStageInputVar(param, &loadedValue, /*forPCF*/ true);
         spvBuilder.createStore(tempVar, loadedValue);

+ 4 - 4
tools/clang/lib/SPIRV/SpirvFunction.cpp

@@ -16,11 +16,11 @@ namespace clang {
 namespace spirv {
 
 SpirvFunction::SpirvFunction(QualType returnType, SpirvType *functionType,
-                             spv::FunctionControlMask control,
-                             SourceLocation loc, llvm::StringRef name)
+                             SourceLocation loc, llvm::StringRef name,
+                             bool isPrecise)
     : functionId(0), astReturnType(returnType), returnType(nullptr),
-      fnType(functionType), relaxedPrecision(false), containsAlias(false),
-      rvalue(false), functionControl(control), functionLoc(loc),
+      fnType(functionType), relaxedPrecision(false), precise(isPrecise),
+      containsAlias(false), rvalue(false), functionLoc(loc),
       functionName(name) {}
 
 bool SpirvFunction::invokeVisitor(Visitor *visitor, bool reverseOrder) {

+ 42 - 3
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -91,7 +91,42 @@ SpirvInstruction::SpirvInstruction(Kind k, spv::Op op, QualType astType,
       debugName(), resultType(nullptr), resultTypeId(0),
       layoutRule(SpirvLayoutRule::Void), containsAlias(false),
       storageClass(spv::StorageClass::Function), isRValue_(false),
-      isRelaxedPrecision_(false), isNonUniform_(false) {}
+      isRelaxedPrecision_(false), isNonUniform_(false), isPrecise_(false) {}
+
+bool SpirvInstruction::isArithmeticInstruction() const {
+  switch (opcode) {
+  case spv::Op::OpSNegate:
+  case spv::Op::OpFNegate:
+  case spv::Op::OpIAdd:
+  case spv::Op::OpFAdd:
+  case spv::Op::OpISub:
+  case spv::Op::OpFSub:
+  case spv::Op::OpIMul:
+  case spv::Op::OpFMul:
+  case spv::Op::OpUDiv:
+  case spv::Op::OpSDiv:
+  case spv::Op::OpFDiv:
+  case spv::Op::OpUMod:
+  case spv::Op::OpSRem:
+  case spv::Op::OpSMod:
+  case spv::Op::OpFRem:
+  case spv::Op::OpFMod:
+  case spv::Op::OpVectorTimesScalar:
+  case spv::Op::OpMatrixTimesScalar:
+  case spv::Op::OpVectorTimesMatrix:
+  case spv::Op::OpMatrixTimesVector:
+  case spv::Op::OpMatrixTimesMatrix:
+  case spv::Op::OpOuterProduct:
+  case spv::Op::OpDot:
+  case spv::Op::OpIAddCarry:
+  case spv::Op::OpISubBorrow:
+  case spv::Op::OpUMulExtended:
+  case spv::Op::OpSMulExtended:
+    return true;
+  default:
+    return false;
+  }
+}
 
 SpirvCapability::SpirvCapability(SourceLocation loc, spv::Capability cap)
     : SpirvInstruction(IK_Capability, spv::Op::OpCapability, QualType(), loc),
@@ -198,17 +233,21 @@ spv::Op SpirvDecoration::getDecorateOpcode(
 }
 
 SpirvVariable::SpirvVariable(QualType resultType, SourceLocation loc,
-                             spv::StorageClass sc,
+                             spv::StorageClass sc, bool precise,
                              SpirvInstruction *initializerInst)
     : SpirvInstruction(IK_Variable, spv::Op::OpVariable, resultType, loc),
       initializer(initializerInst) {
   setStorageClass(sc);
+  setPrecise(precise);
 }
 
 SpirvFunctionParameter::SpirvFunctionParameter(QualType resultType,
+                                               bool isPrecise,
                                                SourceLocation loc)
     : SpirvInstruction(IK_FunctionParameter, spv::Op::OpFunctionParameter,
-                       resultType, loc) {}
+                       resultType, loc) {
+  setPrecise(isPrecise);
+}
 
 SpirvMerge::SpirvMerge(Kind kind, spv::Op op, SourceLocation loc,
                        SpirvBasicBlock *mergeLabel)

+ 3 - 1
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -215,7 +215,9 @@ operator==(const StructType::FieldInfo &that) const {
          (!isRowMajor.hasValue() ||
           isRowMajor.getValue() == that.isRowMajor.getValue()) &&
          // Both should have the same precision
-         isRelaxedPrecision == that.isRelaxedPrecision;
+         isRelaxedPrecision == that.isRelaxedPrecision &&
+         // Both fields should be precise or not precise
+         isPrecise == that.isPrecise;
 }
 
 bool StructType::operator==(const StructType &that) const {

+ 110 - 0
tools/clang/test/CodeGenSPIRV/decoration.no-contraction.hlsl

@@ -0,0 +1,110 @@
+// Run: %dxc -T ps_6_0 -E main
+
+float func(float e, float f, float g, float h);
+float func2(float e, float f, float g, float h);
+precise float func3(float e, float f, float g, float h);
+float func4(float i, float j, precise out float k);
+
+// The purpose of this to make sure the first NoContraction decoration is on a_mul_b.
+// CHECK:      OpName %bb_entry_3
+// CHECK-NEXT: OpDecorate [[a_mul_b:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[c_mul_d:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[r_plus_s:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[aw_mul_bw:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[cw_mul_dw:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[awbw_plus_cwdw:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[func2_e_mul_f:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[func2_g_mul_h:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[func2_ef_plus_gh:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[func3_e_mul_f:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[func3_g_mul_h:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[func3_ef_plus_gh:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[func4_i_mul_i:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[func4_ii_plus_j:%\d+]] NoContraction
+
+void main() {
+  float4 a, b, c, d;
+  precise float4 v; 
+
+// CHECK:      [[a_mul_b]] = OpFMul %v3float {{%\d+}} {{%\d+}}
+// CHECK-NEXT:               OpStore %r [[a_mul_b]]
+  float3 r = float3((float3)a * (float3)b); // precise, used to compute v.xyz
+
+
+// CHECK:      [[c_mul_d]] = OpFMul %v3float {{%\d+}} {{%\d+}}
+// CHECK-NEXT:               OpStore %s [[c_mul_d]]
+  float3 s = float3((float3)c * (float3)d); // precise, used to compute v.xyz
+
+// CHECK:                     OpLoad %v3float %r
+// CHECK-NEXT:                OpLoad %v3float %s
+// CHECK-NEXT: [[r_plus_s]] = OpFAdd %v3float {{%\d+}} {{%\d+}}
+  v.xyz = r + s; // precise
+  
+// CHECK:                           OpAccessChain %_ptr_Function_float %a %int_3
+// CHECK-NEXT:                      OpLoad %float
+// CHECK-NEXT:                      OpAccessChain %_ptr_Function_float %b %int_3
+// CHECK-NEXT:                      OpLoad %float
+// CHECK-NEXT:      [[aw_mul_bw]] = OpFMul %float
+// CHECK-NEXT:                      OpAccessChain %_ptr_Function_float %c %int_3
+// CHECK-NEXT:                      OpLoad %float
+// CHECK-NEXT:                      OpAccessChain %_ptr_Function_float %d %int_3
+// CHECK-NEXT:                      OpLoad %float
+// CHECK-NEXT:      [[cw_mul_dw]] = OpFMul %float
+// CHECK-NEXT: [[awbw_plus_cwdw]] = OpFAdd %float
+  v.w = (a.w * b.w) + (c.w * d.w);  // precise
+
+
+  v.x = func(a.x, b.x, c.x, d.x);   // values computed in func() are NOT precise
+  
+  // Even though v.x is precise, values computed inside func2 are not forced to
+  // be precise. Meaning, precise-ness does not cross function boundary.
+  v.x = func2(a.x, b.x, c.x, d.x);
+  
+  // Even though v.x is precise, values computed inside func4 are not forced to
+  // be precise. Meaning, precise-ness does not cross function boundary.
+  v.x = func3(a.x, b.x, c.x, d.x);
+  
+  func4(a.x * b.x, c.x * d.x, v.x);
+}
+
+float func(float e, float f, float g, float h) {
+  return (e*f) + (g*h); // no constraint on order or operator consistency
+}
+
+// CHECK: %func2 = OpFunction %float
+float func2(float e, float f, float g, float h) {
+// CHECK:                             OpLoad %float %e_0
+// CHECK-NEXT:                        OpLoad %float %f_0
+// CHECK-NEXT:    [[func2_e_mul_f]] = OpFMul %float
+// CHECK-NEXT:                        OpLoad %float %g_0
+// CHECK-NEXT:                        OpLoad %float %h_0
+// CHECK-NEXT:    [[func2_g_mul_h]] = OpFMul %float
+// CHECK-NEXT: [[func2_ef_plus_gh]] = OpFAdd %float
+  precise float result = (e*f) + (g*h); // ensures same precision for the two multiplies
+  return result;
+}
+
+// CHECK: %func3 = OpFunction %float
+precise float func3(float e, float f, float g, float h) {
+// CHECK:                             OpLoad %float %e_1
+// CHECK-NEXT:                        OpLoad %float %f_1
+// CHECK-NEXT:    [[func3_e_mul_f]] = OpFMul %float
+// CHECK-NEXT:                        OpLoad %float %g_1
+// CHECK-NEXT:                        OpLoad %float %h_1
+// CHECK-NEXT:    [[func3_g_mul_h]] = OpFMul %float
+// CHECK-NEXT: [[func3_ef_plus_gh]] = OpFAdd %float %162 %165
+  float result = (e*f) + (g*h); // precise because it's the function return value.
+  return result;
+}
+
+// CHECK: %func4 = OpFunction %float
+float func4(float i, float j, precise out float k) {
+// CEHCK:                            OpLoad %float %i
+// CEHCK-NEXT:                       OpLoad %float %i
+// CEHCK-NEXT:   [[func4_i_mul_i]] = OpFMul %float
+// CEHCK-NEXT:                       OpLoad %float %j
+// CEHCK-NEXT: [[func4_ii_plus_j]] = OpFAdd %float
+  k = i * i + j; // precise, due to <k> declaration
+  return 1.0;
+}
+

+ 116 - 0
tools/clang/test/CodeGenSPIRV/decoration.no-contraction.stage-vars.hlsl

@@ -0,0 +1,116 @@
+// Run: %dxc -T vs_6_0 -E main -fspv-reflect
+
+// CHECK:      OpDecorate [[aa_1:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[aa_plus_b_1:%\d+]] NoContraction
+
+// CHECK-NEXT: OpDecorate [[aa_2:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[aa_plus_b_2:%\d+]] NoContraction
+
+// CHECK-NEXT: OpDecorate [[ee:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ee_plus_f:%\d+]] NoContraction
+
+// CHECK-NEXT: OpDecorate [[cc_1:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[cc_plus_d_1:%\d+]] NoContraction
+
+// CHECK-NEXT: OpDecorate [[cc_2:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[cc_plus_d_2:%\d+]] NoContraction
+
+// CHECK-NEXT: OpDecorate [[cxcy_1:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[cxcy_plus_dz_1:%\d+]] NoContraction
+
+// CHECK-NOT: OpDecorate [[cxcy_2]] NoContraction
+// CHECK-NOT: OpDecorate [[cxcy_plus_dz_2]] NoContraction
+
+// CHECK-NOT: OpDecorate [[cxcy_3]] NoContraction
+// CHECK-NOT: OpDecorate [[cxcy_plus_dz_3]] NoContraction
+
+// CHECK-NEXT: OpDecorate [[aa_3:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[aa_plus_b_3:%\d+]] NoContraction
+
+struct InnerInnerStruct {
+  precise float4   position : SV_Position;      // -> BuiltIn Position in gl_Pervertex
+};
+
+struct InnerStruct {
+  float2           clipdis1 : SV_ClipDistance1; // -> BuiltIn ClipDistance in gl_PerVertex
+  InnerInnerStruct s;
+};
+
+struct VSOut {
+  float4           color    : COLOR;            // -> Output variable
+  InnerStruct s;
+};
+
+[[vk::builtin("PointSize")]]
+float main(out VSOut  vsOut,
+           out   precise float4 coord    : TEXCOORD,         // -> Input & output variable
+           out   precise float3 clipdis0 : SV_ClipDistance0, // -> BuiltIn ClipDistance in gl_PerVertex
+           out   precise float  culldis5 : SV_CullDistance5, // -> BuiltIn CullDistance in gl_PerVertex
+           out           float  culldis3 : SV_CullDistance3, // -> BuiltIn CullDistance in gl_PerVertex
+           out           float  clipdis6 : SV_ClipDistance6, // -> BuiltIn ClipDistance in gl_PerVertex
+           in    precise float4 inPos    : SV_Position,      // -> Input variable
+           in    precise float2 inClip   : SV_ClipDistance,  // -> Input variable
+           in    precise float3 inCull   : SV_CullDistance0  // -> Input variable
+         ) : PSize {
+  vsOut    = (VSOut)0;
+  float4 a, b;
+  float3 c, d;
+  float2 e, f;
+    
+// Output variable. coord is precise.
+//
+// CHECK:        [[aa_1]] = OpFMul %v4float
+// CHECK: [[aa_plus_b_1]] = OpFAdd %v4float
+  coord = a * a + b;
+
+// Input variable for position is precise.
+//
+// CHECK:        [[aa_2]] = OpFMul %v4float
+// CHECK: [[aa_plus_b_2]] = OpFAdd %v4float
+  inPos = a * a + b;
+
+// Input ClipDistance variable. inClip is precise.
+//
+// CHECK:        [[ee]] = OpFMul %v2float
+// CHECK: [[ee_plus_f]] = OpFAdd %v2float
+  inClip = e * e + f;
+  
+// Input CullDistance variable. inCull is precise.
+//
+// CHECK:        [[cc_1]] = OpFMul %v3float
+// CHECK: [[cc_plus_d_1]] = OpFAdd %v3float
+  inCull = c * c + d;
+  
+// Output ClipDistance builtin. clipdis0 is precise.
+//
+// CHECK:        [[cc_2]] = OpFMul %v3float
+// CHECK: [[cc_plus_d_2]] = OpFAdd %v3float
+  clipdis0 = c * c + d;
+  
+// Output CullDistance builtin. culldis5 is precise.
+//
+// CHECK:         [[cxcy_1]] = OpFMul %float
+// CHECK: [[cxcy_plus_dz_1]] = OpFAdd %float
+  culldis5 = c.x * c.y + d.z;
+  
+// Output CullDistance builtin. culldis3 is NOT precise.
+//
+// CHECK:         [[cxcy_2:%\d+]] = OpFMul %float
+// CHECK: [[cxcy_plus_dz_2:%\d+]] = OpFAdd %float
+  culldis3 = c.x * c.y + d.z;
+  
+// Output CullDistance builtin. clipdis6 is NOT precise.
+//
+// CHECK:         [[cxcy_3:%\d+]] = OpFMul %float
+// CHECK: [[cxcy_plus_dz_3:%\d+]] = OpFAdd %float
+  clipdis6 = c.x * c.y + d.z;
+  
+// Position builtin is precise.
+//
+// CHECK:        [[aa_3]] = OpFMul %v4float
+// CHECK: [[aa_plus_b_3]] = OpFAdd %v4float
+  vsOut.s.s.position = a * a + b;
+  
+  return inPos.x + inClip.x + inCull.x;
+}
+

+ 174 - 0
tools/clang/test/CodeGenSPIRV/decoration.no-contraction.struct.hlsl

@@ -0,0 +1,174 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct S {
+            float4 a;
+    precise float4 ap;
+               int b[5];
+       precise int bp[5];
+            int2x3 c[6][7][8];
+    precise int2x3 cp[6][7][8];
+          float3x4 d;
+  precise float3x4 dp;
+};
+
+struct T {
+  precise S sub1; // all members of sub1 should be precise.
+  S sub2; // only some members of sub2 are precise.
+};
+
+
+// CHECK:      OpName %w "w"
+// CHECK-NOT:  OpDecorate [[x_mul_x_1]] NoContraction
+// CHECK-NOT:  OpDecorate [[xx_plus_y_1]] NoContraction
+// CHECK-NEXT: OpDecorate [[x_mul_x_2:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[xx_plus_y_2:%\d+]] NoContraction
+
+// CHECK-NOT:  OpDecorate [[z2_mul_z3_1]] NoContraction
+// CHECK-NOT:  OpDecorate [[z2z3_plus_z4_1]] NoContraction
+// CHECK-NEXT: OpDecorate [[z2_mul_z3_2:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[z2z3_plus_z4_2:%\d+]] NoContraction
+
+// CHECK-NOT:  OpDecorate [[uu_row0_1]] NoContraction
+// CHECK-NOT:  OpDecorate [[uu_row1_1]] NoContraction
+// CHECK-NEXT: OpDecorate [[uu_row0_2:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[uu_row1_2:%\d+]] NoContraction
+
+// CHECK-NOT:  OpDecorate [[ww_row0_1]] NoContraction
+// CHECK-NOT:  OpDecorate [[ww_row1_1]] NoContraction
+// CHECK-NOT:  OpDecorate [[ww_row2_1]] NoContraction
+// CHECK-NOT:  OpDecorate [[ww_plus_w_row0_1]] NoContraction
+// CHECK-NOT:  OpDecorate [[ww_plus_w_row1_1]] NoContraction
+// CHECK-NOT:  OpDecorate [[ww_plus_w_row2_1]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_row0_2:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_row1_2:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_row2_2:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_plus_w_row0_2:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_plus_w_row1_2:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_plus_w_row2_2:%\d+]] NoContraction
+
+// CHECK-NEXT: OpDecorate [[x_mul_x_3:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[xx_plus_y_3:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[x_mul_x_4:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[xx_plus_y_4:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[z2_mul_z3_3:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[z2z3_plus_z4_3:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[z2_mul_z3_4:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[z2z3_plus_z4_4:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[uu_row0_3:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[uu_row1_3:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[uu_row0_4:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[uu_row1_4:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_row0_3:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_row1_3:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_row2_3:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_plus_w_row0_3:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_plus_w_row1_3:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_plus_w_row2_3:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_row0_4:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_row1_4:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_row2_4:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_plus_w_row0_4:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_plus_w_row1_4:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[ww_plus_w_row2_4:%\d+]] NoContraction
+
+void main() {
+  T t;
+  float4 x,y;
+  int z[5];
+  int2x3 u;
+  float3x4 w;
+
+
+// 'a' is NOT precise.
+//
+// CHECK:   [[x_mul_x_1:%\d+]] = OpFMul %v4float
+// CHECK: [[xx_plus_y_1:%\d+]] = OpFAdd %v4float
+  t.sub2.a = x * x + y;
+
+// 'ap' is precise.
+//
+// CHECK:   [[x_mul_x_2]] = OpFMul %v4float
+// CHECK: [[xx_plus_y_2]] = OpFAdd %v4float
+  t.sub2.ap = x * x + y;
+
+// 'b' is NOT precise.
+//
+// CHECK:    [[z2_mul_z3_1:%\d+]] = OpIMul %int
+// CHECK: [[z2z3_plus_z4_1:%\d+]] = OpIAdd %int
+  t.sub2.b[1] = z[2] * z[3] + z[4];
+
+// 'bp' is precise.
+//
+// CHECK:    [[z2_mul_z3_2]] = OpIMul %int
+// CHECK: [[z2z3_plus_z4_2]] = OpIAdd %int
+  t.sub2.bp[1] = z[2] * z[3] + z[4];
+
+// 'c' is NOT precise.
+//
+// CHECK: [[uu_row0_1:%\d+]] = OpIMul %v3int
+// CHECK: [[uu_row1_1:%\d+]] = OpIMul %v3int
+  t.sub2.c[0][1][2] = u * u;
+
+// 'cp' is precise.
+//
+// CHECK: [[uu_row0_2]] = OpIMul %v3int
+// CHECK: [[uu_row1_2]] = OpIMul %v3int
+  t.sub2.cp[0][1][2] = u * u;
+
+// 'd' is NOT precise.
+//
+// CHECK:        [[ww_row0_1:%\d+]] = OpFMul %v4float
+// CHECK:        [[ww_row1_1:%\d+]] = OpFMul %v4float
+// CHECK:        [[ww_row2_1:%\d+]] = OpFMul %v4float
+// CHECK: [[ww_plus_w_row0_1:%\d+]] = OpFAdd %v4float
+// CHECK: [[ww_plus_w_row1_1:%\d+]] = OpFAdd %v4float
+// CHECK: [[ww_plus_w_row2_1:%\d+]] = OpFAdd %v4float
+  t.sub2.d = w * w + w;
+
+// 'dp' is precise.
+//
+// CHECK:        [[ww_row0_2]] = OpFMul %v4float
+// CHECK:        [[ww_row1_2]] = OpFMul %v4float
+// CHECK:        [[ww_row2_2]] = OpFMul %v4float
+// CHECK: [[ww_plus_w_row0_2]] = OpFAdd %v4float
+// CHECK: [[ww_plus_w_row1_2]] = OpFAdd %v4float
+// CHECK: [[ww_plus_w_row2_2]] = OpFAdd %v4float
+  t.sub2.dp = w * w + w;
+
+// *ALL* members of sub1 are precise. So this operation should be precise.
+//
+//
+// CHECK:   [[x_mul_x_3]] = OpFMul %v4float
+// CHECK: [[xx_plus_y_3]] = OpFAdd %v4float
+  t.sub1.a = x * x + y;
+// CHECK:   [[x_mul_x_4]] = OpFMul %v4float
+// CHECK: [[xx_plus_y_4]] = OpFAdd %v4float
+  t.sub1.ap = x * x + y;
+// CHECK:    [[z2_mul_z3_3]] = OpIMul %int
+// CHECK: [[z2z3_plus_z4_3]] = OpIAdd %int
+  t.sub1.b[1] = z[2] * z[3] + z[4];
+// CHECK:    [[z2_mul_z3_4]] = OpIMul %int
+// CHECK: [[z2z3_plus_z4_4]] = OpIAdd %int
+  t.sub1.bp[1] = z[2] * z[3] + z[4];
+// CHECK: [[uu_row0_3]] = OpIMul %v3int
+// CHECK: [[uu_row1_3]] = OpIMul %v3int
+  t.sub1.c[0][1][2] = u * u;
+// CHECK: [[uu_row0_4]] = OpIMul %v3int
+// CHECK: [[uu_row1_4]] = OpIMul %v3int
+  t.sub1.cp[0][1][2] = u * u;
+// CHECK:        [[ww_row0_3]] = OpFMul %v4float
+// CHECK:        [[ww_row1_3]] = OpFMul %v4float
+// CHECK:        [[ww_row2_3]] = OpFMul %v4float
+// CHECK: [[ww_plus_w_row0_3]] = OpFAdd %v4float
+// CHECK: [[ww_plus_w_row1_3]] = OpFAdd %v4float
+// CHECK: [[ww_plus_w_row2_3]] = OpFAdd %v4float
+  t.sub1.d = w * w + w;
+// CHECK:        [[ww_row0_4]] = OpFMul %v4float
+// CHECK:        [[ww_row1_4]] = OpFMul %v4float
+// CHECK:        [[ww_row2_4]] = OpFMul %v4float
+// CHECK: [[ww_plus_w_row0_4]] = OpFAdd %v4float
+// CHECK: [[ww_plus_w_row1_4]] = OpFAdd %v4float
+// CHECK: [[ww_plus_w_row2_4]] = OpFAdd %v4float
+  t.sub1.dp = w * w + w;
+}
+

+ 91 - 0
tools/clang/test/CodeGenSPIRV/decoration.no-contraction.variable-reuse.hlsl

@@ -0,0 +1,91 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// The purpose of this test is to make sure non-precise computations are *not*
+// decorated with NoContraction.
+//
+// To this end, we will perform the same computation twice, once when it
+// affects a precise variable, and once when it doesn't.
+
+void foo(float p) { p = p + 1; }
+
+// CHECK:      OpName %bb_entry_0 "bb.entry"
+// CHECK-NEXT: OpDecorate [[first_b_plus_c:%\d+]] NoContraction
+// CHECK-NOT:  OpDecorate [[first_a_mul_b]] NoContraction
+// CHECK-NOT:  OpDecorate [[ax_mul_bx]] NoContraction
+// CHECK-NEXT: OpDecorate [[second_a_mul_b:%\d+]] NoContraction
+// CHECK-NOT:  OpDecorate [[second_a_plus_b]] NoContraction
+// CHECK-NEXT: OpDecorate [[first_d_plus_e:%\d+]] NoContraction
+// CHECK-NEXT: OpDecorate [[c_mul_d:%\d+]] NoContraction
+// CHECK-NOT:  OpDecorate [[second_d_plus_e]] NoContraction
+// CHECK-NEXT: OpDecorate [[r_plus_s:%\d+]] NoContraction
+
+void main() {
+  float4 a, b, c, d, e;
+  precise float4 v; 
+  float3 r, s, u;
+
+// This can change "a" which can then change "r" which can then change "v". Precise.
+//
+// CHECK:                           OpLoad %v4float %b
+// CHECK-NEXT:                      OpLoad %v4float %c
+// CHECK-NEXT: [[first_b_plus_c]] = OpFAdd %v4float
+// CHECK-NEXT:                      OpStore %a %29
+  a = b + c;
+  
+// Even though this looks like the statement on line 52:
+// This changes "u", which does not affect "v" in any way. Not Precise.
+//
+// CHECK:      [[first_a_mul_b:%\d+]] = OpFMul %v3float
+// CHECK-NEXT:                          OpStore %u
+  u = float3((float3)a * (float3)b);
+  
+// Does not affect the value of "v". Not Precise.
+//
+// CHECK:      [[ax_mul_bx:%\d+]] = OpFMul %float
+// CHECK-NEXT:                      OpStore %param_var_p
+  foo(a.x * b.x);
+
+// This changes "r" which will later change "v". Precise.
+//
+// CHECK:      [[second_a_mul_b]] = OpFMul %v3float
+// CHECK-NEXT:                      OpStore %r %58
+  r = float3((float3)a * (float3)b);
+
+// Even though this looks identical to "a = b + c" above:
+// This can change the value of "a", BUT, this change will not affect "v". Not Precise.
+//
+// CHECK:      [[second_a_plus_b:%\d+]] = OpFAdd %v4float
+// CHECK-NEXT:                            OpStore %a %61
+  a = b + c;
+
+// This can change "c" which can then change "s" which can then change "v". Precise.
+//
+// CHECK:                           OpLoad %v4float %d
+// CHECK-NEXT:                      OpLoad %v4float %e
+// CHECK-NEXT: [[first_d_plus_e]] = OpFAdd %v4float
+// CHECK-NEXT:                      OpStore %c
+  c = d + e;
+
+// This changes "s" which will later change "v". Precise.
+//
+// CHECK:      [[c_mul_d]] = OpFMul %v3float
+// CHECK-NEXT:               OpStore %s %75
+  s = float3((float3)c * (float3)d);
+
+// Even though this looks identical to "c = d + e" above:
+// This can change the value of "c", BUT, this change will not affect "v". Not Precise.
+//
+// CHECK:                                 OpLoad %v4float %d
+// CHECK-NEXT:                            OpLoad %v4float %e
+// CHECK-NEXT: [[second_d_plus_e:%\d+]] = OpFAdd %v4float
+// CHECK-NEXT:                            OpStore %c
+  c = d + e;
+
+// Precise because "v" is precise.
+// CHECK:                OpLoad %v3float %r
+// CHECK:                OpLoad %v3float %s
+// CHECK: [[r_plus_s]] = OpFAdd %v3float
+// CHECK:                OpStore %v
+  v.xyz = r + s;
+}
+

+ 15 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -1834,6 +1834,7 @@ TEST_F(FileTest, RayTracingNVLibrary) {
   runFileTest("raytracing.nv.library.hlsl");
 }
 
+// For RelaxedPrecision decorations
 TEST_F(FileTest, DecorationRelaxedPrecisionBasic) {
   runFileTest("decoration.relaxed-precision.basic.hlsl");
 }
@@ -1844,4 +1845,18 @@ TEST_F(FileTest, DecorationRelaxedPrecisionImage) {
   runFileTest("decoration.relaxed-precision.image.hlsl");
 }
 
+// For NoContraction decorations
+TEST_F(FileTest, DecorationNoContraction) {
+  runFileTest("decoration.no-contraction.hlsl");
+}
+TEST_F(FileTest, DecorationNoContractionVariableReuse) {
+  runFileTest("decoration.no-contraction.variable-reuse.hlsl");
+}
+TEST_F(FileTest, DecorationNoContractionStruct) {
+  runFileTest("decoration.no-contraction.struct.hlsl");
+}
+TEST_F(FileTest, DecorationNoContractionStageVars) {
+  runFileTest("decoration.no-contraction.stage-vars.hlsl");
+}
+
 } // namespace

+ 11 - 11
tools/clang/unittests/SPIRV/SpirvTypeTest.cpp

@@ -33,8 +33,8 @@ TEST_F(SpirvTypeTest, IntType) {
   EXPECT_TRUE(llvm::isa<IntegerType>(uint32));
   EXPECT_TRUE(llvm::isa<NumericalType>(sint16));
   EXPECT_TRUE(llvm::isa<NumericalType>(uint32));
-  EXPECT_EQ(16, sint16.getBitwidth());
-  EXPECT_EQ(32, uint32.getBitwidth());
+  EXPECT_EQ(16u, sint16.getBitwidth());
+  EXPECT_EQ(32u, uint32.getBitwidth());
   EXPECT_EQ(true, sint16.isSignedInt());
   EXPECT_EQ(false, uint32.isSignedInt());
 }
@@ -43,7 +43,7 @@ TEST_F(SpirvTypeTest, FloatType) {
   FloatType f16(16);
   EXPECT_TRUE(llvm::isa<FloatType>(f16));
   EXPECT_TRUE(llvm::isa<NumericalType>(f16));
-  EXPECT_EQ(16, f16.getBitwidth());
+  EXPECT_EQ(16u, f16.getBitwidth());
 }
 
 TEST_F(SpirvTypeTest, VectorType) {
@@ -51,7 +51,7 @@ TEST_F(SpirvTypeTest, VectorType) {
   VectorType float3(&f16, 3);
   EXPECT_TRUE(llvm::isa<VectorType>(float3));
   EXPECT_EQ(&f16, float3.getElementType());
-  EXPECT_EQ(3, float3.getElementCount());
+  EXPECT_EQ(3u, float3.getElementCount());
 }
 
 TEST_F(SpirvTypeTest, MatrixType) {
@@ -61,9 +61,9 @@ TEST_F(SpirvTypeTest, MatrixType) {
 
   EXPECT_TRUE(llvm::isa<MatrixType>(mat2x3));
   EXPECT_EQ(&f16, float3.getElementType());
-  EXPECT_EQ(2, mat2x3.getVecCount());
-  EXPECT_EQ(2, mat2x3.numCols());
-  EXPECT_EQ(3, mat2x3.numRows());
+  EXPECT_EQ(2u, mat2x3.getVecCount());
+  EXPECT_EQ(2u, mat2x3.numCols());
+  EXPECT_EQ(3u, mat2x3.numRows());
 }
 
 TEST_F(SpirvTypeTest, ImageType) {
@@ -108,9 +108,9 @@ TEST_F(SpirvTypeTest, ArrayType) {
   ArrayType arr5(&f16, 5, 2);
   EXPECT_TRUE(llvm::isa<ArrayType>(arr5));
   EXPECT_EQ(arr5.getElementType(), &f16);
-  EXPECT_EQ(arr5.getElementCount(), 5);
+  EXPECT_EQ(arr5.getElementCount(), 5u);
   EXPECT_TRUE(arr5.getStride().hasValue());
-  EXPECT_EQ(arr5.getStride().getValue(), 2);
+  EXPECT_EQ(arr5.getStride().getValue(), 2u);
 }
 
 TEST_F(SpirvTypeTest, RuntimeArrayType) {
@@ -119,7 +119,7 @@ TEST_F(SpirvTypeTest, RuntimeArrayType) {
   EXPECT_TRUE(llvm::isa<RuntimeArrayType>(ra));
   EXPECT_EQ(ra.getElementType(), &f16);
   EXPECT_TRUE(ra.getStride().hasValue());
-  EXPECT_EQ(ra.getStride().getValue(), 2);
+  EXPECT_EQ(ra.getStride().getValue(), 2u);
 }
 
 TEST_F(SpirvTypeTest, StructType) {
@@ -138,7 +138,7 @@ TEST_F(SpirvTypeTest, StructType) {
   EXPECT_EQ(s.getStructName(), "some_struct");
 
   const auto &fields = s.getFields();
-  EXPECT_EQ(2, fields.size());
+  EXPECT_EQ(2u, fields.size());
   EXPECT_EQ(fields[0], field0);
   EXPECT_EQ(fields[1], field1);
   EXPECT_TRUE(s.isReadOnly());