Sfoglia il codice sorgente

Merge remote-tracking branch 'origin/master' into fix-combine-dim

Ehsan Nasiri 7 anni fa
parent
commit
2d530bd3ea
33 ha cambiato i file con 451 aggiunte e 51 eliminazioni
  1. 1 1
      external/SPIRV-Headers
  2. 1 1
      external/SPIRV-Tools
  3. 1 1
      external/googletest
  4. 2 0
      include/dxc/HLSL/DxilUtil.h
  5. 4 3
      include/dxc/HLSL/HLModule.h
  6. 2 1
      include/dxc/Support/HLSLOptions.h
  7. 1 0
      include/dxc/Support/WinAdapter.h
  8. 2 0
      include/dxc/Support/WinFunctions.h
  9. 7 3
      lib/DxcSupport/HLSLOptions.cpp
  10. 17 0
      lib/DxcSupport/WinFunctions.cpp
  11. 1 0
      lib/HLSL/DxilTypeSystem.cpp
  12. 13 0
      lib/HLSL/DxilUtil.cpp
  13. 124 10
      lib/HLSL/HLOperationLower.cpp
  14. 1 1
      lib/HLSL/HLSignatureLower.cpp
  15. 2 1
      tools/clang/include/clang/Basic/LangOptions.h
  16. 6 5
      tools/clang/lib/CodeGen/CGHLSLMS.cpp
  17. 1 1
      tools/clang/lib/Parse/ParseDecl.cpp
  18. 43 15
      tools/clang/lib/SPIRV/SPIRVEmitter.cpp
  19. 3 0
      tools/clang/lib/SPIRV/SPIRVEmitter.h
  20. 20 0
      tools/clang/test/CodeGenHLSL/quick-test/lit-function.hlsl
  21. 35 0
      tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-check-count01.hlsl
  22. 18 0
      tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-check-count02.hlsl
  23. 38 0
      tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-correctness.hlsl
  24. 27 0
      tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-criteria01.hlsl
  25. 17 0
      tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-criteria02.hlsl
  26. 14 0
      tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-lit-types.hlsl
  27. 11 0
      tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-one-as-power.hlsl
  28. 11 0
      tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-zero-as-power.hlsl
  29. 15 2
      tools/clang/test/CodeGenSPIRV/cf.switch.opswitch.hlsl
  30. 5 2
      tools/clang/test/CodeGenSPIRV/sm6.wave-active-count-bits.hlsl
  31. 3 1
      tools/clang/test/CodeGenSPIRV/sm6.wave-prefix-count-bits.hlsl
  32. 3 2
      tools/clang/tools/dxcompiler/dxcompilerobj.cpp
  33. 2 1
      tools/clang/tools/libclang/dxcrewriteunused.cpp

+ 1 - 1
external/SPIRV-Headers

@@ -1 +1 @@
-Subproject commit dcf23bdabacc3c54b83b1f9367e7a8adb27f8d87
+Subproject commit d5b2e1255f706ce1f88812217e9a554f299848af

+ 1 - 1
external/SPIRV-Tools

@@ -1 +1 @@
-Subproject commit 9fbcce4ca17de7b2d8f6b322bcd1d43a7d6adc29
+Subproject commit 4b4bd4c53aaa020f7e349aede394d42476b7e3aa

+ 1 - 1
external/googletest

@@ -1 +1 @@
-Subproject commit d25268a55f6f6f38c65a7d1b7b119e33a46d1688
+Subproject commit 440527a61e1c91188195f7de212c63c77e8f0a45

+ 2 - 0
include/dxc/HLSL/DxilUtil.h

@@ -14,6 +14,7 @@
 #include <string>
 #include <string>
 #include <memory>
 #include <memory>
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/IR/Constants.h"
 
 
 namespace llvm {
 namespace llvm {
 class Type;
 class Type;
@@ -92,6 +93,7 @@ namespace dxilutil {
   void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context);
   void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context);
   // Returns true if type contains HLSL Object type (resource)
   // Returns true if type contains HLSL Object type (resource)
   bool ContainsHLSLObjectType(llvm::Type *Ty);
   bool ContainsHLSLObjectType(llvm::Type *Ty);
+  bool IsSplat(llvm::ConstantDataVector *cdv);
 }
 }
 
 
 }
 }

+ 4 - 3
include/dxc/HLSL/HLModule.h

@@ -49,7 +49,7 @@ class RootSignatureHandle;
 struct HLOptions {
 struct HLOptions {
   HLOptions()
   HLOptions()
       : bDefaultRowMajor(false), bIEEEStrict(false), bDisableOptimizations(false),
       : bDefaultRowMajor(false), bIEEEStrict(false), bDisableOptimizations(false),
-        bLegacyCBufferLoad(false), PackingStrategy(0), bBackCompatMode(0), unused(0) {
+        bLegacyCBufferLoad(false), PackingStrategy(0), bDX9CompatMode(0), bFXCCompatMode(0), unused(0) {
   }
   }
   uint32_t GetHLOptionsRaw() const;
   uint32_t GetHLOptionsRaw() const;
   void SetHLOptionsRaw(uint32_t data);
   void SetHLOptionsRaw(uint32_t data);
@@ -61,8 +61,9 @@ struct HLOptions {
   unsigned PackingStrategy         : 2;
   unsigned PackingStrategy         : 2;
   static_assert((unsigned)DXIL::PackingStrategy::Invalid < 4, "otherwise 2 bits is not enough to store PackingStrategy");
   static_assert((unsigned)DXIL::PackingStrategy::Invalid < 4, "otherwise 2 bits is not enough to store PackingStrategy");
   unsigned bUseMinPrecision        : 1;
   unsigned bUseMinPrecision        : 1;
-  unsigned bBackCompatMode         : 1;
-  unsigned unused                  : 23;
+  unsigned bDX9CompatMode          : 1;
+  unsigned bFXCCompatMode          : 1;
+  unsigned unused                  : 22;
 };
 };
 
 
 typedef std::unordered_map<const llvm::Function *, std::unique_ptr<DxilFunctionProps>> DxilFunctionPropsMap;
 typedef std::unordered_map<const llvm::Function *, std::unique_ptr<DxilFunctionProps>> DxilFunctionPropsMap;

+ 2 - 1
include/dxc/Support/HLSLOptions.h

@@ -132,7 +132,8 @@ public:
   bool AvoidFlowControl = false;     // OPT_Gfa
   bool AvoidFlowControl = false;     // OPT_Gfa
   bool PreferFlowControl = false;    // OPT_Gfp
   bool PreferFlowControl = false;    // OPT_Gfp
   bool EnableStrictMode = false;     // OPT_Ges
   bool EnableStrictMode = false;     // OPT_Ges
-  bool EnableBackCompatMode = false;     // OPT_Gec
+  bool EnableDX9CompatMode = false;     // OPT_Gec
+  bool EnableFXCCompatMode = false;     // internal flag
   unsigned long HLSLVersion = 0; // OPT_hlsl_version (2015-2018)
   unsigned long HLSLVersion = 0; // OPT_hlsl_version (2015-2018)
   bool Enable16BitTypes = false; // OPT_enable_16bit_types
   bool Enable16BitTypes = false; // OPT_enable_16bit_types
   bool OptDump = false; // OPT_ODump - dump optimizer commands
   bool OptDump = false; // OPT_ODump - dump optimizer commands

+ 1 - 0
include/dxc/Support/WinAdapter.h

@@ -172,6 +172,7 @@
 #define _atoi64 atoll
 #define _atoi64 atoll
 #define sprintf_s snprintf
 #define sprintf_s snprintf
 #define _strdup strdup
 #define _strdup strdup
+#define _strnicmp strnicmp
 
 
 #define vsprintf_s vsprintf
 #define vsprintf_s vsprintf
 #define strcat_s strcat
 #define strcat_s strcat

+ 2 - 0
include/dxc/Support/WinFunctions.h

@@ -26,6 +26,8 @@ HRESULT UIntAdd(UINT uAugend, UINT uAddend, UINT *puResult);
 HRESULT IntToUInt(int in, UINT *out);
 HRESULT IntToUInt(int in, UINT *out);
 HRESULT SizeTToInt(size_t in, INT *out);
 HRESULT SizeTToInt(size_t in, INT *out);
 HRESULT UInt32Mult(UINT a, UINT b, UINT *out);
 HRESULT UInt32Mult(UINT a, UINT b, UINT *out);
+
+int strnicmp(const char *str1, const char *str2, size_t count);
 int _stricmp(const char *str1, const char *str2);
 int _stricmp(const char *str1, const char *str2);
 int _wcsicmp(const wchar_t *str1, const wchar_t *str2);
 int _wcsicmp(const wchar_t *str1, const wchar_t *str2);
 int _wcsnicmp(const wchar_t *str1, const wchar_t *str2, size_t n);
 int _wcsnicmp(const wchar_t *str1, const wchar_t *str2, size_t n);

+ 7 - 3
lib/DxcSupport/HLSLOptions.cpp

@@ -363,10 +363,10 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
     }
     }
   }
   }
 
 
