Browse Source

[spirv] Adding test fixture for HLSL to SPIRV codegen flow (#383)

This change enables CodeGenSPIRV test flow.

New tests can be added by simply adding the test file to
`tools/clang/test/CodeGenSPIRV/` directory and running:

```cpp
TEST_F(WholeFileTest, NewTest) {
  EXPECT_TRUE(runWholeFileTest("new-test-name"));
}
```

The input file with the format described in `WholeFileCheck.h` is read in;
the HLSL portion is passed to the compiler with SPIR-V codegen enabled.
The resulting SPIR-V binary is disassembled and compared to the expected
result in the input file.
Ehsan 8 years ago
parent
commit
3e84980892

+ 1 - 1
appveyor.yml

@@ -27,7 +27,7 @@ test_script:
     powershell utils\appveyor\appveyor_test.ps1
 # Running SPIR-V tests
 - cmd: >-
-    %HLSL_BLD_DIR%\%CONFIGURATION%\bin\clang-spirv-tests.exe
+    %HLSL_BLD_DIR%\%CONFIGURATION%\bin\clang-spirv-tests.exe --spirv-test-root %HLSL_SRC_DIR%\tools\clang\test\CodeGenSPIRV
 
 notifications:
 - provider: GitHubPullRequest

+ 2 - 0
external/CMakeLists.txt

@@ -26,6 +26,8 @@ if (${ENABLE_SPIRV_CODEGEN})
   endif()
   if (NOT TARGET SPIRV-Tools)
     message(FATAL_ERROR "SPIRV-Tools was not found - required for SPIR-V codegen")
+  else()
+    set(SPIRV_TOOLS_INCLUDE_DIR ${SPIRV-Tools_SOURCE_DIR}/include PARENT_SCOPE)
   endif()
 
   set(SPIRV_DEP_TARGETS

+ 1 - 1
tools/clang/include/clang/SPIRV/String.h

@@ -23,7 +23,7 @@ std::vector<uint32_t> encodeSPIRVString(std::string s);
 /// \brief Reinterprets the given vector of 32-bit words as a string.
 /// Expectes that the words represent a NULL-terminated string.
 /// It follows the SPIR-V string encoding requirements.
-std::string decodeSPIRVString(std::vector<uint32_t> &vec);
+std::string decodeSPIRVString(const std::vector<uint32_t> &vec);
 
 } // end namespace string
 } // end namespace spirv

+ 33 - 23
tools/clang/include/clang/SPIRV/Type.h

