Răsfoiți Sursa

[spirv] Support primitive types, their constants, and assignments (#458)

* [spirv] Convert primitive types

Also solves the local variable ordering problem.

* [spirv] Constants for primitive types: bool, int/uint, float

Also add support for constant variable initializers.

* [spirv] Add support for assignment for primitive types
Lei Zhang 8 ani în urmă
părinte
comite
5e4b7c46ba

+ 5 - 1
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -15,6 +15,7 @@
 #include "clang/SPIRV/SPIRVContext.h"
 #include "clang/SPIRV/Structure.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/Optional.h"
 #include "llvm/ADT/StringRef.h"
 
 namespace clang {
@@ -54,7 +55,8 @@ public:
   ///
   /// The corresponding pointer type of the given value type will be constructed
   /// for the variable itself.
-  uint32_t addFnVariable(uint32_t valueType, llvm::StringRef name = "");
+  uint32_t addFnVariable(uint32_t valueType, llvm::StringRef name = "",
+                         llvm::Optional<uint32_t> init = llvm::None);
 
   /// \brief Ends building of the current function. Returns true of success,
   /// false on failure. All basic blocks constructed from the beginning or
@@ -129,6 +131,7 @@ public:
   // === Type ===
 
   uint32_t getVoidType();
+  uint32_t getBoolType();
   uint32_t getInt32Type();
   uint32_t getUint32Type();
   uint32_t getFloat32Type();
@@ -139,6 +142,7 @@ public:
                            const std::vector<uint32_t> &paramTypes);
 
   // === Constant ===
+  uint32_t getConstantBool(bool value);
   uint32_t getConstantInt32(int32_t value);
   uint32_t getConstantUint32(uint32_t value);
   uint32_t getConstantFloat32(float value);

+ 4 - 7
tools/clang/include/clang/SPIRV/Structure.h

@@ -112,7 +112,8 @@ public:
   inline void addParameter(uint32_t paramResultType, uint32_t paramResultId);
 
   /// \brief Adds a local variable to this function.
-  inline void addVariable(uint32_t varResultType, uint32_t varResultId);
+  void addVariable(uint32_t varResultType, uint32_t varResultId,
+                   llvm::Optional<uint32_t> init);
 
   /// \brief Adds a basic block to this function.
   inline void addBasicBlock(std::unique_ptr<BasicBlock> block);
@@ -125,8 +126,8 @@ private:
 
   /// Parameter <result-type> and <result-id> pairs.
   std::vector<std::pair<uint32_t, uint32_t>> parameters;
-  /// Local variable <result-type> and <result-id> pairs.
-  std::vector<std::pair<uint32_t, uint32_t>> variables;
+  /// Local variables.
+  std::vector<Instruction> variables;
   std::vector<std::unique_ptr<BasicBlock>> blocks;
 };
 
@@ -312,10 +313,6 @@ void Function::addParameter(uint32_t rType, uint32_t rId) {
   parameters.emplace_back(rType, rId);
 }
 
-void Function::addVariable(uint32_t varType, uint32_t varId) {
-  variables.emplace_back(varType, varId);
-}
-
 void Function::addBasicBlock(std::unique_ptr<BasicBlock> block) {
   blocks.push_back(std::move(block));
 }

