Преглед изворни кода

Add initial support for extension intrinsics and defines (#5)

This commit adds initial support for extension intrinsics and defines, including:

-Support for recognizing hlsl extensions as new intrinsic functions.
-Support for requesting lowering of extensions.
-Support for preserving semantic defines.
-Support for validating semantic defines.

This commit adds support for hlsl extensions in the form of additional
intrisic functions and semantic defines.

We now allow a dxcompiler instance to register that it can handle
additional intrinsic functions beyond the standard hlsl intrinsics. These
new intrinsics are called extensions.  For each extension, the compiler
supplies a lowering strategy that is used by the dxcompiler to translate
from the source level intrinsic to a high-level dxil intrinsic.

We initially support the two lowering strategies: replicate and pack.
The replicate strategy will scalarize the vector elements in the call
and replicate the function call once for each element in the vector.
The pack strategy changes the vector arguments into literal struct
arguments.

We also now include support for "semantic defines". A semantic define is
a source level define whose value is preserved in the final dxil as
metatdata. The source level define can come as either a #define in
the hlsl source or a /D define from the command line.

We provide a hook to validate that a semantic define has a legal value
(for example to ensure that the value is an integer). Validation failures
can produce warnings and errors that will be emitted through the standard
clang diagnostic mechanism.

This code was originally written by marcelolr and modified by dmpots to
support packed lowering of intriniscs and validation of semantic defines.
David Peixotto пре 8 година
родитељ
комит
9451f2c0b3
29 измењених фајлова са 1746 додато и 93 уклоњено
  1. 3 1
      include/dxc/HLSL/DxilGenerationPass.h
  2. 3 1
      include/dxc/HLSL/HLOperationLower.h
  3. 85 0
      include/dxc/HLSL/HLOperationLowerExtension.h
  4. 11 2
      include/dxc/HLSL/HLOperations.h
  5. 70 0
      include/dxc/HLSL/HLSLExtensionsCodegenHelper.h
  6. 133 0
      include/dxc/Support/DxcLangExtensionsHelper.h
  7. 25 1
      include/dxc/dxcapi.internal.h
  8. 5 0
      include/llvm/Transforms/IPO/PassManagerBuilder.h
  9. 1 0
      lib/HLSL/CMakeLists.txt
  10. 11 4
      lib/HLSL/DxilGenerationPass.cpp
  11. 38 3
      lib/HLSL/HLOperationLower.cpp
  12. 527 0
      lib/HLSL/HLOperationLowerExtension.cpp
  13. 82 40
      lib/HLSL/HLOperations.cpp
  14. 4 4
      lib/Transforms/IPO/PassManagerBuilder.cpp
  15. 2 2
      lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp
  16. 1 0
      tools/clang/include/clang/AST/HlslTypes.h
  17. 1 1
      tools/clang/include/clang/Basic/Attr.td
  18. 3 0
      tools/clang/include/clang/Frontend/CodeGenOptions.h
  19. 2 0
      tools/clang/include/clang/Frontend/CompilerInstance.h
  20. 11 1
      tools/clang/lib/AST/ASTContextHLSL.cpp
  21. 1 0
      tools/clang/lib/CodeGen/BackendUtil.cpp
  22. 34 3
      tools/clang/lib/CodeGen/CGHLSLMS.cpp
  23. 22 7
      tools/clang/lib/Sema/SemaHLSL.cpp
  24. 71 1
      tools/clang/tools/dxcompiler/dxcompilerobj.cpp
  25. 37 11
      tools/clang/tools/libclang/dxcrewriteunused.cpp
  26. 1 0
      tools/clang/unittests/HLSL/CMakeLists.txt
  27. 2 0
      tools/clang/unittests/HLSL/DxcTestUtils.h
  28. 538 0
      tools/clang/unittests/HLSL/ExtensionTest.cpp
  29. 22 11
      tools/clang/unittests/HLSL/ValidationTest.cpp

+ 3 - 1
include/dxc/HLSL/DxilGenerationPass.h

@@ -29,6 +29,8 @@ public:
   virtual void Analyze(llvm::Function *F) = 0;
   virtual bool IsWaveSensitive(llvm::Instruction *op) = 0;
 };
+
+class HLSLExtensionsCodegenHelper;
 }
 
 namespace llvm {
@@ -36,7 +38,7 @@ namespace llvm {
 /// \brief Create and return a pass that tranform the module into a DXIL module
 /// Note that this pass is designed for use with the legacy pass manager.
 ModulePass *createDxilCondenseResourcesPass();
-ModulePass *createDxilGenerationPass(bool NotOptimized);
+ModulePass *createDxilGenerationPass(bool NotOptimized, hlsl::HLSLExtensionsCodegenHelper *extensionsHelper);
 ModulePass *createHLEmitMetadataPass();
 ModulePass *createHLEnsureMetadataPass();
 ModulePass *createDxilEmitMetadataPass();

+ 3 - 1
include/dxc/HLSL/HLOperationLower.h

@@ -20,8 +20,10 @@ class Function;
 namespace hlsl {
 class HLModule;
 class DxilResourceBase;
+class HLSLExtensionsCodegenHelper;
 
 void TranslateBuiltinOperations(
     HLModule &HLM,
-    std::unordered_map<llvm::Instruction *, llvm::Value *> &handleMap);
+    std::unordered_map<llvm::Instruction *, llvm::Value *> &handleMap,
+    HLSLExtensionsCodegenHelper *extCodegenHelper);
 }

+ 85 - 0
include/dxc/HLSL/HLOperationLowerExtension.h

@@ -0,0 +1,85 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLOperationLowerExtension.h                                                //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// Licensed under the MIT license. See COPYRIGHT in the project root for     //
+// full license information.                                                 //
+//                                                                           //
+// Functions to lower HL operations coming from HLSL extensions to DXIL      //
+// operations.                                                               //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+#pragma once
+
+#include "dxc/HLSL/HLSLExtensionsCodegenHelper.h"
+#include "llvm/ADT/StringRef.h"
+#include <string>
+
+namespace llvm {
+  class Value;
+  class CallInst;
+  class Function;
+  class StringRef;
+}
+
+namespace hlsl {
+  // Lowers HLSL extensions from HL operation to DXIL operation.
+  class ExtensionLowering {
+  public:
+    // Strategy used for lowering extensions.
+    enum class Strategy {
+      Unknown,        // Do not know how to lower. This is an error condition.
+      NoTranslation,  // Propagate the call arguments as is down to dxil.
+      Replicate,      // Scalarize the vector arguments and replicate the call.
+      Pack,           // Convert the vector arguments into structs.
+    };
+
+    // Create the lowering using the given strategy and custom codegen helper.
+    ExtensionLowering(llvm::StringRef strategy, HLSLExtensionsCodegenHelper *helper);
+    ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper);
+
+    // Translate the HL op call to a DXIL op call.
+    // Returns a new value if translation was successful.
+    // Returns nullptr if translation failed or made no changes.
+    llvm::Value *Translate(llvm::CallInst *CI);
+    
+    // Translate the strategy string to an enum. The strategy string is
+    // added as a custom attribute on the high level extension function.
+    // It is translated as follows:
+    //  "r" -> Replicate
+    //  "n" -> NoTranslation
+    //  "c" -> Custom
+    static Strategy GetStrategy(llvm::StringRef strategy);
+
+    // Translate the strategy enum into a name. This is the inverse of the
+    // GetStrategy() function.
+    static llvm::StringRef GetStrategyName(Strategy strategy);
+
+    // Get the name that will be used for the extension function call after
+    // lowering.
+    std::string GetExtensionName(llvm::CallInst *CI);
+
+  private:
+    Strategy m_strategy;
+    HLSLExtensionsCodegenHelper *m_helper;
+
+    llvm::Value *Unknown(llvm::CallInst *CI);
+    llvm::Value *NoTranslation(llvm::CallInst *CI);
+    llvm::Value *Replicate(llvm::CallInst *CI);
+    llvm::Value *Pack(llvm::CallInst *CI);
+
+    // Translate the HL call by replicating the call for each vector element.
+    //
+    // For example,
+    //
+    //    <2xi32> %r = call @ext.foo(i32 %op, <2xi32> %v)
+    //    ==>
+    //    %r.1 = call @ext.foo.s(i32 %op, i32 %v.1)
+    //    %r.2 = call @ext.foo.s(i32 %op, i32 %v.2)
+    //    <2xi32> %r.v.1 = insertelement %r.1, 0, <2xi32> undef
+    //    <2xi32> %r.v.2 = insertelement %r.2, 1, %r.v.1
+    //
+    // You can then RAWU %r with %r.v.2. The RAWU is not done by the translate function.
+    static llvm::Value *TranslateReplicating(llvm::CallInst *CI, llvm::Function *ReplicatedFunction);
+  };
+}

+ 11 - 2
include/dxc/HLSL/HLOperations.h