@@ -35,8 +35,14 @@ class SPIRVContext;
 /// context).
 class Type {
 public:
+  using DecorationSet = std::set<const Decoration *>;
+
   spv::Op getOpcode() const { return opcode; }
   const std::vector<uint32_t> &getArgs() const { return args; }
+  const std::set<const Decoration *> &getDecorations() const {
+    return decorations;
+  }
+  bool hasDecoration(const Decoration *) const;
 
   bool isBooleanType() const;
   bool isIntegerType() const;
@@ -51,10 +57,7 @@ public:
   bool isCompositeType() const;
   bool isImageType() const;
 
-  static const Type *getType(SPIRVContext &ctx, spv::Op op,
-                             std::vector<uint32_t> arg = {},
-                             std::set<const Decoration *> decs = {});
-
+  // Scalar types do not take any decorations.
   static const Type *getVoid(SPIRVContext &ctx);
   static const Type *getBool(SPIRVContext &ctx);
   static const Type *getInt8(SPIRVContext &ctx);
@@ -68,39 +71,46 @@ public:
   static const Type *getFloat16(SPIRVContext &ctx);
   static const Type *getFloat32(SPIRVContext &ctx);
   static const Type *getFloat64(SPIRVContext &ctx);
-  static const Type *getVector(SPIRVContext &ctx, uint32_t component_type,
-                               uint32_t vec_size);
   static const Type *getVec2(SPIRVContext &ctx, uint32_t component_type);
   static const Type *getVec3(SPIRVContext &ctx, uint32_t component_type);
   static const Type *getVec4(SPIRVContext &ctx, uint32_t component_type);
   static const Type *getMatrix(SPIRVContext &ctx, uint32_t column_type_id,
                                uint32_t column_count);
+
   static const Type *
   getImage(SPIRVContext &ctx, uint32_t sampled_type, spv::Dim dim,
            uint32_t depth, uint32_t arrayed, uint32_t ms, uint32_t sampled,
            spv::ImageFormat image_format,
-           llvm::Optional<spv::AccessQualifier> access_qualifier);
-  static const Type *getSampler(SPIRVContext &ctx);
-  static const Type *getSampledImage(SPIRVContext &ctx, uint32_t imag_type_id);
+           llvm::Optional<spv::AccessQualifier> access_qualifier = llvm::None,
+           DecorationSet decs = {});
+  static const Type *getSampler(SPIRVContext &ctx, DecorationSet decs = {});
+  static const Type *getSampledImage(SPIRVContext &ctx, uint32_t imag_type_id,
+                                     DecorationSet decs = {});
   static const Type *getArray(SPIRVContext &ctx, uint32_t component_type_id,
-                              uint32_t len_id);
+                              uint32_t len_id, DecorationSet decs = {});
   static const Type *getRuntimeArray(SPIRVContext &ctx,
-                                     uint32_t component_type_id);
+                                     uint32_t component_type_id,
+                                     DecorationSet decs = {});
   static const Type *getStruct(SPIRVContext &ctx,
-                               std::initializer_list<uint32_t> members);
-  static const Type *getOpaque(SPIRVContext &ctx, std::string name);
-  static const Type *getTyePointer(SPIRVContext &ctx,
-                                   spv::StorageClass storage_class,
-                                   uint32_t type);
+                               std::initializer_list<uint32_t> members,
+                               DecorationSet d = {});
+  static const Type *getOpaque(SPIRVContext &ctx, std::string name,
+                               DecorationSet decs = {});
+  static const Type *getPointer(SPIRVContext &ctx,
+                                spv::StorageClass storage_class, uint32_t type,
+                                DecorationSet decs = {});
   static const Type *getFunction(SPIRVContext &ctx, uint32_t return_type,
-                                 std::initializer_list<uint32_t> params);
-  static const Type *getEvent(SPIRVContext &ctx);
-  static const Type *getDeviceEvent(SPIRVContext &ctx);
-  static const Type *getQueue(SPIRVContext &ctx);
-  static const Type *getPipe(SPIRVContext &ctx, spv::AccessQualifier qualifier);
+                                 std::initializer_list<uint32_t> params,
+                                 DecorationSet decs = {});
+  static const Type *getEvent(SPIRVContext &ctx, DecorationSet decs = {});
+  static const Type *getDeviceEvent(SPIRVContext &ctx, DecorationSet decs = {});
+  static const Type *getReserveId(SPIRVContext &ctx, DecorationSet decs = {});
+  static const Type *getQueue(SPIRVContext &ctx, DecorationSet decs = {});
+  static const Type *getPipe(SPIRVContext &ctx, spv::AccessQualifier qualifier,
+                             DecorationSet decs = {});
   static const Type *getForwardPointer(SPIRVContext &ctx, uint32_t pointer_type,
-                                       spv::StorageClass storage_class);
-
+                                       spv::StorageClass storage_class,
+                                       DecorationSet decs = {});
   bool operator==(const Type &other) const {
     return opcode == other.opcode && args == other.args &&
            decorations == other.decorations;

+ 2 - 2
tools/clang/lib/SPIRV/String.cpp

@@ -39,10 +39,10 @@ std::vector<uint32_t> encodeSPIRVString(std::string s) {
 /// \brief Reinterprets the given vector of 32-bit words as a string.
 /// Expectes that the words represent a NULL-terminated string.
 /// Assumes Little Endian architecture.
-std::string decodeSPIRVString(std::vector<uint32_t>& vec) {
+std::string decodeSPIRVString(const std::vector<uint32_t> &vec) {
   std::string result;
   if (!vec.empty()) {
-    result = std::string(reinterpret_cast<const char*>(vec.data()));
+    result = std::string(reinterpret_cast<const char *>(vec.data()));
   }
   return result;
 }

+ 54 - 48
tools/clang/lib/SPIRV/Type.cpp

@@ -22,11 +22,11 @@ const Type *Type::getUniqueType(SPIRVContext &context, const Type &t) {
   return context.registerType(t);
 }
 const Type *Type::getVoid(SPIRVContext &context) {
-  Type t = Type(spv::Op::OpTypeVoid);
+  Type t = Type(spv::Op::OpTypeVoid, {});
   return getUniqueType(context, t);
 }
 const Type *Type::getBool(SPIRVContext &context) {
-  Type t = Type(spv::Op::OpTypeBool);
+  Type t = Type(spv::Op::OpTypeBool, {});
   return getUniqueType(context, t);
 }
 const Type *Type::getInt8(SPIRVContext &context) {
@@ -73,19 +73,17 @@ const Type *Type::getFloat64(SPIRVContext &context) {
   Type t = Type(spv::Op::OpTypeFloat, {64});
   return getUniqueType(context, t);
 }
-const Type *Type::getVector(SPIRVContext &context, uint32_t component_type,
-                            uint32_t vec_size) {
-  Type t = Type(spv::Op::OpTypeVector, {component_type, vec_size});
-  return getUniqueType(context, t);
-}
 const Type *Type::getVec2(SPIRVContext &context, uint32_t component_type) {
-  return getVector(context, component_type, 2u);
+  Type t = Type(spv::Op::OpTypeVector, {component_type, 2u});
+  return getUniqueType(context, t);
 }
 const Type *Type::getVec3(SPIRVContext &context, uint32_t component_type) {
-  return getVector(context, component_type, 3u);
+  Type t = Type(spv::Op::OpTypeVector, {component_type, 3u});
+  return getUniqueType(context, t);
 }
 const Type *Type::getVec4(SPIRVContext &context, uint32_t component_type) {
-  return getVector(context, component_type, 4u);
+  Type t = Type(spv::Op::OpTypeVector, {component_type, 4u});
+  return getUniqueType(context, t);
 }
 const Type *Type::getMatrix(SPIRVContext &context, uint32_t column_type_id,
                             uint32_t column_count) {
@@ -96,88 +94,92 @@ const Type *
 Type::getImage(SPIRVContext &context, uint32_t sampled_type, spv::Dim dim,
                uint32_t depth, uint32_t arrayed, uint32_t ms, uint32_t sampled,
                spv::ImageFormat image_format,
-               llvm::Optional<spv::AccessQualifier> access_qualifier) {
+               llvm::Optional<spv::AccessQualifier> access_qualifier,
+               DecorationSet d) {
   std::vector<uint32_t> args = {
       sampled_type, uint32_t(dim),         depth, arrayed, ms,
       sampled,      uint32_t(image_format)};
   if (access_qualifier.hasValue()) {
     args.push_back(static_cast<uint32_t>(access_qualifier.getValue()));
   }
-  Type t = Type(spv::Op::OpTypeImage, args);
+  Type t = Type(spv::Op::OpTypeImage, args, d);
   return getUniqueType(context, t);
 }
-const Type *Type::getSampler(SPIRVContext &context) {
-  Type t = Type(spv::Op::OpTypeSampler);
+const Type *Type::getSampler(SPIRVContext &context, DecorationSet d) {
+  Type t = Type(spv::Op::OpTypeSampler, {}, d);
   return getUniqueType(context, t);
 }
-const Type *Type::getSampledImage(SPIRVContext &context,
-                                  uint32_t image_type_id) {
-  Type t = Type(spv::Op::OpTypeSampledImage, {image_type_id});
+const Type *Type::getSampledImage(SPIRVContext &context, uint32_t image_type_id,
+                                  DecorationSet d) {
+  Type t = Type(spv::Op::OpTypeSampledImage, {image_type_id}, d);
   return getUniqueType(context, t);
 }
 const Type *Type::getArray(SPIRVContext &context, uint32_t component_type_id,
-                           uint32_t len_id) {
-  Type t = Type(spv::Op::OpTypeArray, {component_type_id, len_id});
+                           uint32_t len_id, DecorationSet d) {
+  Type t = Type(spv::Op::OpTypeArray, {component_type_id, len_id}, d);
   return getUniqueType(context, t);
 }
 const Type *Type::getRuntimeArray(SPIRVContext &context,
-                                  uint32_t component_type_id) {
-  Type t = Type(spv::Op::OpTypeRuntimeArray, {component_type_id});
+                                  uint32_t component_type_id, DecorationSet d) {
+  Type t = Type(spv::Op::OpTypeRuntimeArray, {component_type_id}, d);
   return getUniqueType(context, t);
 }
 const Type *Type::getStruct(SPIRVContext &context,
-                            std::initializer_list<uint32_t> members) {
-  Type t = Type(spv::Op::OpTypeStruct, std::vector<uint32_t>(members));
+                            std::initializer_list<uint32_t> members,
+                            DecorationSet d) {
+  Type t = Type(spv::Op::OpTypeStruct, std::vector<uint32_t>(members), d);
   return getUniqueType(context, t);
 }
-const Type *Type::getOpaque(SPIRVContext &context, std::string name) {
-  Type t = Type(spv::Op::OpTypeOpaque, string::encodeSPIRVString(name));
+const Type *Type::getOpaque(SPIRVContext &context, std::string name,
+                            DecorationSet d) {
+  Type t = Type(spv::Op::OpTypeOpaque, string::encodeSPIRVString(name), d);
   return getUniqueType(context, t);
 }
-const Type *Type::getTyePointer(SPIRVContext &context,
-                                spv::StorageClass storage_class,
-                                uint32_t type) {
+const Type *Type::getPointer(SPIRVContext &context,
+                             spv::StorageClass storage_class, uint32_t type,
+                             DecorationSet d) {
   Type t = Type(spv::Op::OpTypePointer,
-                {static_cast<uint32_t>(storage_class), type});
+                {static_cast<uint32_t>(storage_class), type}, d);
   return getUniqueType(context, t);
 }
 const Type *Type::getFunction(SPIRVContext &context, uint32_t return_type,
-                              std::initializer_list<uint32_t> params) {
+                              std::initializer_list<uint32_t> params,
+                              DecorationSet d) {
   std::vector<uint32_t> args = {return_type};
   args.insert(args.end(), params.begin(), params.end());
-  Type t = Type(spv::Op::OpTypeFunction, args);
+  Type t = Type(spv::Op::OpTypeFunction, args, d);
+  return getUniqueType(context, t);
+}
+const Type *Type::getEvent(SPIRVContext &context, DecorationSet d) {
+  Type t = Type(spv::Op::OpTypeEvent, {}, d);
   return getUniqueType(context, t);
 }
-const Type *Type::getEvent(SPIRVContext &context) {
-  Type t = Type(spv::Op::OpTypeEvent);
+const Type *Type::getDeviceEvent(SPIRVContext &context, DecorationSet d) {
+  Type t = Type(spv::Op::OpTypeDeviceEvent, {}, d);
   return getUniqueType(context, t);
 }
-const Type *Type::getDeviceEvent(SPIRVContext &context) {
-  Type t = Type(spv::Op::OpTypeDeviceEvent);
+const Type *Type::getReserveId(SPIRVContext &context, DecorationSet d) {
+  Type t = Type(spv::Op::OpTypeReserveId, {}, d);
   return getUniqueType(context, t);
 }
-const Type *Type::getQueue(SPIRVContext &context) {
-  Type t = Type(spv::Op::OpTypeQueue);
+const Type *Type::getQueue(SPIRVContext &context, DecorationSet d) {
+  Type t = Type(spv::Op::OpTypeQueue, {}, d);
   return getUniqueType(context, t);
 }
-const Type *Type::getPipe(SPIRVContext &context,
-                          spv::AccessQualifier qualifier) {
-  Type t = Type(spv::Op::OpTypePipe, {static_cast<uint32_t>(qualifier)});
+const Type *Type::getPipe(SPIRVContext &context, spv::AccessQualifier qualifier,
+                          DecorationSet d) {
+  Type t = Type(spv::Op::OpTypePipe, {static_cast<uint32_t>(qualifier)}, d);
   return getUniqueType(context, t);
 }
 const Type *Type::getForwardPointer(SPIRVContext &context,
                                     uint32_t pointer_type,
-                                    spv::StorageClass storage_class) {
+                                    spv::StorageClass storage_class,
+                                    DecorationSet d) {
   Type t = Type(spv::Op::OpTypeForwardPointer,
-                {pointer_type, static_cast<uint32_t>(storage_class)});
-  return getUniqueType(context, t);
-}
-const Type *Type::getType(SPIRVContext &context, spv::Op op,
-                          std::vector<uint32_t> arg,
-                          std::set<const Decoration *> dec) {
-  Type t = Type(op, arg, dec);
+                {pointer_type, static_cast<uint32_t>(storage_class)}, d);
   return getUniqueType(context, t);
 }
+
 bool Type::isBooleanType() const { return opcode == spv::Op::OpTypeBool; }
 bool Type::isIntegerType() const { return opcode == spv::Op::OpTypeInt; }
 bool Type::isFloatType() const { return opcode == spv::Op::OpTypeFloat; }
@@ -195,5 +197,9 @@ bool Type::isCompositeType() const {
 }
 bool Type::isImageType() const { return opcode == spv::Op::OpTypeImage; }
 
+bool Type::hasDecoration(const Decoration *d) const {
+  return decorations.find(d) != decorations.end();
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 14 - 0
tools/clang/test/CodeGenSPIRV/basic.hlsl2spv

@@ -0,0 +1,14 @@
+// Comments 1
+// Comments 2
+// Run: %dxc -T ps_6_0 -E main
+void main()
+{
+
+}
+
+// CHECK-WHOLE-SPIR-V:
+// ; SPIR-V
+// ; Version: 1.0
+// ; Generator: Google spiregg; 0
+// ; Bound: 1
+// ; Schema: 0

+ 6 - 0
tools/clang/unittests/SPIRV/CMakeLists.txt

@@ -5,20 +5,26 @@ set(LLVM_LINK_COMPONENTS
   )
 
 add_clang_unittest(clang-spirv-tests
+  CodeGenSPIRVTest.cpp
   DecorationTest.cpp
   InstBuilderTest.cpp
   ModuleBuilderTest.cpp
   SPIRVContextTest.cpp
+  SPIRVTestOptions.cpp
   StructureTest.cpp
   TestMain.cpp
   StringTest.cpp
   TypeTest.cpp
+  WholeFileCheck.cpp
   )
 
 target_link_libraries(clang-spirv-tests
   clangCodeGen
   clangFrontend
   clangSPIRV
+  SPIRV-Tools
   )
 
+target_include_directories(clang-spirv-tests PRIVATE ${SPIRV_TOOLS_INCLUDE_DIR})
+
 set_output_directory(clang-spirv-tests ${LLVM_RUNTIME_OUTPUT_INTDIR} ${LLVM_LIBRARY_OUTPUT_INTDIR})

+ 22 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -0,0 +1,22 @@
+//===- unittests/SPIRV/CodeGenSPIRVTest.cpp ---- Run CodeGenSPIRV tests ---===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include <fstream>
+
+#include "WholeFileCheck.h"
+#include "gtest/gtest.h"
+
+TEST_F(WholeFileTest, BringUp) {
+  // Ideally all generated SPIR-V must be valid, but this currently fails with
+  // this error message: "No OpEntryPoint instruction was found...".
+  // TODO: change this test such that it does run validation.
+  bool success = runWholeFileTest("basic.hlsl2spv", /*generateHeader*/ true,
+                                  /*runValidation*/ false);
+  EXPECT_TRUE(success);
+}

+ 4 - 4
tools/clang/unittests/SPIRV/SPIRVContextTest.cpp

@@ -60,12 +60,12 @@ TEST(ValidateSPIRVContext, ValidateUniqueIdForUniqueAggregateType) {
   const auto mem_0_position =
       Decoration::getBuiltIn(ctx, spv::BuiltIn::Position, 0);
 
-  const Type *struct_1 = Type::getType(
-      ctx, spv::Op::OpTypeStruct, {intt_id, boolt_id},
+  const Type *struct_1 = Type::getStruct(
+      ctx, {intt_id, boolt_id},
       {relaxed, bufferblock, mem_0_offset, mem_1_offset, mem_0_position});
 
-  const Type *struct_2 = Type::getType(
-      ctx, spv::Op::OpTypeStruct, {intt_id, boolt_id},
+  const Type *struct_2 = Type::getStruct(
+      ctx, {intt_id, boolt_id},
       {relaxed, bufferblock, mem_0_offset, mem_1_offset, mem_0_position});
 
   const uint32_t struct_1_id = ctx.getResultIdForType(struct_1);

+ 25 - 0
tools/clang/unittests/SPIRV/SPIRVTestOptions.cpp

@@ -0,0 +1,25 @@
+//===- unittests/SPIRV/SpirvTestOptions.cpp ----- Test Options Init -------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines and initializes command line options that can be passed to
+// SPIR-V gtests.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVTestOptions.h"
+
+namespace clang {
+namespace spirv {
+namespace testOptions {
+
+std::string inputDataDir = "";
+
+} // namespace testOptions
+} // namespace spirv
+} // namespace clang

+ 37 - 0
tools/clang/unittests/SPIRV/SPIRVTestOptions.h

@@ -0,0 +1,37 @@
+//===- unittests/SPIRV/SpirvTestOptions.h ----- Command Line Options ------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the command line options that can be passed to SPIR-V
+// gtests. This file should be included in any test file that intends to use any
+// options.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_UNITTESTS_SPIRV_TEST_OPTIONS_H
+#define LLVM_CLANG_UNITTESTS_SPIRV_TEST_OPTIONS_H
+
+#include <string>
+
+namespace clang {
+namespace spirv {
+
+/// \brief Includes any command line options that may be passed to gtest for
+/// running the SPIR-V tests. New options should be added in this namespace.
+namespace testOptions {
+
+/// \brief Command line option that specifies the path to the directory that
+/// contains files that have the HLSL source code and expected SPIR-V code (used
+/// for the CodeGen test flow).
+extern std::string inputDataDir;
+
+} // namespace testOptions
+} // namespace spirv
+} // namespace clang
+
+#endif

+ 31 - 18
tools/clang/unittests/SPIRV/TestMain.cpp

@@ -7,16 +7,18 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Signals.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
+#include "llvm/Support/Signals.h"
+
+#include "SPIRVTestOptions.h"
+
 #if defined(_WIN32)
-# include <windows.h>
-# if defined(_MSC_VER)
-#   include <crtdbg.h>
-# endif
+#include <windows.h>
+#if defined(_MSC_VER)
+#include <crtdbg.h>
+#endif
 #endif
 
 const char *TestMainArgv0;
@@ -27,25 +29,36 @@ int main(int argc, char **argv) {
   // Initialize both gmock and gtest.
   testing::InitGoogleMock(&argc, argv);
 
-  llvm::cl::ParseCommandLineOptions(argc, argv);
+  for (int i = 1; i < argc; ++i) {
+    if (std::string("--spirv-test-root") == argv[i]) {
+      // Allow the user set the root directory for test input files.
+      if (i + 1 < argc) {
+        clang::spirv::testOptions::inputDataDir = argv[i + 1];
+        i++;
+      } else {
+        fprintf(stderr, "Error: --spirv-test-root requires an argument\n");
+        return 1;
+      }
+    }
+  }
 
   // Make it easy for a test to re-execute itself by saving argv[0].
   TestMainArgv0 = argv[0];
 
-# if defined(_WIN32)
+#if defined(_WIN32)
   // Disable all of the possible ways Windows conspires to make automated
   // testing impossible.
   ::SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX);
-#   if defined(_MSC_VER)
-    ::_set_error_mode(_OUT_TO_STDERR);
-    _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG);
-    _CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDERR);
-    _CrtSetReportMode(_CRT_ERROR, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG);
-    _CrtSetReportFile(_CRT_ERROR, _CRTDBG_FILE_STDERR);
-    _CrtSetReportMode(_CRT_ASSERT, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG);
-    _CrtSetReportFile(_CRT_ASSERT, _CRTDBG_FILE_STDERR);
-#   endif
-# endif
+#if defined(_MSC_VER)
+  ::_set_error_mode(_OUT_TO_STDERR);
+  _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG);
+  _CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDERR);
+  _CrtSetReportMode(_CRT_ERROR, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG);
+  _CrtSetReportFile(_CRT_ERROR, _CRTDBG_FILE_STDERR);
+  _CrtSetReportMode(_CRT_ASSERT, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG);
+  _CrtSetReportFile(_CRT_ASSERT, _CRTDBG_FILE_STDERR);
+#endif
+#endif
 
   return RUN_ALL_TESTS();
 }

+ 488 - 7
tools/clang/unittests/SPIRV/TypeTest.cpp

@@ -9,12 +9,14 @@
 
 #include "gmock/gmock.h"
 #include "clang/SPIRV/SPIRVContext.h"
+#include "clang/SPIRV/String.h"
 #include "clang/SPIRV/Type.h"
 #include "gtest/gtest.h"
 
 using namespace clang::spirv;
 
 namespace {
+using ::testing::ElementsAre;
 
 TEST(Type, SameTypeWoParameterShouldHaveSameAddress) {
   SPIRVContext context;
@@ -42,16 +44,16 @@ TEST(Type, SameAggregateTypeWithDecorationsShouldHaveSameAddress) {
   const Decoration *mem_0_position =
       Decoration::getBuiltIn(ctx, spv::BuiltIn::Position, 0);
 
-  const Type *struct_1 = Type::getType(
-      ctx, spv::Op::OpTypeStruct, {intt_id, boolt_id},
+  const Type *struct_1 = Type::getStruct(
+      ctx, {intt_id, boolt_id},
       {relaxed, bufferblock, mem_0_offset, mem_1_offset, mem_0_position});
 
-  const Type *struct_2 = Type::getType(
-      ctx, spv::Op::OpTypeStruct, {intt_id, boolt_id},
+  const Type *struct_2 = Type::getStruct(
+      ctx, {intt_id, boolt_id},
       {relaxed, bufferblock, mem_0_offset, mem_1_offset, mem_0_position});
 
-  const Type *struct_3 = Type::getType(
-      ctx, spv::Op::OpTypeStruct, {intt_id, boolt_id},
+  const Type *struct_3 = Type::getStruct(
+      ctx, {intt_id, boolt_id},
       {bufferblock, mem_0_offset, mem_0_position, mem_1_offset, relaxed});
 
   // 2 types with the same signature. We should get the same pointer.
@@ -61,6 +63,485 @@ TEST(Type, SameAggregateTypeWithDecorationsShouldHaveSameAddress) {
   EXPECT_EQ(struct_1, struct_3);
 }
 
-// TODO: Add Type tests for all types
+TEST(Type, Void) {
+  SPIRVContext ctx;
+  const Type *t = Type::getVoid(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeVoid);
+  EXPECT_TRUE(t->getArgs().empty());
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Bool) {
+  SPIRVContext ctx;
+  const Type *t = Type::getBool(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeBool);
+  EXPECT_TRUE(t->getArgs().empty());
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Int8) {
+  SPIRVContext ctx;
+  const Type *t = Type::getInt8(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeInt);
+  EXPECT_THAT(t->getArgs(), ElementsAre(8, 1));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Uint8) {
+  SPIRVContext ctx;
+  const Type *t = Type::getUint8(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeInt);
+  EXPECT_THAT(t->getArgs(), ElementsAre(8, 0));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Int16) {
+  SPIRVContext ctx;
+  const Type *t = Type::getInt16(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeInt);
+  EXPECT_THAT(t->getArgs(), ElementsAre(16, 1));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Uint16) {
+  SPIRVContext ctx;
+  const Type *t = Type::getUint16(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeInt);
+  EXPECT_THAT(t->getArgs(), ElementsAre(16, 0));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Int32) {
+  SPIRVContext ctx;
+  const Type *t = Type::getInt32(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeInt);
+  EXPECT_THAT(t->getArgs(), ElementsAre(32, 1));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Uint32) {
+  SPIRVContext ctx;
+  const Type *t = Type::getUint32(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeInt);
+  EXPECT_THAT(t->getArgs(), ElementsAre(32, 0));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Int64) {
+  SPIRVContext ctx;
+  const Type *t = Type::getInt64(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeInt);
+  EXPECT_THAT(t->getArgs(), ElementsAre(64, 1));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Uint64) {
+  SPIRVContext ctx;
+  const Type *t = Type::getUint64(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeInt);
+  EXPECT_THAT(t->getArgs(), ElementsAre(64, 0));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Float16) {
+  SPIRVContext ctx;
+  const Type *t = Type::getFloat16(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeFloat);
+  EXPECT_THAT(t->getArgs(), ElementsAre(16));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Float32) {
+  SPIRVContext ctx;
+  const Type *t = Type::getFloat32(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeFloat);
+  EXPECT_THAT(t->getArgs(), ElementsAre(32));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Float64) {
+  SPIRVContext ctx;
+  const Type *t = Type::getFloat64(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeFloat);
+  EXPECT_THAT(t->getArgs(), ElementsAre(64));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Vec2) {
+  SPIRVContext ctx;
+  const Type *t = Type::getVec2(ctx, 1);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeVector);
+  EXPECT_THAT(t->getArgs(), ElementsAre(1, 2));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Vec3) {
+  SPIRVContext ctx;
+  const Type *t = Type::getVec3(ctx, 1);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeVector);
+  EXPECT_THAT(t->getArgs(), ElementsAre(1, 3));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Vec4) {
+  SPIRVContext ctx;
+  const Type *t = Type::getVec4(ctx, 1);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeVector);
+  EXPECT_THAT(t->getArgs(), ElementsAre(1, 4));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, Matrix) {
+  SPIRVContext ctx;
+  const Type *t = Type::getMatrix(ctx, /*type-id*/ 7, /*column-count*/ 4);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeMatrix);
+  EXPECT_THAT(t->getArgs(), ElementsAre(7, 4));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, ImageWithoutAccessQualifier) {
+  SPIRVContext ctx;
+  const Type *t = Type::getImage(ctx, /*sampled-type*/ 5, spv::Dim::Cube,
+                                 /*depth*/ 1, /*arrayed*/ 1, /*multisampled*/ 0,
+                                 /*sampled*/ 2, spv::ImageFormat::Rgba32f);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeImage);
+  EXPECT_THAT(t->getArgs(),
+              ElementsAre(5, static_cast<uint32_t>(spv::Dim::Cube), 1, 1, 0, 2,
+                          static_cast<uint32_t>(spv::ImageFormat::Rgba32f)));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedImageWithoutAccessQualifier) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t =
+      Type::getImage(ctx, /*sampled-type*/ 5, spv::Dim::Cube, /*depth*/ 1,
+                     /*arrayed*/ 1, /*multisampled*/ 0, /*sampled*/ 2,
+                     spv::ImageFormat::Rgba32f, llvm::None, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeImage);
+  EXPECT_THAT(t->getArgs(),
+              ElementsAre(5, static_cast<uint32_t>(spv::Dim::Cube), 1, 1, 0, 2,
+                          static_cast<uint32_t>(spv::ImageFormat::Rgba32f)));
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, ImageWithAccessQualifier) {
+  SPIRVContext ctx;
+  const Type *t = Type::getImage(
+      ctx, /*sampled-type*/ 5, spv::Dim::Cube, /*depth*/ 1, /*arrayed*/ 1,
+      /*multisampled*/ 0, /*sampled*/ 2, spv::ImageFormat::Rgba32f,
+      /*access-qualifier*/ spv::AccessQualifier::ReadWrite);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeImage);
+  EXPECT_THAT(
+      t->getArgs(),
+      ElementsAre(5, static_cast<uint32_t>(spv::Dim::Cube), 1, 1, 0, 2,
+                  static_cast<uint32_t>(spv::ImageFormat::Rgba32f),
+                  static_cast<uint32_t>(spv::AccessQualifier::ReadWrite)));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedImageWithAccessQualifier) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t = Type::getImage(
+      ctx, /*sampled-type*/ 5, spv::Dim::Cube, /*depth*/ 1, /*arrayed*/ 1,
+      /*multisampled*/ 0, /*sampled*/ 2, spv::ImageFormat::Rgba32f,
+      /*access-qualifier*/ spv::AccessQualifier::ReadWrite, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeImage);
+  EXPECT_THAT(
+      t->getArgs(),
+      ElementsAre(5, static_cast<uint32_t>(spv::Dim::Cube), 1, 1, 0, 2,
+                  static_cast<uint32_t>(spv::ImageFormat::Rgba32f),
+                  static_cast<uint32_t>(spv::AccessQualifier::ReadWrite)));
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, ImageWithAndWithoutAccessQualifierAreDifferentTypes) {
+  SPIRVContext ctx;
+  const Type *img1 =
+      Type::getImage(ctx, /*sampled-type*/ 5, spv::Dim::Cube,
+                     /*depth*/ 1, /*arrayed*/ 1, /*multisampled*/ 0,
+                     /*sampled*/ 2, spv::ImageFormat::Rgba32f);
+  const Type *img2 =
+      Type::getImage(ctx, /*sampled-type*/ 5, spv::Dim::Cube,
+                     /*depth*/ 1, /*arrayed*/ 1, /*multisampled*/ 0,
+                     /*sampled*/ 2, spv::ImageFormat::Rgba32f,
+                     /*access-qualifier*/ spv::AccessQualifier::ReadWrite);
+
+  // The only difference between these two types is the Access Qualifier which
+  // is an optional argument.
+  EXPECT_NE(img1, img2);
+}
+
+TEST(Type, Sampler) {
+  SPIRVContext ctx;
+  const Type *t = Type::getSampler(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeSampler);
+  EXPECT_TRUE(t->getArgs().empty());
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedSampler) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t = Type::getSampler(ctx, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeSampler);
+  EXPECT_TRUE(t->getArgs().empty());
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, SampledImage) {
+  SPIRVContext ctx;
+  const Type *t = Type::getSampledImage(ctx, 1);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeSampledImage);
+  EXPECT_THAT(t->getArgs(), ElementsAre(1));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedSampledImage) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t = Type::getSampledImage(ctx, 1, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeSampledImage);
+  EXPECT_THAT(t->getArgs(), ElementsAre(1));
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, Array) {
+  SPIRVContext ctx;
+  const Type *t = Type::getArray(ctx, 2, 4);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeArray);
+  EXPECT_THAT(t->getArgs(), ElementsAre(2, 4));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedArray) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t = Type::getArray(ctx, 2, 4, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeArray);
+  EXPECT_THAT(t->getArgs(), ElementsAre(2, 4));
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, RuntimeArray) {
+  SPIRVContext ctx;
+  const Type *t = Type::getRuntimeArray(ctx, 2);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeRuntimeArray);
+  EXPECT_THAT(t->getArgs(), ElementsAre(2));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedRuntimeArray) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t = Type::getRuntimeArray(ctx, 2, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeRuntimeArray);
+  EXPECT_THAT(t->getArgs(), ElementsAre(2));
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, StructBasic) {
+  SPIRVContext ctx;
+  const Type *t = Type::getStruct(ctx, {2, 3, 4});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeStruct);
+  EXPECT_THAT(t->getArgs(), ElementsAre(2, 3, 4));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, StructWithDecoration) {
+  SPIRVContext ctx;
+  const Decoration *bufferblock = Decoration::getBufferBlock(ctx);
+  const Type *t = Type::getStruct(ctx, {2, 3, 4}, {bufferblock});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeStruct);
+  EXPECT_THAT(t->getArgs(), ElementsAre(2, 3, 4));
+  EXPECT_THAT(t->getDecorations(), ElementsAre(bufferblock));
+}
+
+TEST(Type, StructWithDecoratedMembers) {
+  SPIRVContext ctx;
+  const Decoration *relaxed = Decoration::getRelaxedPrecision(ctx);
+  const Decoration *bufferblock = Decoration::getBufferBlock(ctx);
+  const Decoration *mem_0_offset = Decoration::getOffset(ctx, 0u, 0);
+  const Decoration *mem_1_offset = Decoration::getOffset(ctx, 0u, 1);
+  const Decoration *mem_0_position =
+      Decoration::getBuiltIn(ctx, spv::BuiltIn::Position, 0);
+
+  const Type *t = Type::getStruct(
+      ctx, {2, 3, 4},
+      {relaxed, bufferblock, mem_0_position, mem_0_offset, mem_1_offset});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeStruct);
+  EXPECT_THAT(t->getArgs(), ElementsAre(2, 3, 4));
+  // Since decorations are an ordered set of pointers, it's better not to use
+  // ElementsAre()
+  EXPECT_EQ(t->getDecorations().size(), 5u);
+  EXPECT_TRUE(t->hasDecoration(relaxed));
+  EXPECT_TRUE(t->hasDecoration(bufferblock));
+  EXPECT_TRUE(t->hasDecoration(mem_0_offset));
+  EXPECT_TRUE(t->hasDecoration(mem_0_position));
+  EXPECT_TRUE(t->hasDecoration(mem_1_offset));
+}
+
+TEST(Type, Opaque) {
+  SPIRVContext ctx;
+  const Type *t = Type::getOpaque(ctx, "opaque_type");
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeOpaque);
+  EXPECT_EQ(string::decodeSPIRVString(t->getArgs()), "opaque_type");
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedOpaque) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t = Type::getOpaque(ctx, "opaque_type", {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeOpaque);
+  EXPECT_EQ(string::decodeSPIRVString(t->getArgs()), "opaque_type");
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, Pointer) {
+  SPIRVContext ctx;
+  const Type *t = Type::getPointer(ctx, spv::StorageClass::Uniform, 2);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypePointer);
+  EXPECT_THAT(
+      t->getArgs(),
+      ElementsAre(static_cast<uint32_t>(spv::StorageClass::Uniform), 2));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedPointer) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t = Type::getPointer(ctx, spv::StorageClass::Uniform, 2, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypePointer);
+  EXPECT_THAT(
+      t->getArgs(),
+      ElementsAre(static_cast<uint32_t>(spv::StorageClass::Uniform), 2));
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, Function) {
+  SPIRVContext ctx;
+  const Type *t = Type::getFunction(ctx, 1, {2, 3, 4});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeFunction);
+  EXPECT_THAT(t->getArgs(), ElementsAre(1, 2, 3, 4));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedFunction) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t = Type::getFunction(ctx, 1, {2, 3, 4}, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeFunction);
+  EXPECT_THAT(t->getArgs(), ElementsAre(1, 2, 3, 4));
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, Event) {
+  SPIRVContext ctx;
+  const Type *t = Type::getEvent(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeEvent);
+  EXPECT_TRUE(t->getArgs().empty());
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedEvent) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t = Type::getEvent(ctx, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeEvent);
+  EXPECT_TRUE(t->getArgs().empty());
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, DeviceEvent) {
+  SPIRVContext ctx;
+  const Type *t = Type::getDeviceEvent(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeDeviceEvent);
+  EXPECT_TRUE(t->getArgs().empty());
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedDeviceEvent) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t = Type::getDeviceEvent(ctx, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeDeviceEvent);
+  EXPECT_TRUE(t->getArgs().empty());
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, ReserveId) {
+  SPIRVContext ctx;
+  const Type *t = Type::getReserveId(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeReserveId);
+  EXPECT_TRUE(t->getArgs().empty());
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedReserveId) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t = Type::getReserveId(ctx, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeReserveId);
+  EXPECT_TRUE(t->getArgs().empty());
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, Queue) {
+  SPIRVContext ctx;
+  const Type *t = Type::getQueue(ctx);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeQueue);
+  EXPECT_TRUE(t->getArgs().empty());
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedQueue) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t = Type::getQueue(ctx, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeQueue);
+  EXPECT_TRUE(t->getArgs().empty());
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, Pipe) {
+  SPIRVContext ctx;
+  const Type *t = Type::getPipe(ctx, spv::AccessQualifier::WriteOnly);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypePipe);
+  EXPECT_THAT(t->getArgs(), ElementsAre(static_cast<uint32_t>(
+                                spv::AccessQualifier::WriteOnly)));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedPipe) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t = Type::getPipe(ctx, spv::AccessQualifier::WriteOnly, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypePipe);
+  EXPECT_THAT(t->getArgs(), ElementsAre(static_cast<uint32_t>(
+                                spv::AccessQualifier::WriteOnly)));
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
+
+TEST(Type, ForwardPointer) {
+  SPIRVContext ctx;
+  const Type *t = Type::getForwardPointer(ctx, 6, spv::StorageClass::Workgroup);
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeForwardPointer);
+  EXPECT_THAT(t->getArgs(), ElementsAre(6, static_cast<uint32_t>(
+                                               spv::StorageClass::Workgroup)));
+  EXPECT_TRUE(t->getDecorations().empty());
+}
+
+TEST(Type, DecoratedForwardPointer) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getAliased(ctx);
+  const Type *t =
+      Type::getForwardPointer(ctx, 6, spv::StorageClass::Workgroup, {d});
+  EXPECT_EQ(t->getOpcode(), spv::Op::OpTypeForwardPointer);
+  EXPECT_THAT(t->getArgs(), ElementsAre(6, static_cast<uint32_t>(
+                                               spv::StorageClass::Workgroup)));
+  EXPECT_THAT(t->getDecorations(), ElementsAre(d));
+}
 
 } // anonymous namespace