-  opts.EnableBackCompatMode = Args.hasFlag(OPT_Gec, OPT_INVALID, false);
+  opts.EnableDX9CompatMode = Args.hasFlag(OPT_Gec, OPT_INVALID, false);
   llvm::StringRef ver = Args.getLastArgValue(OPT_hlsl_version);
   llvm::StringRef ver = Args.getLastArgValue(OPT_hlsl_version);
   if (ver.empty()) {
   if (ver.empty()) {
-    if (opts.EnableBackCompatMode)
+    if (opts.EnableDX9CompatMode)
       opts.HLSLVersion = 2016; // Default to max supported version with /Gec flag
       opts.HLSLVersion = 2016; // Default to max supported version with /Gec flag
     else
     else
       opts.HLSLVersion = 2018; // Default to latest version
       opts.HLSLVersion = 2018; // Default to latest version
@@ -393,11 +393,15 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
     return 1;
     return 1;
   }
   }
 
 
-  if (opts.EnableBackCompatMode && opts.HLSLVersion > 2016) {
+  if (opts.EnableDX9CompatMode && opts.HLSLVersion > 2016) {
     errors << "/Gec is not supported with HLSLVersion " << opts.HLSLVersion;
     errors << "/Gec is not supported with HLSLVersion " << opts.HLSLVersion;
     return 1;
     return 1;
   }
   }
 
 
+  if (opts.HLSLVersion <= 2016) {
+    opts.EnableFXCCompatMode = true;
+  }
+
   // AssemblyCodeHex not supported (Fx)
   // AssemblyCodeHex not supported (Fx)
   // OutputLibrary not supported (Fl)
   // OutputLibrary not supported (Fl)
   opts.AssemblyCode = Args.getLastArgValue(OPT_Fc);
   opts.AssemblyCode = Args.getLastArgValue(OPT_Fc);

+ 17 - 0
lib/DxcSupport/WinFunctions.cpp

@@ -98,6 +98,23 @@ HRESULT UInt32Mult(UINT a, UINT b, UINT *out) {
   return S_OK;
   return S_OK;
 }
 }
 
 
+int strnicmp(const char *str1, const char *str2, size_t count) {
+  size_t i = 0;
+  for (; i < count && str1[i] && str2[i]; ++i) {
+    int d = std::tolower(str1[i]) - std::tolower(str2[i]);
+    if (d != 0)
+      return d;
+  }
+
+  if (i == count) {
+    // All 'count' characters matched.
+    return 0;
+  }
+
+  // str1 or str2 reached NULL before 'count' characters were compared.
+  return str1[i] - str2[i];
+}
+
 int _stricmp(const char *str1, const char *str2) {
 int _stricmp(const char *str1, const char *str2) {
   size_t i = 0;
   size_t i = 0;
   for (; str1[i] && str2[i]; ++i) {
   for (; str1[i] && str2[i]; ++i) {

+ 1 - 0
lib/HLSL/DxilTypeSystem.cpp

@@ -11,6 +11,7 @@
 #include "dxc/HLSL/DxilModule.h"
 #include "dxc/HLSL/DxilModule.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/Support/Global.h"
 #include "dxc/Support/Global.h"
+#include "dxc/Support/WinFunctions.h"
 
 
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/LLVMContext.h"

+ 13 - 0
lib/HLSL/DxilUtil.cpp

@@ -394,6 +394,19 @@ bool ContainsHLSLObjectType(llvm::Type *Ty) {
   return false;
   return false;
 }
 }
 
 
+// Based on the implementation available in LLVM's trunk:
+// http://llvm.org/doxygen/Constants_8cpp_source.html#l02734
+bool IsSplat(llvm::ConstantDataVector *cdv) {
+  const char *Base = cdv->getRawDataValues().data();
+
+  // Compare elements 1+ to the 0'th element.
+  unsigned EltSize = cdv->getElementByteSize();
+  for (unsigned i = 1, e = cdv->getNumElements(); i != e; ++i)
+    if (memcmp(Base, Base + i * EltSize, EltSize))
+      return false;
+
+  return true;
+}
 
 
 }
 }
 }
 }

+ 124 - 10
lib/HLSL/HLOperationLower.cpp

@@ -27,6 +27,7 @@
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Module.h"
+#include "llvm/ADT/APSInt.h"
 
 
 using namespace llvm;
 using namespace llvm;
 using namespace hlsl;
 using namespace hlsl;
@@ -609,6 +610,123 @@ Value *TranslateD3DColorToUByte4(CallInst *CI, IntrinsicOp IOP,
   return Builder.CreateBitCast(byte4, CI->getType());
   return Builder.CreateBitCast(byte4, CI->getType());
 }
 }
 
 