@@ -26,6 +26,7 @@ namespace hlsl {
 
 enum class HLOpcodeGroup {
   NotHL,
+  HLExtIntrinsic,
   HLIntrinsic,
   HLCast,
   HLInit,
@@ -109,11 +110,13 @@ enum class HLMatLoadStoreOpcode {
 
 extern const char * const HLPrefix;
 
-bool IsHLOp(llvm::Function *F);
-HLOpcodeGroup GetHLOpcodeGroupByAttr(llvm::Function *F);
+HLOpcodeGroup GetHLOpcodeGroup(llvm::Function *F);
 HLOpcodeGroup GetHLOpcodeGroupByName(llvm::Function *F);
+llvm::StringRef GetHLOpcodeGroupNameByAttr(llvm::Function *F);
+llvm::StringRef GetHLLowerStrategy(llvm::Function *F);
 unsigned  GetHLOpcode(llvm::CallInst *CI);
 unsigned  GetRowMajorOpcode(HLOpcodeGroup group, unsigned opcode);
+void SetHLLowerStrategy(llvm::Function *F, llvm::StringRef S);
 
 // For intrinsic opcode.
 bool HasUnsignedOpcode(unsigned opcode);
@@ -323,6 +326,12 @@ const unsigned kWaveAllEqualValueOpIdx = 1;
 llvm::Function *GetOrCreateHLFunction(llvm::Module &M,
                                       llvm::FunctionType *funcTy,
                                       HLOpcodeGroup group, unsigned opcode);
+llvm::Function *GetOrCreateHLFunction(llvm::Module &M,
+                                      llvm::FunctionType *funcTy,
+                                      HLOpcodeGroup group,
+                                      llvm::StringRef *groupName,
+                                      llvm::StringRef *fnName,
+                                      unsigned opcode);
 
 llvm::Function *GetOrCreateHLFunctionWithBody(llvm::Module &M,
                                               llvm::FunctionType *funcTy,

+ 70 - 0
include/dxc/HLSL/HLSLExtensionsCodegenHelper.h

@@ -0,0 +1,70 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLSLExtensionsCodegenHelper.h                                             //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// Licensed under the MIT license. See COPYRIGHT in the project root for     //
+// full license information.                                                 //
+//                                                                           //
+// Codegen support for hlsl extensions.                                      //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#pragma once
+#include <vector>
+#include <string>
+
+namespace llvm {
+class CallInst;
+class Value;
+class Module;
+}
+
+namespace hlsl {
+
+// Provide DXIL codegen support for private HLSL extensions.
+// The HLSL extension mechanism has hooks for two cases:
+//
+//  1. You can mark certain defines as "semantic" defines which
+//     will be preserved as metadata in the final DXIL.
+//  2. You can add new HLSL intrinsic functions.
+//
+// This class provides an interface for generating the DXIL bitcode
+// needed for the two types of extensions above.
+//  
+class HLSLExtensionsCodegenHelper {
+public:
+  // Used to indicate a semantic define was used incorrectly.
+  // Since semantic defines have semantic meaning it is possible
+  // that a programmer can use them incorrectly. This class provides
+  // a way to signal the error to the programmer. Semantic define
+  // errors will be propagated as errors to the clang frontend.
+  class SemanticDefineError {
+  public:
+    enum class Level { Warning, Error };
+    SemanticDefineError(unsigned location, Level level, const std::string &message)
+    :  m_location(location)
+    ,  m_level(level)
+    ,  m_message(message)
+    { }
+
+    unsigned Location() const { return m_location; }
+    bool IsWarning() const { return m_level == Level::Warning; }
+    const std::string &Message() const { return m_message; }
+
+  private:
+    unsigned m_location; // Use an encoded clang::SourceLocation to avoid a clang include dependency.
+    Level m_level;
+    std::string m_message;
+  };
+  typedef std::vector<SemanticDefineError> SemanticDefineErrorList;
+
+  // Write semantic defines as metadata in the module.
+  virtual SemanticDefineErrorList WriteSemanticDefines(llvm::Module *M) = 0;
+
+  // Get the name to use for the dxil intrinsic function.
+  virtual std::string GetIntrinsicName(unsigned opcode) = 0;
+
+  // Virtual destructor.
+  virtual ~HLSLExtensionsCodegenHelper() {};
+};
+}

+ 133 - 0
include/dxc/Support/DxcLangExtensionsHelper.h

@@ -13,6 +13,17 @@
 #define __DXCLANGEXTENSIONSHELPER_H__
 
 #include "dxc/Support/Unicode.h"
+#include "dxc/Support/FileIOHelper.h"
+#include <vector>
+
+namespace llvm {
+class raw_string_ostream;
+class CallInst;
+class Value;
+}
+namespace clang {
+class CompilerInstance;
+}
 
 namespace hlsl {
 
@@ -22,6 +33,8 @@ private:
   llvm::SmallVector<std::string, 2> m_semanticDefineExclusions;
   llvm::SmallVector<std::string, 2> m_defines;
   llvm::SmallVector<CComPtr<IDxcIntrinsicTable>, 2> m_intrinsicTables;
+  CComPtr<IDxcSemanticDefineValidator> m_semanticDefineValidator;
+  std::string m_semanticDefineMetaDataName;
 
   HRESULT STDMETHODCALLTYPE RegisterIntoVector(LPCWSTR name, llvm::SmallVector<std::string, 2>& here)
   {
@@ -42,6 +55,7 @@ public:
   const llvm::SmallVector<std::string, 2>& GetSemanticDefineExclusions() const { return m_semanticDefineExclusions; }
   const llvm::SmallVector<std::string, 2>& GetDefines() const { return m_defines; }
   llvm::SmallVector<CComPtr<IDxcIntrinsicTable>, 2>& GetIntrinsicTables(){ return m_intrinsicTables; }
+  const std::string &GetSemanticDefineMetadataName() { return m_semanticDefineMetaDataName; }
 
   HRESULT STDMETHODCALLTYPE RegisterSemanticDefine(LPCWSTR name)
   {
@@ -78,6 +92,95 @@ public:
     CATCH_CPP_RETURN_HRESULT();
   }
 
+  // Set the validator used to validate semantic defines.
+  // Only one validator stored and used to run validation.
+  HRESULT STDMETHODCALLTYPE SetSemanticDefineValidator(_In_ IDxcSemanticDefineValidator* pValidator) {
+    if (pValidator == nullptr)
+      return E_POINTER;
+
+    m_semanticDefineValidator = pValidator;
+    return S_OK;
+  }
+
+  HRESULT STDMETHODCALLTYPE SetSemanticDefineMetaDataName(LPCSTR name) {
+    try {
+      m_semanticDefineMetaDataName = name;
+      return S_OK;
+    }
+    CATCH_CPP_RETURN_HRESULT();
+  }
+
+  // Get the name of the dxil intrinsic function.
+  std::string GetIntrinsicName(UINT opcode) {
+    LPCSTR pName = "";
+    for (IDxcIntrinsicTable *table : m_intrinsicTables) {
+      if (SUCCEEDED(table->GetIntrinsicName(opcode, &pName))) {
+        return pName;
+      }
+    }
+
+      return "";
+  }
+
+  // Result of validating a semantic define.
+  // Stores any warning or error messages produced by the validator.
+  // Successful validation means that there are no warning or error messages.
+  struct SemanticDefineValidationResult {
+    std::string Warning;
+    std::string Error;
+
+    bool HasError() { return Error.size() > 0; }
+    bool HasWarning() { return Warning.size() > 0; }
+
+    static SemanticDefineValidationResult Success() {
+      return SemanticDefineValidationResult();
+    }
+  };
+
+  // Use the contained semantice define validator to validate the given semantic define.
+  SemanticDefineValidationResult ValidateSemanticDefine(const std::string &name, const std::string &value) {
+    if (!m_semanticDefineValidator)
+      return SemanticDefineValidationResult::Success();
+
+    // Blobs for getting restul from validator. Strings for returning results to caller.
+    CComPtr<IDxcBlobEncoding> pError;
+    CComPtr<IDxcBlobEncoding> pWarning;
+    std::string error;
+    std::string warning;
+
+    // Run semantic define validator.
+    HRESULT result = m_semanticDefineValidator->GetSemanticDefineWarningsAndErrors(name.c_str(), value.c_str(), &pWarning, &pError);
+
+
+    if (FAILED(result)) {
+      // Failure indicates it was not able to even run validation so
+      // we cannot say whether the define is invalid or not. Return a
+      // generic error message about failure to run the valiadator.
+      error = "failed to run semantic define validator for: ";
+      error.append(name); error.append("="); error.append(value);
+      return SemanticDefineValidationResult{ warning, error };
+    }
+
+    // Define a  little function to convert encoded blob into a string.
+    auto GetErrorAsString = [&name](const CComPtr<IDxcBlobEncoding> &pBlobString) -> std::string {
+      CComPtr<IDxcBlobEncoding> pUTF8BlobStr;
+      if (SUCCEEDED(hlsl::DxcGetBlobAsUtf8(pBlobString, &pUTF8BlobStr)))
+        return std::string(static_cast<char*>(pUTF8BlobStr->GetBufferPointer()), pUTF8BlobStr->GetBufferSize());
+      else
+        return std::string("invalid semantic define " + name);
+    };
+
+    // Check to see if any warnings or errors were produced.
+    if (pError && pError->GetBufferSize()) {
+      error = GetErrorAsString(pError);
+    }
+    if (pWarning && pWarning->GetBufferSize()) {
+      warning = GetErrorAsString(pWarning);
+    }
+
+    return SemanticDefineValidationResult{ warning, error };
+  }
+
   __override void SetupSema(clang::Sema &S) {
     clang::ExternalASTSource *astSource = S.getASTContext().getExternalSource();
     if (clang::ExternalSemaSource *externalSema =
@@ -93,6 +196,14 @@ public:
       PPOpts.addMacroDef(llvm::StringRef(define.c_str()));
     }
   }
+
+  __override DxcLangExtensionsHelper *GetDxcLangExtensionsHelper() {
+    return this;
+  }
+ 
+  DxcLangExtensionsHelper()
+  : m_semanticDefineMetaDataName("hlsl.semdefs")
+  {}
 };
 
 // Use this macro to embed an implementation that will delegate to a field.
@@ -110,6 +221,28 @@ public:
   __override HRESULT STDMETHODCALLTYPE RegisterDefine(LPCWSTR name) { \
     return (_helper_field_).RegisterDefine(name); \
   } \
+  __override HRESULT STDMETHODCALLTYPE SetSemanticDefineValidator(_In_ IDxcSemanticDefineValidator* pValidator) { \
+    return (_helper_field_).SetSemanticDefineValidator(pValidator); \
+  } \
+  __override HRESULT STDMETHODCALLTYPE SetSemanticDefineMetaDataName(LPCSTR name) { \
+    return (_helper_field_).SetSemanticDefineMetaDataName(name); \
+  } \
+
+// A parsed semantic define is a semantic define that has actually been
+// parsed by the compiler. It has a name (required), a value (could be
+// the empty string), and a location. We use an encoded clang::SourceLocation
+// for the location to avoid a clang include dependency.
+struct ParsedSemanticDefine{
+  std::string Name;
+  std::string Value;
+  unsigned Location;
+};
+typedef std::vector<ParsedSemanticDefine> ParsedSemanticDefineList;
+
+// Return the collection of semantic defines parsed by the compiler instance.
+ParsedSemanticDefineList
+  CollectSemanticDefinesParsedByCompiler(clang::CompilerInstance &compiler,
+                                         _In_ DxcLangExtensionsHelper *helper);
 
 } // namespace hlsl
 

+ 25 - 1
include/dxc/dxcapi.internal.h

@@ -113,7 +113,6 @@ struct HLSL_INTRINSIC {
 
 ///////////////////////////////////////////////////////////////////////////////
 // Interfaces.
-
 struct __declspec(uuid("f0d4da3f-f863-4660-b8b4-dfd94ded6215"))
 IDxcIntrinsicTable : public IUnknown
 {
@@ -123,6 +122,25 @@ public:
     LPCWSTR typeName, LPCWSTR functionName,
     const HLSL_INTRINSIC** pIntrinsic,
     _Inout_ UINT64* pLookupCookie) = 0;
+
+  // Get the lowering strategy for an hlsl extension intrinsic.
+  virtual HRESULT STDMETHODCALLTYPE GetLoweringStrategy(UINT opcode, LPCSTR *pStrategy) = 0;
+  
+  // Callback to support custom naming of hlsl extension intrinsic functions in dxil.
+  // Return the empty string to get the default intrinsic name, which is the mangled
+  // name of the high level intrinsic function.
+  //
+  // Overloaded intrinsics are supported by use of an overload place holder in the
+  // name. The string "$o" in the name will be replaced by the return type of the
+  // intrinsic.
+  virtual HRESULT STDMETHODCALLTYPE GetIntrinsicName(UINT opcode, LPCSTR *pName) = 0;
+};
+
+struct __declspec(uuid("1d063e4f-515a-4d57-a12a-431f6a44cfb9"))
+IDxcSemanticDefineValidator : public IUnknown
+{
+public:
+  virtual HRESULT STDMETHODCALLTYPE GetSemanticDefineWarningsAndErrors(LPCSTR pName, LPCSTR pValue, IDxcBlobEncoding **ppWarningBlob, IDxcBlobEncoding **ppErrorBlob) = 0;
 };
 
 struct __declspec(uuid("282a56b4-3f56-4360-98c7-9ea04a752272"))
@@ -140,6 +158,12 @@ public:
   virtual HRESULT STDMETHODCALLTYPE RegisterDefine(LPCWSTR name) = 0;
   /// <summary>Registers a table of built-in intrinsics.</summary>
   virtual HRESULT STDMETHODCALLTYPE RegisterIntrinsicTable(_In_ IDxcIntrinsicTable* pTable) = 0;
+  /// <summary>Sets an (optional) validator for parsed semantic defines.<summary>
+  /// This provides a hook to check that the semantic defines present in the source
+  /// contain valid data. One validator is used to validate all parsed semantic defines.
+  virtual HRESULT STDMETHODCALLTYPE SetSemanticDefineValidator(_In_ IDxcSemanticDefineValidator* pValidator) = 0;
+  /// <summary>Sets the name for the root metadata node used in DXIL to hold the semantic defines.</summary>
+  virtual HRESULT STDMETHODCALLTYPE SetSemanticDefineMetaDataName(LPCSTR name) = 0;
 };
 
 struct __declspec(uuid("454b764f-3549-475b-958c-a7a6fcd05fbc"))

+ 5 - 0
include/llvm/Transforms/IPO/PassManagerBuilder.h

@@ -16,6 +16,10 @@
 
 #include <vector>
 
+namespace hlsl {
+  class HLSLExtensionsCodegenHelper;
+}
+
 namespace llvm {
 class Pass;
 class TargetLibraryInfoImpl;
@@ -122,6 +126,7 @@ public:
   bool MergeFunctions;
   bool PrepareForLTO;
   bool HLSLHighLevel = false; // HLSL Change
+  hlsl::HLSLExtensionsCodegenHelper *HLSLExtensionsCodeGen = nullptr; // HLSL Change
 
 private:
   /// ExtensionList - This is list of all of the extensions that are registered.

+ 1 - 0
lib/HLSL/CMakeLists.txt

@@ -28,6 +28,7 @@ add_llvm_library(LLVMHLSL
   HLModule.cpp
   HLOperations.cpp
   HLOperationLower.cpp
+  HLOperationLowerExtension.cpp
   HLResource.cpp
   ReducibilityAnalysis.cpp
   WaveSensitivityAnalysis.cpp

+ 11 - 4
lib/HLSL/DxilGenerationPass.cpp

@@ -319,14 +319,19 @@ void InitDxilModuleFromHLModule(HLModule &H, DxilModule &M, bool HasDebugInfo) {
 class DxilGenerationPass : public ModulePass {
   HLModule *m_pHLModule;
   bool m_HasDbgInfo;
+  HLSLExtensionsCodegenHelper *m_extensionsCodegenHelper;
 
 public:
   static char ID; // Pass identification, replacement for typeid
   explicit DxilGenerationPass(bool NoOpt = false)
-      : ModulePass(ID), m_pHLModule(nullptr), NotOptimized(NoOpt) {}
+      : ModulePass(ID), m_pHLModule(nullptr), NotOptimized(NoOpt), m_extensionsCodegenHelper(nullptr) {}
 
   const char *getPassName() const override { return "DXIL Generator"; }
 
+  void SetExtensionsHelper(HLSLExtensionsCodegenHelper *helper) {
+    m_extensionsCodegenHelper = helper;
+  }
+
   bool runOnModule(Module &M) override {
     m_pHLModule = &M.GetOrCreateHLModule();
     const ShaderModel *SM = m_pHLModule->GetShaderModel();
@@ -2217,7 +2222,7 @@ void DxilGenerationPass::GenerateDxilOperations(
       func->eraseFromParent();
   }
 
-  TranslateBuiltinOperations(*m_pHLModule, handleMap);
+  TranslateBuiltinOperations(*m_pHLModule, handleMap, m_extensionsCodegenHelper);
 
   if (pSM->IsGS())
     GenerateStreamOutputOperations();
@@ -2292,8 +2297,10 @@ void DxilGenerationPass::TranslatePreciseAttribute() {
 
 char DxilGenerationPass::ID = 0;
 
-ModulePass *llvm::createDxilGenerationPass(bool NotOptimized) {
-  return new DxilGenerationPass(NotOptimized);
+ModulePass *llvm::createDxilGenerationPass(bool NotOptimized, hlsl::HLSLExtensionsCodegenHelper *extensionsHelper) {
+  DxilGenerationPass *dxilPass = new DxilGenerationPass(NotOptimized);
+  dxilPass->SetExtensionsHelper(extensionsHelper);
+  return dxilPass;
 }
 
 INITIALIZE_PASS(DxilGenerationPass, "dxilgen", "HLSL DXIL Generation", false, false)

+ 38 - 3
lib/HLSL/HLOperationLower.cpp

@@ -14,6 +14,7 @@
 #include "dxc/HLSL/HLMatrixLowerHelper.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/HLSL/HLOperationLower.h"
+#include "dxc/HLSL/HLOperationLowerExtension.h"
 #include "dxc/HLSL/HLOperations.h"
 #include "dxc/HlslIntrinsicOp.h"
 
@@ -5681,7 +5682,6 @@ void TranslateSubscriptOperation(Function *F, HLOperationLowerHelper &helper,  H
 void TranslateHLBuiltinOperation(Function *F, HLOperationLowerHelper &helper,
                                hlsl::HLOpcodeGroup group, HLObjectOperationLowerHelper *pObjHelper) {
   if (group == HLOpcodeGroup::HLIntrinsic) {
-
     // map to dxil operations
     for (auto U = F->user_begin(); U != F->user_end();) {
       Value *User = *(U++);
@@ -5723,11 +5723,40 @@ void TranslateHLBuiltinOperation(Function *F, HLOperationLowerHelper &helper,
   }
 }
 
+static void TranslateHLExtension(Function *F, HLSLExtensionsCodegenHelper *helper) {
+  // Find all calls to the function F.
+  // Store the calls in a vector for now to be replaced the loop below.
+  // We use a two step "find then replace" to avoid removing uses while
+  // iterating.
+  SmallVector<CallInst *, 8> CallsToReplace;
+  for (User *U : F->users()) {
+    if (CallInst *CI = dyn_cast<CallInst>(U)) {
+      CallsToReplace.push_back(CI);
+    }
+  }
+
+  // Get the lowering strategy to use for this intrinsic.
+  llvm::StringRef LowerStrategy = GetHLLowerStrategy(F);
+  ExtensionLowering lower(LowerStrategy, helper);
+
+  // Replace all calls that were successfully translated.
+  for (CallInst *CI : CallsToReplace) {
+      Value *Result = lower.Translate(CI);
+      if (Result && Result != CI) {
+        CI->replaceAllUsesWith(Result);
+        CI->eraseFromParent();
+      }
+  }
+}
+
+
 namespace hlsl {
 
 void TranslateBuiltinOperations(
     HLModule &HLM,
-    std::unordered_map<llvm::Instruction *, llvm::Value *> &handleMap) {
+    std::unordered_map<llvm::Instruction *, llvm::Value *> &handleMap,
+    HLSLExtensionsCodegenHelper *extCodegenHelper
+  ) {
   HLOperationLowerHelper helper(HLM);
 
   HLObjectOperationLowerHelper objHelper = {handleMap, HLM};
@@ -5738,8 +5767,14 @@ void TranslateBuiltinOperations(
     if (!F->isDeclaration()) {
       continue;
     }
-    hlsl::HLOpcodeGroup group = hlsl::GetHLOpcodeGroupByName(F);
+    hlsl::HLOpcodeGroup group = hlsl::GetHLOpcodeGroup(F);
     if (group == HLOpcodeGroup::NotHL) {
+      // Nothing to do.
+      continue;
+    }
+    if (group == HLOpcodeGroup::HLExtIntrinsic) {
+      // TODO: consider handling extensions to object methods
+      TranslateHLExtension(F, extCodegenHelper);
       continue;
     }
     TranslateHLBuiltinOperation(F, helper, group, &objHelper);

+ 527 - 0
lib/HLSL/HLOperationLowerExtension.cpp

@@ -0,0 +1,527 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLOperationLowerExtension.cpp                                             //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// Licensed under the MIT license. See COPYRIGHT in the project root for     //
+// full license information.                                                 //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "dxc/HLSL/HLOperationLowerExtension.h"
+
+#include "dxc/HLSL/DxilModule.h"
+#include "dxc/HLSL/DxilOperations.h"
+#include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLModule.h"
+#include "dxc/HLSL/HLOperationLower.h"
+#include "dxc/HLSL/HLOperations.h"
+#include "dxc/HlslIntrinsicOp.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/raw_os_ostream.h"
+
+using namespace llvm;
+using namespace hlsl;
+
+ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
+  if (strategy.size() < 1)
+    return Strategy::Unknown;
+
+  switch (strategy[0]) {
+    case 'n': return Strategy::NoTranslation;
+    case 'r': return Strategy::Replicate;
+    case 'p': return Strategy::Pack;
+    default: break;
+  }
+  return Strategy::Unknown;
+}
+
+llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
+  switch (strategy) {
+    case Strategy::NoTranslation: return "n";
+    case Strategy::Replicate:     return "r";
+    case Strategy::Pack:          return "p";
+    default: break;
+  }
+  return "?";
+}
+
+ExtensionLowering::ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper) 
+  : m_strategy(strategy), m_helper(helper)
+  {}
+
+ExtensionLowering::ExtensionLowering(StringRef strategy, HLSLExtensionsCodegenHelper *helper) 
+  : ExtensionLowering(GetStrategy(strategy), helper)
+  {}
+
+llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
+  switch (m_strategy) {
+  case Strategy::NoTranslation: return NoTranslation(CI);
+  case Strategy::Replicate:     return Replicate(CI);
+  case Strategy::Pack:          return Pack(CI);
+  default: break;
+  }
+  return Unknown(CI);
+}
+
+llvm::Value *ExtensionLowering::Unknown(CallInst *CI) {
+  assert(false && "unknown translation strategy");
+  return nullptr;
+}
+
+// Interface to describe how to translate types from HL-dxil to dxil.
+class FunctionTypeTranslator {
+public:
+  virtual Type *TranslateReturnType(CallInst *CI) = 0;
+  virtual Type *TranslateArgumentType(Type *OrigArgType) = 0;
+};
+
+// Class to create the new function with the translated types for low-level dxil.
+class FunctionTranslator {
+public:
+  template <typename TypeTranslator>
+  static Function *GetLoweredFunction(CallInst *CI, ExtensionLowering &lower) {
+    TypeTranslator typeTranslator;
+    FunctionTranslator translator(typeTranslator, lower);
+    return translator.GetLoweredFunction(CI);
+  }
+
+private:
+  FunctionTypeTranslator &m_typeTranslator;
+  ExtensionLowering &m_lower;
+
+  FunctionTranslator(FunctionTypeTranslator &typeTranslator, ExtensionLowering &lower)
+    : m_typeTranslator(typeTranslator)
+    , m_lower(lower)
+  {}
+
+  Function *GetLoweredFunction(CallInst *CI) {
+    // Ge the return type of replicated function.
+    Type *RetTy = m_typeTranslator.TranslateReturnType(CI);
+    if (!RetTy)
+      return nullptr;
+
+    // Get the Function type for replicated function.
+    FunctionType *FTy = GetFunctionType(CI, RetTy);
+    if (!FTy)
+      return nullptr;
+
+    // Create a new function that will be the replicated call.
+    AttributeSet attributes = GetAttributeSet(CI);
+    std::string name = m_lower.GetExtensionName(CI);
+    return cast<Function>(CI->getModule()->getOrInsertFunction(name, FTy, attributes));
+  }
+
+  FunctionType *GetFunctionType(CallInst *CI, Type *RetTy) {
+    // Create a new function type with the translated argument.
+    SmallVector<Type *, 10> ParamTypes;
+    ParamTypes.reserve(CI->getNumArgOperands());
+    for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
+      Type *OrigTy = CI->getArgOperand(i)->getType();
+      Type *TranslatedTy = m_typeTranslator.TranslateArgumentType(OrigTy);
+      ParamTypes.push_back(TranslatedTy);
+    }
+
+    const bool IsVarArg = false;
+    return FunctionType::get(RetTy, ParamTypes, IsVarArg);
+  }
+
+  AttributeSet GetAttributeSet(CallInst *CI) {
+    Function *F = CI->getCalledFunction();
+    AttributeSet attributes;
+    auto copyAttribute = [=, &attributes](Attribute::AttrKind a) {
+      if (F->hasFnAttribute(a)) {
+        attributes = attributes.addAttribute(CI->getContext(), AttributeSet::FunctionIndex, a);
+      }
+    };
+    copyAttribute(Attribute::AttrKind::ReadOnly);
+    copyAttribute(Attribute::AttrKind::ReadNone);
+    copyAttribute(Attribute::AttrKind::ArgMemOnly);
+
+    return attributes;
+  }
+};
+
+///////////////////////////////////////////////////////////////////////////////
+// NoTranslation Lowering.
+class NoTranslationTypeTranslator : public FunctionTypeTranslator {
+  virtual Type *TranslateReturnType(CallInst *CI) override {
+    return CI->getType();
+  }
+  virtual Type *TranslateArgumentType(Type *OrigArgType) override {
+    return OrigArgType;
+  }
+};
+
+llvm::Value *ExtensionLowering::NoTranslation(CallInst *CI) {
+  Function *NoTranslationFunction = FunctionTranslator::GetLoweredFunction<NoTranslationTypeTranslator>(CI, *this);
+  if (!NoTranslationFunction)
+    return nullptr;
+
+  IRBuilder<> builder(CI);
+  SmallVector<Value *, 8> args(CI->arg_operands().begin(), CI->arg_operands().end());
+  return builder.CreateCall(NoTranslationFunction, args);
+};
+
+///////////////////////////////////////////////////////////////////////////////
+// Replicated Lowering.
+enum {
+  NO_COMMON_VECTOR_SIZE = 0xFFFFFFFF,
+};
+// Find the vector size that will be used for replication.
+// The function call will be replicated once for each element of the vector
+// size.
+static unsigned GetReplicatedVectorSize(llvm::CallInst *CI) {
+  unsigned commonVectorSize = NO_COMMON_VECTOR_SIZE;
+  for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
+    Type *Ty = CI->getArgOperand(i)->getType();
+    if (Ty->isVectorTy()) {
+      unsigned vectorSize = Ty->getVectorNumElements();
+      if (commonVectorSize != NO_COMMON_VECTOR_SIZE && commonVectorSize != vectorSize) {
+        // Inconsistent vector sizes; need a different strategy.
+        return NO_COMMON_VECTOR_SIZE;
+      }
+      commonVectorSize = vectorSize;
+    }
+  }
+
+  return commonVectorSize;
+}
+
+class ReplicatedFunctionTypeTranslator : public FunctionTypeTranslator {
+  virtual Type *TranslateReturnType(CallInst *CI) override {
+    unsigned commonVectorSize = GetReplicatedVectorSize(CI);
+    if (commonVectorSize == NO_COMMON_VECTOR_SIZE)
+      return nullptr;
+
+    // Result should be vector or void.
+    Type *RetTy = CI->getType();
+    if (RetTy->isVectorTy()) {
+      if (RetTy->getVectorNumElements() != commonVectorSize)
+        return nullptr;
+      RetTy = RetTy->getVectorElementType();
+    }
+    else {
+      if (!RetTy->isVoidTy())
+        return nullptr;
+    }
+
+    return RetTy;
+  }
+
+  virtual Type *TranslateArgumentType(Type *OrigArgType) override {
+    Type *Ty = OrigArgType;
+    if (Ty->isVectorTy()) {
+      Ty = Ty->getVectorElementType();
+    }
+
+    return Ty;
+  }
+
+};
+
+class ReplicateCall {
+public:
+  ReplicateCall(CallInst *CI, Function &ReplicatedFunction)
+    : m_CI(CI)
+    , m_ReplicatedFunction(ReplicatedFunction)
+    , m_numReplicatedCalls(GetReplicatedVectorSize(CI))
+    , m_ScalarizeArgIdx()
+    , m_Args(CI->getNumArgOperands())
+    , m_ReplicatedCalls(m_numReplicatedCalls)
+    , m_Builder(CI)
+  {
+    assert(m_numReplicatedCalls != NO_COMMON_VECTOR_SIZE);
+  }
+
+  Value *Generate() {
+    CollectReplicatedArguments();
+    CreateReplicatedCalls();
+    Value *retVal = GetReturnValue();
+    return retVal;
+  }
+
+private:
+  CallInst *m_CI;
+  Function &m_ReplicatedFunction;
+  unsigned m_numReplicatedCalls;
+  SmallVector<unsigned, 10> m_ScalarizeArgIdx;
+  SmallVector<Value *, 10> m_Args;
+  SmallVector<Value *, 10> m_ReplicatedCalls;
+  IRBuilder<> m_Builder;
+
+  // Collect replicated arguments.
+  // For non-vector arguments we can add them to the args list directly.
+  // These args will be shared by each replicated call. For the vector
+  // arguments we remember the position it will go in the argument list.
+  // We will fill in the vector args below when we replicate the call
+  // (once for each vector lane).
+  void CollectReplicatedArguments() {
+    for (unsigned i = 0; i < m_CI->getNumArgOperands(); ++i) {
+      Type *Ty = m_CI->getArgOperand(i)->getType();
+      if (Ty->isVectorTy()) {
+        m_ScalarizeArgIdx.push_back(i);
+      }
+      else {
+        m_Args[i] = m_CI->getArgOperand(i);
+      }
+    }
+  }
+
+  // Create replicated calls.
+  // Replicate the call once for each element of the replicated vector size.
+  void CreateReplicatedCalls() {
+    for (unsigned vecIdx = 0; vecIdx < m_numReplicatedCalls; vecIdx++) {
+      for (unsigned i = 0, e = m_ScalarizeArgIdx.size(); i < e; ++i) {
+        unsigned argIdx = m_ScalarizeArgIdx[i];
+        Value *arg = m_CI->getArgOperand(argIdx);
+        m_Args[argIdx] = m_Builder.CreateExtractElement(arg, vecIdx);
+      }
+      Value *EltOP = m_Builder.CreateCall(&m_ReplicatedFunction, m_Args);
+      m_ReplicatedCalls[vecIdx] = EltOP;
+    }
+  }
+
+  // Get the final replicated value.
+  // If the function is a void type then return (arbitrarily) the first call.
+  // We do not return nullptr because that indicates a failure to replicate.
+  // If the function is a vector type then aggregate all of the replicated
+  // call values into a new vector.
+  Value *GetReturnValue() {
+    if (m_CI->getType()->isVoidTy())
+      return m_ReplicatedCalls.back();
+
+      Value *retVal = llvm::UndefValue::get(m_CI->getType());
+      for (unsigned i = 0; i < m_ReplicatedCalls.size(); ++i)
+        retVal = m_Builder.CreateInsertElement(retVal, m_ReplicatedCalls[i], i);
+
+    return retVal;
+  }
+};
+
+Value *ExtensionLowering::TranslateReplicating(CallInst *CI, Function *ReplicatedFunction) {
+  if (!ReplicatedFunction)
+    return nullptr;
+
+  ReplicateCall replicate(CI, *ReplicatedFunction);
+  return replicate.Generate();
+}
+
+Value *ExtensionLowering::Replicate(CallInst *CI) {
+  Function *ReplicatedFunction = FunctionTranslator::GetLoweredFunction<ReplicatedFunctionTypeTranslator>(CI, *this);
+  return TranslateReplicating(CI, ReplicatedFunction);
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// Packed Lowering.
+class PackCall {
+public:
+  PackCall(CallInst *CI, Function &PackedFunction)
+    : m_CI(CI)
+    , m_packedFunction(PackedFunction)
+    , m_builder(CI)
+  {}
+
+  Value *Generate() {
+    SmallVector<Value *, 10> args;
+    PackArgs(args);
+    Value *result = CreateCall(args);
+    return UnpackResult(result);
+  }
+  
+  static StructType *ConvertVectorTypeToStructType(Type *vecTy) {
+    assert(vecTy->isVectorTy());
+    Type *elementTy = vecTy->getVectorElementType();
+    unsigned numElements = vecTy->getVectorNumElements();
+    SmallVector<Type *, 4> elements;
+    for (unsigned i = 0; i < numElements; ++i)
+      elements.push_back(elementTy);
+
+    return StructType::get(vecTy->getContext(), elements);
+  }
+
+private:
+  CallInst *m_CI;
+  Function &m_packedFunction;
+  IRBuilder<> m_builder;
+
+  void PackArgs(SmallVectorImpl<Value*> &args) {
+    args.clear();
+    for (Value *arg : m_CI->arg_operands()) {
+      if (arg->getType()->isVectorTy())
+        arg = PackVectorIntoStruct(m_builder, arg);
+      args.push_back(arg);
+    }
+  }
+
+  Value *CreateCall(const SmallVectorImpl<Value*> &args) {
+    return m_builder.CreateCall(&m_packedFunction, args);
+  }
+
+  Value *UnpackResult(Value *result) {
+    if (result->getType()->isStructTy()) {
+      result = PackStructIntoVector(m_builder, result);
+    }
+    return result;
+  }
+
+  static VectorType *ConvertStructTypeToVectorType(Type *structTy) {
+    assert(structTy->isStructTy());
+    return VectorType::get(structTy->getStructElementType(0), structTy->getStructNumElements());
+  }
+
+  static Value *PackVectorIntoStruct(IRBuilder<> &builder, Value *vec) {
+    StructType *structTy = ConvertVectorTypeToStructType(vec->getType());
+    Value *packed = UndefValue::get(structTy);
+
+    unsigned numElements = structTy->getStructNumElements();
+    for (unsigned i = 0; i < numElements; ++i) {
+      Value *element = builder.CreateExtractElement(vec, i);
+      packed = builder.CreateInsertValue(packed, element, { i });
+    }
+
+    return packed;
+  }
+
+  static Value *PackStructIntoVector(IRBuilder<> &builder, Value *strukt) {
+    Type *vecTy = ConvertStructTypeToVectorType(strukt->getType());
+    Value *packed = UndefValue::get(vecTy);
+
+    unsigned numElements = vecTy->getVectorNumElements();
+    for (unsigned i = 0; i < numElements; ++i) {
+      Value *element = builder.CreateExtractValue(strukt, i);
+      packed = builder.CreateInsertElement(packed, element, { i });
+    }
+
+    return packed;
+  }
+};
+
+class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
+  virtual Type *TranslateReturnType(CallInst *CI) override {
+    return TranslateIfVector(CI->getType());
+  }
+  virtual Type *TranslateArgumentType(Type *OrigArgType) override {
+    return TranslateIfVector(OrigArgType);
+  }
+
+  Type *TranslateIfVector(Type *ty) {
+    if (ty->isVectorTy())
+      ty = PackCall::ConvertVectorTypeToStructType(ty);
+    return ty;
+  }
+};
+
+Value *ExtensionLowering::Pack(CallInst *CI) {
+  Function *PackedFunction = FunctionTranslator::GetLoweredFunction<PackedFunctionTypeTranslator>(CI, *this);
+  if (!PackedFunction)
+    return nullptr;
+
+  PackCall pack(CI, *PackedFunction);
+  Value *result = pack.Generate();
+  return result;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// Computing Extension Names.
+
+// Compute the name to use for the intrinsic function call once it is lowered to dxil.
+// First checks to see if we have a custom name from the codegen helper and if not
+// chooses a default name based on the lowergin strategy.
+class ExtensionName {
+public:
+  ExtensionName(CallInst *CI, ExtensionLowering::Strategy strategy, HLSLExtensionsCodegenHelper *helper)
+    : m_CI(CI)
+    , m_strategy(strategy)
+    , m_helper(helper)
+  {}
+
+  std::string Get() {
+    std::string name;
+    if (m_helper)
+      name = GetCustomExtensionName(m_CI, *m_helper);
+
+    if (!HasCustomExtensionName(name))
+      name = GetDefaultCustomExtensionName(m_CI, ExtensionLowering::GetStrategyName(m_strategy));
+
+    return name;
+  }
+
+private:
+  CallInst *m_CI;
+  ExtensionLowering::Strategy m_strategy;
+  HLSLExtensionsCodegenHelper *m_helper;
+
+  static std::string GetCustomExtensionName(CallInst *CI, HLSLExtensionsCodegenHelper &helper) {
+    unsigned opcode = GetHLOpcode(CI);
+    std::string name = helper.GetIntrinsicName(opcode);
+    ReplaceOverloadMarkerWithTypeName(name, CI);
+
+    return name;
+  }
+
+  static std::string GetDefaultCustomExtensionName(CallInst *CI, StringRef strategyName) {
+    return (Twine(CI->getCalledFunction()->getName()) + "." + Twine(strategyName)).str();
+  }
+
+  static bool HasCustomExtensionName(const std::string name) {
+    return name.size() > 0;
+  }
+
+  // Choose the (return value or argument) type that determines the overload type
+  // for the intrinsic call.
+  // For now we take the return type as the overload. If the return is void we
+  // take the first (non-opcode) argument as the overload type. We could extend the
+  // $o sytnax in the extension name to explicitly specify the overload slot (e.g.
+  // $o:3 would say the overload type is determined by parameter 3.
+  static Type *SelectOverloadSlot(CallInst *CI) {
+    Type *ty = CI->getType();
+    if (ty->isVoidTy()) {
+      if (CI->getNumArgOperands() > 1)
+        ty = CI->getArgOperand(1)->getType(); // First non-opcode argument.
+    }
+
+    return ty;
+  }
+
+  static Type *GetOverloadType(CallInst *CI) {
+    Type *ty = SelectOverloadSlot(CI);
+    if (ty->isVectorTy())
+      ty = ty->getVectorElementType();
+
+    return ty;
+  }
+
+  static std::string GetTypeName(Type *ty) {
+      std::string typeName;
+      llvm::raw_string_ostream os(typeName);
+      ty->print(os);
+      os.flush();
+      return typeName;
+  }
+
+  static std::string GetOverloadTypeName(CallInst *CI) {
+    Type *ty = GetOverloadType(CI);
+    return GetTypeName(ty);
+  }
+
+  // Find the occurence of the overload marker $o and replace it the the overload type name.
+  static void ReplaceOverloadMarkerWithTypeName(std::string &functionName, CallInst *CI) {
+    const char *OverloadMarker = "$o";
+    const size_t OverloadMarkerLength = 2;
+
+    size_t pos = functionName.find(OverloadMarker);
+    if (pos != std::string::npos) {
+      std::string typeName = GetOverloadTypeName(CI);
+      functionName.replace(pos, OverloadMarkerLength, typeName);
+    }
+  }
+};
+
+std::string ExtensionLowering::GetExtensionName(llvm::CallInst *CI) {
+  ExtensionName name(CI, m_strategy, m_helper);
+  return name.Get();
+}

+ 82 - 40
lib/HLSL/HLOperations.cpp

@@ -25,9 +25,12 @@ namespace hlsl {
 
 const char HLPrefixStr [] = "dx.hl";
 const char * const HLPrefix = HLPrefixStr;
+static const char HLLowerStrategyStr[] = "dx.hlls";
+static const char * const HLLowerStrategy = HLLowerStrategyStr;
 
 static StringRef HLOpcodeGroupNames[]{
     "notHLDXIL",   // NotHL,
+    "<ext>",       // HLExtIntrinsic - should always refer through extension
     "op",          // HLIntrinsic,
     "cast",        // HLCast,
     "init",        // HLInit,
@@ -41,6 +44,7 @@ static StringRef HLOpcodeGroupNames[]{
 
 static StringRef HLOpcodeGroupFullNames[]{
     "notHLDXIL",       // NotHL,
+    "<ext>",           // HLExtIntrinsic - should aways refer through extension
     "dx.hl.op",        // HLIntrinsic,
     "dx.hl.cast",      // HLCast,
     "dx.hl.init",      // HLInit,
@@ -51,58 +55,66 @@ static StringRef HLOpcodeGroupFullNames[]{
     "dx.hl.select",    // HLSelect,
     "numOfHLDXIL",     // NumOfHLOps
 };
-// Check HLOperation by function name.
-bool IsHLOpByName(Function *F) {
-  return F->getName().startswith(HLPrefix);
-}
-// Check HLOperation by attribute.
-bool IsHLOp(llvm::Function *F) {
-  return F->hasFnAttribute(hlsl::HLPrefix);
-}
 
 static HLOpcodeGroup GetHLOpcodeGroupInternal(StringRef group) {
-  switch (group[0]) {
-  case 'o': // op
-    return HLOpcodeGroup::HLIntrinsic;
-  case 'c': // cast
-    return HLOpcodeGroup::HLCast;
-  case 'i': // init
-    return HLOpcodeGroup::HLInit;
-  case 'b': // binaryOp
-    return HLOpcodeGroup::HLBinOp;
-  case 'u': // unaryOp
-    return HLOpcodeGroup::HLUnOp;
-  case 's': // subscript and select
-    switch (group[1]) {
-    case 'u':
-      return HLOpcodeGroup::HLSubscript;
-    case 'e':
-      return HLOpcodeGroup::HLSelect;
+  if (!group.empty()) {
+    switch (group[0]) {
+    case 'o': // op
+      return HLOpcodeGroup::HLIntrinsic;
+    case 'c': // cast
+      return HLOpcodeGroup::HLCast;
+    case 'i': // init
+      return HLOpcodeGroup::HLInit;
+    case 'b': // binaryOp
+      return HLOpcodeGroup::HLBinOp;
+    case 'u': // unaryOp
+      return HLOpcodeGroup::HLUnOp;
+    case 's': // subscript
+      switch (group[1]) {
+      case 'u':
+        return HLOpcodeGroup::HLSubscript;
+      case 'e':
+        return HLOpcodeGroup::HLSelect;
+      }
+    case 'm': // matldst
+      return HLOpcodeGroup::HLMatLoadStore;
     }
-  case 'm': // matldst
-    return HLOpcodeGroup::HLMatLoadStore;
-  default:
-    return HLOpcodeGroup::NotHL;
   }
+  return HLOpcodeGroup::NotHL;
 }
 // GetHLOpGroup by function name.
 HLOpcodeGroup GetHLOpcodeGroupByName(Function *F) {
   StringRef name = F->getName();
 
-  if (!name.startswith(HLPrefix))
+  if (!name.startswith(HLPrefix)) {
+    // This could be an external intrinsic, but this function
+    // won't recognize those as such. Use GetHLOpcodeGroupByName
+    // to make that distinction.
     return HLOpcodeGroup::NotHL;
+  }
 
   const unsigned prefixSize = sizeof(HLPrefixStr);
 
   StringRef group = name.substr(prefixSize);
   return GetHLOpcodeGroupInternal(group);
 }
-// GetHLOpGroup by function attribute.
-HLOpcodeGroup GetHLOpcodeGroupByAttr(llvm::Function *F) {
+
+HLOpcodeGroup GetHLOpcodeGroup(llvm::Function *F) {
+  llvm::StringRef name = GetHLOpcodeGroupNameByAttr(F);
+  HLOpcodeGroup result = GetHLOpcodeGroupInternal(name);
+  if (result == HLOpcodeGroup::NotHL) {
+    result = name.empty() ? result : HLOpcodeGroup::HLExtIntrinsic;
+  }
+  if (result == HLOpcodeGroup::NotHL) {
+    result = GetHLOpcodeGroupByName(F);
+  }
+  return result;
+}
+
+llvm::StringRef GetHLOpcodeGroupNameByAttr(llvm::Function *F) {
   Attribute groupAttr = F->getFnAttribute(hlsl::HLPrefix);
   StringRef group = groupAttr.getValueAsString();
-
-  return GetHLOpcodeGroupInternal(group);
+  return group;
 }
 
 StringRef GetHLOpcodeGroupName(HLOpcodeGroup op) {
@@ -240,7 +252,18 @@ llvm::StringRef GetHLOpcodeName(HLMatLoadStoreOpcode Op) {
   llvm_unreachable("invalid matrix load store operator");
 }
 
+StringRef GetHLLowerStrategy(Function *F) {
+  llvm::Attribute A = F->getFnAttribute(HLLowerStrategy);
+  llvm::StringRef LowerStrategy = A.getValueAsString();
+  return LowerStrategy;
+}
+
+void SetHLLowerStrategy(Function *F, StringRef S) {
+  F->addFnAttr(HLLowerStrategy, S);
+}
+
 std::string GetHLFullName(HLOpcodeGroup op, unsigned opcode) {
+  assert(op != HLOpcodeGroup::HLExtIntrinsic && "else table name should be used");
   std::string opName = GetHLOpcodeGroupFullName(op).str() + ".";
 
   switch (op) {
@@ -380,16 +403,35 @@ static void SetHLFunctionAttribute(Function *F, HLOpcodeGroup group,
   }
 }
 
-Function *GetOrCreateHLFunction(Module &M,
-                                       FunctionType *funcTy,
-                                       HLOpcodeGroup group, unsigned opcode) {
-  std::string operatorName = GetHLFullName(group, opcode);
-  std::string mangledName = operatorName + ".";
+Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
+                                HLOpcodeGroup group, unsigned opcode) {
+  return GetOrCreateHLFunction(M, funcTy, group, nullptr, nullptr, opcode);
+}
+
+Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
+                                HLOpcodeGroup group, llvm::StringRef *groupName,
+                                llvm::StringRef *fnName, unsigned opcode) {
+  std::string mangledName;
   raw_string_ostream mangledNameStr(mangledName);
-  funcTy->print(mangledNameStr);
+  if (group == HLOpcodeGroup::HLExtIntrinsic) {
+    assert(groupName && "else intrinsic should have been rejected");
+    assert(fnName && "else intrinsic should have been rejected");
+    mangledNameStr << *groupName;
+    mangledNameStr << '.';
+    mangledNameStr << *fnName;
+  }
+  else {
+    mangledNameStr << GetHLFullName(group, opcode);
+    mangledNameStr << '.';
+    funcTy->print(mangledNameStr);
+  }
+
   mangledNameStr.flush();
 
   Function *F = cast<Function>(M.getOrInsertFunction(mangledName, funcTy));
+  if (group == HLOpcodeGroup::HLExtIntrinsic) {
+    F->addFnAttr(hlsl::HLPrefix, *groupName);
+  }
 
   SetHLFunctionAttribute(F, group, opcode);
 

+ 4 - 4
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -186,7 +186,7 @@ void PassManagerBuilder::populateFunctionPassManager(
 }
 
 // HLSL Change Starts
-static void addHLSLPasses(bool HLSLHighLevel, bool NoOpt, legacy::PassManagerBase &MPM) {
+static void addHLSLPasses(bool HLSLHighLevel, bool NoOpt, hlsl::HLSLExtensionsCodegenHelper *ExtHelper, legacy::PassManagerBase &MPM) {
   // Don't do any lowering if we're targeting high-level.
   if (HLSLHighLevel) {
     MPM.add(createHLEmitMetadataPass());
@@ -208,7 +208,7 @@ static void addHLSLPasses(bool HLSLHighLevel, bool NoOpt, legacy::PassManagerBas
   // Change dynamic indexing vector to array.
   MPM.add(createDynamicIndexingVectorToArrayPass(NoOpt));
 
-  MPM.add(createDxilGenerationPass(NoOpt));
+  MPM.add(createDxilGenerationPass(NoOpt, ExtHelper));
 
   MPM.add(createSimplifyInstPass());
 
@@ -257,7 +257,7 @@ void PassManagerBuilder::populateModulePassManager(
 
     addExtensionsToPM(EP_EnabledOnOptLevel0, MPM);
     // HLSL Change Begins.
-    addHLSLPasses(HLSLHighLevel, true/*NoOpt*/, MPM); // HLSL Change
+    addHLSLPasses(HLSLHighLevel, true/*NoOpt*/, HLSLExtensionsCodeGen, MPM); // HLSL Change
     if (!HLSLHighLevel) {
       MPM.add(createMultiDimArrayToOneDimArrayPass());// HLSL Change
       MPM.add(createDxilCondenseResourcesPass()); // HLSL Change
@@ -277,7 +277,7 @@ void PassManagerBuilder::populateModulePassManager(
     delete Inliner;
     Inliner = nullptr;
   }
-  addHLSLPasses(HLSLHighLevel, false/*NoOpt*/, MPM); // HLSL Change
+  addHLSLPasses(HLSLHighLevel, false/*NoOpt*/, HLSLExtensionsCodeGen, MPM); // HLSL Change
   // HLSL Change Ends
 
   // Add LibraryInfo if we have some.

+ 2 - 2
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -3403,9 +3403,9 @@ public:
 
     std::deque<Function *> WorkList;
     for (Function &F : M.functions()) {
-      HLOpcodeGroup group = GetHLOpcodeGroupByName(&F);
+      HLOpcodeGroup group = GetHLOpcodeGroup(&F);
       // Skip HL operations.
-      if (group != HLOpcodeGroup::NotHL)
+      if (group != HLOpcodeGroup::NotHL || group == HLOpcodeGroup::HLExtIntrinsic)
         continue;
 
       if (F.isDeclaration()) {

+ 1 - 0
tools/clang/include/clang/AST/HlslTypes.h

@@ -379,6 +379,7 @@ clang::QualType GetHLSLVecElementType(clang::QualType type);
 bool IsIntrinsicOp(const clang::FunctionDecl *FD);
 bool GetIntrinsicOp(const clang::FunctionDecl *FD, unsigned &opcode,
                     llvm::StringRef &group);
+bool GetIntrinsicLowering(const clang::FunctionDecl *FD, llvm::StringRef &S);
 
 /// <summary>Adds a function declaration to the specified class record.</summary>
 /// <param name="context">ASTContext that owns declarations.</param>

+ 1 - 1
tools/clang/include/clang/Basic/Attr.td

@@ -705,7 +705,7 @@ def HLSLMaxVertexCount: InheritableAttr {
 }
 def HLSLIntrinsic: InheritableAttr {
   let Spellings = [CXX11<"", "intrinsic", 2015>];
-  let Args = [StringArgument<"group">, IntArgument<"opcode">];
+  let Args = [StringArgument<"group">, StringArgument<"lowering">, IntArgument<"opcode">];
   let Documentation = [Undocumented];
 }
 def HLSLPrecise : InheritableAttr {

+ 3 - 0
tools/clang/include/clang/Frontend/CodeGenOptions.h

@@ -18,6 +18,7 @@
 #include <memory>
 #include <string>
 #include <vector>
+#include "dxc/HLSL/HLSLExtensionsCodegenHelper.h" // HLSL change
 
 namespace clang {
 
@@ -192,6 +193,8 @@ public:
   std::vector<std::string> HLSLDefines;
   /// Arguments passed in from command line
   std::vector<std::string> HLSLArguments;
+  /// Helper for generating llvm bitcode for hlsl extensions.
+  std::shared_ptr<hlsl::HLSLExtensionsCodegenHelper> HLSLExtensionsCodegen;
   // HLSL Change Ends
   /// Regular expression to select optimizations for which we should enable
   /// optimization remarks. Transformation passes whose name matches this

+ 2 - 0
tools/clang/include/clang/Frontend/CompilerInstance.h

@@ -30,10 +30,12 @@
 
 // HLSL Change Starts
 namespace hlsl {
+  class DxcLangExtensionsHelper;
   class DxcLangExtensionsHelperApply {
   public:
     virtual void SetupSema(clang::Sema &S) = 0;
     virtual void SetupPreprocessorOptions(clang::PreprocessorOptions &PPOpts) = 0;
+    virtual DxcLangExtensionsHelper *GetDxcLangExtensionsHelper() = 0;
   };
 }
 // HLSL Change Ends

+ 11 - 1
tools/clang/lib/AST/ASTContextHLSL.cpp

@@ -238,7 +238,7 @@ void hlsl::AddHLSLMatrixTemplate(ASTContext& context, ClassTemplateDecl* vectorT
 
 static void AddHLSLVectorSubscriptAttr(Decl *D, ASTContext &context) {
   StringRef group = GetHLOpcodeGroupName(HLOpcodeGroup::HLSubscript);
-  D->addAttr(HLSLIntrinsicAttr::CreateImplicit(context, group, static_cast<unsigned>(HLSubscriptOpcode::VectorSubscript)));
+  D->addAttr(HLSLIntrinsicAttr::CreateImplicit(context, group, "", static_cast<unsigned>(HLSubscriptOpcode::VectorSubscript)));
 }
 
 /// <summary>Adds up-front support for HLSL vector types (just the template declaration).</summary>
@@ -743,6 +743,16 @@ bool hlsl::GetIntrinsicOp(const clang::FunctionDecl *FD, unsigned &opcode,
   return true;
 }
 
+bool hlsl::GetIntrinsicLowering(const clang::FunctionDecl *FD, llvm::StringRef &S) {
+  if (FD == nullptr || !FD->hasAttr<HLSLIntrinsicAttr>()) {
+    return false;
+  }
+
+  HLSLIntrinsicAttr *A = FD->getAttr<HLSLIntrinsicAttr>();
+  S = A->getLowering();
+  return true;
+}
+
 /// <summary>Parses a column or row digit.</summary>
 static
 bool TryParseColOrRowChar(const char digit, _Out_ int* count) {

+ 1 - 0
tools/clang/lib/CodeGen/BackendUtil.cpp

@@ -321,6 +321,7 @@ void EmitAssemblyHelper::CreatePasses() {
   PMBuilder.SLPVectorize = CodeGenOpts.VectorizeSLP;
   PMBuilder.LoopVectorize = CodeGenOpts.VectorizeLoop;
   PMBuilder.HLSLHighLevel = CodeGenOpts.HLSLHighLevel; // HLSL Change
+  PMBuilder.HLSLExtensionsCodeGen = CodeGenOpts.HLSLExtensionsCodegen.get(); // HLSL Change
 
   PMBuilder.DisableUnitAtATime = !CodeGenOpts.UnitAtATime;
   PMBuilder.DisableUnrollLoops = !CodeGenOpts.UnrollLoops;

+ 34 - 3
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -35,6 +35,7 @@
 #include "clang/Parse/ParseHLSL.h"      // root sig would be in Parser if part of lang
 #include "dxc/Support/WinIncludes.h"    // stream support
 #include "dxc/dxcapi.h"                 // stream support
+#include "dxc/HLSL/HLSLExtensionsCodegenHelper.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -987,6 +988,11 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
       } break;
       }
     }
+
+    StringRef lower;
+    if (hlsl::GetIntrinsicLowering(FD, lower))
+      hlsl::SetHLLowerStrategy(F, lower);
+
     // Don't need to add FunctionQual for intrinsic function.
     return;
   }
@@ -2756,7 +2762,13 @@ static Function *CreateOpFunction(llvm::Module &M, Function *F,
       opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
       break;
     }
-  } else {
+  }
+  else if (group == HLOpcodeGroup::HLExtIntrinsic) {
+    llvm::StringRef fnName = F->getName();
+    llvm::StringRef groupName = GetHLOpcodeGroupNameByAttr(F);
+    opFunc = GetOrCreateHLFunction(M, funcTy, group, &groupName, &fnName, opcode);
+  }
+  else {
     opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
   }
 
@@ -2793,7 +2805,7 @@ static void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F,
     }
   }
 
-  HLOpcodeGroup group = hlsl::GetHLOpcodeGroupByAttr(F);
+  HLOpcodeGroup group = hlsl::GetHLOpcodeGroup(F);
 
   if (group == HLOpcodeGroup::HLSubscript &&
       opcode == static_cast<unsigned>(HLSubscriptOpcode::VectorSubscript)) {
@@ -2867,6 +2879,9 @@ static void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F,
       llvm::FunctionType::get(RetTy, paramTyList, false);
 
   Function *opFunc = CreateOpFunction(M, F, funcTy, group, opcode);
+  StringRef lower = hlsl::GetHLLowerStrategy(F);
+  if (!lower.empty())
+    hlsl::SetHLLowerStrategy(opFunc, lower);
 
   for (auto user = F->user_begin(); user != F->user_end();) {
     // User must be a call.
@@ -3651,6 +3666,22 @@ void CGMSHLSLRuntime::FinishCodeGen() {
 
   // Do simple transform to make later lower pass easier.
   SimpleTransformForHLDXIR(m_pHLModule->GetModule());
+
+  // Add semantic defines for extensions if any are available.
+  if (CGM.getCodeGenOpts().HLSLExtensionsCodegen) {
+    HLSLExtensionsCodegenHelper::SemanticDefineErrorList errors =
+      CGM.getCodeGenOpts().HLSLExtensionsCodegen->WriteSemanticDefines(m_pHLModule->GetModule());
+
+    DiagnosticsEngine &Diags = CGM.getDiags();
+    for (const HLSLExtensionsCodegenHelper::SemanticDefineError& error : errors) {
+      DiagnosticsEngine::Level level = DiagnosticsEngine::Error;
+      if (error.IsWarning())
+        level = DiagnosticsEngine::Warning;
+      unsigned DiagID = Diags.getCustomDiagID(level, "%0");
+      Diags.Report(SourceLocation::getFromRawEncoding(error.Location()), DiagID) << error.Message();
+    }
+  }
+
 }
 
 RValue CGMSHLSLRuntime::EmitHLSLBuiltinCallExpr(CodeGenFunction &CGF,
@@ -3666,7 +3697,7 @@ RValue CGMSHLSLRuntime::EmitHLSLBuiltinCallExpr(CodeGenFunction &CGF,
   if (RV.isScalar() && RV.getScalarVal() != nullptr) {
     if (CallInst *CI = dyn_cast<CallInst>(RV.getScalarVal())) {
       Function *F = CI->getCalledFunction();
-      HLOpcodeGroup group = hlsl::GetHLOpcodeGroupByAttr(F);
+      HLOpcodeGroup group = hlsl::GetHLOpcodeGroup(F);
       if (group == HLOpcodeGroup::HLIntrinsic) {
         bool allOperandImm = true;
         for (auto &operand : CI->arg_operands()) {

+ 22 - 7
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -1472,7 +1472,7 @@ static bool IsBuiltinTable(LPCSTR tableName) {
 }
 
 static void AddHLSLIntrinsicAttr(FunctionDecl *FD, ASTContext &context,
-                              LPCSTR tableName,
+                              LPCSTR tableName, LPCSTR lowering,
                               const HLSL_INTRINSIC *pIntrinsic) {
   unsigned opcode = (unsigned)pIntrinsic->Op;
   if (HasUnsignedOpcode(opcode)) {
@@ -1494,7 +1494,7 @@ static void AddHLSLIntrinsicAttr(FunctionDecl *FD, ASTContext &context,
       opcode = hlsl::GetUnsignedOpcode(opcode);
     }
   }
-  FD->addAttr(HLSLIntrinsicAttr::CreateImplicit(context, tableName, opcode));
+  FD->addAttr(HLSLIntrinsicAttr::CreateImplicit(context, tableName, lowering, opcode));
   if (pIntrinsic->bReadNone)
     FD->addAttr(ConstAttr::CreateImplicit(context));
   if (pIntrinsic->bReadOnly)
@@ -1504,7 +1504,7 @@ static void AddHLSLIntrinsicAttr(FunctionDecl *FD, ASTContext &context,
 static
 FunctionDecl *AddHLSLIntrinsicFunction(
     ASTContext &context, _In_ NamespaceDecl *NS,
-    LPCSTR tableName,
+    LPCSTR tableName, LPCSTR lowering,
     _In_ const HLSL_INTRINSIC *pIntrinsic,
     _In_count_(functionArgTypeCount) QualType *functionArgQualTypes,
     _In_range_(0, g_MaxIntrinsicParamCount - 1) size_t functionArgTypeCount) {
@@ -1551,7 +1551,7 @@ FunctionDecl *AddHLSLIntrinsicFunction(
   // put under hlsl namespace
   functionDecl->setDeclContext(NS);
   // Add intrinsic attribute
-  AddHLSLIntrinsicAttr(functionDecl, context, tableName, pIntrinsic);
+  AddHLSLIntrinsicAttr(functionDecl, context, tableName, lowering, pIntrinsic);
 
   ParmVarDecl *paramDecls[g_MaxIntrinsicParamCount];
   for (size_t i = 1; i < functionArgTypeCount; i++) {
@@ -2140,6 +2140,15 @@ public:
     return tableName;
   }
 
+  LPCSTR GetLoweringStrategy()
+  {
+    LPCSTR lowering = nullptr;
+    if (FAILED(_tables[_tableIndex]->GetLoweringStrategy(_tableIntrinsic->Op, &lowering))) {
+      return nullptr;
+    }
+    return lowering;
+  }
+
   IntrinsicTableDefIter& operator++()
   {
     MoveToNext();
@@ -2186,6 +2195,11 @@ public:
     return (_current != _end) ? kBuiltinIntrinsicTableName : _tableIter.GetTableName();
   }
 
+  LPCSTR GetLoweringStrategy()
+  {
+    return (_current != _end) ? "" : _tableIter.GetLoweringStrategy();
+  }
+
   IntrinsicDefIter& operator++()
   {
     if (_current != _end) {
@@ -2206,7 +2220,7 @@ public:
 
 static void AddHLSLSubscriptAttr(Decl *D, ASTContext &context, HLSubscriptOpcode opcode) {
   StringRef group = GetHLOpcodeGroupName(HLOpcodeGroup::HLSubscript);
-  D->addAttr(HLSLIntrinsicAttr::CreateImplicit(context, group, static_cast<unsigned>(opcode)));
+  D->addAttr(HLSLIntrinsicAttr::CreateImplicit(context, group, "", static_cast<unsigned>(opcode)));
 }
 
 class HLSLExternalSource : public ExternalSemaSource {
@@ -3350,6 +3364,7 @@ public:
       // of the types we need.
       const HLSL_INTRINSIC* pIntrinsic = *cursor;
       LPCSTR tableName = cursor.GetTableName();
+      LPCSTR lowering = cursor.GetLoweringStrategy();
       DXASSERT(
         pIntrinsic->uNumArgs <= g_MaxIntrinsicParamCount + 1,
         "otherwise g_MaxIntrinsicParamCount needs to be updated for wider signatures");
@@ -3369,7 +3384,7 @@ public:
       if (insertedNewValue)
       {
         DXASSERT(tableName, "otherwise IDxcIntrinsicTable::GetTableName() failed");
-        intrinsicFuncDecl = AddHLSLIntrinsicFunction(*m_context, m_hlslNSDecl, tableName, pIntrinsic, functionArgTypes, functionArgTypeCount);
+        intrinsicFuncDecl = AddHLSLIntrinsicFunction(*m_context, m_hlslNSDecl, tableName, lowering, pIntrinsic, functionArgTypes, functionArgTypeCount);
         insertResult.first->setFunctionDecl(intrinsicFuncDecl);
       }
       else
@@ -3943,7 +3958,7 @@ public:
       SC_Extern, InlineSpecifiedFalse, IsConstexprFalse, NoLoc);
 
     // Add intrinsic attr
-    AddHLSLIntrinsicAttr(method, *m_context, tableName, intrinsic);
+    AddHLSLIntrinsicAttr(method, *m_context, tableName, "", intrinsic);
 
     // Record this function template specialization.
     TemplateArgumentList *argListCopy = TemplateArgumentList::CreateCopy(

+ 71 - 1
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -37,6 +37,7 @@
 #include "llvm/Support/FormattedStream.h"
 #include "dxc/Support/WinIncludes.h"  // For DxilPipelineStateValidation.h
 #include "dxc/HLSL/DxilPipelineStateValidation.h"
+#include "dxc/HLSL/HLSLExtensionsCodegenHelper.h"
 
 #ifdef _DEBUG
 #if defined(_MSC_VER)
@@ -1784,6 +1785,75 @@ static void PrintPipelineStateValidationRuntimeInfo(const char *pBuffer, DXIL::S
   OS << comment << "\n";
 }
 
+class HLSLExtensionsCodegenHelperImpl : public HLSLExtensionsCodegenHelper {
+private:
+  CompilerInstance &m_CI;
+  DxcLangExtensionsHelper &m_langExtensionsHelper;
+
+  // The metadata format is a root node that has pointers to metadata
+  // nodes for each define. The metatdata node for a define is a pair
+  // of (name, value) metadata strings.
+  //
+  // Example:
+  // !hlsl.semdefs = {!0, !1}
+  // !0 = !{!"FOO", !"BAR"}
+  // !1 = !{!"BOO", !"HOO"}
+  void WriteSemanticDefines(llvm::Module *M, const ParsedSemanticDefineList &defines) {
+    // Create all metadata nodes for each define. Each node is a (name, value) pair.
+    std::vector<MDNode *> mdNodes;
+    for (const ParsedSemanticDefine &define : defines) {
+      MDString *name  = MDString::get(M->getContext(), define.Name);
+      MDString *value = MDString::get(M->getContext(), define.Value);
+      mdNodes.push_back(MDNode::get(M->getContext(), { name, value }));
+    }
+
+    // Add root node with pointers to all define metadata nodes.
+    NamedMDNode *Root = M->getOrInsertNamedMetadata(m_langExtensionsHelper.GetSemanticDefineMetadataName());
+    for (MDNode *node : mdNodes)
+      Root->addOperand(node);
+  }
+
+  SemanticDefineErrorList GetValidatedSemanticDefines(const ParsedSemanticDefineList &defines, ParsedSemanticDefineList &validated, SemanticDefineErrorList &errors) {
+    for (const ParsedSemanticDefine &define : defines) {
+      DxcLangExtensionsHelper::SemanticDefineValidationResult result = m_langExtensionsHelper.ValidateSemanticDefine(define.Name, define.Value);
+        if (result.HasError())
+          errors.emplace_back(SemanticDefineError(define.Location, SemanticDefineError::Level::Error, result.Error));
+        if (result.HasWarning())
+          errors.emplace_back(SemanticDefineError(define.Location, SemanticDefineError::Level::Warning, result.Warning));
+        if (!result.HasError())
+          validated.emplace_back(define);
+    }
+
+    return errors;
+  }
+
+public:
+  HLSLExtensionsCodegenHelperImpl(CompilerInstance &CI, DxcLangExtensionsHelper &langExtensionsHelper)
+  : m_CI(CI), m_langExtensionsHelper(langExtensionsHelper)
+  {}
+
+  // Write semantic defines as metadata in the module.
+  virtual std::vector<SemanticDefineError> WriteSemanticDefines(llvm::Module *M) override {
+    // Grab the semantic defines seen by the parser.
+    ParsedSemanticDefineList defines =
+      CollectSemanticDefinesParsedByCompiler(m_CI, &m_langExtensionsHelper);
+
+    // Nothing to do if we have no defines.
+    SemanticDefineErrorList errors;
+    if (!defines.size())
+      return errors;
+
+    ParsedSemanticDefineList validated;
+    GetValidatedSemanticDefines(defines, validated, errors);
+    WriteSemanticDefines(M, validated);
+    return errors;
+  }
+
+  virtual std::string GetIntrinsicName(UINT opcode) override {
+    return m_langExtensionsHelper.GetIntrinsicName(opcode);
+  }
+};
+
 // Class to manage lifetime of llvm module and provide some utility
 // functions used for generating compiler output.
 class DxilCompilerLLVMModuleOutput {
@@ -1813,7 +1883,6 @@ private:
   std::unique_ptr<llvm::Module> m_llvmModuleWithDebugInfo;
 };
 
-
 class DxcCompiler : public IDxcCompiler, public IDxcLangExtensions, public IDxcContainerEvent {
 private:
   DXC_MICROCOM_REF_FIELD(m_dwRef)
@@ -2471,6 +2540,7 @@ public:
     compiler.getCodeGenOpts().setInlining(
         clang::CodeGenOptions::OnlyAlwaysInlining);
 
+    compiler.getCodeGenOpts().HLSLExtensionsCodegen = std::make_shared<HLSLExtensionsCodegenHelperImpl>(compiler, m_langExtensionsHelper);
   }
 };
 

+ 37 - 11
tools/clang/tools/libclang/dxcrewriteunused.cpp

@@ -158,11 +158,21 @@ bool MacroPairCompareIsLessThan(const std::pair<const IdentifierInfo*, const Mac
 }
 
 static
-void WriteSemanticDefines(CompilerInstance& compiler, _In_ DxcLangExtensionsHelper* helper, raw_string_ostream& o)
-{
+void WriteSemanticDefines(CompilerInstance &compiler, _In_ DxcLangExtensionsHelper *helper, raw_string_ostream &o) {
+  ParsedSemanticDefineList macros = CollectSemanticDefinesParsedByCompiler(compiler, helper);
+  if (!macros.empty()) {
+    o << "\n// Macros:\n";
+    for (auto&& m : macros) {
+      o << "#define " << m.Name << " " << m.Value << "\n";
+    }
+  }
+}
+
+ParsedSemanticDefineList hlsl::CollectSemanticDefinesParsedByCompiler(CompilerInstance &compiler, _In_ DxcLangExtensionsHelper *helper) {
+  ParsedSemanticDefineList parsedDefines;
   const llvm::SmallVector<std::string, 2>& defines = helper->GetSemanticDefines();
   if (defines.size() == 0) {
-    return;
+    return parsedDefines;
   }
 
   const llvm::SmallVector<std::string, 2>& defineExclusions = helper->GetSemanticDefineExclusions();
@@ -206,20 +216,36 @@ void WriteSemanticDefines(CompilerInstance& compiler, _In_ DxcLangExtensionsHelp
   }
 
   if (!macros.empty()) {
-    o << "\n// Macros:\n";
     std::sort(macros.begin(), macros.end(), MacroPairCompareIsLessThan);
+    SmallVector<std::string, 8> tokens;
     for (auto&& m : macros) {
-      o << "#define " << m.first->getName() << " ";
-      SmallString<128> SpellingBuffer;
-      MacroInfo::tokens_iterator tiEnd = m.second->tokens_end();
-      for (MacroInfo::tokens_iterator ti = m.second->tokens_begin(); ti != tiEnd; ++ti) {
+      std::string name = m.first->getName();
+      // Collect all macro token values into a vector.
+      // Put them in a vector to avoid repeated copying of data when appending strings.
+      tokens.clear();
+      tokens.reserve(m.second->getNumTokens());
+      for (MacroInfo::tokens_iterator ti = m.second->tokens_begin(), tiEnd = m.second->tokens_end(); ti != tiEnd; ++ti) {
         if (ti->hasLeadingSpace())
-          ' ';
-        o << pp.getSpelling(*ti, SpellingBuffer);
+          tokens.push_back(" ");
+        tokens.push_back(pp.getSpelling(*ti));
       }
-      o << "\n";
+
+      // Compute total size of defined string value.
+      size_t size = 0;
+      for (const std::string &s : tokens)
+        size += s.size();
+
+      // Concatenate all values into a single string.
+      std::string value;
+      value.reserve(size);
+      for (const std::string &s : tokens)
+        value.append(s);
+
+      parsedDefines.emplace_back(ParsedSemanticDefine{ name, value, m.second->getDefinitionLoc().getRawEncoding() });
     }
   }
+
+  return parsedDefines;
 }
 
 static

+ 1 - 0
tools/clang/unittests/HLSL/CMakeLists.txt

@@ -20,6 +20,7 @@ add_clang_library(clang-hlsl-tests SHARED
   DxilContainerTest.cpp
   DXIsenseTest.cpp
   ExecutionTest.cpp
+  ExtensionTest.cpp
   FileCheckerTest.cpp
   FileCheckForTest.cpp
   FunctionTest.cpp

+ 2 - 0
tools/clang/unittests/HLSL/DxcTestUtils.h

@@ -99,6 +99,8 @@ inline std::string BlobToUtf8(_In_ IDxcBlob *pBlob) {
 }
 
 std::wstring BlobToUtf16(_In_ IDxcBlob *pBlob);
+void CheckOperationSucceeded(IDxcOperationResult *pResult, IDxcBlob **ppBlob);
+std::string DisassembleProgram(dxc::DxcDllSupport &dllSupport, IDxcBlob *pProgram);
 void Utf8ToBlob(dxc::DxcDllSupport &dllSupport, const std::string &val, _Outptr_ IDxcBlob **ppBlob);
 void Utf8ToBlob(dxc::DxcDllSupport &dllSupport, const std::string &val, _Outptr_ IDxcBlobEncoding **ppBlob);
 void Utf8ToBlob(dxc::DxcDllSupport &dllSupport, const char *pVal, _Outptr_ IDxcBlobEncoding **ppBlob);

+ 538 - 0
tools/clang/unittests/HLSL/ExtensionTest.cpp

@@ -0,0 +1,538 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// ExtensionTest.cpp                                                         //
+//                                                                           //
+// Provides tests for the language extension APIs.                           //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "CompilationResult.h"
+#include "WexTestClass.h"
+#include "HlslTestUtils.h"
+#include "DxcTestUtils.h"
+#include "dxc/Support/microcom.h"
+#include "dxc/dxcapi.internal.h"
+#include "dxc/HLSL/HLOperationLowerExtension.h"
+
+///////////////////////////////////////////////////////////////////////////////
+// Support for test intrinsics.
+
+// $result = test_fn(any_vector<any_cardinality> value)
+static const HLSL_INTRINSIC_ARGUMENT TestFnArgs[] = {
+  { "test_fn", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
+  { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
+};
+
+// void test_proc(any_vector<any_cardinality> value)
+static const HLSL_INTRINSIC_ARGUMENT TestProcArgs[] = {
+  { "test_proc", 0, 0, LITEMPLATE_VOID, 0, LICOMPTYPE_VOID, 0, 0 },
+  { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
+};
+
+// $result = test_poly(any_vector<any_cardinality> value)
+static const HLSL_INTRINSIC_ARGUMENT TestFnCustomArgs[] = {
+  { "test_poly", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
+  { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
+};
+
+// $result = test_int(int<any_cardinality> value)
+static const HLSL_INTRINSIC_ARGUMENT TestFnIntArgs[] = {
+  { "test_int", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_INT, 1, IA_C },
+  { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_INT, 1, IA_C }
+};
+
+// $result = test_nolower(any_vector<any_cardinality> value)
+static const HLSL_INTRINSIC_ARGUMENT TestFnNoLowerArgs[] = {
+  { "test_nolower", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
+  { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
+};
+
+// void test_pack_0(any_vector<any_cardinality> value)
+static const HLSL_INTRINSIC_ARGUMENT TestFnPack0[] = {
+  { "test_pack_0", 0, 0, LITEMPLATE_VOID, 0, LICOMPTYPE_VOID, 0, 0 },
+  { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
+};
+
+// $result = test_pack_1()
+static const HLSL_INTRINSIC_ARGUMENT TestFnPack1[] = {
+  { "test_pack_1", AR_QUAL_OUT, 0, LITEMPLATE_VECTOR, 0, LICOMPTYPE_FLOAT, 1, 2 },
+};
+
+// $result = test_pack_2(any_vector<any_cardinality> value1, any_vector<any_cardinality> value2)
+static const HLSL_INTRINSIC_ARGUMENT TestFnPack2[] = {
+  { "test_pack_2", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
+  { "value1", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
+  { "value2", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
+};
+
+// $scalar = test_pack_3(any_vector<any_cardinality> value)
+static const HLSL_INTRINSIC_ARGUMENT TestFnPack3[] = {
+  { "test_pack_3", AR_QUAL_OUT, 0, LITEMPLATE_SCALAR, 0, LICOMPTYPE_FLOAT, 1, 1 },
+  { "value1", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_FLOAT, 1, 2},
+};
+
+// float<2> = test_pack_4(float<3> value)
+static const HLSL_INTRINSIC_ARGUMENT TestFnPack4[] = {
+  { "test_pack_4", AR_QUAL_OUT, 0, LITEMPLATE_SCALAR, 0, LICOMPTYPE_FLOAT, 1, 2 },
+  { "value", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_FLOAT, 1, 3},
+};
+
+struct Intrinsic {
+  LPCWSTR hlslName;
+  const char *dxilName;
+  const char *strategy;
+  HLSL_INTRINSIC hlsl;
+};
+const char * DEFAULT_NAME = "";
+
+// llvm::array_lengthof that returns a UINT instead of size_t
+template <class T, std::size_t N>
+UINT countof(T(&)[N]) { return static_cast<UINT>(N); }
+
+Intrinsic Intrinsics[] = {
+  {L"test_fn",      DEFAULT_NAME,      "r", {  1, false, true, -1, countof(TestFnArgs), TestFnArgs }},
+  {L"test_proc",    DEFAULT_NAME,      "r", {  2, false, true, -1, countof(TestProcArgs), TestProcArgs }},
+  {L"test_poly",    "test_poly.$o",    "r", {  3, false, true, -1, countof(TestFnCustomArgs), TestFnCustomArgs }},
+  {L"test_int",     "test_int",        "r", {  4, false, true, -1, countof(TestFnIntArgs), TestFnIntArgs}},
+  {L"test_nolower", "test_nolower.$o", "n", {  5, false, true, -1, countof(TestFnNoLowerArgs), TestFnNoLowerArgs}},
+  {L"test_pack_0",  "test_pack_0.$o",  "p", {  6, false, true, -1, countof(TestFnPack0), TestFnPack0}},
+  {L"test_pack_1",  "test_pack_1.$o",  "p", {  7, false, true, -1, countof(TestFnPack1), TestFnPack1}},
+  {L"test_pack_2",  "test_pack_2.$o",  "p", {  8, false, true, -1, countof(TestFnPack2), TestFnPack2}},
+  {L"test_pack_3",  "test_pack_3.$o",  "p", {  9, false, true, -1, countof(TestFnPack3), TestFnPack3}},
+  {L"test_pack_4",  "test_pack_4.$o",  "p", { 10, false, true, -1, countof(TestFnPack4), TestFnPack4}},
+};
+
+class TestIntrinsicTable : public IDxcIntrinsicTable {
+private:
+  DXC_MICROCOM_REF_FIELD(m_dwRef);
+public:
+  TestIntrinsicTable() : m_dwRef(0) { }
+  DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
+  __override HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void** ppvObject) {
+    return DoBasicQueryInterface<IDxcIntrinsicTable>(this, iid, ppvObject);
+  }
+
+  __override HRESULT STDMETHODCALLTYPE
+  GetTableName(_Outptr_ LPCSTR *pTableName) {
+    *pTableName = "test";
+    return S_OK;
+  }
+
+  __override HRESULT STDMETHODCALLTYPE LookupIntrinsic(
+      LPCWSTR typeName, LPCWSTR functionName, const HLSL_INTRINSIC **pIntrinsic,
+      _Inout_ UINT64 *pLookupCookie) {
+    if (typeName != nullptr && *typeName) return E_FAIL;
+    Intrinsic *intrinsic =
+      std::find_if(std::begin(Intrinsics), std::end(Intrinsics), [functionName](const Intrinsic &i) {
+        return wcscmp(i.hlslName, functionName) == 0;
+    });
+    if (intrinsic == std::end(Intrinsics))
+      return E_FAIL;
+
+    *pIntrinsic = &intrinsic->hlsl;
+    *pLookupCookie = 0;
+    return S_OK;
+  }
+
+  __override HRESULT STDMETHODCALLTYPE
+  GetLoweringStrategy(UINT opcode, _Outptr_ LPCSTR *pStrategy) {
+    Intrinsic *intrinsic =
+      std::find_if(std::begin(Intrinsics), std::end(Intrinsics), [opcode](const Intrinsic &i) {
+      return i.hlsl.Op == opcode;
+    });
+    
+    if (intrinsic == std::end(Intrinsics))
+      return E_FAIL;
+
+    *pStrategy = intrinsic->strategy;
+
+    return S_OK;
+  }
+
+  __override HRESULT STDMETHODCALLTYPE
+  GetIntrinsicName(UINT opcode, _Outptr_ LPCSTR *pName) {
+    Intrinsic *intrinsic =
+      std::find_if(std::begin(Intrinsics), std::end(Intrinsics), [opcode](const Intrinsic &i) {
+      return i.hlsl.Op == opcode;
+    });
+
+    if (intrinsic == std::end(Intrinsics))
+      return E_FAIL;
+
+    *pName = intrinsic->dxilName;
+    return S_OK;
+  }
+};
+
+// A class to test semantic define validation.
+// It takes a list of defines that when present should cause errors
+// and defines that should cause warnings. A more realistic validator
+// would look at the values and make sure (for example) they are
+// the correct type (integer, string, etc).
+class TestSemanticDefineValidator : public IDxcSemanticDefineValidator {
+private:
+  DXC_MICROCOM_REF_FIELD(m_dwRef);
+  std::vector<std::string> m_errorDefines;
+  std::vector<std::string> m_warningDefines;
+public:
+  TestSemanticDefineValidator(const std::vector<std::string> &errorDefines, const std::vector<std::string> &warningDefines)
+    : m_dwRef(0)
+    , m_errorDefines(errorDefines)
+    , m_warningDefines(warningDefines)
+  { }
+  DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
+
+    __override HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void** ppvObject) {
+    return DoBasicQueryInterface<IDxcSemanticDefineValidator>(this, iid, ppvObject);
+  }
+
+  virtual HRESULT STDMETHODCALLTYPE GetSemanticDefineWarningsAndErrors(LPCSTR pName, LPCSTR pValue, IDxcBlobEncoding **ppWarningBlob, IDxcBlobEncoding **ppErrorBlob) {
+    if (!pName || !pValue || !ppWarningBlob || !ppErrorBlob)
+      return E_FAIL;
+
+    auto Check = [pName](const std::vector<std::string> &errors, IDxcBlobEncoding **blob) {
+      if (std::find(errors.begin(), errors.end(), pName) != errors.end()) {
+        dxc::DxcDllSupport dllSupport;
+        VERIFY_SUCCEEDED(dllSupport.Initialize());
+        std::string error("bad define: ");
+        error.append(pName);
+        Utf8ToBlob(dllSupport, error.c_str(), blob);
+      }
+    };
+    Check(m_errorDefines, ppErrorBlob);
+    Check(m_warningDefines, ppWarningBlob);
+
+    return S_OK;
+  }
+};
+static void CheckOperationFailed(IDxcOperationResult *pResult) {
+  HRESULT status;
+  VERIFY_SUCCEEDED(pResult->GetStatus(&status));
+  VERIFY_FAILED(status);
+}
+
+static std::string GetCompileErrors(IDxcOperationResult *pResult) {
+  CComPtr<IDxcBlobEncoding> pErrors;
+  VERIFY_SUCCEEDED(pResult->GetErrorBuffer(&pErrors));
+  if (!pErrors)
+    return "";
+  return BlobToUtf8(pErrors);
+}
+
+class Compiler {
+public:
+  Compiler(dxc::DxcDllSupport &dll) : m_dllSupport(dll) {
+    VERIFY_SUCCEEDED(m_dllSupport.Initialize());
+    VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
+    VERIFY_SUCCEEDED(pCompiler.QueryInterface(&pLangExtensions));
+  }
+  void RegisterSemanticDefine(LPCWSTR define) {
+    VERIFY_SUCCEEDED(pLangExtensions->RegisterSemanticDefine(define));
+  }
+  void RegisterSemanticDefineExclusion(LPCWSTR define) {
+    VERIFY_SUCCEEDED(pLangExtensions->RegisterSemanticDefineExclusion(define));
+  }
+  void SetSemanticDefineValidator(IDxcSemanticDefineValidator *validator) {
+    pTestSemanticDefineValidator = validator;
+    VERIFY_SUCCEEDED(pLangExtensions->SetSemanticDefineValidator(pTestSemanticDefineValidator));
+  }
+  void SetSemanticDefineMetaDataName(const char *name) {
+    VERIFY_SUCCEEDED(pLangExtensions->SetSemanticDefineMetaDataName("test.defs"));
+  }
+  void RegisterIntrinsicTable(IDxcIntrinsicTable *table) {
+    pTestIntrinsicTable = table;
+    VERIFY_SUCCEEDED(pLangExtensions->RegisterIntrinsicTable(pTestIntrinsicTable));
+  }
+  
+  IDxcOperationResult *Compile(const char *program) {
+    return Compile(program, {}, {});
+  }
+
+  IDxcOperationResult *Compile(const char *program, const std::vector<LPCWSTR> &arguments, const std::vector<DxcDefine> defs ) {
+    Utf8ToBlob(m_dllSupport, program, &pCodeBlob);
+    VERIFY_SUCCEEDED(pCompiler->Compile(pCodeBlob, L"hlsl.hlsl", L"main",
+      L"ps_6_0",
+      const_cast<LPCWSTR *>(arguments.data()), arguments.size(),
+      defs.data(), defs.size(),
+      nullptr, &pCompileResult));
+
+    return pCompileResult;
+  }
+
+  std::string Disassemble() {
+    CComPtr<IDxcBlob> pBlob;
+    CheckOperationSucceeded(pCompileResult, &pBlob);
+    return DisassembleProgram(m_dllSupport, pBlob);
+  }
+
+  dxc::DxcDllSupport &m_dllSupport;
+  CComPtr<IDxcCompiler> pCompiler;
+  CComPtr<IDxcLangExtensions> pLangExtensions;
+  CComPtr<IDxcBlobEncoding> pCodeBlob;
+  CComPtr<IDxcOperationResult> pCompileResult;
+  CComPtr<IDxcSemanticDefineValidator> pTestSemanticDefineValidator;
+  CComPtr<IDxcIntrinsicTable> pTestIntrinsicTable;
+};
+
+///////////////////////////////////////////////////////////////////////////////
+// Extension unit tests.
+
+class ExtensionTest {
+public:
+  BEGIN_TEST_CLASS(ExtensionTest)
+    TEST_METHOD_PROPERTY(L"Priority", L"0")
+  END_TEST_CLASS()
+
+  dxc::DxcDllSupport m_dllSupport;
+
+  TEST_METHOD(DefineWhenRegisteredThenPreserved);
+  TEST_METHOD(DefineValidationError);
+  TEST_METHOD(DefineValidationWarning);
+  TEST_METHOD(DefineNoValidatorOk);
+  TEST_METHOD(IntrinsicWhenAvailableThenUsed);
+  TEST_METHOD(CustomIntrinsicName);
+  TEST_METHOD(NoLowering);
+  TEST_METHOD(PackedLowering);
+};
+
+TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
+  Compiler c(m_dllSupport);
+  c.RegisterSemanticDefine(L"FOO*");
+  c.RegisterSemanticDefineExclusion(L"FOOBAR");
+  c.SetSemanticDefineValidator(new TestSemanticDefineValidator({ "FOOLALA" }, {}));
+  c.SetSemanticDefineMetaDataName("test.defs");
+  c.Compile(
+    "#define FOOTBALL AWESOME\n"
+    "#define FOOTLOOSE TOO\n"
+    "#define FOOBAR 123\n"
+    "#define FOOD\n"
+    "#define FOO 1 2 3\n"
+    "float4 main() : SV_Target {\n"
+    "  return 0;\n"
+    "}\n",
+    {L"/Vd"},
+    { { L"FOODEF", L"1"} }
+  );
+  std::string disassembly = c.Disassemble();
+  // Check for root named md node. It contains pointers to md nodes for each define.
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("!test.defs"));
+  // #define FOODEF 1
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("!{!\"FOODEF\", !\"1\"}"));
+  // #define FOOTBALL AWESOME
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("!{!\"FOOTBALL\", !\"AWESOME\"}"));
+  // #define FOOTLOOSE TOO
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("!{!\"FOOTLOOSE\", !\"TOO\"}"));
+  // #define FOOD
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("!{!\"FOOD\", !\"\"}"));
+  // #define FOO 1 2 3
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("!{!\"FOO\", !\"1 2 3\"}"));
+  // FOOBAR should be excluded.
+  VERIFY_IS_TRUE(
+    disassembly.npos ==
+    disassembly.find("!{!\"FOOBAR\""));
+}
+
+TEST_F(ExtensionTest, DefineValidationError) {
+  Compiler c(m_dllSupport);
+  c.RegisterSemanticDefine(L"FOO*");
+  c.SetSemanticDefineValidator(new TestSemanticDefineValidator({ "FOO" }, {}));
+  IDxcOperationResult *pCompileResult = c.Compile(
+    "#define FOO 1\n"
+    "float4 main() : SV_Target {\n"
+    "  return 0;\n"
+    "}\n",
+    {L"/Vd"}, {}
+  );
+
+  // Check that validation error causes compile failure.
+  CheckOperationFailed(pCompileResult);
+  std::string errors = GetCompileErrors(pCompileResult);
+  // Check that the error message is for the validation failure.
+  VERIFY_IS_TRUE(
+    errors.npos !=
+    errors.find("hlsl.hlsl:1:9: error: bad define: FOO"));
+}
+
+TEST_F(ExtensionTest, DefineValidationWarning) {
+  Compiler c(m_dllSupport);
+  c.RegisterSemanticDefine(L"FOO*");
+  c.SetSemanticDefineValidator(new TestSemanticDefineValidator({}, { "FOO" }));
+  IDxcOperationResult *pCompileResult = c.Compile(
+    "#define FOO 1\n"
+    "float4 main() : SV_Target {\n"
+    "  return 0;\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+
+  std::string errors = GetCompileErrors(pCompileResult);
+  // Check that the error message is for the validation failure.
+  VERIFY_IS_TRUE(
+    errors.npos !=
+    errors.find("hlsl.hlsl:1:9: warning: bad define: FOO"));
+
+  // Check the define is still emitted.
+  std::string disassembly = c.Disassemble();
+  // Check for root named md node. It contains pointers to md nodes for each define.
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("!hlsl.semdefs"));
+  // #define FOO 1
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("!{!\"FOO\", !\"1\"}"));
+}
+
+TEST_F(ExtensionTest, DefineNoValidatorOk) {
+  Compiler c(m_dllSupport);
+  c.RegisterSemanticDefine(L"FOO*");
+  c.Compile(
+    "#define FOO 1\n"
+    "float4 main() : SV_Target {\n"
+    "  return 0;\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+
+  std::string disassembly = c.Disassemble();
+  // Check the define is emitted.
+  // #define FOO 1
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("!{!\"FOO\", !\"1\"}"));
+}
+
+
+TEST_F(ExtensionTest, IntrinsicWhenAvailableThenUsed) {
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  c.Compile(
+    "float2 main(float2 v : V, int2 i : I) : SV_Target {\n"
+    "  test_proc(v);\n"
+    "  float2 a = test_fn(v);\n"
+    "  int2 b = test_fn(i);\n"
+    "  return a + b;\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  std::string disassembly = c.Disassemble();
+
+  // Things to call out:
+  // - result is float, not a vector
+  // - mangled name contains the 'test' and '.r' parts
+  // - opcode is first i32 argument
+  // - second argument is float, ie it got scalarized
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call void @\"test.\\01?test_proc@hlsl@@YAXV?$vector@M$01@@@Z.r\"(i32 2, float %6)"));
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call float @\"test.\\01?test_fn@hlsl@@YA?AV?$vector@M$01@@V2@@Z.r\"(i32 1, float"));
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call i32 @\"test.\\01?test_fn@hlsl@@YA?AV?$vector@H$01@@V2@@Z.r\"(i32 1, i32"));
+
+  // - attributes are added to the declaration (the # at the end of the decl)
+  //   TODO: would be nice to check for the actual attribute (e.g. readonly)
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("declare float @\"test.\\01?test_fn@hlsl@@YA?AV?$vector@M$01@@V2@@Z.r\"(i32, float) #"));
+}
+
+TEST_F(ExtensionTest, CustomIntrinsicName) {
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  c.Compile(
+    "float2 main(float2 v : V, int2 i : I) : SV_Target {\n"
+    "  float2 a = test_poly(v);\n"
+    "  int2   b = test_poly(i);\n"
+    "  int2   c = test_int(i);\n"
+    "  return a + b + c;\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  std::string disassembly = c.Disassemble();
+
+  // - custom name works for polymorphic function
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call float @test_poly.float(i32 3, float"));
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call i32 @test_poly.i32(i32 3, i32"));
+
+  // - custom name works for non-polymorphic function
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call i32 @test_int(i32 4, i32"));
+}
+
+TEST_F(ExtensionTest, NoLowering) {
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  c.Compile(
+    "float2 main(float2 v : V, int2 i : I) : SV_Target {\n"
+    "  float2 a = test_nolower(v);\n"
+    "  float2 b = test_nolower(i);\n"
+    "  return a + b;\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  std::string disassembly = c.Disassemble();
+
+  // - custom name works for non-lowered function
+  // - non-lowered function has vector type as argument
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call <2 x float> @test_nolower.float(i32 5, <2 x float>"));
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call <2 x i32> @test_nolower.i32(i32 5, <2 x i32>"));
+}
+
+TEST_F(ExtensionTest, PackedLowering) {
+  Compiler c(m_dllSupport);
+  c.RegisterIntrinsicTable(new TestIntrinsicTable());
+  c.Compile(
+    "float2 main(float2 v1 : V1, float2 v2 : V2, float3 v3 : V3) : SV_Target {\n"
+    "  test_pack_0(v1);\n"
+    "  int2   a = test_pack_1();\n"
+    "  float2 b = test_pack_2(v1, v2);\n"
+    "  float  c = test_pack_3(v1);\n"
+    "  float2 d = test_pack_4(v3);\n"
+    "  return a + b + float2(c, c);\n"
+    "}\n",
+    { L"/Vd" }, {}
+  );
+  std::string disassembly = c.Disassemble();
+
+  // - pack strategy changes vectors into structs
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call void @test_pack_0.float(i32 6, { float, float }"));
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call { float, float } @test_pack_1.float(i32 7)"));
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call { float, float } @test_pack_2.float(i32 8, { float, float }"));
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call float @test_pack_3.float(i32 9, { float, float }"));
+  VERIFY_IS_TRUE(
+    disassembly.npos !=
+    disassembly.find("call { float, float } @test_pack_4.float(i32 10, { float, float, float }"));
+}

+ 22 - 11
tools/clang/unittests/HLSL/ValidationTest.cpp

@@ -20,6 +20,27 @@
 
 using namespace std;
 
+void CheckOperationSucceeded(IDxcOperationResult *pResult, IDxcBlob **ppBlob) {
+  HRESULT status;
+  VERIFY_SUCCEEDED(pResult->GetStatus(&status));
+  VERIFY_SUCCEEDED(status);
+  VERIFY_SUCCEEDED(pResult->GetResult(ppBlob));
+}
+
+std::string DisassembleProgram(dxc::DxcDllSupport &dllSupport,
+                               IDxcBlob *pProgram) {
+  CComPtr<IDxcCompiler> pCompiler;
+  CComPtr<IDxcBlobEncoding> pDisassembly;
+
+  if (!dllSupport.IsEnabled()) {
+    VERIFY_SUCCEEDED(dllSupport.Initialize());
+  }
+
+  VERIFY_SUCCEEDED(dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
+  VERIFY_SUCCEEDED(pCompiler->Disassemble(pProgram, &pDisassembly));
+  return BlobToUtf8(pDisassembly);
+}
+
 class ValidationTest
 {
 public:
@@ -221,17 +242,7 @@ public:
   }
 
   void DisassembleProgram(IDxcBlob *pProgram, std::string *text) {
-    CComPtr<IDxcCompiler> pCompiler;
-    CComPtr<IDxcBlobEncoding> pDisassembly;
-
-    if (!m_dllSupport.IsEnabled()) {
-      VERIFY_SUCCEEDED(m_dllSupport.Initialize());
-    }
-
-    VERIFY_SUCCEEDED(
-      m_dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
-    VERIFY_SUCCEEDED(pCompiler->Disassemble(pProgram, &pDisassembly));
-    *text = BlobToUtf8(pDisassembly);
+    *text = ::DisassembleProgram(m_dllSupport, pProgram);
   }
 
   void RewriteAssemblyCheckMsg(LPCSTR pSource, LPCSTR pShaderModel,