+ 176 - 50
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -29,7 +29,8 @@ namespace spirv {
 class SPIRVEmitter : public ASTConsumer {
 public:
   explicit SPIRVEmitter(CompilerInstance &ci)
-      : theCompilerInstance(ci), diags(ci.getDiagnostics()),
+      : theCompilerInstance(ci), astContext(ci.getASTContext()),
+        diags(ci.getDiagnostics()),
         entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
         shaderStage(getSpirvShaderStageFromHlslProfile(
             ci.getCodeGenOpts().HLSLProfile.c_str())),
@@ -203,7 +204,25 @@ public:
       const uint32_t ptrType = theBuilder.getPointerType(
           typeTranslator.translateType(decl->getType()),
           spv::StorageClass::Function);
-      const uint32_t varId = theBuilder.addFnVariable(ptrType, decl->getName());
+
+      // Handle initializer. SPIR-V requires that "initializer must be an <id>
+      // from a constant instruction or a global (module scope) OpVariable
+      // instruction."
+      llvm::Optional<uint32_t> init = llvm::None;
+      if (decl->hasInit()) {
+        const Expr *declInit = decl->getInit();
+        if (declInit->isConstantInitializer(astContext, /*ForRef*/ false)) {
+          APValue evalResult;
+          llvm::SmallVector<PartialDiagnosticAt, 0> notes;
+          declInit->EvaluateAsInitializer(evalResult, astContext, decl, notes);
+          init = llvm::Optional<uint32_t>(
+              translateAPValue(evalResult, decl->getType()));
+        }
+      }
+
+      const uint32_t varId =
+          theBuilder.addFnVariable(ptrType, decl->getName(), init);
+
       declIdMapper.registerDeclResultId(decl, varId);
     } else {
       // TODO: handle global variables
@@ -218,32 +237,18 @@ public:
       for (auto *decl : declStmt->decls()) {
         doDecl(decl);
       }
-    } else if (auto *binOp = dyn_cast<BinaryOperator>(stmt)) {
-      const auto opcode = binOp->getOpcode();
-      const uint32_t lhs = doExpr(binOp->getLHS());
-      const uint32_t rhs = doExpr(binOp->getRHS());
-
-      doBinaryOperator(opcode, lhs, rhs);
+    } else if (auto *expr = dyn_cast<Expr>(stmt)) {
+      // All cases for expressions used as statements
+      doExpr(expr);
     } else {
       emitError("Stmt '%0' is not supported yet.") << stmt->getStmtClassName();
     }
-    // TODO: handle other statements
-  }
-
-  void doBinaryOperator(BinaryOperatorKind opcode, uint32_t lhs, uint32_t rhs) {
-    if (opcode == BO_Assign) {
-      theBuilder.createStore(lhs, rhs);
-    } else {
-      emitError("BinaryOperator '%0' is not supported yet.") << opcode;
-    }
   }
 
   void doReturnStmt(ReturnStmt *stmt) {
-    // First get the <result-id> of the value we want to return.
-    const uint32_t retValue = doExpr(stmt->getRetValue());
-
+    // For normal functions, just return in the normal way.
     if (curFunction->getName() != entryFunctionName) {
-      theBuilder.createReturnValue(retValue);
+      theBuilder.createReturnValue(doExpr(stmt->getRetValue()));
       return;
     }
 
@@ -254,13 +259,14 @@ public:
     // We need to walk through the return type, and for each subtype attached
     // with semantics, write out the value to the corresponding stage variable
     // mapped to the semantic.
+
     const uint32_t stageVarId =
         declIdMapper.getRemappedDeclResultId(curFunction);
 
     if (stageVarId) {
       // The return value is mapped to a single stage variable. We just need
       // to store the value into the stage variable instead.
-      theBuilder.createStore(stageVarId, retValue);
+      theBuilder.createStore(stageVarId, doExpr(stmt->getRetValue()));
       theBuilder.createReturn();
       return;
     }
@@ -268,7 +274,16 @@ public:
     QualType retType = stmt->getRetValue()->getType();
 
     if (const auto *structType = retType->getAsStructureType()) {
-      // We are trying to return a value of struct type. Go through all fields.
+      // We are trying to return a value of struct type.
+
+      // First get the return value. Clang AST will use an LValueToRValue cast
+      // for returning a struct variable. We need to ignore the cast to avoid
+      // creating OpLoad instruction since we need the pointer to the variable
+      // for creating access chain later.
+      const uint32_t retValue =
+          doExpr(stmt->getRetValue()->IgnoreParenLValueCasts());
+
+      // Then go through all fields.
       uint32_t fieldIndex = 0;
       for (const auto *field : structType->getDecl()->fields()) {
         // Load the value from the current field.
@@ -292,24 +307,12 @@ public:
     }
   }
 
-  uint32_t doExpr(Expr *expr) {
+  uint32_t doExpr(const Expr *expr) {
     if (auto *delRefExpr = dyn_cast<DeclRefExpr>(expr)) {
       // Returns the <result-id> of the referenced Decl.
       const NamedDecl *referredDecl = delRefExpr->getFoundDecl();
       assert(referredDecl && "found non-NamedDecl referenced");
       return declIdMapper.getDeclResultId(referredDecl);
-    } else if (auto *castExpr = dyn_cast<ImplicitCastExpr>(expr)) {
-      const uint32_t fromValue = doExpr(castExpr->getSubExpr());
-      // Using lvalue as rvalue will result in a ImplicitCast in Clang AST.
-      // This place gives us a place to inject the code for handling stage
-      // variables. Since using the <result-id> of a stage variable as
-      // rvalue means OpLoad it first. For normal values, it is not required.
-      if (declIdMapper.isStageVariable(fromValue)) {
-        const uint32_t resultType =
-            typeTranslator.translateType(castExpr->getType());
-        return theBuilder.createLoad(resultType, fromValue);
-      }
-      return fromValue;
     } else if (auto *memberExpr = dyn_cast<MemberExpr>(expr)) {
       const uint32_t base = doExpr(memberExpr->getBase());
       auto *memberDecl = memberExpr->getMemberDecl();
@@ -325,41 +328,163 @@ public:
         emitError("Decl '%0' in MemberExpr is not supported yet.")
             << memberDecl->getDeclKindName();
       }
+    } else if (auto *castExpr = dyn_cast<ImplicitCastExpr>(expr)) {
+      return doImplicitCastExpr(castExpr);
     } else if (auto *cxxFunctionalCastExpr =
                    dyn_cast<CXXFunctionalCastExpr>(expr)) {
       // Explicit cast is a NO-OP (e.g. vector<float, 4> -> float4)
       if (cxxFunctionalCastExpr->getCastKind() == CK_NoOp) {
         return doExpr(cxxFunctionalCastExpr->getSubExpr());
-      } else {
-        emitError("Found unhandled CXXFunctionalCastExpr cast type: %0")
-            << cxxFunctionalCastExpr->getCastKindName();
       }
+      emitError("Found unhandled CXXFunctionalCastExpr cast type: %0")
+          << cxxFunctionalCastExpr->getCastKindName();
+      return 0;
     } else if (auto *initListExpr = dyn_cast<InitListExpr>(expr)) {
-      const bool isConstantInitializer = expr->isConstantInitializer(
-          theCompilerInstance.getASTContext(), false);
       const uint32_t resultType =
           typeTranslator.translateType(initListExpr->getType());
+
       std::vector<uint32_t> constituents;
       for (size_t i = 0; i < initListExpr->getNumInits(); ++i) {
         constituents.push_back(doExpr(initListExpr->getInit(i)));
       }
-      if (isConstantInitializer) {
+
+      if (expr->isConstantInitializer(astContext, false)) {
         return theBuilder.getConstantComposite(resultType, constituents);
-      } else {
-        // TODO: use OpCompositeConstruct if it is not a constant initializer
-        // list.
-        emitError("Non-const initializer lists are currently not supported.");
       }
-    } else if (auto *floatingLiteral = dyn_cast<FloatingLiteral>(expr)) {
-      // TODO: use floatingLiteral->getType() to also handle float64 cases.
-      const float value = floatingLiteral->getValue().convertToFloat();
-      return theBuilder.getConstantFloat32(value);
+      // TODO: use OpCompositeConstruct for non-constant initializer lists.
+      emitError("Non-const initializer lists are currently not supported.");
+      return 0;
+    } else if (auto *boolLiteral = dyn_cast<CXXBoolLiteralExpr>(expr)) {
+      const bool value = boolLiteral->getValue();
+      return theBuilder.getConstantBool(value);
+    } else if (auto *intLiteral = dyn_cast<IntegerLiteral>(expr)) {
+      return translateAPInt(intLiteral->getValue(), expr->getType());
+    } else if (auto *floatLiteral = dyn_cast<FloatingLiteral>(expr)) {
+      return translateAPFloat(floatLiteral->getValue(), expr->getType());
+    } else if (auto *binOp = dyn_cast<BinaryOperator>(expr)) {
+      return doBinaryOperator(binOp);
     }
+
     emitError("Expr '%0' is not supported yet.") << expr->getStmtClassName();
     // TODO: handle other expressions
     return 0;
   }
 
+  uint32_t doBinaryOperator(const BinaryOperator *expr) {
+    const auto opcode = expr->getOpcode();
+    const uint32_t rhs = doExpr(expr->getRHS());
+    const uint32_t lhs = doExpr(expr->getLHS());
+
+    switch (opcode) {
+    case BO_Assign:
+      theBuilder.createStore(lhs, rhs);
+      // Assignment returns a rvalue.
+      return rhs;
+    default:
+      break;
+    }
+
+    emitError("BinaryOperator '%0' is not supported yet.") << opcode;
+    return 0;
+  }
+
+  uint32_t doImplicitCastExpr(const ImplicitCastExpr *expr) {
+    const Expr *subExpr = expr->getSubExpr();
+    const QualType toType = expr->getType();
+
+    switch (expr->getCastKind()) {
+    // Integer literals in the AST are represented using 64bit APInt
+    // themselves and then implicitly casted into the expected bitwidth.
+    // We need special treatment of integer literals here because generating
+    // a 64bit constant and then explicit casting in SPIR-V requires Int64
+    // capability. We should avoid introducing unnecessary capabilities to
+    // our best.
+    case CastKind::CK_IntegralCast: {
+      llvm::APSInt intValue;
+      if (expr->EvaluateAsInt(intValue, astContext, Expr::SE_NoSideEffects)) {
+        return translateAPInt(intValue, toType);
+      } else {
+        emitError("Integral cast is not supported yet");
+        return 0;
+      }
+    }
+    case CastKind::CK_LValueToRValue: {
+      const uint32_t fromValue = doExpr(subExpr);
+      // Using lvalue as rvalue means we need to OpLoad the contents from
+      // the parameter/variable first.
+      const uint32_t resultType = typeTranslator.translateType(toType);
+      return theBuilder.createLoad(resultType, fromValue);
+    }
+    default:
+      emitError("ImplictCast Kind '%0' is not supported yet.")
+          << expr->getCastKind();
+      return 0;
+    }
+  }
+
+  uint32_t translateAPValue(const APValue &value, const QualType targetType) {
+    if (targetType->isBooleanType()) {
+      const bool boolValue = value.getInt().getBoolValue();
+      return theBuilder.getConstantBool(boolValue);
+    }
+
+    if (targetType->isIntegerType()) {
+      const llvm::APInt &intValue = value.getInt();
+      return translateAPInt(intValue, targetType);
+    }
+
+    if (targetType->isFloatingType()) {
+      const llvm::APFloat &floatValue = value.getFloat();
+      return translateAPFloat(floatValue, targetType);
+    }
+
+    emitError("APValue of type '%0' is not supported yet.") << value.getKind();
+    return 0;
+  }
+
+  uint32_t translateAPInt(const llvm::APInt &intValue, QualType targetType) {
+    const auto bitwidth = astContext.getIntWidth(targetType);
+
+    if (targetType->isSignedIntegerType()) {
+      const int64_t value = intValue.getSExtValue();
+      switch (bitwidth) {
+      case 32:
+        return theBuilder.getConstantInt32(static_cast<int32_t>(value));
+      default:
+        break;
+      }
+    } else {
+      const uint64_t value = intValue.getZExtValue();
+      switch (bitwidth) {
+      case 32:
+        return theBuilder.getConstantUint32(static_cast<uint32_t>(value));
+      default:
+        break;
+      }
+    }
+
+    emitError("APInt for target bitwidth '%0' is not supported yet.")
+        << bitwidth;
+    return 0;
+  }
+
+  uint32_t translateAPFloat(const llvm::APFloat &floatValue,
+                            QualType targetType) {
+    const auto &semantics = astContext.getFloatTypeSemantics(targetType);
+    const auto bitwidth = llvm::APFloat::getSizeInBits(semantics);
+
+    switch (bitwidth) {
+    case 32:
+      return theBuilder.getConstantFloat32(floatValue.convertToFloat());
+    default:
+      break;
+    }
+
+    emitError("APFloat for target bitwidth '%0' is not supported yet.")
+        << bitwidth;
+    return 0;
+  }
+
 private:
   /// \brief Wrapper method to create an error message and report it
   /// in the diagnostic engine associated with this consumer.
@@ -380,6 +505,7 @@ private:
 
 private:
   CompilerInstance &theCompilerInstance;
+  ASTContext &astContext;
   DiagnosticsEngine &diags;
 
   /// Entry function name and shader stage. Both of them are derived from the

+ 18 - 6
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -64,11 +64,12 @@ uint32_t ModuleBuilder::addFnParameter(uint32_t type, llvm::StringRef name) {
   return paramId;
 }
 
-uint32_t ModuleBuilder::addFnVariable(uint32_t type, llvm::StringRef name) {
+uint32_t ModuleBuilder::addFnVariable(uint32_t type, llvm::StringRef name,
+                                      llvm::Optional<uint32_t> init) {
   assert(theFunction && "found detached local variable");
 
   const uint32_t varId = theContext.takeNextId();
-  theFunction->addVariable(type, varId);
+  theFunction->addVariable(type, varId, init);
   theModule.addDebugName(varId, name);
   return varId;
 }
@@ -214,6 +215,7 @@ uint32_t ModuleBuilder::get##ty##Type() {                                      \
 }
 
 IMPL_GET_PRIMITIVE_TYPE(Void)
+IMPL_GET_PRIMITIVE_TYPE(Bool)
 IMPL_GET_PRIMITIVE_TYPE(Int32)
 IMPL_GET_PRIMITIVE_TYPE(Uint32)
 IMPL_GET_PRIMITIVE_TYPE(Float32)
@@ -268,7 +270,17 @@ ModuleBuilder::getFunctionType(uint32_t returnType,
   return typeId;
 }
 
-#define IMPL_GET_PRIMITIVE_VALUE(builderTy, cppTy)                             \
+uint32_t ModuleBuilder::getConstantBool(bool value) {
+  const uint32_t typeId = getBoolType();
+  const Constant *constant = value ? Constant::getTrue(theContext, typeId)
+                                   : Constant::getFalse(theContext, typeId);
+
+  const uint32_t constId = theContext.getResultIdForConstant(constant);
+  theModule.addConstant(constant, constId);
+  return constId;
+}
+
+#define IMPL_GET_PRIMITIVE_CONST(builderTy, cppTy)                             \
   \
 uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) {                  \
     const uint32_t typeId = get##builderTy##Type();                            \
@@ -280,9 +292,9 @@ uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) {                  \
   \
 }
 
-IMPL_GET_PRIMITIVE_VALUE(Int32, int32_t)
-IMPL_GET_PRIMITIVE_VALUE(Uint32, uint32_t)
-IMPL_GET_PRIMITIVE_VALUE(Float32, float)
+IMPL_GET_PRIMITIVE_CONST(Int32, int32_t)
+IMPL_GET_PRIMITIVE_CONST(Uint32, uint32_t)
+IMPL_GET_PRIMITIVE_CONST(Float32, float)
 
 #undef IMPL_GET_PRIMITIVE_VALUE
 

+ 15 - 7
tools/clang/lib/SPIRV/Structure.cpp

@@ -96,6 +96,14 @@ void Function::clear() {
   blocks.clear();
 }
 
+void Function::addVariable(uint32_t varType, uint32_t varId,
+                           llvm::Optional<uint32_t> init) {
+  variables.emplace_back(
+      InstBuilder(nullptr)
+          .opVariable(varType, varId, spv::StorageClass::Function, init)
+          .take());
+}
+
 void Function::take(InstBuilder *builder) {
   builder->opFunction(resultType, resultId, funcControl, funcType).x();
 
@@ -105,12 +113,12 @@ void Function::take(InstBuilder *builder) {
   }
 
   // Preprend all local variables to the entry block.
-  for (auto &var : variables) {
-    blocks.front()->prependInstruction(
-        builder
-            ->opVariable(var.first, var.second, spv::StorageClass::Function,
-                         llvm::None)
-            .take());
+  // This is necessary since SPIR-V requires all local variables to be defined
+  // at the very begining of the entry block.
+  // We need to do it in the reverse order to guarantee variables have the
+  // same definition order in SPIR-V as in the source code.
+  for (auto it = variables.rbegin(), ie = variables.rend(); it != ie; ++it) {
+    blocks.front()->prependInstruction(std::move(*it));
   }
 
   // Write out all basic blocks.
@@ -181,7 +189,7 @@ void SPIRVModule::takeConstantForArrayType(const Type *arrType,
   // If it finds the constant, feeds it into the consumer, and removes it
   // from the constants collection.
   constants.remove_if([&consumer, arrayLengthResultId](
-                          std::pair<const Constant *, uint32_t> &item) {
+      std::pair<const Constant *, uint32_t> &item) {
     const bool isArrayLengthConstant = (item.second == arrayLengthResultId);
     if (isArrayLengthConstant)
       consumer(item.first->withResultId(item.second));

+ 15 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -21,6 +21,12 @@ uint32_t TypeTranslator::translateType(QualType type) {
     switch (builtinType->getKind()) {
     case BuiltinType::Void:
       return theBuilder.getVoidType();
+    case BuiltinType::Bool:
+      return theBuilder.getBoolType();
+    case BuiltinType::Int:
+      return theBuilder.getInt32Type();
+    case BuiltinType::UInt:
+      return theBuilder.getUint32Type();
     case BuiltinType::Float:
       return theBuilder.getFloat32Type();
     default:
@@ -30,11 +36,20 @@ uint32_t TypeTranslator::translateType(QualType type) {
     }
   }
 
+  if (const auto *typedefType = dyn_cast<TypedefType>(typePtr)) {
+    return translateType(typedefType->desugar());
+  }
+
   // In AST, vector types are TypedefType of TemplateSpecializationType.
   // We handle them via HLSL type inspection functions.
   if (hlsl::IsHLSLVecType(type)) {
     const auto elemType = hlsl::GetHLSLVecElementType(type);
     const auto elemCount = hlsl::GetHLSLVecSize(type);
+    // In SPIR-V, vectors must have two or more elements. So translate vectors
+    // of size 1 into the underlying primitive types directly.
+    if (elemCount == 1) {
+      return translateType(elemType);
+    }
     return theBuilder.getVecType(translateType(elemType), elemCount);
   }
 

+ 23 - 0
tools/clang/test/CodeGenSPIRV/binary-op.assign.hlsl

@@ -0,0 +1,23 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// TODO: assignment for composite types
+
+void main() {
+    int a, b, c;
+
+// CHECK: [[b0:%\d+]] = OpLoad %int %b
+// CHECK-NEXT: OpStore %a [[b0]]
+    a = b;
+// CHECK-NEXT: [[c0:%\d+]] = OpLoad %int %c
+// CHECK-NEXT: OpStore %b [[c0]]
+// CHECK-NEXT: OpStore %a [[c0]]
+    a = b = c;
+
+// CHECK-NEXT: [[a0:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: OpStore %a [[a0]]
+    a = a;
+// CHECK-NEXT: [[a1:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: OpStore %a [[a1]]
+// CHECK-NEXT: OpStore %a [[a1]]
+    a = a = a;
+}

+ 0 - 10
tools/clang/test/CodeGenSPIRV/check-entrypoint.hlsl

@@ -1,10 +0,0 @@
-// Run: %dxc -T ps_6_0 -E main
-
-// CHECK: OpCapability Shader
-// CHECK-NOT: OpCapability Kernel
-// CHECK-NEXT: OpMemoryModel Logical GLSL450
-void main()
-// CHECK-NEXT: OpEntryPoint Fragment %main "main"
-{
-
-}

+ 49 - 0
tools/clang/test/CodeGenSPIRV/constant.scalar.hlsl

@@ -0,0 +1,49 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// TODO
+// 16bit & 64bit integer & floats (require additional capability)
+// float: denormalized numbers, Inf, NaN
+
+void main() {
+  // Boolean constants
+// CHECK-DAG: %true = OpConstantTrue %bool
+  bool c_bool_t = true;
+// CHECK-DAG: %false = OpConstantFalse %bool
+  bool c_bool_f = false;
+
+  // Signed integer constants
+// CHECK-DAG: %int_0 = OpConstant %int 0
+  int c_int_0 = 0;
+// CHECK-DAG: %int_1 = OpConstant %int 1
+  int c_int_1 = 1;
+// CHECK-DAG: %int_n1 = OpConstant %int -1
+  int c_int_n1 = -1;
+// CHECK-DAG: %int_42 = OpConstant %int 42
+  int c_int_42 = 42;
+// CHECK-DAG: %int_n42 = OpConstant %int -42
+  int c_int_n42 = -42;
+// CHECK-DAG: %int_2147483647 = OpConstant %int 2147483647
+  int c_int_max = 2147483647;
+// CHECK-DAG: %int_n2147483648 = OpConstant %int -2147483648
+  int c_int_min = -2147483648;
+
+  // Unsigned integer constants
+// CHECK-DAG: %uint_0 = OpConstant %uint 0
+  uint c_uint_0 = 0;
+// CHECK-DAG: %uint_1 = OpConstant %uint 1
+  uint c_uint_1 = 1;
+// CHECK-DAG: %uint_38 = OpConstant %uint 38
+  uint c_uint_38 = 38;
+// CHECK-DAG: %uint_4294967295 = OpConstant %uint 4294967295
+  uint c_uint_max = 4294967295;
+
+  // Float constants
+// CHECK-DAG: %float_0 = OpConstant %float 0
+  float c_float_0 = 0.;
+// CHECK-DAG: %float_n0 = OpConstant %float -0
+  float c_float_n0 = -0.;
+// CHECK-DAG: %float_4_2 = OpConstant %float 4.2
+  float c_float_4_2 = 4.2;
+// CHECK-DAG: %float_n4_2 = OpConstant %float -4.2
+  float c_float_n4_2 = -4.2;
+}

+ 6 - 6
tools/clang/test/CodeGenSPIRV/passthru-vs.hlsl2spv

@@ -48,12 +48,12 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
 // %VSmain = OpFunction %void None %10
 // %bb_entry = OpLabel
 // %result = OpVariable %_ptr_Function__struct_13 Function
-// %19 = OpAccessChain %_ptr_Function_v4float %result %int_0
-// %20 = OpLoad %v4float %7
-// OpStore %19 %20
-// %22 = OpAccessChain %_ptr_Function_v4float %result %int_1
-// %23 = OpLoad %v4float %8
-// OpStore %22 %23
+// %16 = OpLoad %v4float %7
+// %20 = OpAccessChain %_ptr_Function_v4float %result %int_0
+// OpStore %20 %16
+// %21 = OpLoad %v4float %8
+// %23 = OpAccessChain %_ptr_Function_v4float %result %int_1
+// OpStore %23 %21
 // %24 = OpAccessChain %_ptr_Function_v4float %result %int_0
 // %25 = OpLoad %v4float %24
 // OpStore %gl_Position %25

+ 38 - 0
tools/clang/test/CodeGenSPIRV/type.scalar.hlsl

@@ -0,0 +1,38 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// TODO
+// - 16bit & 64bit integers/floats (require additional capabilities)
+
+// CHECK-DAG: %void = OpTypeVoid
+// CHECK-DAG: %{{[0-9]+}} = OpTypeFunction %void
+void main() {
+// CHECK-DAG: %bool = OpTypeBool
+// CHECK-DAG: %_ptr_Function_bool = OpTypePointer Function %bool
+  bool a;
+// CHECK-DAG: %int = OpTypeInt 32 1
+// CHECK-DAG: %_ptr_Function_int = OpTypePointer Function %int
+  int b;
+// CHECK-DAG: %uint = OpTypeInt 32 0
+// CHECK-DAG: %_ptr_Function_uint = OpTypePointer Function %uint
+  uint c;
+  dword d;
+// CHECK-DAG: %float = OpTypeFloat 32
+// CHECK-DAG: %_ptr_Function_float = OpTypePointer Function %float
+  float e;
+  bool1 a1;
+  int1 b1;
+  uint1 c1;
+  dword1 d1;
+  float1 e1;
+
+// CHECK: %a = OpVariable %_ptr_Function_bool Function
+// CHECK-NEXT: %b = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: %c = OpVariable %_ptr_Function_uint Function
+// CHECK-NEXT: %d = OpVariable %_ptr_Function_uint Function
+// CHECK-NEXT: %e = OpVariable %_ptr_Function_float Function
+// CHECK-NEXT: %a1 = OpVariable %_ptr_Function_bool Function
+// CHECK-NEXT: %b1 = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: %c1 = OpVariable %_ptr_Function_uint Function
+// CHECK-NEXT: %d1 = OpVariable %_ptr_Function_uint Function
+// CHECK-NEXT: %e1 = OpVariable %_ptr_Function_float Function
+}

+ 14 - 16
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -14,32 +14,30 @@ namespace {
 using clang::spirv::FileTest;
 using clang::spirv::WholeFileTest;
 
+// === Whole output tests ===
+
 TEST_F(WholeFileTest, EmptyVoidMain) {
-  runWholeFileTest("empty-void-main.hlsl2spv",
-                   /*generateHeader*/ true,
-                   /*runValidation*/ true);
+  runWholeFileTest("empty-void-main.hlsl2spv", /*generateHeader*/ true);
 }
 
 TEST_F(WholeFileTest, PassThruPixelShader) {
-  runWholeFileTest("passthru-ps.hlsl2spv",
-                   /*generateHeader*/ true,
-                   /*runValidation*/ true);
+  runWholeFileTest("passthru-ps.hlsl2spv", /*generateHeader*/ true);
 }
 
 TEST_F(WholeFileTest, PassThruVertexShader) {
-  runWholeFileTest("passthru-vs.hlsl2spv",
-                   /*generateHeader*/ true,
-                   /*runValidation*/ true);
+  runWholeFileTest("passthru-vs.hlsl2spv", /*generateHeader*/ true);
 }
 
 TEST_F(WholeFileTest, ConstantPixelShader) {
-  runWholeFileTest("constant-ps.hlsl2spv",
-                   /*generateHeader*/ true,
-                   /*runValidation*/ true);
+  runWholeFileTest("constant-ps.hlsl2spv", /*generateHeader*/ true);
 }
 
-TEST_F(FileTest, CheckMemoryModelAndEntryPoint) {
-  runFileTest("check-entrypoint.hlsl",
-              /*runValidation*/ true);
-}
+// === Partial output tests ===
+
+TEST_F(FileTest, ScalarTypes) { runFileTest("type.scalar.hlsl"); }
+
+TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }
+
+TEST_F(FileTest, BinaryOpAssign) { runFileTest("binary-op.assign.hlsl"); }
+
 } // namespace