+// Returns true if pow can be implemented using Fxc's mul-only code gen pattern.
+// Fxc uses the below rules when choosing mul-only code gen pattern to implement pow function.
+// Rule 1: Applicable only to power values in the range [INT32_MIN, INT32_MAX]
+// Rule 2: The maximum number of mul ops needed shouldn't exceed (2n+1) or (n+1) based on whether the power
+//         is a positive or a negative value. Here "n" is the number of scalar elements in power.
+// Rule 3: Power must be an exact value.
+// +----------+---------------------+------------------+
+// | BaseType | IsExponentPositive  | MaxMulOpsAllowed |
+// +----------+---------------------+------------------+
+// | float4x4 | True                |               33 |
+// | float4x4 | False               |               17 |
+// | float4x2 | True                |               17 |
+// | float4x2 | False               |                9 |
+// | float2x4 | True                |               17 |
+// | float2x4 | False               |                9 |
+// | float4   | True                |                9 |
+// | float4   | False               |                5 |
+// | float2   | True                |                5 |
+// | float2   | False               |                3 |
+// | float    | True                |                3 |
+// | float    | False               |                2 |
+// +----------+---------------------+------------------+
+
+bool CanUseFxcMulOnlyPatternForPow(IRBuilder<>& Builder, Value *x, Value *pow, int32_t& powI) {
+  // Applicable only when power is a literal.
+  if (!isa<ConstantDataVector>(pow) && !isa<ConstantFP>(pow)) {
+    return false;
+  }
+
+  // Only apply this code gen on splat values.
+  if (ConstantDataVector *cdv = dyn_cast<ConstantDataVector>(pow)) {
+    if (!hlsl::dxilutil::IsSplat(cdv)) {
+      return false;
+    }
+  }
+
+  APFloat powAPF = isa<ConstantDataVector>(pow) ?
+    cast<ConstantDataVector>(pow)->getElementAsAPFloat(0) : // should be a splat value
+    cast<ConstantFP>(pow)->getValueAPF();
+  APSInt powAPS(32, false);
+  bool isExact = false;
+  // Try converting float value of power to integer and also check if the float value is exact.
+  APFloat::opStatus status = powAPF.convertToInteger(powAPS, APFloat::rmTowardZero, &isExact);
+  if (status == APFloat::opStatus::opOK && isExact) {
+    powI = powAPS.getExtValue();
+    uint32_t powU = abs(powI);
+    int setBitCount = 0;
+    int maxBitSetPos = -1;
+    for (int i = 0; i < 32; i++) {
+      if ((powU >> i) & 1) {
+        setBitCount++;
+        maxBitSetPos = i;
+      }
+    }
+
+    DXASSERT(maxBitSetPos <= 30, "msb should always be zero.");
+    unsigned numElem = isa<ConstantDataVector>(pow) ? x->getType()->getVectorNumElements() : 1;
+    int mulOpThreshold = powI < 0 ? numElem + 1 : 2 * numElem + 1;
+    int mulOpNeeded = maxBitSetPos + setBitCount - 1;
+    return mulOpNeeded <= mulOpThreshold;
+  }
+
+  return false;
+}
+
+Value *TranslatePowUsingFxcMulOnlyPattern(IRBuilder<>& Builder, Value *x, const int32_t y) {
+  uint32_t absY = abs(y);
+  // If y is zero then always return 1.
+  if (absY == 0) {
+    return ConstantFP::get(x->getType(), 1);
+  }
+
+  int lastSetPos = -1;
+  Value *result = nullptr;
+  Value *mul = nullptr;
+  for (int i = 0; i < 32; i++) {
+    if ((absY >> i) & 1) {
+      for (int j = i; j > lastSetPos; j--) {
+        if (!mul) {
+          mul = x;
+        }
+        else {
+          mul = Builder.CreateFMul(mul, mul);
+        }
+      }
+
+      result = (result == nullptr) ? mul : Builder.CreateFMul(result, mul);
+      lastSetPos = i;
+    }
+  }
+
+  // Compute reciprocal for negative power values.
+  if (y < 0) {
+    Value* constOne = ConstantFP::get(x->getType(), 1);
+    result = Builder.CreateFDiv(constOne, result);
+  }
+
+  return result;
+}
+
+Value *TranslatePowImpl(hlsl::OP *hlslOP, IRBuilder<>& Builder, Value *x, Value *y, bool isFXCCompatMode = false) {
+  // As applicable implement pow using only mul ops as done by Fxc.
+  int32_t p = 0;
+  if (isFXCCompatMode && CanUseFxcMulOnlyPatternForPow(Builder, x, y, p)) {
+    return TranslatePowUsingFxcMulOnlyPattern(Builder, x, p);
+  }
+
+  // Default to log-mul-exp pattern if previous scenarios don't apply.
+  // t = log(x);
+  Value *logX =
+    TrivialDxilUnaryOperation(DXIL::OpCode::Log, x, hlslOP, Builder);
+  // t = y * t;
+  Value *mulY = Builder.CreateFMul(logX, y);
+  // pow = exp(t);
+  return TrivialDxilUnaryOperation(DXIL::OpCode::Exp, mulY, hlslOP, Builder);
+}
+
 Value *TranslateAddUint64(CallInst *CI, IntrinsicOp IOP,
 Value *TranslateAddUint64(CallInst *CI, IntrinsicOp IOP,
                                  OP::OpCode opcode,
                                  OP::OpCode opcode,
                                  HLOperationLowerHelper &helper,  HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
                                  HLOperationLowerHelper &helper,  HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
@@ -1501,11 +1619,12 @@ Value *TranslateLit(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
   Value *nlCmp = Builder.CreateFCmpOLT(n_dot_l, zeroConst);
   Value *nlCmp = Builder.CreateFCmpOLT(n_dot_l, zeroConst);
   Value *diffuse = Builder.CreateSelect(nlCmp, zeroConst, n_dot_l);
   Value *diffuse = Builder.CreateSelect(nlCmp, zeroConst, n_dot_l);
   Result = Builder.CreateInsertElement(Result, diffuse, 1);
   Result = Builder.CreateInsertElement(Result, diffuse, 1);
-  // specular = ((n_dot_l < 0) || (n_dot_h < 0)) ? 0: (n_dot_h * m).
+  // specular = ((n_dot_l < 0) || (n_dot_h < 0)) ? 0: (n_dot_h ^ m).
   Value *nhCmp = Builder.CreateFCmpOLT(n_dot_h, zeroConst);
   Value *nhCmp = Builder.CreateFCmpOLT(n_dot_h, zeroConst);
   Value *specCond = Builder.CreateOr(nlCmp, nhCmp);
   Value *specCond = Builder.CreateOr(nlCmp, nhCmp);
-  Value *nhMulM = Builder.CreateFMul(n_dot_h, m);
-  Value *spec = Builder.CreateSelect(specCond, zeroConst, nhMulM);
+  bool isFXCCompatMode = CI->getModule()->GetHLModule().GetHLOptions().bFXCCompatMode;
+  Value *nhPowM = TranslatePowImpl(&helper.hlslOP, Builder, n_dot_h, m, isFXCCompatMode);
+  Value *spec = Builder.CreateSelect(specCond, zeroConst, nhPowM);
   Result = Builder.CreateInsertElement(Result, spec, 2);
   Result = Builder.CreateInsertElement(Result, spec, 2);
   return Result;
   return Result;
 }
 }
@@ -2118,14 +2237,9 @@ Value *TranslatePow(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
   hlsl::OP *hlslOP = &helper.hlslOP;
   hlsl::OP *hlslOP = &helper.hlslOP;
   Value *x = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
   Value *x = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
   Value *y = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
   Value *y = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
+  bool isFXCCompatMode = CI->getModule()->GetHLModule().GetHLOptions().bFXCCompatMode;
   IRBuilder<> Builder(CI);
   IRBuilder<> Builder(CI);
-  // t = log(x);
-  Value *logX =
-      TrivialDxilUnaryOperation(DXIL::OpCode::Log, x, hlslOP, Builder);
-  // t = y * t;
-  Value *mulY = Builder.CreateFMul(logX, y);
-  // pow = exp(t);
-  return TrivialDxilUnaryOperation(DXIL::OpCode::Exp, mulY, hlslOP, Builder);
+  return TranslatePowImpl(hlslOP,Builder,x,y,isFXCCompatMode);
 }
 }
 
 
 Value *TranslateFaceforward(CallInst *CI, IntrinsicOp IOP, OP::OpCode op,
 Value *TranslateFaceforward(CallInst *CI, IntrinsicOp IOP, OP::OpCode op,

+ 1 - 1
lib/HLSL/HLSignatureLower.cpp

@@ -231,7 +231,7 @@ void HLSignatureLower::ProcessArgument(Function *func,
   }
   }
 
 
   //  back-compat mode - remap obsolete semantics
   //  back-compat mode - remap obsolete semantics
-  if (HLM.GetHLOptions().bBackCompatMode && paramAnnotation.HasSemanticString()) {
+  if (HLM.GetHLOptions().bDX9CompatMode && paramAnnotation.HasSemanticString()) {
     hlsl::RemapObsoleteSemantic(paramAnnotation, sigPoint->GetKind(), HLM.GetCtx());
     hlsl::RemapObsoleteSemantic(paramAnnotation, sigPoint->GetKind(), HLM.GetCtx());
   }
   }
 
 

+ 2 - 1
tools/clang/include/clang/Basic/LangOptions.h

@@ -156,7 +156,8 @@ public:
   unsigned RootSigMinor;
   unsigned RootSigMinor;
   bool IsHLSLLibrary;
   bool IsHLSLLibrary;
   bool UseMinPrecision; // use min precision, not native precision.
   bool UseMinPrecision; // use min precision, not native precision.
-  bool EnableBackCompatMode;
+  bool EnableDX9CompatMode;
+  bool EnableFXCCompatMode;
   // HLSL Change Ends
   // HLSL Change Ends
 
 
   bool SPIRV = false;  // SPIRV Change
   bool SPIRV = false;  // SPIRV Change

+ 6 - 5
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -385,7 +385,8 @@ CGMSHLSLRuntime::CGMSHLSLRuntime(CodeGenModule &CGM)
   opts.PackingStrategy = CGM.getCodeGenOpts().HLSLSignaturePackingStrategy;
   opts.PackingStrategy = CGM.getCodeGenOpts().HLSLSignaturePackingStrategy;
 
 
   opts.bUseMinPrecision = CGM.getLangOpts().UseMinPrecision;
   opts.bUseMinPrecision = CGM.getLangOpts().UseMinPrecision;
-  opts.bBackCompatMode = CGM.getLangOpts().EnableBackCompatMode;
+  opts.bDX9CompatMode = CGM.getLangOpts().EnableDX9CompatMode;
+  opts.bFXCCompatMode = CGM.getLangOpts().EnableFXCCompatMode;
 
 
   m_pHLModule->SetHLOptions(opts);
   m_pHLModule->SetHLOptions(opts);
   m_pHLModule->SetAutoBindingSpace(CGM.getCodeGenOpts().HLSLDefaultSpace);
   m_pHLModule->SetAutoBindingSpace(CGM.getCodeGenOpts().HLSLDefaultSpace);
@@ -1559,7 +1560,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   SourceLocation retTySemanticLoc = SetSemantic(FD, retTyAnnotation);
   SourceLocation retTySemanticLoc = SetSemantic(FD, retTyAnnotation);
   retTyAnnotation.SetParamInputQual(DxilParamInputQual::Out);
   retTyAnnotation.SetParamInputQual(DxilParamInputQual::Out);
   if (isEntry) {
   if (isEntry) {
-    if (CGM.getLangOpts().EnableBackCompatMode && retTyAnnotation.HasSemanticString()) {
+    if (CGM.getLangOpts().EnableDX9CompatMode && retTyAnnotation.HasSemanticString()) {
       RemapObsoleteSemantic(retTyAnnotation, /*isPatchConstantFunction*/ false);
       RemapObsoleteSemantic(retTyAnnotation, /*isPatchConstantFunction*/ false);
     }
     }
     CheckParameterAnnotation(retTySemanticLoc, retTyAnnotation,
     CheckParameterAnnotation(retTySemanticLoc, retTyAnnotation,
@@ -1840,7 +1841,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
 
 
     paramAnnotation.SetParamInputQual(dxilInputQ);
     paramAnnotation.SetParamInputQual(dxilInputQ);
     if (isEntry) {
     if (isEntry) {
-      if (CGM.getLangOpts().EnableBackCompatMode && paramAnnotation.HasSemanticString()) {
+      if (CGM.getLangOpts().EnableDX9CompatMode && paramAnnotation.HasSemanticString()) {
         RemapObsoleteSemantic(paramAnnotation, /*isPatchConstantFunction*/ false);
         RemapObsoleteSemantic(paramAnnotation, /*isPatchConstantFunction*/ false);
       }
       }
       CheckParameterAnnotation(paramSemanticLoc, paramAnnotation,
       CheckParameterAnnotation(paramSemanticLoc, paramAnnotation,
@@ -1941,7 +1942,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
 }
 }
 
 
 void CGMSHLSLRuntime::RemapObsoleteSemantic(DxilParameterAnnotation &paramInfo, bool isPatchConstantFunction) {
 void CGMSHLSLRuntime::RemapObsoleteSemantic(DxilParameterAnnotation &paramInfo, bool isPatchConstantFunction) {
-  DXASSERT(CGM.getLangOpts().EnableBackCompatMode, "should be used only in back-compat mode");
+  DXASSERT(CGM.getLangOpts().EnableDX9CompatMode, "should be used only in back-compat mode");
 
 
   const ShaderModel *SM = m_pHLModule->GetShaderModel();
   const ShaderModel *SM = m_pHLModule->GetShaderModel();
   DXIL::SigPointKind sigPointKind = SigPointFromInputQual(paramInfo.GetParamInputQual(), SM->GetKind(), isPatchConstantFunction);
   DXIL::SigPointKind sigPointKind = SigPointFromInputQual(paramInfo.GetParamInputQual(), SM->GetKind(), isPatchConstantFunction);
@@ -4577,7 +4578,7 @@ void CGMSHLSLRuntime::FinishCodeGen() {
     // In back-compat mode (with /Gec flag) create a static global for each const global
     // In back-compat mode (with /Gec flag) create a static global for each const global
     // to allow writing to it.
     // to allow writing to it.
     // TODO: Verfiy the behavior of static globals in hull shader
     // TODO: Verfiy the behavior of static globals in hull shader
-    if(CGM.getLangOpts().EnableBackCompatMode && CGM.getLangOpts().HLSLVersion <= 2016)
+    if(CGM.getLangOpts().EnableDX9CompatMode && CGM.getLangOpts().HLSLVersion <= 2016)
       CreateWriteEnabledStaticGlobals(m_pHLModule->GetModule(), m_pHLModule->GetEntryFunction());
       CreateWriteEnabledStaticGlobals(m_pHLModule->GetModule(), m_pHLModule->GetEntryFunction());
     if (m_pHLModule->GetShaderModel()->IsHS()) {
     if (m_pHLModule->GetShaderModel()->IsHS()) {
       SetPatchConstantFunction(Entry);
       SetPatchConstantFunction(Entry);

+ 1 - 1
tools/clang/lib/Parse/ParseDecl.cpp

@@ -2177,7 +2177,7 @@ Parser::DeclGroupPtrTy Parser::ParseDeclGroup(ParsingDeclSpec &DS,
   // global variable can be inside a global structure as a static member.
   // global variable can be inside a global structure as a static member.
   // Check if the global is a static member and skip global const pass.
   // Check if the global is a static member and skip global const pass.
   // in backcompat mode, the check for global const is deferred to later stage in CGMSHLSLRuntime::FinishCodeGen()
   // in backcompat mode, the check for global const is deferred to later stage in CGMSHLSLRuntime::FinishCodeGen()
-  bool CheckGlobalConst = getLangOpts().HLSL && getLangOpts().EnableBackCompatMode && getLangOpts().HLSLVersion <= 2016 ? false : true;
+  bool CheckGlobalConst = getLangOpts().HLSL && getLangOpts().EnableDX9CompatMode && getLangOpts().HLSLVersion <= 2016 ? false : true;
   if (NestedNameSpecifier *nameSpecifier = D.getCXXScopeSpec().getScopeRep()) {
   if (NestedNameSpecifier *nameSpecifier = D.getCXXScopeSpec().getScopeRep()) {
     if (nameSpecifier->getKind() == NestedNameSpecifier::SpecifierKind::TypeSpec) {
     if (nameSpecifier->getKind() == NestedNameSpecifier::SpecifierKind::TypeSpec) {
       const Type *type = D.getCXXScopeSpec().getScopeRep()->getAsType();
       const Type *type = D.getCXXScopeSpec().getScopeRep()->getAsType();

+ 43 - 15
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -2241,8 +2241,8 @@ SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
     if (const uint32_t valueId = tryToEvaluateAsConst(expr))
     if (const uint32_t valueId = tryToEvaluateAsConst(expr))
       return SpirvEvalInfo(valueId).setConstant().setRValue();
       return SpirvEvalInfo(valueId).setConstant().setRValue();
 
 
-    const auto valueId =
-        castToInt(doExpr(subExpr), subExprType, toType, subExpr->getExprLoc());
+    const auto valueId = castToInt(loadIfGLValue(subExpr), subExprType, toType,
+                                   subExpr->getExprLoc());
     return SpirvEvalInfo(valueId).setRValue();
     return SpirvEvalInfo(valueId).setRValue();
   }
   }
   case CastKind::CK_FloatingCast:
   case CastKind::CK_FloatingCast:
@@ -6640,9 +6640,7 @@ SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
     retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAllEqual);
     retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAllEqual);
     break;
     break;
   case hlsl::IntrinsicOp::IOP_WaveActiveCountBits:
   case hlsl::IntrinsicOp::IOP_WaveActiveCountBits:
-    retVal = processWaveReductionOrPrefix(
-        callExpr, spv::Op::OpGroupNonUniformBallotBitCount,
-        spv::GroupOperation::Reduce);
+    retVal = processWaveCountBits(callExpr, spv::GroupOperation::Reduce);
     break;
     break;
   case hlsl::IntrinsicOp::IOP_WaveActiveUSum:
   case hlsl::IntrinsicOp::IOP_WaveActiveUSum:
   case hlsl::IntrinsicOp::IOP_WaveActiveSum:
   case hlsl::IntrinsicOp::IOP_WaveActiveSum:
@@ -6670,9 +6668,7 @@ SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
         spv::GroupOperation::ExclusiveScan);
         spv::GroupOperation::ExclusiveScan);
   } break;
   } break;
   case hlsl::IntrinsicOp::IOP_WavePrefixCountBits:
   case hlsl::IntrinsicOp::IOP_WavePrefixCountBits:
-    retVal = processWaveReductionOrPrefix(
-        callExpr, spv::Op::OpGroupNonUniformBallotBitCount,
-        spv::GroupOperation::ExclusiveScan);
+    retVal = processWaveCountBits(callExpr, spv::GroupOperation::ExclusiveScan);
     break;
     break;
   case hlsl::IntrinsicOp::IOP_WaveReadLaneAt:
   case hlsl::IntrinsicOp::IOP_WaveReadLaneAt:
   case hlsl::IntrinsicOp::IOP_WaveReadLaneFirst:
   case hlsl::IntrinsicOp::IOP_WaveReadLaneFirst:
@@ -7106,7 +7102,8 @@ uint32_t SPIRVEmitter::processWaveQuery(const CallExpr *callExpr,
   featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
   featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
                                   callExpr->getExprLoc());
                                   callExpr->getExprLoc());
   theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
   theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
-  const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
+  const uint32_t subgroupScope =
+      theBuilder.getConstantInt32(static_cast<int32_t>(spv::Scope::Subgroup));
   const uint32_t retType =
   const uint32_t retType =
       typeTranslator.translateType(callExpr->getCallReturnType(astContext));
       typeTranslator.translateType(callExpr->getCallReturnType(astContext));
   return theBuilder.createGroupNonUniformOp(opcode, retType, subgroupScope);
   return theBuilder.createGroupNonUniformOp(opcode, retType, subgroupScope);
@@ -7123,7 +7120,8 @@ uint32_t SPIRVEmitter::processWaveVote(const CallExpr *callExpr,
                                   callExpr->getExprLoc());
                                   callExpr->getExprLoc());
   theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
   theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
   const uint32_t predicate = doExpr(callExpr->getArg(0));
   const uint32_t predicate = doExpr(callExpr->getArg(0));
-  const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
+  const uint32_t subgroupScope =
+      theBuilder.getConstantInt32(static_cast<int32_t>(spv::Scope::Subgroup));
   const uint32_t retType =
   const uint32_t retType =
       typeTranslator.translateType(callExpr->getCallReturnType(astContext));
       typeTranslator.translateType(callExpr->getCallReturnType(astContext));
   return theBuilder.createGroupNonUniformUnaryOp(opcode, retType, subgroupScope,
   return theBuilder.createGroupNonUniformUnaryOp(opcode, retType, subgroupScope,
@@ -7199,11 +7197,39 @@ spv::Op SPIRVEmitter::translateWaveOp(hlsl::IntrinsicOp op, QualType type,
   return spv::Op::OpNop;
   return spv::Op::OpNop;
 }
 }
 
 
+uint32_t SPIRVEmitter::processWaveCountBits(const CallExpr *callExpr,
+                                            spv::GroupOperation groupOp) {
+  // Signatures:
+  // uint WaveActiveCountBits(bool bBit)
+  // uint WavePrefixCountBits(Bool bBit)
+  assert(callExpr->getNumArgs() == 1);
+
+  featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
+                                  callExpr->getExprLoc());
+  theBuilder.requireCapability(getCapabilityForGroupNonUniform(
+      spv::Op::OpGroupNonUniformBallotBitCount));
+
+  const uint32_t predicate = doExpr(callExpr->getArg(0));
+  const uint32_t subgroupScope =
+      theBuilder.getConstantInt32(static_cast<int32_t>(spv::Scope::Subgroup));
+
+  const uint32_t u32Type = theBuilder.getUint32Type();
+  const uint32_t v4u32Type = theBuilder.getVecType(u32Type, 4);
+  const uint32_t retType =
+      typeTranslator.translateType(callExpr->getCallReturnType(astContext));
+
+  const uint32_t ballot = theBuilder.createGroupNonUniformUnaryOp(
+      spv::Op::OpGroupNonUniformBallot, v4u32Type, subgroupScope, predicate);
+
+  return theBuilder.createGroupNonUniformUnaryOp(
+      spv::Op::OpGroupNonUniformBallotBitCount, retType, subgroupScope, ballot,
+      llvm::Optional<spv::GroupOperation>(groupOp));
+}
+
 uint32_t SPIRVEmitter::processWaveReductionOrPrefix(
 uint32_t SPIRVEmitter::processWaveReductionOrPrefix(
     const CallExpr *callExpr, spv::Op opcode, spv::GroupOperation groupOp) {
     const CallExpr *callExpr, spv::Op opcode, spv::GroupOperation groupOp) {
   // Signatures:
   // Signatures:
   // bool WaveActiveAllEqual( <type> expr )
   // bool WaveActiveAllEqual( <type> expr )
-  // uint WaveActiveCountBits( bool bBit )
   // <type> WaveActiveSum( <type> expr )
   // <type> WaveActiveSum( <type> expr )
   // <type> WaveActiveProduct( <type> expr )
   // <type> WaveActiveProduct( <type> expr )
   // <int_type> WaveActiveBitAnd( <int_type> expr )
   // <int_type> WaveActiveBitAnd( <int_type> expr )
@@ -7212,7 +7238,6 @@ uint32_t SPIRVEmitter::processWaveReductionOrPrefix(
   // <type> WaveActiveMin( <type> expr)
   // <type> WaveActiveMin( <type> expr)
   // <type> WaveActiveMax( <type> expr)
   // <type> WaveActiveMax( <type> expr)
   //
   //
-  // uint WavePrefixCountBits(Bool bBit)
   // <type> WavePrefixProduct(<type> value)
   // <type> WavePrefixProduct(<type> value)
   // <type> WavePrefixSum(<type> value)
   // <type> WavePrefixSum(<type> value)
   assert(callExpr->getNumArgs() == 1);
   assert(callExpr->getNumArgs() == 1);
@@ -7220,7 +7245,8 @@ uint32_t SPIRVEmitter::processWaveReductionOrPrefix(
                                   callExpr->getExprLoc());
                                   callExpr->getExprLoc());
   theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
   theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
   const uint32_t predicate = doExpr(callExpr->getArg(0));
   const uint32_t predicate = doExpr(callExpr->getArg(0));
-  const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
+  const uint32_t subgroupScope =
+      theBuilder.getConstantInt32(static_cast<int32_t>(spv::Scope::Subgroup));
   const uint32_t retType =
   const uint32_t retType =
       typeTranslator.translateType(callExpr->getCallReturnType(astContext));
       typeTranslator.translateType(callExpr->getCallReturnType(astContext));
   return theBuilder.createGroupNonUniformUnaryOp(
   return theBuilder.createGroupNonUniformUnaryOp(
@@ -7238,7 +7264,8 @@ uint32_t SPIRVEmitter::processWaveBroadcast(const CallExpr *callExpr) {
                                   callExpr->getExprLoc());
                                   callExpr->getExprLoc());
   theBuilder.requireCapability(spv::Capability::GroupNonUniformBallot);
   theBuilder.requireCapability(spv::Capability::GroupNonUniformBallot);
   const uint32_t value = doExpr(callExpr->getArg(0));
   const uint32_t value = doExpr(callExpr->getArg(0));
-  const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
+  const uint32_t subgroupScope =
+      theBuilder.getConstantInt32(static_cast<int32_t>(spv::Scope::Subgroup));
   const uint32_t retType =
   const uint32_t retType =
       typeTranslator.translateType(callExpr->getCallReturnType(astContext));
       typeTranslator.translateType(callExpr->getCallReturnType(astContext));
   if (numArgs == 2)
   if (numArgs == 2)
@@ -7264,7 +7291,8 @@ uint32_t SPIRVEmitter::processWaveQuadWideShuffle(const CallExpr *callExpr,
   theBuilder.requireCapability(spv::Capability::GroupNonUniformQuad);
   theBuilder.requireCapability(spv::Capability::GroupNonUniformQuad);
 
 
   const uint32_t value = doExpr(callExpr->getArg(0));
   const uint32_t value = doExpr(callExpr->getArg(0));
-  const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
+  const uint32_t subgroupScope =
+      theBuilder.getConstantInt32(static_cast<int32_t>(spv::Scope::Subgroup));
   const uint32_t retType =
   const uint32_t retType =
       typeTranslator.translateType(callExpr->getCallReturnType(astContext));
       typeTranslator.translateType(callExpr->getCallReturnType(astContext));
 
 

+ 3 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -467,6 +467,9 @@ private:
   /// Processes SM6.0 wave vote intrinsic calls.
   /// Processes SM6.0 wave vote intrinsic calls.
   uint32_t processWaveVote(const CallExpr *, spv::Op opcode);
   uint32_t processWaveVote(const CallExpr *, spv::Op opcode);
 
 
+  /// Processes SM6.0 wave active/prefix count bits.
+  uint32_t processWaveCountBits(const CallExpr *, spv::GroupOperation groupOp);
+
   /// Processes SM6.0 wave reduction or scan/prefix intrinsic calls.
   /// Processes SM6.0 wave reduction or scan/prefix intrinsic calls.
   uint32_t processWaveReductionOrPrefix(const CallExpr *, spv::Op op,
   uint32_t processWaveReductionOrPrefix(const CallExpr *, spv::Op op,
                                         spv::GroupOperation groupOp);
                                         spv::GroupOperation groupOp);

+ 20 - 0
tools/clang/test/CodeGenHLSL/quick-test/lit-function.hlsl

@@ -0,0 +1,20 @@
+// RUN: %dxc -T ps_6_0 -E main  %s | FileCheck %s
+
+// Verify lit function defined as lit(ambient, diffuse, specular, 1) where:
+// ambient = 1.
+// diffuse = ((n l) < 0) ? 0 : n l.
+// specular = ((n l) < 0) || ((n h) < 0) ? 0 : ((n h) ^ m).
+
+// CHECK: fcmp
+// CHECK: select
+// CHECK: fcmp
+// CHECK: or
+// CHECK: Log
+// CHECK: fmul
+// CHECK: Exp
+// CHECK: select
+
+float4 main(float a : A, float b : B, float c : C) : SV_Target
+{
+  return lit(a, b, c);
+}

+ 35 - 0
tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-check-count01.hlsl

@@ -0,0 +1,35 @@
+// RUN: %dxc -HV 2016 -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fdiv
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+
+float4 main (float x1 : A, float4x4 x2 : B, float2 x3 : C, float4 x4 : D) : SV_Target
+{
+    float p1 = 8.0;
+    float4x4 p2 =         {57.0, 57.0, 57.0, 57.0,
+                           57.0, 57.0, 57.0, 57.0,
+                           57.0, 57.0, 57.0, 57.0,
+                           57.0, 57.0, 57.0, 57.0};
+    float2 p3 = float2(-5.0,-5.0);
+    float4 p4 = float4(17.0,17.0,17.0,17.0);
+
+    return float4(pow(x1, p1), pow(x2, p2)[0][0], pow(x3, p3)[0], pow(x4, p4)[0]);
+}

+ 18 - 0
tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-check-count02.hlsl

@@ -0,0 +1,18 @@
+// RUN: %dxc -HV 2016 -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+// CHECK: fmul
+
+float2 main (float4 x : A) : SV_Target
+{
+    float2 y = float2(11.0,11.0);
+    return pow(x, y);
+}

+ 38 - 0
tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-correctness.hlsl

@@ -0,0 +1,38 @@
+// RUN: %dxc -HV 2016 -E main -T ps_6_0 %s | FileCheck %s
+
+// Verify the mul-only pattern implemented to support Fxc compatability.
+
+// 2.0^8.0.
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float 2.560000e+02)
+
+// 2.0^57.0 = 144115188075855872 (0x4380000000000000)
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float 0x4380000000000000)
+
+// 2.0^-5.0
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float 3.125000e-02)
+
+//2.0^17.0
+// call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float 1.310720e+05)
+
+float4 main () : SV_Target
+{
+    float x1 = 2.0;
+    float p1 = 8.0;
+
+    float4x4 x2 = {2.0, 2.0, 2.0, 2.0,
+                           2.0, 2.0, 2.0, 2.0,
+                           2.0, 2.0, 2.0, 2.0,
+                           2.0, 2.0, 2.0, 2.0};
+    float4x4 p2 = {57.0, 57.0, 57.0, 57.0,
+                           57.0, 57.0, 57.0, 57.0,
+                           57.0, 57.0, 57.0, 57.0,
+                           57.0, 57.0, 57.0, 57.0};
+
+    float2 x3 = float2(2.0,2.0);
+    float2 p3 = float2(-5.0,-5.0);
+
+    float4 x4 = float4(2.0,2.0,2.0,2.0);
+    float4 p4 = float4(17.0,17.0,17.0,17.0);
+
+    return float4(pow(x1, p1), pow(x2, p2)[0][0], pow(x3, p3)[0], pow(x4, p4)[0]);
+}

+ 27 - 0
tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-criteria01.hlsl

@@ -0,0 +1,27 @@
+// RUN: %dxc -HV 2016 -E main -T ps_6_0 %s | FileCheck %s
+// dxc should use log-mul-exp pattern to implement all scenarios listed below.
+
+// CHECK: Log
+// CHECK: Exp
+// CHECK: Log
+// CHECK: Exp
+// CHECK: Log
+// CHECK: Exp
+// CHECK: Log
+// CHECK: Exp
+// CHECK: Log
+// CHECK: Exp
+
+float main (float4x4 a : A, float b : B, float4 c: C) : SV_Target
+{
+    float4x4 p1 = {2.0, 2.0, 3.0, 2.0,
+                  2.0, 2.0, 2.0, 2.0,
+                  2.0, 2.0, 2.0, 2.0,
+                  2.0, 2.0, -1.0, 2.0,}; // not a splat vector
+    float4 p2 = {2.33, 2.33, 2.33, 2.33}; // a splat vector but not exact
+    float p3 = 2.001; // not an exact value
+    float p4 = 4294967296.0; // value greater than int max
+    float p5 = 7; // exceeds the mulop threshold criteria for float
+
+    return pow(a,p1)[0][0] + pow(b,p2)[0] + pow(a,p3)[0][0] + pow(a,p4)[0][0] + pow(c,p4)[0] + pow(b,p5);
+}

+ 17 - 0
tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-criteria02.hlsl

@@ -0,0 +1,17 @@
+// RUN: %dxc -HV 2016 -E main -T ps_6_0 %s | FileCheck %s
+// dxc should use log-mul-exp pattern to implement all scenarios listed below.
+
+// CHECK-NOT: Log
+// CHECK-NOT: Exp
+
+float main (float4x4 a : A, float b : B, float4 c: C) : SV_Target
+{
+    float4x4 p1 = {2.0, 2.0, 2.0, 2.0,
+                  2.0, 2.0, 2.0, 2.0,
+                  2.0, 2.0, 2.0, 2.0,
+                  2.0, 2.0, 2.0, 2.0,}; // a splat
+    float4 p2 = {9, 9, 9, 9}; // another splat
+    float p3 = 8; // meets the threshold criteria
+
+    return pow(a,p1)[0][0] + pow(b,p2)[0] + pow(a,p3)[0][0];
+}

+ 14 - 0
tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-lit-types.hlsl

@@ -0,0 +1,14 @@
+// RUN: %dxc -HV 2016 -E main -T ps_6_0 %s | FileCheck %s
+// check that different float literals are being considered for mul-only code gen for pow.
+// CHECK-NOT: Log
+// CHECK-NOT: Exp
+
+float main ( float a : A, float4x4 b: B, float4 c: C, float2 d: D) : SV_Target
+{
+    return pow(a, 8.0f) + 
+           pow(d, 14.0h)[0] +
+           pow(c, 384.0H)[0] +
+           pow(c, -32.0F)[0] +
+           pow(b, -131072.0L)[0][0] +
+           pow(b, 1073741824.0L)[0][0];
+}

+ 11 - 0
tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-one-as-power.hlsl

@@ -0,0 +1,11 @@
+// RUN: %dxc -HV 2016 -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %{{[a-z0-9]+.*[a-z0-9]*}})
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float %{{[a-z0-9]+.*[a-z0-9]*}})
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float %{{[a-z0-9]+.*[a-z0-9]*}})
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float %{{[a-z0-9]+.*[a-z0-9]*}})
+
+float4 main ( float a : A, float2 b : B, float4 c: C, float4x4 d: D) : SV_Target
+{
+    return float4(pow(a, 1), pow(b, float2(1.00,1.00))[0], pow(c, float4(1.00,1.00,1.00,1.00))[2], pow(d, 1.00)[1][2]);
+}

+ 11 - 0
tools/clang/test/CodeGenHLSL/quick-test/pow-mulonly-zero-as-power.hlsl

@@ -0,0 +1,11 @@
+// RUN: %dxc -HV 2016 -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float 1.000000e+00)
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float 1.000000e+00)
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float 1.000000e+00)
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float 1.000000e+00)
+
+float4 main ( float a : A, float2 b : B, float4 c: C, float4x4 d: D) : SV_Target
+{
+    return float4(pow(a, 0), pow(b, float2(0.00,0.00))[0], pow(c, -0.00)[2], pow(d, 0.00)[1][2]);
+}

+ 15 - 2
tools/clang/test/CodeGenSPIRV/cf.switch.opswitch.hlsl

@@ -342,6 +342,19 @@ void main() {
 // CHECK-NEXT: %switch_merge_8 = OpLabel
 // CHECK-NEXT: %switch_merge_8 = OpLabel
   }
   }
 
 
-// CHECK-NEXT: OpReturn
-// CHECK-NEXT: OpFunctionEnd
+
+  //////////////////////////////////////////////////////////////////
+  // Using float as selector results in multiple casts in the AST //
+  //////////////////////////////////////////////////////////////////
+  float sel;
+// CHECK:      [[floatSelector:%\d+]] = OpLoad %float %sel
+// CHECK-NEXT:           [[sel:%\d+]] = OpConvertFToS %int [[floatSelector]]
+// CHECK-NEXT:                          OpSelectionMerge %switch_merge_9 None
+// CHECK-NEXT:                          OpSwitch [[sel]] %switch_merge_9 0 %switch_0_0
+  switch (sel) {
+  case 0:
+    result = 0;
+    break;
+  }
+
 }
 }

