Bladeren bron

[spirv] Refactor and add more unit tests (#1944)

* [spirv] Add unit tests for SpirvConstant.

* [spirv] Add unit tests for SpirvType classes.
Ehsan 6 jaren geleden
bovenliggende
commit
e29556df26

+ 6 - 6
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -395,8 +395,8 @@ SpirvConstantBoolean::SpirvConstantBoolean(const BoolType *type, bool val,
       value(val) {}
 
 bool SpirvConstantBoolean::operator==(const SpirvConstantBoolean &that) const {
-  return resultType == that.resultType && value == that.value &&
-         opcode == that.opcode;
+  return resultType == that.resultType && astResultType == that.astResultType &&
+         value == that.value && opcode == that.opcode;
 }
 
 SpirvConstantInteger::SpirvConstantInteger(QualType type, llvm::APInt val,
@@ -409,8 +409,8 @@ SpirvConstantInteger::SpirvConstantInteger(QualType type, llvm::APInt val,
 }
 
 bool SpirvConstantInteger::operator==(const SpirvConstantInteger &that) const {
-  return resultType == that.resultType && value == that.value &&
-         opcode == that.opcode;
+  return resultType == that.resultType && astResultType == that.astResultType &&
+         value == that.value && opcode == that.opcode;
 }
 
 SpirvConstantFloat::SpirvConstantFloat(QualType type, llvm::APFloat val,
@@ -423,8 +423,8 @@ SpirvConstantFloat::SpirvConstantFloat(QualType type, llvm::APFloat val,
 }
 
 bool SpirvConstantFloat::operator==(const SpirvConstantFloat &that) const {
-  return resultType == that.resultType && value.bitwiseIsEqual(that.value) &&
-         opcode == that.opcode;
+  return resultType == that.resultType && astResultType == that.astResultType &&
+         value.bitwiseIsEqual(that.value) && opcode == that.opcode;
 }
 
 SpirvConstantComposite::SpirvConstantComposite(

+ 1 - 0
tools/clang/unittests/SPIRV/CMakeLists.txt

@@ -13,6 +13,7 @@ add_clang_unittest(clang-spirv-tests
   SpirvBasicBlockTest.cpp
   SpirvContextTest.cpp
   SpirvTestOptions.cpp
+  SpirvTypeTest.cpp
   SpirvConstantTest.cpp
   StringTest.cpp
   TestMain.cpp

+ 110 - 100
tools/clang/unittests/SPIRV/SpirvConstantTest.cpp

@@ -7,98 +7,79 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-#include "clang/SPIRV/SpirvContext.h"
+#include "SpirvTestBase.h"
 #include "clang/SPIRV/SpirvInstruction.h"
 
-namespace {
 using namespace clang::spirv;
 
-TEST(SpirvConstant, BoolFalse) {
+namespace {
+
+class SpirvConstantTest : public SpirvTestBase {};
+
+TEST_F(SpirvConstantTest, BoolFalse) {
   SpirvContext ctx;
   const bool val = false;
   SpirvConstantBoolean constant(ctx.getBoolType(), val);
   EXPECT_EQ(val, constant.getValue());
 }
 
-TEST(SpirvConstant, BoolTrue) {
+TEST_F(SpirvConstantTest, BoolTrue) {
   SpirvContext ctx;
   const bool val = true;
   SpirvConstantBoolean constant(ctx.getBoolType(), val);
   EXPECT_EQ(val, constant.getValue());
 }
 
-/*
-TEST(SpirvConstant, Uint16) {
-  SpirvContext ctx;
-  const uint16_t u16 = 12;
-  SpirvConstantInteger constant(ctx.getUIntType(16), u16);
-  EXPECT_EQ(u16, constant.getUnsignedInt16Value());
-}
-
-TEST(SpirvConstant, Int16) {
-  SpirvContext ctx;
-  const int16_t i16 = -12;
-  SpirvConstantInteger constant(ctx.getSIntType(16), i16);
-  EXPECT_EQ(i16, constant.getSignedInt16Value());
-}
-
-TEST(SpirvConstant, Uint32) {
-  SpirvContext ctx;
-  const uint32_t u32 = 65536;
-  SpirvConstantInteger constant(ctx.getUIntType(32), u32);
-  EXPECT_EQ(u32, constant.getUnsignedInt32Value());
+TEST_F(SpirvConstantTest, Uint16) {
+  clang::ASTContext &astContext = getAstContext();
+  const auto u16 = llvm::APInt(16, 12u);
+  SpirvConstantInteger constant(astContext.UnsignedShortTy, u16);
+  EXPECT_EQ(u16, constant.getValue());
 }
 
-TEST(SpirvConstant, Int32) {
-  SpirvContext ctx;
-  const int32_t i32 = -65536;
-  SpirvConstantInteger constant(ctx.getSIntType(32), i32);
-  EXPECT_EQ(i32, constant.getSignedInt32Value());
+TEST_F(SpirvConstantTest, Int16) {
+  clang::ASTContext &astContext = getAstContext();
+  const auto i16 = llvm::APInt(16, -12, /*isSigned*/ true);
+  SpirvConstantInteger constant(astContext.ShortTy, i16);
+  EXPECT_EQ(i16, constant.getValue());
 }
 
-TEST(SpirvConstant, Uint64) {
-  SpirvContext ctx;
-  const uint64_t u64 = 4294967296;
-  SpirvConstantInteger constant(ctx.getUIntType(64), u64);
-  EXPECT_EQ(u64, constant.getUnsignedInt64Value());
+TEST_F(SpirvConstantTest, Uint32) {
+  clang::ASTContext &astContext = getAstContext();
+  const auto u32 = llvm::APInt(32, 65536);
+  SpirvConstantInteger constant(astContext.UnsignedIntTy, u32);
+  EXPECT_EQ(u32, constant.getValue());
 }
 
-TEST(SpirvConstant, Int64) {
-  SpirvContext ctx;
-  const int64_t i64 = -4294967296;
-  SpirvConstantInteger constant(ctx.getSIntType(64), i64);
-  EXPECT_EQ(i64, constant.getSignedInt64Value());
+TEST_F(SpirvConstantTest, Int32) {
+  clang::ASTContext &astContext = getAstContext();
+  const auto i32 = llvm::APInt(32, -65536, /*isSigned*/ true);
+  SpirvConstantInteger constant(astContext.IntTy, i32);
+  EXPECT_EQ(i32, constant.getValue());
 }
-*/
 
-/*
-TEST(SpirvConstant, Float16) {
-  SpirvContext ctx;
-  const uint16_t f16 = 12;
-  SpirvConstantFloat constant(ctx.getFloatType(16), f16);
-  EXPECT_EQ(f16, constant.getValue16());
+TEST_F(SpirvConstantTest, Uint64) {
+  clang::ASTContext &astContext = getAstContext();
+  const auto u64 = llvm::APInt(64, 4294967296);
+  SpirvConstantInteger constant(astContext.UnsignedLongLongTy, u64);
+  EXPECT_EQ(u64, constant.getValue());
 }
 
-TEST(SpirvConstant, Float32) {
-  SpirvContext ctx;
-  const float f32 = 1.5;
-  SpirvConstantFloat constant(ctx.getFloatType(32), f32);
-  EXPECT_EQ(f32, constant.getValue32());
+TEST_F(SpirvConstantTest, Int64) {
+  clang::ASTContext &astContext = getAstContext();
+  const auto i64 = llvm::APInt(64, -4294967296, /*isSigned*/ true);
+  SpirvConstantInteger constant(astContext.LongLongTy, i64);
+  EXPECT_EQ(i64, constant.getValue());
 }
 
-TEST(SpirvConstant, Float64) {
-  SpirvContext ctx;
-  const double f64 = 3.14;
-  SpirvConstantFloat constant(ctx.getFloatType(64), f64);
-  EXPECT_EQ(f64, constant.getValue64());
+TEST_F(SpirvConstantTest, Float32) {
+  clang::ASTContext &astContext = getAstContext();
+  const auto f32 = llvm::APFloat(1.5f);
+  SpirvConstantFloat constant(astContext.FloatTy, f32);
+  EXPECT_EQ(1.5f, constant.getValue().convertToFloat());
 }
 
-*/
-
-TEST(SpirvConstant, CheckOperatorEqualOnBool) {
+TEST_F(SpirvConstantTest, CheckOperatorEqualOnBool) {
   SpirvContext ctx;
   const bool val = true;
   SpirvConstantBoolean constant1(ctx.getBoolType(), val);
@@ -106,80 +87,109 @@ TEST(SpirvConstant, CheckOperatorEqualOnBool) {
   EXPECT_TRUE(constant1 == constant2);
 }
 
-/*
-TEST(SpirvConstant, CheckOperatorEqualOnInt) {
-  SpirvContext ctx;
-  const int32_t i32 = -65536;
-  SpirvConstantInteger constant1(ctx.getSIntType(32), i32);
-  SpirvConstantInteger constant2(ctx.getSIntType(32), i32);
+TEST_F(SpirvConstantTest, CheckOperatorEqualOnInt) {
+  clang::ASTContext &astContext = getAstContext();
+  const auto i32 = llvm::APInt(32, -65536, /*isSigned*/ true);
+  SpirvConstantInteger constant1(astContext.IntTy, i32);
+  SpirvConstantInteger constant2(astContext.IntTy, i32);
   EXPECT_TRUE(constant1 == constant2);
 }
 
-TEST(SpirvConstant, CheckOperatorEqualOnFloat) {
-  SpirvContext ctx;
-  const double f64 = 3.14;
-  SpirvConstantFloat constant1(ctx.getFloatType(64), f64);
-  SpirvConstantFloat constant2(ctx.getFloatType(64), f64);
+TEST_F(SpirvConstantTest, CheckOperatorEqualOnFloat) {
+  clang::ASTContext &astContext = getAstContext();
+  const auto f32 = llvm::APFloat(1.5f);
+  SpirvConstantFloat constant1(astContext.FloatTy, f32);
+  SpirvConstantFloat constant2(astContext.FloatTy, f32);
   EXPECT_TRUE(constant1 == constant2);
 }
-*/
 
-TEST(SpirvConstant, CheckOperatorEqualOnNull) {
+TEST_F(SpirvConstantTest, CheckOperatorEqualOnNull) {
   SpirvContext ctx;
   SpirvConstantNull constant1(ctx.getSIntType(32));
   SpirvConstantNull constant2(ctx.getSIntType(32));
   EXPECT_TRUE(constant1 == constant2);
 }
 
-TEST(SpirvConstant, CheckOperatorEqualOnBool2) {
+TEST_F(SpirvConstantTest, CheckOperatorEqualOnBool2) {
   SpirvContext ctx;
   SpirvConstantBoolean constant1(ctx.getBoolType(), true);
   SpirvConstantBoolean constant2(ctx.getBoolType(), false);
   EXPECT_FALSE(constant1 == constant2);
 }
 
-/*
-TEST(SpirvConstant, CheckOperatorEqualOnInt2) {
-  SpirvContext ctx;
-  SpirvConstantInteger constant1(ctx.getSIntType(32), 5);
-  SpirvConstantInteger constant2(ctx.getSIntType(32), 7);
+TEST_F(SpirvConstantTest, CheckOperatorEqualOnInt2) {
+  clang::ASTContext &astContext = getAstContext();
+  const auto i1 = llvm::APInt(32, 5, /*isSigned*/ true);
+  const auto i2 = llvm::APInt(32, 7, /*isSigned*/ true);
+  SpirvConstantInteger constant1(astContext.IntTy, i1);
+  SpirvConstantInteger constant2(astContext.IntTy, i2);
   EXPECT_FALSE(constant1 == constant2);
 }
 
-TEST(SpirvConstant, CheckOperatorEqualOnFloat2) {
-  SpirvContext ctx;
-  SpirvConstantFloat constant1(ctx.getFloatType(64), 3.14);
-  SpirvConstantFloat constant2(ctx.getFloatType(64), 3.15);
+TEST_F(SpirvConstantTest, CheckOperatorEqualOnFloat2) {
+  clang::ASTContext &astContext = getAstContext();
+  const auto f1 = llvm::APFloat(1.5f);
+  const auto f2 = llvm::APFloat(1.6f);
+  SpirvConstantFloat constant1(astContext.FloatTy, f1);
+  SpirvConstantFloat constant2(astContext.FloatTy, f2);
+  EXPECT_FALSE(constant1 == constant2);
+}
+
+TEST_F(SpirvConstantTest, CheckOperatorEqualOnInt3) {
+  // Different signedness should mean different constants.
+  clang::ASTContext &astContext = getAstContext();
+  const auto i32 = llvm::APInt(32, 7, /*isSigned*/ true);
+  SpirvConstantInteger constant1(astContext.UnsignedIntTy, i32);
+  SpirvConstantInteger constant2(astContext.IntTy, i32);
+  EXPECT_FALSE(constant1 == constant2);
+}
+
+TEST_F(SpirvConstantTest, CheckOperatorEqualOnFloat3) {
+  // Different bitwidth should mean different constants.
+  clang::ASTContext &astContext = getAstContext();
+  const auto f32 = llvm::APFloat(1.5f);
+  SpirvConstantFloat constant1(astContext.DoubleTy, f32);
+  SpirvConstantFloat constant2(astContext.FloatTy, f32);
+  EXPECT_FALSE(constant1 == constant2);
+}
+
+TEST_F(SpirvConstantTest, CheckOperatorEqualOnInt4) {
+  // Different bitwidth should mean different constants.
+  clang::ASTContext &astContext = getAstContext();
+  const auto i32 = llvm::APInt(32, 7, /*isSigned*/ true);
+  SpirvConstantInteger constant1(astContext.ShortTy, i32);
+  SpirvConstantInteger constant2(astContext.IntTy, i32);
   EXPECT_FALSE(constant1 == constant2);
 }
-*/
 
-TEST(SpirvConstant, CheckOperatorEqualOnNull2) {
+TEST_F(SpirvConstantTest, CheckOperatorEqualOnNull2) {
   SpirvContext ctx;
   SpirvConstantNull constant1(ctx.getSIntType(32));
   SpirvConstantNull constant2(ctx.getUIntType(32));
   EXPECT_FALSE(constant1 == constant2);
 }
 
-TEST(SpirvConstant, BoolConstNotEqualSpecConst) {
+TEST_F(SpirvConstantTest, BoolConstNotEqualSpecConst) {
   SpirvContext ctx;
   SpirvConstantBoolean constant1(ctx.getBoolType(), true, /*SpecConst*/ true);
   SpirvConstantBoolean constant2(ctx.getBoolType(), false, /*SpecConst*/ false);
   EXPECT_FALSE(constant1 == constant2);
 }
 
-// TEST(SpirvConstant, IntConstNotEqualSpecConst) {
-//  SpirvContext ctx;
-//  SpirvConstantInteger constant1(ctx.getSIntType(32), 5, /*SpecConst*/ true);
-//  SpirvConstantInteger constant2(ctx.getSIntType(32), 7, /*SpecConst*/ false);
-//  EXPECT_FALSE(constant1 == constant2);
-//}
+TEST_F(SpirvConstantTest, IntConstNotEqualSpecConst) {
+  clang::ASTContext &astContext = getAstContext();
+  const auto i32 = llvm::APInt(32, 7, /*isSigned*/ true);
+  SpirvConstantInteger constant1(astContext.IntTy, i32, /*SpecConst*/ false);
+  SpirvConstantInteger constant2(astContext.IntTy, i32, /*SpecConst*/ true);
+  EXPECT_FALSE(constant1 == constant2);
+}
 
-// TEST(SpirvConstant, FloatConstNotEqualSpecConst) {
-//   SpirvContext ctx;
-//   SpirvConstantFloat constant1(ctx.getFloatType(64), 3.14, /*SpecConst*/
-//   true); SpirvConstantFloat constant2(ctx.getFloatType(64), 3.15,
-//   /*SpecConst*/ false); EXPECT_FALSE(constant1 == constant2);
-// }
+TEST_F(SpirvConstantTest, FloatConstNotEqualSpecConst) {
+  clang::ASTContext &astContext = getAstContext();
+  const auto f32 = llvm::APFloat(1.5f);
+  SpirvConstantFloat constant1(astContext.FloatTy, f32, /*SpecConst*/ false);
+  SpirvConstantFloat constant2(astContext.FloatTy, f32, /*SpecConst*/ true);
+  EXPECT_FALSE(constant1 == constant2);
+}
 
 } // anonymous namespace

+ 2 - 53
tools/clang/unittests/SPIRV/SpirvContextTest.cpp

@@ -7,64 +7,13 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "clang/SPIRV/SpirvContext.h"
-#include "clang/AST/ASTContext.h"
-#include "clang/AST/Type.h"
-#include "clang/Basic/TargetInfo.h"
-#include "clang/Frontend/CompilerInstance.h"
-#include "clang/Frontend/TextDiagnosticPrinter.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
+#include "SpirvTestBase.h"
 
 using namespace clang::spirv;
 
 namespace {
 
-class SpirvContextTest : public ::testing::Test {
-public:
-  SpirvContextTest() : spvContext(), compilerInstance(), initialized(false) {}
-
-  SpirvContext &getSpirvContext() { return spvContext; }
-
-  clang::ASTContext &getAstContext() {
-    if (!initialized)
-      initialize();
-    return compilerInstance.getASTContext();
-  }
-
-private:
-  // We don't initialize the compiler instance unless it is asked for in order
-  // to make the tests run faster.
-  void initialize() {
-    std::string warnings;
-    llvm::raw_string_ostream w(warnings);
-    std::unique_ptr<clang::TextDiagnosticPrinter> diagPrinter =
-        llvm::make_unique<clang::TextDiagnosticPrinter>(
-            w, &compilerInstance.getDiagnosticOpts());
-
-    std::shared_ptr<clang::TargetOptions> targetOptions(
-        new clang::TargetOptions);
-    targetOptions->Triple = "dxil-ms-dx";
-    compilerInstance.createDiagnostics(diagPrinter.get(), false);
-    compilerInstance.createFileManager();
-    compilerInstance.createSourceManager(compilerInstance.getFileManager());
-    compilerInstance.setTarget(clang::TargetInfo::CreateTargetInfo(
-        compilerInstance.getDiagnostics(), targetOptions));
-
-    clang::HeaderSearchOptions &HSOpts = compilerInstance.getHeaderSearchOpts();
-    HSOpts.UseBuiltinIncludes = 0;
-
-    compilerInstance.createPreprocessor(
-        clang::TranslationUnitKind::TU_Complete);
-    compilerInstance.createASTContext();
-    initialized = true;
-  }
-
-private:
-  SpirvContext spvContext;
-  clang::CompilerInstance compilerInstance;
-  bool initialized;
-};
+class SpirvContextTest : public SpirvTestBase {};
 
 TEST_F(SpirvContextTest, VoidTypeUnique) {
   SpirvContext &spvContext = getSpirvContext();

+ 66 - 0
tools/clang/unittests/SPIRV/SpirvTestBase.h

@@ -0,0 +1,66 @@
+//===- unittests/SPIRV/SpirvTestBase.h ---- Base class for SPIR-V Tests ---===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/SPIRV/SpirvContext.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/Type.h"
+#include "clang/Basic/TargetInfo.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Frontend/TextDiagnosticPrinter.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace clang::spirv;
+
+class SpirvTestBase : public ::testing::Test {
+public:
+  SpirvTestBase() : spvContext(), compilerInstance(), initialized(false) {}
+
+  SpirvContext &getSpirvContext() { return spvContext; }
+
+  clang::ASTContext &getAstContext() {
+    if (!initialized)
+      initialize();
+    return compilerInstance.getASTContext();
+  }
+
+private:
+  // We don't initialize the compiler instance unless it is asked for in order
+  // to make the tests run faster.
+  void initialize() {
+    std::string warnings;
+    llvm::raw_string_ostream w(warnings);
+    std::unique_ptr<clang::TextDiagnosticPrinter> diagPrinter =
+        llvm::make_unique<clang::TextDiagnosticPrinter>(
+            w, &compilerInstance.getDiagnosticOpts());
+
+    std::shared_ptr<clang::TargetOptions> targetOptions(
+        new clang::TargetOptions);
+    targetOptions->Triple = "dxil-ms-dx";
+    compilerInstance.createDiagnostics(diagPrinter.get(), false);
+    compilerInstance.createFileManager();
+    compilerInstance.createSourceManager(compilerInstance.getFileManager());
+    compilerInstance.setTarget(clang::TargetInfo::CreateTargetInfo(
+        compilerInstance.getDiagnostics(), targetOptions));
+
+    clang::HeaderSearchOptions &HSOpts = compilerInstance.getHeaderSearchOpts();
+    HSOpts.UseBuiltinIncludes = 0;
+
+    compilerInstance.createPreprocessor(
+        clang::TranslationUnitKind::TU_Complete);
+    compilerInstance.createASTContext();
+    initialized = true;
+  }
+
+private:
+  SpirvContext spvContext;
+  clang::CompilerInstance compilerInstance;
+  bool initialized;
+};
+

+ 170 - 0
tools/clang/unittests/SPIRV/SpirvTypeTest.cpp

@@ -0,0 +1,170 @@
+//===- unittests/SPIRV/SpirvTypeTest.cpp - Tests For SPIR-V Type classes --===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/SPIRV/SpirvType.h"
+#include "SpirvTestBase.h"
+
+using namespace clang::spirv;
+
+namespace {
+
+class SpirvTypeTest : public SpirvTestBase {};
+
+TEST_F(SpirvTypeTest, VoidType) {
+  VoidType t;
+  EXPECT_TRUE(llvm::isa<VoidType>(t));
+}
+
+TEST_F(SpirvTypeTest, BoolType) {
+  BoolType t;
+  EXPECT_TRUE(llvm::isa<BoolType>(t));
+}
+
+TEST_F(SpirvTypeTest, IntType) {
+  IntegerType sint16(16, true);
+  IntegerType uint32(32, false);
+  EXPECT_TRUE(llvm::isa<IntegerType>(sint16));
+  EXPECT_TRUE(llvm::isa<IntegerType>(uint32));
+  EXPECT_TRUE(llvm::isa<NumericalType>(sint16));
+  EXPECT_TRUE(llvm::isa<NumericalType>(uint32));
+  EXPECT_EQ(16, sint16.getBitwidth());
+  EXPECT_EQ(32, uint32.getBitwidth());
+  EXPECT_EQ(true, sint16.isSignedInt());
+  EXPECT_EQ(false, uint32.isSignedInt());
+}
+
+TEST_F(SpirvTypeTest, FloatType) {
+  FloatType f16(16);
+  EXPECT_TRUE(llvm::isa<FloatType>(f16));
+  EXPECT_TRUE(llvm::isa<NumericalType>(f16));
+  EXPECT_EQ(16, f16.getBitwidth());
+}
+
+TEST_F(SpirvTypeTest, VectorType) {
+  FloatType f16(16);
+  VectorType float3(&f16, 3);
+  EXPECT_TRUE(llvm::isa<VectorType>(float3));
+  EXPECT_EQ(&f16, float3.getElementType());
+  EXPECT_EQ(3, float3.getElementCount());
+}
+
+TEST_F(SpirvTypeTest, MatrixType) {
+  FloatType f16(16);
+  VectorType float3(&f16, 3);
+  MatrixType mat2x3(&float3, 2);
+
+  EXPECT_TRUE(llvm::isa<MatrixType>(mat2x3));
+  EXPECT_EQ(&f16, float3.getElementType());
+  EXPECT_EQ(2, mat2x3.getVecCount());
+  EXPECT_EQ(2, mat2x3.numCols());
+  EXPECT_EQ(3, mat2x3.numRows());
+}
+
+TEST_F(SpirvTypeTest, ImageType) {
+  FloatType f16(16);
+  ImageType img(&f16, spv::Dim::Dim2D, ImageType::WithDepth::Yes,
+                /*isArrayed*/ false, /*isMultiSampled*/ true,
+                ImageType::WithSampler::No, spv::ImageFormat::R16f);
+
+  EXPECT_TRUE(llvm::isa<ImageType>(img));
+  EXPECT_EQ(&f16, img.getSampledType());
+  EXPECT_EQ(spv::Dim::Dim2D, img.getDimension());
+  EXPECT_EQ(ImageType::WithDepth::Yes, img.getDepth());
+  EXPECT_EQ(false, img.isArrayedImage());
+  EXPECT_EQ(true, img.isMSImage());
+  EXPECT_EQ(ImageType::WithSampler::No, img.withSampler());
+  EXPECT_EQ(spv::ImageFormat::R16f, img.getImageFormat());
+  EXPECT_EQ(img.getName(), "type.2d.image");
+  EXPECT_FALSE(SpirvType::isTexture(&img));
+  EXPECT_TRUE(SpirvType::isRWTexture(&img));
+}
+
+TEST_F(SpirvTypeTest, SamplerType) {
+  SamplerType t;
+  EXPECT_TRUE(llvm::isa<SamplerType>(t));
+  EXPECT_EQ(t.getName(), "type.sampler");
+}
+
+TEST_F(SpirvTypeTest, SampledImageType) {
+  FloatType f16(16);
+  ImageType img(&f16, spv::Dim::Dim2D, ImageType::WithDepth::Yes,
+                /*isArrayed*/ false, /*isMultiSampled*/ true,
+                ImageType::WithSampler::No, spv::ImageFormat::R16f);
+  SampledImageType s(&img);
+
+  EXPECT_TRUE(llvm::isa<SampledImageType>(s));
+  EXPECT_EQ(s.getName(), "type.sampled.image");
+  EXPECT_EQ(s.getImageType(), &img);
+}
+
+TEST_F(SpirvTypeTest, ArrayType) {
+  FloatType f16(16);
+  ArrayType arr5(&f16, 5, 2);
+  EXPECT_TRUE(llvm::isa<ArrayType>(arr5));
+  EXPECT_EQ(arr5.getElementType(), &f16);
+  EXPECT_EQ(arr5.getElementCount(), 5);
+  EXPECT_TRUE(arr5.getStride().hasValue());
+  EXPECT_EQ(arr5.getStride().getValue(), 2);
+}
+
+TEST_F(SpirvTypeTest, RuntimeArrayType) {
+  FloatType f16(16);
+  RuntimeArrayType ra(&f16, 2);
+  EXPECT_TRUE(llvm::isa<RuntimeArrayType>(ra));
+  EXPECT_EQ(ra.getElementType(), &f16);
+  EXPECT_TRUE(ra.getStride().hasValue());
+  EXPECT_EQ(ra.getStride().getValue(), 2);
+}
+
+TEST_F(SpirvTypeTest, StructType) {
+  IntegerType int32(32, true);
+  IntegerType uint32(32, false);
+
+  StructType::FieldInfo field0(&int32, "field1");
+  StructType::FieldInfo field1(&uint32, "field2", /*offset*/ 4,
+                               /*matrixStride*/ 16, /*isRowMajor*/ false);
+
+  StructType s({field0, field1}, "some_struct", /*isReadOnly*/ true,
+               StructInterfaceType::InternalStorage);
+
+  EXPECT_TRUE(llvm::isa<StructType>(s));
+  EXPECT_EQ(s.getName(), "some_struct");
+  EXPECT_EQ(s.getStructName(), "some_struct");
+
+  const auto &fields = s.getFields();
+  EXPECT_EQ(2, fields.size());
+  EXPECT_EQ(fields[0], field0);
+  EXPECT_EQ(fields[1], field1);
+  EXPECT_TRUE(s.isReadOnly());
+  EXPECT_EQ(s.getInterfaceType(), StructInterfaceType::InternalStorage);
+}
+
+TEST_F(SpirvTypeTest, SpirvPointerType) {
+  FloatType f16(16);
+  SpirvPointerType ptr(&f16, spv::StorageClass::UniformConstant);
+  EXPECT_TRUE(llvm::isa<SpirvPointerType>(ptr));
+  EXPECT_EQ(ptr.getStorageClass(), spv::StorageClass::UniformConstant);
+  EXPECT_EQ(ptr.getPointeeType(), &f16);
+}
+
+TEST_F(SpirvTypeTest, FunctionType) {
+  FloatType f16(16);
+  IntegerType uint32(32, false);
+  BoolType retType;
+  FunctionType fnType(&retType, {&f16, &uint32});
+  EXPECT_TRUE(llvm::isa<FunctionType>(fnType));
+  EXPECT_EQ(fnType.getReturnType(), &retType);
+  EXPECT_EQ(fnType.getParamTypes().size(), 2u);
+  EXPECT_EQ(fnType.getParamTypes()[0], &f16);
+  EXPECT_EQ(fnType.getParamTypes()[1], &uint32);
+}
+
+// TODO: Add tests for HybridTypes.
+
+} // anonymous namespace