Просмотр исходного кода

Add IDxcLangExtensions2 to set target triple. (#2981)

* Add IDxcLangExtensions2 to set target triple.
Xiang Li 5 лет назад
Родитель
Сommit
19ef31e860

+ 238 - 0
include/dxc/Support/DxcLangExtensionsCommonHelper.h

@@ -0,0 +1,238 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// DxcLangExtensionsCommonHelper.h                                           //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+// Provides a helper class to implement language extensions to HLSL.         //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#pragma once
+
+#include "dxc/Support/Unicode.h"
+#include "dxc/Support/FileIOHelper.h"
+#include "dxc/dxcapi.internal.h"
+#include <vector>
+
+namespace llvm {
+class raw_string_ostream;
+class CallInst;
+class Value;
+}
+
+namespace hlsl {
+
+class DxcLangExtensionsCommonHelper {
+private:
+  llvm::SmallVector<std::string, 2> m_semanticDefines;
+  llvm::SmallVector<std::string, 2> m_semanticDefineExclusions;
+  llvm::SmallVector<std::string, 2> m_defines;
+  llvm::SmallVector<CComPtr<IDxcIntrinsicTable>, 2> m_intrinsicTables;
+  CComPtr<IDxcSemanticDefineValidator> m_semanticDefineValidator;
+  std::string m_semanticDefineMetaDataName;
+  std::string m_targetTriple;
+  HRESULT STDMETHODCALLTYPE RegisterIntoVector(LPCWSTR name, llvm::SmallVector<std::string, 2>& here)
+  {
+    try {
+      IFTPTR(name);
+      std::string s;
+      if (!Unicode::UTF16ToUTF8String(name, &s)) {
+        throw ::hlsl::Exception(E_INVALIDARG);
+      }
+      here.push_back(s);
+      return S_OK;
+    }
+    CATCH_CPP_RETURN_HRESULT();
+  }
+
+public:
+  const llvm::SmallVector<std::string, 2>& GetSemanticDefines() const { return m_semanticDefines; }
+  const llvm::SmallVector<std::string, 2>& GetSemanticDefineExclusions() const { return m_semanticDefineExclusions; }
+  const llvm::SmallVector<std::string, 2>& GetDefines() const { return m_defines; }
+  llvm::SmallVector<CComPtr<IDxcIntrinsicTable>, 2>& GetIntrinsicTables(){ return m_intrinsicTables; }
+  const std::string &GetSemanticDefineMetadataName() { return m_semanticDefineMetaDataName; }
+  const std::string &GetTargetTriple() { return m_targetTriple; }
+
+  HRESULT STDMETHODCALLTYPE RegisterSemanticDefine(LPCWSTR name)
+  {
+    return RegisterIntoVector(name, m_semanticDefines);
+  }
+
+  HRESULT STDMETHODCALLTYPE RegisterSemanticDefineExclusion(LPCWSTR name)
+  {
+    return RegisterIntoVector(name, m_semanticDefineExclusions);
+  }
+
+  HRESULT STDMETHODCALLTYPE RegisterDefine(LPCWSTR name)
+  {
+    return RegisterIntoVector(name, m_defines);
+  }
+
+  HRESULT STDMETHODCALLTYPE RegisterIntrinsicTable(_In_ IDxcIntrinsicTable* pTable)
+  {
+    try {
+      IFTPTR(pTable);
+      LPCSTR tableName = nullptr;
+      IFT(pTable->GetTableName(&tableName));
+      IFTPTR(tableName);
+      IFTARG(strcmp(tableName, "op") != 0);   // "op" is reserved for builtin intrinsics
+      for (auto &&table : m_intrinsicTables) {
+        LPCSTR otherTableName = nullptr;
+        IFT(table->GetTableName(&otherTableName));
+        IFTPTR(otherTableName);
+        IFTARG(strcmp(tableName, otherTableName) != 0); // Added a duplicate table name
+      }
+      m_intrinsicTables.push_back(pTable);
+      return S_OK;
+    }
+    CATCH_CPP_RETURN_HRESULT();
+  }
+
+  // Set the validator used to validate semantic defines.
+  // Only one validator stored and used to run validation.
+  HRESULT STDMETHODCALLTYPE SetSemanticDefineValidator(_In_ IDxcSemanticDefineValidator* pValidator) {
+    if (pValidator == nullptr)
+      return E_POINTER;
+
+    m_semanticDefineValidator = pValidator;
+    return S_OK;
+  }
+
+  HRESULT STDMETHODCALLTYPE SetSemanticDefineMetaDataName(LPCSTR name) {
+    try {
+      m_semanticDefineMetaDataName = name;
+      return S_OK;
+    }
+    CATCH_CPP_RETURN_HRESULT();
+  }
+
+  HRESULT STDMETHODCALLTYPE SetTargetTriple(LPCSTR triple) {
+    try {
+      m_targetTriple = triple;
+      return S_OK;
+    }
+    CATCH_CPP_RETURN_HRESULT();
+  }
+
+  // Get the name of the dxil intrinsic function.
+  std::string GetIntrinsicName(UINT opcode) {
+    LPCSTR pName = "";
+    for (IDxcIntrinsicTable *table : m_intrinsicTables) {
+      if (SUCCEEDED(table->GetIntrinsicName(opcode, &pName))) {
+        return pName;
+      }
+    }
+
+      return "";
+  }
+
+  // Get the dxil opcode for the extension opcode if one exists.
+  // Return true if the opcode was mapped successfully.
+  bool GetDxilOpCode(UINT opcode, UINT &dxilOpcode) {
+    for (IDxcIntrinsicTable *table : m_intrinsicTables) {
+      if (SUCCEEDED(table->GetDxilOpCode(opcode, &dxilOpcode))) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  // Result of validating a semantic define.
+  // Stores any warning or error messages produced by the validator.
+  // Successful validation means that there are no warning or error messages.
+  struct SemanticDefineValidationResult {
+    std::string Warning;
+    std::string Error;
+
+    bool HasError() { return Error.size() > 0; }
+    bool HasWarning() { return Warning.size() > 0; }
+
+    static SemanticDefineValidationResult Success() {
+      return SemanticDefineValidationResult();
+    }
+  };
+
+  // Use the contained semantice define validator to validate the given semantic define.
+  SemanticDefineValidationResult ValidateSemanticDefine(const std::string &name, const std::string &value) {
+    if (!m_semanticDefineValidator)
+      return SemanticDefineValidationResult::Success();
+
+    // Blobs for getting restul from validator. Strings for returning results to caller.
+    CComPtr<IDxcBlobEncoding> pError;
+    CComPtr<IDxcBlobEncoding> pWarning;
+    std::string error;
+    std::string warning;
+
+    // Run semantic define validator.
+    HRESULT result = m_semanticDefineValidator->GetSemanticDefineWarningsAndErrors(name.c_str(), value.c_str(), &pWarning, &pError);
+
+
+    if (FAILED(result)) {
+      // Failure indicates it was not able to even run validation so
+      // we cannot say whether the define is invalid or not. Return a
+      // generic error message about failure to run the valiadator.
+      error = "failed to run semantic define validator for: ";
+      error.append(name); error.append("="); error.append(value);
+      return SemanticDefineValidationResult{ warning, error };
+    }
+
+    // Define a  little function to convert encoded blob into a string.
+    auto GetErrorAsString = [&name](const CComPtr<IDxcBlobEncoding> &pBlobString) -> std::string {
+      CComPtr<IDxcBlobUtf8> pUTF8BlobStr;
+      if (SUCCEEDED(hlsl::DxcGetBlobAsUtf8(pBlobString, DxcGetThreadMallocNoRef(), &pUTF8BlobStr)))
+        return std::string(pUTF8BlobStr->GetStringPointer(), pUTF8BlobStr->GetStringLength());
+      else
+        return std::string("invalid semantic define " + name);
+    };
+
+    // Check to see if any warnings or errors were produced.
+    if (pError && pError->GetBufferSize()) {
+      error = GetErrorAsString(pError);
+    }
+    if (pWarning && pWarning->GetBufferSize()) {
+      warning = GetErrorAsString(pWarning);
+    }
+
+    return SemanticDefineValidationResult{ warning, error };
+  }
+
+  DxcLangExtensionsCommonHelper()
+      : m_semanticDefineMetaDataName("hlsl.semdefs"),
+        m_targetTriple("dxil-ms-dx") {}
+};
+
+// Use this macro to embed an implementation that will delegate to a field.
+// Note that QueryInterface still needs to return the vtable.
+#define DXC_LANGEXTENSIONS_HELPER_IMPL(_helper_field_) \
+  HRESULT STDMETHODCALLTYPE RegisterIntrinsicTable(_In_ IDxcIntrinsicTable *pTable) override { \
+    DxcThreadMalloc TM(m_pMalloc); \
+    return (_helper_field_).RegisterIntrinsicTable(pTable); \
+  } \
+  HRESULT STDMETHODCALLTYPE RegisterSemanticDefine(LPCWSTR name) override { \
+    DxcThreadMalloc TM(m_pMalloc); \
+    return (_helper_field_).RegisterSemanticDefine(name); \
+  } \
+  HRESULT STDMETHODCALLTYPE RegisterSemanticDefineExclusion(LPCWSTR name) override { \
+    DxcThreadMalloc TM(m_pMalloc); \
+    return (_helper_field_).RegisterSemanticDefineExclusion(name); \
+  } \
+  HRESULT STDMETHODCALLTYPE RegisterDefine(LPCWSTR name) override { \
+    DxcThreadMalloc TM(m_pMalloc); \
+    return (_helper_field_).RegisterDefine(name); \
+  } \
+  HRESULT STDMETHODCALLTYPE SetSemanticDefineValidator(_In_ IDxcSemanticDefineValidator* pValidator) override { \
+    DxcThreadMalloc TM(m_pMalloc); \
+    return (_helper_field_).SetSemanticDefineValidator(pValidator); \
+  } \
+  HRESULT STDMETHODCALLTYPE SetSemanticDefineMetaDataName(LPCSTR name) override { \
+    DxcThreadMalloc TM(m_pMalloc); \
+    return (_helper_field_).SetSemanticDefineMetaDataName(name); \
+  } \
+  HRESULT STDMETHODCALLTYPE SetTargetTriple(LPCSTR name)  override { \
+    DxcThreadMalloc TM(m_pMalloc);                                   \
+    return (_helper_field_).SetTargetTriple(name);                   \
+  } \
+
+} // namespace hlsl

+ 5 - 195
include/dxc/Support/DxcLangExtensionsHelper.h

@@ -14,6 +14,7 @@
 
 #include "dxc/Support/Unicode.h"
 #include "dxc/Support/FileIOHelper.h"
+#include "dxc/Support/DxcLangExtensionsCommonHelper.h"
 #include <vector>
 
 namespace llvm {
@@ -27,183 +28,22 @@ class CompilerInstance;
 
 namespace hlsl {
 
-class DxcLangExtensionsHelper : public DxcLangExtensionsHelperApply {
+class DxcLangExtensionsHelper : public DxcLangExtensionsCommonHelper, public DxcLangExtensionsHelperApply {
 private:
-  llvm::SmallVector<std::string, 2> m_semanticDefines;
-  llvm::SmallVector<std::string, 2> m_semanticDefineExclusions;
-  llvm::SmallVector<std::string, 2> m_defines;
-  llvm::SmallVector<CComPtr<IDxcIntrinsicTable>, 2> m_intrinsicTables;
-  CComPtr<IDxcSemanticDefineValidator> m_semanticDefineValidator;
-  std::string m_semanticDefineMetaDataName;
-
-  HRESULT STDMETHODCALLTYPE RegisterIntoVector(LPCWSTR name, llvm::SmallVector<std::string, 2>& here)
-  {
-    try {
-      IFTPTR(name);
-      std::string s;
-      if (!Unicode::UTF16ToUTF8String(name, &s)) {
-        throw ::hlsl::Exception(E_INVALIDARG);
-      }
-      here.push_back(s);
-      return S_OK;
-    }
-    CATCH_CPP_RETURN_HRESULT();
-  }
 
 public:
-  const llvm::SmallVector<std::string, 2>& GetSemanticDefines() const { return m_semanticDefines; }
-  const llvm::SmallVector<std::string, 2>& GetSemanticDefineExclusions() const { return m_semanticDefineExclusions; }
-  const llvm::SmallVector<std::string, 2>& GetDefines() const { return m_defines; }
-  llvm::SmallVector<CComPtr<IDxcIntrinsicTable>, 2>& GetIntrinsicTables(){ return m_intrinsicTables; }
-  const std::string &GetSemanticDefineMetadataName() { return m_semanticDefineMetaDataName; }
-
-  HRESULT STDMETHODCALLTYPE RegisterSemanticDefine(LPCWSTR name)
-  {
-    return RegisterIntoVector(name, m_semanticDefines);
-  }
-
-  HRESULT STDMETHODCALLTYPE RegisterSemanticDefineExclusion(LPCWSTR name)
-  {
-    return RegisterIntoVector(name, m_semanticDefineExclusions);
-  }
-
-  HRESULT STDMETHODCALLTYPE RegisterDefine(LPCWSTR name)
-  {
-    return RegisterIntoVector(name, m_defines);
-  }
-
-  HRESULT STDMETHODCALLTYPE RegisterIntrinsicTable(_In_ IDxcIntrinsicTable* pTable)
-  {
-    try {
-      IFTPTR(pTable);
-      LPCSTR tableName = nullptr;
-      IFT(pTable->GetTableName(&tableName));
-      IFTPTR(tableName);
-      IFTARG(strcmp(tableName, "op") != 0);   // "op" is reserved for builtin intrinsics
-      for (auto &&table : m_intrinsicTables) {
-        LPCSTR otherTableName = nullptr;
-        IFT(table->GetTableName(&otherTableName));
-        IFTPTR(otherTableName);
-        IFTARG(strcmp(tableName, otherTableName) != 0); // Added a duplicate table name
-      }
-      m_intrinsicTables.push_back(pTable);
-      return S_OK;
-    }
-    CATCH_CPP_RETURN_HRESULT();
-  }
-
-  // Set the validator used to validate semantic defines.
-  // Only one validator stored and used to run validation.
-  HRESULT STDMETHODCALLTYPE SetSemanticDefineValidator(_In_ IDxcSemanticDefineValidator* pValidator) {
-    if (pValidator == nullptr)
-      return E_POINTER;
-
-    m_semanticDefineValidator = pValidator;
-    return S_OK;
-  }
-
-  HRESULT STDMETHODCALLTYPE SetSemanticDefineMetaDataName(LPCSTR name) {
-    try {
-      m_semanticDefineMetaDataName = name;
-      return S_OK;
-    }
-    CATCH_CPP_RETURN_HRESULT();
-  }
-
-  // Get the name of the dxil intrinsic function.
-  std::string GetIntrinsicName(UINT opcode) {
-    LPCSTR pName = "";
-    for (IDxcIntrinsicTable *table : m_intrinsicTables) {
-      if (SUCCEEDED(table->GetIntrinsicName(opcode, &pName))) {
-        return pName;
-      }
-    }
-
-      return "";
-  }
-
-  // Get the dxil opcode for the extension opcode if one exists.
-  // Return true if the opcode was mapped successfully.
-  bool GetDxilOpCode(UINT opcode, UINT &dxilOpcode) {
-    for (IDxcIntrinsicTable *table : m_intrinsicTables) {
-      if (SUCCEEDED(table->GetDxilOpCode(opcode, &dxilOpcode))) {
-        return true;
-      }
-    }
-    return false;
-  }
-
-  // Result of validating a semantic define.
-  // Stores any warning or error messages produced by the validator.
-  // Successful validation means that there are no warning or error messages.
-  struct SemanticDefineValidationResult {
-    std::string Warning;
-    std::string Error;
-
-    bool HasError() { return Error.size() > 0; }
-    bool HasWarning() { return Warning.size() > 0; }
-
-    static SemanticDefineValidationResult Success() {
-      return SemanticDefineValidationResult();
-    }
-  };
-
-  // Use the contained semantice define validator to validate the given semantic define.
-  SemanticDefineValidationResult ValidateSemanticDefine(const std::string &name, const std::string &value) {
-    if (!m_semanticDefineValidator)
-      return SemanticDefineValidationResult::Success();
-
-    // Blobs for getting restul from validator. Strings for returning results to caller.
-    CComPtr<IDxcBlobEncoding> pError;
-    CComPtr<IDxcBlobEncoding> pWarning;
-    std::string error;
-    std::string warning;
-
-    // Run semantic define validator.
-    HRESULT result = m_semanticDefineValidator->GetSemanticDefineWarningsAndErrors(name.c_str(), value.c_str(), &pWarning, &pError);
-
-
-    if (FAILED(result)) {
-      // Failure indicates it was not able to even run validation so
-      // we cannot say whether the define is invalid or not. Return a
-      // generic error message about failure to run the valiadator.
-      error = "failed to run semantic define validator for: ";
-      error.append(name); error.append("="); error.append(value);
-      return SemanticDefineValidationResult{ warning, error };
-    }
-
-    // Define a  little function to convert encoded blob into a string.
-    auto GetErrorAsString = [&name](const CComPtr<IDxcBlobEncoding> &pBlobString) -> std::string {
-      CComPtr<IDxcBlobUtf8> pUTF8BlobStr;
-      if (SUCCEEDED(hlsl::DxcGetBlobAsUtf8(pBlobString, DxcGetThreadMallocNoRef(), &pUTF8BlobStr)))
-        return std::string(pUTF8BlobStr->GetStringPointer(), pUTF8BlobStr->GetStringLength());
-      else
-        return std::string("invalid semantic define " + name);
-    };
-
-    // Check to see if any warnings or errors were produced.
-    if (pError && pError->GetBufferSize()) {
-      error = GetErrorAsString(pError);
-    }
-    if (pWarning && pWarning->GetBufferSize()) {
-      warning = GetErrorAsString(pWarning);
-    }
-
-    return SemanticDefineValidationResult{ warning, error };
-  }
-
   void SetupSema(clang::Sema &S) override {
     clang::ExternalASTSource *astSource = S.getASTContext().getExternalSource();
     if (clang::ExternalSemaSource *externalSema =
             llvm::dyn_cast_or_null<clang::ExternalSemaSource>(astSource)) {
-      for (auto &&table : m_intrinsicTables) {
+      for (auto &&table : GetIntrinsicTables()) {
         hlsl::RegisterIntrinsicTable(externalSema, table);
       }
     }
   }
 
   void SetupPreprocessorOptions(clang::PreprocessorOptions &PPOpts) override {
-    for (const auto & define : m_defines) {
+    for (const auto &define : GetDefines()) {
       PPOpts.addMacroDef(llvm::StringRef(define.c_str()));
     }
   }
@@ -212,39 +52,9 @@ public:
     return this;
   }
  
-  DxcLangExtensionsHelper()
-  : m_semanticDefineMetaDataName("hlsl.semdefs")
-  {}
+  DxcLangExtensionsHelper() {}
 };
 
-// Use this macro to embed an implementation that will delegate to a field.
-// Note that QueryInterface still needs to return the vtable.
-#define DXC_LANGEXTENSIONS_HELPER_IMPL(_helper_field_) \
-  HRESULT STDMETHODCALLTYPE RegisterIntrinsicTable(_In_ IDxcIntrinsicTable *pTable) override { \
-    DxcThreadMalloc TM(m_pMalloc); \
-    return (_helper_field_).RegisterIntrinsicTable(pTable); \
-  } \
-  HRESULT STDMETHODCALLTYPE RegisterSemanticDefine(LPCWSTR name) override { \
-    DxcThreadMalloc TM(m_pMalloc); \
-    return (_helper_field_).RegisterSemanticDefine(name); \
-  } \
-  HRESULT STDMETHODCALLTYPE RegisterSemanticDefineExclusion(LPCWSTR name) override { \
-    DxcThreadMalloc TM(m_pMalloc); \
-    return (_helper_field_).RegisterSemanticDefineExclusion(name); \
-  } \
-  HRESULT STDMETHODCALLTYPE RegisterDefine(LPCWSTR name) override { \
-    DxcThreadMalloc TM(m_pMalloc); \
-    return (_helper_field_).RegisterDefine(name); \
-  } \
-  HRESULT STDMETHODCALLTYPE SetSemanticDefineValidator(_In_ IDxcSemanticDefineValidator* pValidator) override { \
-    DxcThreadMalloc TM(m_pMalloc); \
-    return (_helper_field_).SetSemanticDefineValidator(pValidator); \
-  } \
-  HRESULT STDMETHODCALLTYPE SetSemanticDefineMetaDataName(LPCSTR name) override { \
-    DxcThreadMalloc TM(m_pMalloc); \
-    return (_helper_field_).SetSemanticDefineMetaDataName(name); \
-  } \
-
 // A parsed semantic define is a semantic define that has actually been
 // parsed by the compiler. It has a name (required), a value (could be
 // the empty string), and a location. We use an encoded clang::SourceLocation

+ 7 - 0
include/dxc/dxcapi.internal.h

@@ -184,6 +184,13 @@ public:
   DECLARE_CROSS_PLATFORM_UUIDOF(IDxcLangExtensions)
 };
 
+struct __declspec(uuid("2490C368-89EE-4491-A4B2-C6547B6C9381"))
+IDxcLangExtensions2 : public IDxcLangExtensions {
+public:
+  virtual HRESULT STDMETHODCALLTYPE SetTargetTriple(LPCSTR name) = 0;
+  DECLARE_CROSS_PLATFORM_UUIDOF(IDxcLangExtensions2)
+};
+
 struct __declspec(uuid("454b764f-3549-475b-958c-a7a6fcd05fbc"))
 IDxcSystemAccess : public IUnknown
 {

+ 3 - 2
lib/HLSL/DxilValidation.cpp

@@ -5579,8 +5579,9 @@ void GetValidationVersion(_Out_ unsigned *pMajor, _Out_ unsigned *pMinor) {
   // VALRULE-TEXT:END
 }
 
-_Use_decl_annotations_ HRESULT
-ValidateDxilModule(llvm::Module *pModule, llvm::Module *pDebugModule) {
+_Use_decl_annotations_ HRESULT ValidateDxilModule(
+    llvm::Module *pModule,
+    llvm::Module *pDebugModule) {
   DxilModule *pDxilModule = DxilModule::TryGetDxilModule(pModule);
   if (!pDxilModule) {
     return DXC_E_IR_VERIFICATION_FAILED;

+ 6 - 1
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -68,6 +68,7 @@ using namespace hlsl;
 using std::string;
 
 DEFINE_CROSS_PLATFORM_UUIDOF(IDxcLangExtensions)
+DEFINE_CROSS_PLATFORM_UUIDOF(IDxcLangExtensions2)
 
 // This declaration is used for the locally-linked validator.
 HRESULT CreateDxcValidator(_In_ REFIID riid, _Out_ LPVOID *ppv);
@@ -392,7 +393,7 @@ static void CreateDefineStrings(
 }
 
 class DxcCompiler : public IDxcCompiler3,
-                    public IDxcLangExtensions,
+                    public IDxcLangExtensions2,
                     public IDxcContainerEvent,
 #ifdef SUPPORT_QUERY_GIT_COMMIT_INFO
                     public IDxcVersionInfo2
@@ -428,6 +429,7 @@ public:
     HRESULT hr = DoBasicQueryInterface<
       IDxcCompiler3,
       IDxcLangExtensions,
+      IDxcLangExtensions2,
       IDxcContainerEvent,
       IDxcVersionInfo
 #ifdef SUPPORT_QUERY_GIT_COMMIT_INFO
@@ -997,6 +999,9 @@ public:
     // Setup a compiler instance.
     std::shared_ptr<TargetOptions> targetOptions(new TargetOptions);
     targetOptions->Triple = "dxil-ms-dx";
+    if (helper) {
+      targetOptions->Triple = helper->GetTargetTriple();
+    }
     targetOptions->DescriptionString = Opts.Enable16BitTypes
       ? hlsl::DXIL::kNewLayoutString
       : hlsl::DXIL::kLegacyLayoutString;

+ 3 - 2
tools/clang/tools/libclang/dxcisenseimpl.h

@@ -188,7 +188,7 @@ public:
       _Outptr_result_nullonfailure_ IDxcTranslationUnit** pTranslationUnit) override;
 };
 
-class DxcIntelliSense : public IDxcIntelliSense, public IDxcLangExtensions {
+class DxcIntelliSense : public IDxcIntelliSense, public IDxcLangExtensions2 {
 private:
   DXC_MICROCOM_TM_REF_FIELDS()
   hlsl::DxcLangExtensionsHelper m_langHelper;
@@ -198,7 +198,8 @@ public:
   DXC_LANGEXTENSIONS_HELPER_IMPL(m_langHelper);
 
   HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) override {
-    return DoBasicQueryInterface<IDxcIntelliSense, IDxcLangExtensions>(
+    return DoBasicQueryInterface<IDxcIntelliSense, IDxcLangExtensions,
+                                 IDxcLangExtensions2>(
         this, iid, ppvObject);
   }
 

+ 6 - 3
tools/clang/tools/libclang/dxcrewriteunused.cpp

@@ -922,7 +922,7 @@ HRESULT DoSimpleReWrite(_In_ DxcLangExtensionsHelper *pHelper,
   return S_OK;
 }
 
-class DxcRewriter : public IDxcRewriter2, public IDxcLangExtensions {
+class DxcRewriter : public IDxcRewriter2, public IDxcLangExtensions2 {
 private:
   DXC_MICROCOM_TM_REF_FIELDS()
   DxcLangExtensionsHelper m_langExtensionsHelper;
@@ -931,8 +931,11 @@ public:
   DXC_MICROCOM_TM_CTOR(DxcRewriter)
   DXC_LANGEXTENSIONS_HELPER_IMPL(m_langExtensionsHelper)
 
-  HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) override {
-    return DoBasicQueryInterface<IDxcRewriter2, IDxcRewriter, IDxcLangExtensions>(this, iid, ppvObject);
+  HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid,
+                                           void **ppvObject) override {
+    return DoBasicQueryInterface<IDxcRewriter2, IDxcRewriter,
+                                 IDxcLangExtensions, IDxcLangExtensions2>(
+        this, iid, ppvObject);
   }
 
   HRESULT STDMETHODCALLTYPE RemoveUnusedGlobals(_In_ IDxcBlobEncoding *pSource,

+ 18 - 1
tools/clang/unittests/HLSL/ExtensionTest.cpp

@@ -393,6 +393,10 @@ public:
   void SetSemanticDefineMetaDataName(const char *name) {
     VERIFY_SUCCEEDED(pLangExtensions->SetSemanticDefineMetaDataName("test.defs"));
   }
+  void SetTargetTriple(const char *name) {
+    VERIFY_SUCCEEDED(
+        pLangExtensions->SetTargetTriple(name));
+  }
   void RegisterIntrinsicTable(IDxcIntrinsicTable *table) {
     pTestIntrinsicTable = table;
     VERIFY_SUCCEEDED(pLangExtensions->RegisterIntrinsicTable(pTestIntrinsicTable));
@@ -421,7 +425,7 @@ public:
 
   dxc::DxcDllSupport &m_dllSupport;
   CComPtr<IDxcCompiler> pCompiler;
-  CComPtr<IDxcLangExtensions> pLangExtensions;
+  CComPtr<IDxcLangExtensions2> pLangExtensions;
   CComPtr<IDxcBlobEncoding> pCodeBlob;
   CComPtr<IDxcOperationResult> pCompileResult;
   CComPtr<IDxcSemanticDefineValidator> pTestSemanticDefineValidator;
@@ -449,6 +453,7 @@ public:
   TEST_METHOD(DefineValidationWarning)
   TEST_METHOD(DefineNoValidatorOk)
   TEST_METHOD(DefineFromMacro)
+  TEST_METHOD(TargetTriple)
   TEST_METHOD(IntrinsicWhenAvailableThenUsed)
   TEST_METHOD(CustomIntrinsicName)
   TEST_METHOD(NoLowering)
@@ -603,6 +608,18 @@ TEST_F(ExtensionTest, DefineFromMacro) {
     disassembly.find("!{!\"FOO\", !\"1\"}"));
 }
 
+TEST_F(ExtensionTest, TargetTriple) {
+  Compiler c(m_dllSupport);
+  c.SetTargetTriple("dxil-ms-win32");
+  c.Compile("float4 main() : SV_Target {\n"
+            "  return 0;\n"
+            "}\n",
+            {L"/Vd"}, {});
+
+  std::string disassembly = c.Disassemble();
+  // Check the triple is updated.
+  VERIFY_IS_TRUE(disassembly.npos != disassembly.find("dxil-ms-win32"));
+}
 
 TEST_F(ExtensionTest, IntrinsicWhenAvailableThenUsed) {
   Compiler c(m_dllSupport);