+ 5 - 2
tools/clang/test/CodeGenSPIRV/sm6.wave-active-count-bits.hlsl

@@ -7,6 +7,7 @@ struct S {
 };
 };
 
 
 RWStructuredBuffer<S> values;
 RWStructuredBuffer<S> values;
+RWStructuredBuffer<S> results;
 
 
 // CHECK: OpCapability GroupNonUniformBallot
 // CHECK: OpCapability GroupNonUniformBallot
 
 
@@ -14,7 +15,9 @@ RWStructuredBuffer<S> values;
 void main(uint3 id: SV_DispatchThreadID) {
 void main(uint3 id: SV_DispatchThreadID) {
     uint x = id.x;
     uint x = id.x;
 
 
-// CHECK:  {{%\d+}} = OpGroupNonUniformBallotBitCount %uint %int_3 Reduce {{%\d+}}
-    values[x].val = WaveActiveCountBits(values[x].val == 0);
+// CHECK:         [[cmp:%\d+]] = OpIEqual %bool {{%\d+}} %uint_0
+// CHECK-NEXT: [[ballot:%\d+]] = OpGroupNonUniformBallot %v4uint %int_3 [[cmp]]
+// CHECK:             {{%\d+}} = OpGroupNonUniformBallotBitCount %uint %int_3 Reduce [[ballot]]
+    results[x].val = WaveActiveCountBits(values[x].val == 0);
 }
 }
 
 