+ 229 - 0
tools/clang/unittests/SPIRV/WholeFileCheck.cpp

@@ -0,0 +1,229 @@
+//===- unittests/SPIRV/WholeFileCheck.cpp - WholeFileCheck Implementation -===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include <fstream>
+
+#include "WholeFileCheck.h"
+#include "gtest/gtest.h"
+
+WholeFileTest::WholeFileTest() : spirvTools(SPV_ENV_UNIVERSAL_1_0) {
+  spirvTools.SetMessageConsumer(
+      [](spv_message_level_t, const char *, const spv_position_t &,
+         const char *message) { fprintf(stdout, "%s\n", message); });
+}
+
+bool WholeFileTest::processRunCommandArgs(const std::string &runCommandLine) {
+  std::istringstream buf(runCommandLine);
+  std::istream_iterator<std::string> start(buf), end;
+  std::vector<std::string> tokens(start, end);
+  if (tokens[1].find("Run") == std::string::npos ||
+      tokens[2].find("%dxc") == std::string::npos) {
+    fprintf(stderr, "The only supported format is: \"// Run: %%dxc -T "
+                    "<profile> -E <entry>\"\n");
+    return false;
+  }
+
+  for (size_t i = 0; i < tokens.size(); ++i) {
+    if (tokens[i] == "-T" && i + 1 < tokens.size())
+      targetProfile = tokens[i + 1];
+    else if (tokens[i] == "-E" && i + 1 < tokens.size())
+      entryPoint = tokens[i + 1];
+  }
+  if (targetProfile.empty()) {
+    fprintf(stderr, "Error: Missing target profile argument (-T).\n");
+    return false;
+  }
+  if (entryPoint.empty()) {
+    fprintf(stderr, "Error: Missing entry point argument (-E).\n");
+    return false;
+  }
+  return true;
+}
+
+bool WholeFileTest::parseInputFile() {
+  bool foundRunCommand = false;
+  bool parseSpirv = false;
+  std::ostringstream outString;
+  std::ifstream inputFile;
+  inputFile.exceptions(std::ifstream::failbit);
+  try {
+    inputFile.open(inputFilePath);
+    for (std::string line; !inputFile.eof() && std::getline(inputFile, line);) {
+      if (line.find(hlslStartLabel) != std::string::npos) {
+        foundRunCommand = true;
+        if (!processRunCommandArgs(line)) {
+          // An error has occured when parsing the Run command.
+          return false;
+        }
+      } else if (line.find(spirvStartLabel) != std::string::npos) {
+        // HLSL source has ended.
+        // SPIR-V source starts on the next line.
+        parseSpirv = true;
+      } else if (parseSpirv) {
+        // Strip the leading "//" from the SPIR-V assembly (skip 2 characters)
+        if (line.size() > 2u) {
+          line = line.substr(2);
+        }
+        // Skip any leading whitespace
+        size_t found = line.find_first_not_of(" \t");
+        if (found != std::string::npos) {
+          line = line.substr(found);
+        }
+        outString << line << std::endl;
+      }
+    }
+
+    if (!foundRunCommand) {
+      fprintf(stderr, "Error: Missing \"Run:\" command.\n");
+      return false;
+    }
+    if (!parseSpirv) {
+      fprintf(stderr, "Error: Missing \"CHECK-WHOLE-SPIR-V:\" command.\n");
+      return false;
+    }
+
+    // Reached the end of the file. SPIR-V source has ended. Store it for
+    // comparison.
+    expectedSpirvAsm = outString.str();
+
+    // Close the input file.
+    inputFile.close();
+  } catch (...) {
+    fprintf(
+        stderr,
+        "Error: Exception occurred while opening/reading the input file %s\n",
+        inputFilePath.c_str());
+    return false;
+  }
+
+  // Everything was successful.
+  return true;
+}
+
+bool WholeFileTest::runCompilerWithSpirvGeneration() {
+  std::wstring srcFile(inputFilePath.begin(), inputFilePath.end());
+  std::wstring entry(entryPoint.begin(), entryPoint.end());
+  std::wstring profile(targetProfile.begin(), targetProfile.end());
+  bool success = true;
+
+  try {
+    dxc::DxcDllSupport dllSupport;
+    IFT(dllSupport.Initialize());
+
+    CComPtr<IDxcLibrary> pLibrary;
+    CComPtr<IDxcCompiler> pCompiler;
+    CComPtr<IDxcOperationResult> pResult;
+    CComPtr<IDxcBlobEncoding> pSource;
+    CComPtr<IDxcBlobEncoding> pErrorBuffer;
+    CComPtr<IDxcBlob> pCompiledBlob;
+    CComPtr<IDxcIncludeHandler> pIncludeHandler;
+    HRESULT resultStatus;
+
+    std::vector<LPCWSTR> flags;
+    flags.push_back(L"-E");
+    flags.push_back(entry.c_str());
+    flags.push_back(L"-T");
+    flags.push_back(profile.c_str());
+    flags.push_back(L"-spirv");
+
+    IFT(dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary));
+    IFT(pLibrary->CreateBlobFromFile(srcFile.c_str(), nullptr, &pSource));
+    IFT(pLibrary->CreateIncludeHandler(&pIncludeHandler));
+    IFT(dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
+    IFT(pCompiler->Compile(pSource, srcFile.c_str(), entry.c_str(),
+                           profile.c_str(), flags.data(), flags.size(), nullptr,
+                           0, pIncludeHandler, &pResult));
+    IFT(pResult->GetStatus(&resultStatus));
+
+    if (SUCCEEDED(resultStatus)) {
+      CComPtr<IDxcBlobEncoding> pStdErr;
+      IFT(pResult->GetResult(&pCompiledBlob));
+      convertIDxcBlobToUint32(pCompiledBlob);
+      success = true;
+    } else {
+      IFT(pResult->GetErrorBuffer(&pErrorBuffer));
+      fprintf(stderr, "%s\n", (char *)pErrorBuffer->GetBufferPointer());
+      success = false;
+    }
+  } catch (...) {
+    // An exception has occured while running the compiler with SPIR-V
+    // Generation
+    success = false;
+  }
+
+  return success;
+}
+
+bool WholeFileTest::disassembleSpirvBinary(bool generateHeader) {
+  uint32_t options = SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES;
+  if (!generateHeader)
+    options |= SPV_BINARY_TO_TEXT_OPTION_NO_HEADER;
+  return spirvTools.Disassemble(generatedBinary, &generatedSpirvAsm, options);
+}
+
+bool WholeFileTest::validateSpirvBinary() {
+  return spirvTools.Validate(generatedBinary);
+}
+
+void WholeFileTest::convertIDxcBlobToUint32(const CComPtr<IDxcBlob> &blob) {
+  size_t num32BitWords = (blob->GetBufferSize() + 3) / 4;
+  std::string binaryStr((char *)blob->GetBufferPointer(),
+                        blob->GetBufferSize());
+  binaryStr.resize(num32BitWords * 4, 0);
+  generatedBinary.resize(num32BitWords, 0);
+  memcpy(generatedBinary.data(), binaryStr.data(), binaryStr.size());
+}
+
+bool WholeFileTest::compareExpectedSpirvAndGeneratedSpirv() {
+  return generatedSpirvAsm == expectedSpirvAsm;
+}
+
+std::string
+WholeFileTest::getAbsPathOfInputDataFile(const std::string &filename) {
+  std::string path = clang::spirv::testOptions::inputDataDir;
+
+#ifdef _WIN32
+  const char sep = '\\';
+  std::replace(path.begin(), path.end(), '/', '\\');
+#else
+  const char sep = '/';
+#endif
+
+  if (path[path.size() - 1] != sep) {
+    path = path + sep;
+  }
+  path += filename;
+  return path;
+}
+
+bool WholeFileTest::runWholeFileTest(std::string filename, bool generateHeader,
+                                     bool runSpirvValidation) {
+  inputFilePath = getAbsPathOfInputDataFile(filename);
+
+  bool success = true;
+
+  // Parse the input file.
+  success = success && parseInputFile();
+
+  // Feed the HLSL source into the Compiler.
+  success = success && runCompilerWithSpirvGeneration();
+
+  // Disassemble the generated SPIR-V binary.
+  success = success && disassembleSpirvBinary(generateHeader);
+
+  // Run SPIR-V validation if requested.
+  if (runSpirvValidation) {
+    success = success && validateSpirvBinary();
+  }
+
+  // Compare the expected and the generted SPIR-V code.
+  success = success && compareExpectedSpirvAndGeneratedSpirv();
+
+  return success;
+}