+ 3 - 1
tools/clang/test/CodeGenSPIRV/sm6.wave-prefix-count-bits.hlsl

@@ -14,6 +14,8 @@ RWStructuredBuffer<S> values;
 void main(uint3 id: SV_DispatchThreadID) {
 void main(uint3 id: SV_DispatchThreadID) {
     uint x = id.x;
     uint x = id.x;
 
 
-// CHECK:  {{%\d+}} = OpGroupNonUniformBallotBitCount %uint %int_3 ExclusiveScan {{%\d+}}
+// CHECK:         [[cmp:%\d+]] = OpIEqual %bool {{%\d+}} %uint_0
+// CHECK-NEXT: [[ballot:%\d+]] = OpGroupNonUniformBallot %v4uint %int_3 [[cmp]]
+// CHECK:             {{%\d+}} = OpGroupNonUniformBallotBitCount %uint %int_3 ExclusiveScan [[ballot]]
     values[x].val = WavePrefixCountBits(values[x].val == 0);
     values[x].val = WavePrefixCountBits(values[x].val == 0);
 }
 }

+ 3 - 2
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -800,7 +800,7 @@ public:
     compiler.createSourceManager(compiler.getFileManager());
     compiler.createSourceManager(compiler.getFileManager());
     compiler.setTarget(
     compiler.setTarget(
         TargetInfo::CreateTargetInfo(compiler.getDiagnostics(), targetOptions));
         TargetInfo::CreateTargetInfo(compiler.getDiagnostics(), targetOptions));
-    if (Opts.EnableBackCompatMode) {
+    if (Opts.EnableDX9CompatMode) {
       auto const ID = compiler.getDiagnostics().getCustomDiagID(clang::DiagnosticsEngine::Warning, "/Gec flag is a deprecated functionality.");
       auto const ID = compiler.getDiagnostics().getCustomDiagID(clang::DiagnosticsEngine::Warning, "/Gec flag is a deprecated functionality.");
       compiler.getDiagnostics().Report(ID);
       compiler.getDiagnostics().Report(ID);
     }
     }
@@ -855,7 +855,8 @@ public:
     compiler.getLangOpts().RootSigMajor = 1;
     compiler.getLangOpts().RootSigMajor = 1;
     compiler.getLangOpts().RootSigMinor = rootSigMinor;
     compiler.getLangOpts().RootSigMinor = rootSigMinor;
     compiler.getLangOpts().HLSLVersion = (unsigned) Opts.HLSLVersion;
     compiler.getLangOpts().HLSLVersion = (unsigned) Opts.HLSLVersion;
-    compiler.getLangOpts().EnableBackCompatMode = Opts.EnableBackCompatMode;
+    compiler.getLangOpts().EnableDX9CompatMode = Opts.EnableDX9CompatMode;
+    compiler.getLangOpts().EnableFXCCompatMode = Opts.EnableFXCCompatMode;
 
 
     compiler.getLangOpts().UseMinPrecision = !Opts.Enable16BitTypes;
     compiler.getLangOpts().UseMinPrecision = !Opts.Enable16BitTypes;
 
 