+ 102 - 0
tools/clang/unittests/SPIRV/WholeFileCheck.h

@@ -0,0 +1,102 @@
+//===- unittests/SPIRV/WholeFileCheck.h ---- WholeFileCheck Test Fixture --===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include <algorithm>
+#include <fstream>
+
+#include "dxc/Support/Global.h"
+#include "dxc/Support/WinIncludes.h"
+#include "dxc/Support/dxcapi.use.h"
+#include "spirv-tools/libspirv.hpp"
+#include "gtest/gtest.h"
+
+#include "SpirvTestOptions.h"
+
+namespace {
+const char hlslStartLabel[] = "// Run:";
+const char spirvStartLabel[] = "// CHECK-WHOLE-SPIR-V:";
+}
+
+/// \brief The purpose of the this test class is to take in an input file with
+/// the following format:
+///
+///    // Comments...
+///    // More comments...
+///    // Run: %dxc -T ps_6_0 -E main
+///    ...
+///    <HLSL code goes here>
+///    ...
+///    // CHECK-WHOLE-SPIR-V:
+///    // ...
+///    // <SPIR-V code goes here>
+///    // ...
+///
+/// This file is fully read in as the HLSL source (therefore any non-HLSL must
+/// be commented out). It is fed to the DXC compiler with the SPIR-V Generation
+/// option. The resulting SPIR-V binary is then fed to the SPIR-V disassembler
+/// (via SPIR-V Tools) to get a SPIR-V assembly text. The resulting SPIR-V
+/// assembly text is compared to the second part of the input file (after the
+/// <CHECK-WHOLE-SPIR-V:> directive). If these match, the test is marked as a
+/// PASS, and marked as a FAILED otherwise.
+class WholeFileTest : public ::testing::Test {
+public:
+  WholeFileTest();
+
+  /// \brief Runs a WHOLE-FILE-TEST! (See class description for more info)
+  /// Returns true if the test passes; false otherwise.
+  /// Since SPIR-V headers may change, a test is more robust if the
+  /// disassembler does not include the header.
+  /// It is also important that all generated SPIR-V code is valid. Users of
+  /// WholeFileTest may choose not to run the SPIR-V Validator (for cases where
+  /// a certain feature has not been added to the Validator yet).
+  bool runWholeFileTest(std::string path, bool generateHeader = false,
+                        bool runSpirvValidation = true);
+
+private:
+  /// \brief Reads in the given input file.
+  /// Stores the SPIR-V portion of the file into the <expectedSpirvAsm>
+  /// member variable. All "//" are also removed from the SPIR-V assembly.
+  /// Returns true on success, and false on failure.
+  bool parseInputFile();
+
+  /// \brief Passes the HLSL input to the DXC compiler with SPIR-V CodeGen.
+  /// Writes the SPIR-V Binary to the output file.
+  /// Returns true on success, and false on failure.
+  bool runCompilerWithSpirvGeneration();
+
+  /// \brief Passes the SPIR-V Binary to the disassembler.
+  bool disassembleSpirvBinary(bool generatedHeader = false);
+
+  /// \brief Runs the SPIR-V tools validation on the SPIR-V Binary.
+  /// Returns true if validation is successful; false otherwise.
+  bool validateSpirvBinary();
+
+  /// \brief Compares the expected and the generated SPIR-V code.
+  /// Returns true if they match, and false otherwise.
+  bool compareExpectedSpirvAndGeneratedSpirv();
+
+  /// \brief Parses the Target Profile and Entry Point from the Run command
+  bool processRunCommandArgs(const std::string &runCommandLine);
+
+  /// \brief Converts an IDxcBlob that is the output of "%DXC -spirv" into a
+  /// vector of 32-bit unsigned integers that can be passed into the
+  /// disassembler. Stores the results in <generatedBinary>.
+  void convertIDxcBlobToUint32(const CComPtr<IDxcBlob> &blob);
+
+  /// \brief Returns the absolute path to the input file of the test.
+  std::string getAbsPathOfInputDataFile(const std::string &filename);
+
+  std::string targetProfile;             ///< Target profile (argument of -T)
+  std::string entryPoint;                ///< Entry point name (argument of -E)
+  std::string inputFilePath;             ///< Path to the input test file
+  std::vector<uint32_t> generatedBinary; ///< The generated SPIR-V Binary
+  std::string expectedSpirvAsm;          ///< Expected SPIR-V parsed from input
+  std::string generatedSpirvAsm;         ///< Disassembled binary (SPIR-V code)
+  spvtools::SpirvTools spirvTools;       ///< SPIR-V Tools used by the test
+};

+ 1 - 1
utils/hct/hcttest.cmd

@@ -182,7 +182,7 @@ if "%TEST_SPIRV%"=="1" (
     exit /b 1
   )
   echo Running SPIRV tests ...
-  %BIN_DIR%\clang-spirv-tests.exe
+  %BIN_DIR%\clang-spirv-tests.exe --spirv-test-root %HLSL_SRC_DIR%\tools\clang\test\CodeGenSPIRV
   if errorlevel 1 (
     echo Failure occured in SPIRV unit tests
     exit /b 1