+ 2 - 1
tools/clang/tools/libclang/dxcrewriteunused.cpp

@@ -129,7 +129,8 @@ void SetupCompilerForRewrite(CompilerInstance &compiler,
   compiler.getDiagnostics().setIgnoreAllWarnings(!opts.OutputWarnings);
   compiler.getDiagnostics().setIgnoreAllWarnings(!opts.OutputWarnings);
   compiler.getLangOpts().HLSLVersion = (unsigned)opts.HLSLVersion;
   compiler.getLangOpts().HLSLVersion = (unsigned)opts.HLSLVersion;
   compiler.getLangOpts().UseMinPrecision = !opts.Enable16BitTypes;
   compiler.getLangOpts().UseMinPrecision = !opts.Enable16BitTypes;
-  compiler.getLangOpts().EnableBackCompatMode = opts.EnableBackCompatMode;
+  compiler.getLangOpts().EnableDX9CompatMode = opts.EnableDX9CompatMode;
+  compiler.getLangOpts().EnableFXCCompatMode = opts.EnableFXCCompatMode;
 
 
   PreprocessorOptions &PPOpts = compiler.getPreprocessorOpts();
   PreprocessorOptions &PPOpts = compiler.getPreprocessorOpts();
   if (rewrite != nullptr) {
   if (rewrite != nullptr) {