Преглед изворни кода

[spirv] LiteralTypeVisitor (and reverse visiting).

Ehsan Nasiri пре 6 година
родитељ
комит
52bfb4e86c

+ 12 - 0
tools/clang/include/clang/SPIRV/AstTypeProbe.h

@@ -112,6 +112,14 @@ uint32_t getElementSpirvBitwidth(const ASTContext &astContext, QualType type,
 /// constnesss and literalness.
 /// constnesss and literalness.
 bool canTreatAsSameScalarType(QualType type1, QualType type2);
 bool canTreatAsSameScalarType(QualType type1, QualType type2);
 
 
+/// \brief Returns true if the two types are the same scalar or vector type,
+/// regardless of constness and literalness.
+bool isSameScalarOrVecType(QualType type1, QualType type2);
+
+  /// \brief Returns true if the two types are the same type, regardless of
+  /// constness and literalness.
+bool isSameType(const ASTContext &, QualType type1, QualType type2);
+
 /// Returns true if all members in structType are of the same element
 /// Returns true if all members in structType are of the same element
 /// type and can be fit into a 4-component vector. Writes element type and
 /// type and can be fit into a 4-component vector. Writes element type and
 /// count to *elemType and *elemCount if not nullptr. Otherwise, emit errors
 /// count to *elemType and *elemCount if not nullptr. Otherwise, emit errors
@@ -131,6 +139,10 @@ QualType getTypeWithCustomBitwidth(const ASTContext &, QualType type,
 /// Returns true if the given type is a matrix or an array of matrices.
 /// Returns true if the given type is a matrix or an array of matrices.
 bool isMatrixOrArrayOfMatrix(const ASTContext &, QualType type);
 bool isMatrixOrArrayOfMatrix(const ASTContext &, QualType type);
 
 
+/// Returns true if the given type is a LitInt or LitFloat type or a vector of
+/// them. Returns false otherwise.
+bool isLitTypeOrVecOfLitType(QualType type);
+
 } // namespace spirv
 } // namespace spirv
 } // namespace clang
 } // namespace clang
 
 

+ 2 - 1
tools/clang/include/clang/SPIRV/SpirvBasicBlock.h

@@ -71,7 +71,8 @@ public:
   /// Handle SPIR-V basic block visitors.
   /// Handle SPIR-V basic block visitors.
   /// If a basic block is the first basic block in a function, it must include
   /// If a basic block is the first basic block in a function, it must include
   /// all the variable definitions of the entire function.
   /// all the variable definitions of the entire function.
-  bool invokeVisitor(Visitor *, llvm::ArrayRef<SpirvVariable *> vars = {});
+  bool invokeVisitor(Visitor *, llvm::ArrayRef<SpirvVariable *> vars,
+                     bool reverseOrder = false);
 
 
   /// \brief Adds the given basic block as a successsor to this basic block.
   /// \brief Adds the given basic block as a successsor to this basic block.
   void addSuccessor(SpirvBasicBlock *bb);
   void addSuccessor(SpirvBasicBlock *bb);

+ 1 - 1
tools/clang/include/clang/SPIRV/SpirvFunction.h

@@ -38,7 +38,7 @@ public:
   SpirvFunction &operator=(SpirvFunction &&) = delete;
   SpirvFunction &operator=(SpirvFunction &&) = delete;
 
 
   // Handle SPIR-V function visitors.
   // Handle SPIR-V function visitors.
-  bool invokeVisitor(Visitor *);
+  bool invokeVisitor(Visitor *, bool reverseOrder = false);
 
 
   uint32_t getResultId() const { return functionId; }
   uint32_t getResultId() const { return functionId; }
   void setResultId(uint32_t id) { functionId = id; }
   void setResultId(uint32_t id) { functionId = id; }

+ 3 - 0
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -126,6 +126,8 @@ public:
   Kind getKind() const { return kind; }
   Kind getKind() const { return kind; }
   spv::Op getopcode() const { return opcode; }
   spv::Op getopcode() const { return opcode; }
   QualType getAstResultType() const { return astResultType; }
   QualType getAstResultType() const { return astResultType; }
+  void setAstResultType(QualType type) { astResultType = type; }
+  bool hasAstResultType() const { return astResultType != QualType(); }
 
 
   uint32_t getResultTypeId() const { return resultTypeId; }
   uint32_t getResultTypeId() const { return resultTypeId; }
   void setResultTypeId(uint32_t id) { resultTypeId = id; }
   void setResultTypeId(uint32_t id) { resultTypeId = id; }
@@ -1688,6 +1690,7 @@ public:
   DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvUnaryOp)
   DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvUnaryOp)
 
 
   SpirvInstruction *getOperand() const { return operand; }
   SpirvInstruction *getOperand() const { return operand; }
+  bool isConversionOp() const;
 
 
 private:
 private:
   SpirvInstruction *operand;
   SpirvInstruction *operand;

+ 1 - 1
tools/clang/include/clang/SPIRV/SpirvModule.h

@@ -49,7 +49,7 @@ public:
   SpirvModule &operator=(SpirvModule &&) = delete;
   SpirvModule &operator=(SpirvModule &&) = delete;
 
 
   // Handle SPIR-V module visitors.
   // Handle SPIR-V module visitors.
-  bool invokeVisitor(Visitor *);
+  bool invokeVisitor(Visitor *, bool reverseOrder = false);
 
 
   // Add a function to the list of module functions.
   // Add a function to the list of module functions.
   void addFunction(SpirvFunction *);
   void addFunction(SpirvFunction *);

+ 66 - 0
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -600,5 +600,71 @@ bool isMatrixOrArrayOfMatrix(const ASTContext &context, QualType type) {
   return false;
   return false;
 }
 }
 
 
+bool isLitTypeOrVecOfLitType(QualType type) {
+  if (type == QualType())
+    return false;
+
+  if (type->isSpecificBuiltinType(BuiltinType::LitInt) ||
+      type->isSpecificBuiltinType(BuiltinType::LitFloat))
+    return true;
+
+  // For vector cases
+  {
+    QualType elemType = {};
+    uint32_t elemCount = 0;
+    if (isVectorType(type, &elemType, &elemCount))
+      return isLitTypeOrVecOfLitType(elemType);
+  }
+
+  return false;
+}
+
+bool isSameScalarOrVecType(QualType type1, QualType type2) {
+  { // Scalar types
+    QualType scalarType1 = {}, scalarType2 = {};
+    if (isScalarType(type1, &scalarType1) && isScalarType(type2, &scalarType2))
+      return canTreatAsSameScalarType(scalarType1, scalarType2);
+  }
+
+  { // Vector types
+    QualType elemType1 = {}, elemType2 = {};
+    uint32_t count1 = {}, count2 = {};
+    if (isVectorType(type1, &elemType1, &count1) &&
+        isVectorType(type2, &elemType2, &count2))
+      return count1 == count2 && canTreatAsSameScalarType(elemType1, elemType2);
+  }
+
+  return false;
+}
+
+bool isSameType(const ASTContext &astContext, QualType type1, QualType type2) {
+  if (isSameScalarOrVecType(type1, type2))
+    return true;
+
+  type1.removeLocalConst();
+  type2.removeLocalConst();
+
+  { // Matrix types
+    QualType elemType1 = {}, elemType2 = {};
+    uint32_t row1 = 0, row2 = 0, col1 = 0, col2 = 0;
+    if (isMxNMatrix(type1, &elemType1, &row1, &col1) &&
+        isMxNMatrix(type2, &elemType2, &row2, &col2))
+      return row1 == row2 && col1 == col2 &&
+             canTreatAsSameScalarType(elemType1, elemType2);
+  }
+
+  { // Array types
+    if (const auto *arrType1 = astContext.getAsConstantArrayType(type1))
+      if (const auto *arrType2 = astContext.getAsConstantArrayType(type2))
+        return hlsl::GetArraySize(type1) == hlsl::GetArraySize(type2) &&
+               isSameType(astContext, arrType1->getElementType(),
+                          arrType2->getElementType());
+  }
+
+  // TODO: support other types if needed
+
+  return false;
+}
+
 } // namespace spirv
 } // namespace spirv
 } // namespace clang
 } // namespace clang

+ 1 - 0
tools/clang/lib/SPIRV/CMakeLists.txt

@@ -16,6 +16,7 @@ add_clang_library(clangSPIRV
   InitListHandler.cpp
   InitListHandler.cpp
   InstBuilderAuto.cpp
   InstBuilderAuto.cpp
   InstBuilderManual.cpp
   InstBuilderManual.cpp
+  LiteralTypeVisitor.cpp
   LowerTypeVisitor.cpp
   LowerTypeVisitor.cpp
   ModuleBuilder.cpp
   ModuleBuilder.cpp
   SPIRVContext.cpp
   SPIRVContext.cpp

+ 0 - 1
tools/clang/lib/SPIRV/CapabilityVisitor.h

@@ -12,7 +12,6 @@
 
 
 #include "clang/SPIRV/SPIRVContext.h"
 #include "clang/SPIRV/SPIRVContext.h"
 #include "clang/SPIRV/SpirvVisitor.h"
 #include "clang/SPIRV/SpirvVisitor.h"
-#include "llvm/ADT/Optional.h"
 
 
 namespace clang {
 namespace clang {
 namespace spirv {
 namespace spirv {

+ 44 - 47
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -1057,54 +1057,15 @@ uint32_t EmitTypeHandler::getOrCreateConstantNull(SpirvConstantNull *inst) {
 
 
 uint32_t EmitTypeHandler::getOrCreateConstantFloat(llvm::APFloat value,
 uint32_t EmitTypeHandler::getOrCreateConstantFloat(llvm::APFloat value,
                                                    const SpirvType *type) {
                                                    const SpirvType *type) {
-  // If this constant has already been emitted, return its result-id.
-  auto valueTypePair = std::pair<uint64_t, const SpirvType *>(
-      value.bitcastToAPInt().getZExtValue(), type);
-  auto foundResultId = emittedConstantFloats.find(valueTypePair);
-  if (foundResultId != emittedConstantFloats.end())
-    return foundResultId->second;
-
   assert(isa<FloatType>(type));
   assert(isa<FloatType>(type));
   const auto *floatType = dyn_cast<FloatType>(type);
   const auto *floatType = dyn_cast<FloatType>(type);
   const auto typeBitwidth = floatType->getBitwidth();
   const auto typeBitwidth = floatType->getBitwidth();
   const auto valueBitwidth = llvm::APFloat::getSizeInBits(value.getSemantics());
   const auto valueBitwidth = llvm::APFloat::getSizeInBits(value.getSemantics());
+  auto valueToUse = value;
 
 
-  // Start constructing the instruction
-  const uint32_t constantResultId = takeNextIdFunction();
-  const uint32_t typeId = emitType(type);
-  initTypeInstruction(spv::Op::OpConstant);
-  curTypeInst.push_back(typeId);
-  curTypeInst.push_back(constantResultId);
-
-  // Start constructing the value word / words
-  if (valueBitwidth == typeBitwidth) {
-    if (typeBitwidth == 16) {
-      // According to the SPIR-V Spec:
-      // When the type's bit width is less than 32-bits, the literal's value
-      // appears in the low-order bits of the word, and the high-order bits must
-      // be 0 for a floating-point type.
-      curTypeInst.push_back(
-          static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()));
-    } else if (typeBitwidth == 32) {
-      curTypeInst.push_back(
-          cast::BitwiseCast<uint32_t, float>(value.convertToFloat()));
-    } else {
-      // TODO: The ordering of the 2 words depends on the endian-ness of the
-      // host machine.
-      struct wideFloat {
-        uint32_t word0;
-        uint32_t word1;
-      };
-      wideFloat words =
-          cast::BitwiseCast<wideFloat, double>(value.convertToDouble());
-      curTypeInst.push_back(words.word0);
-      curTypeInst.push_back(words.word1);
-    }
-  }
-  // The type and the value have different widths. We need to convert the value
-  // to the width of the type. Error out if the conversion is lossy.
-  else {
-    auto valueToUse = value;
+  // If the type and the value have different widths, we need to convert the
+  // value to the width of the type. Error out if the conversion is lossy.
+  if (valueBitwidth != typeBitwidth) {
     bool losesInfo = false;
     bool losesInfo = false;
     const llvm::fltSemantics &targetSemantics =
     const llvm::fltSemantics &targetSemantics =
         typeBitwidth == 16 ? llvm::APFloat::IEEEhalf
         typeBitwidth == 16 ? llvm::APFloat::IEEEhalf
@@ -1121,13 +1082,49 @@ uint32_t EmitTypeHandler::getOrCreateConstantFloat(llvm::APFloat value,
           // So only 32/64-bit values can reach here.
           // So only 32/64-bit values can reach here.
           << std::to_string(valueBitwidth == 32 ? valueToUse.convertToFloat()
           << std::to_string(valueBitwidth == 32 ? valueToUse.convertToFloat()
                                                 : valueToUse.convertToDouble());
                                                 : valueToUse.convertToDouble());
-      curTypeInst.push_back(0u);
-    } else {
-      curTypeInst.push_back(
-          cast::BitwiseCast<uint32_t, float>(valueToUse.convertToFloat()));
+      return 0;
     }
     }
   }
   }
 
 
+  // If this constant has already been emitted, return its result-id.
+  auto valueTypePair = std::pair<uint64_t, const SpirvType *>(
+      valueToUse.bitcastToAPInt().getZExtValue(), type);
+  auto foundResultId = emittedConstantFloats.find(valueTypePair);
+  if (foundResultId != emittedConstantFloats.end())
+    return foundResultId->second;
+
+  // Start constructing the instruction
+  const uint32_t constantResultId = takeNextIdFunction();
+  const uint32_t typeId = emitType(type);
+  initTypeInstruction(spv::Op::OpConstant);
+  curTypeInst.push_back(typeId);
+  curTypeInst.push_back(constantResultId);
+
+  // Start constructing the value word / words
+
+  if (typeBitwidth == 16) {
+    // According to the SPIR-V Spec:
+    // When the type's bit width is less than 32-bits, the literal's value
+    // appears in the low-order bits of the word, and the high-order bits must
+    // be 0 for a floating-point type.
+    curTypeInst.push_back(
+        static_cast<uint32_t>(valueToUse.bitcastToAPInt().getZExtValue()));
+  } else if (typeBitwidth == 32) {
+    curTypeInst.push_back(
+        cast::BitwiseCast<uint32_t, float>(valueToUse.convertToFloat()));
+  } else {
+    // TODO: The ordering of the 2 words depends on the endian-ness of the
+    // host machine.
+    struct wideFloat {
+      uint32_t word0;
+      uint32_t word1;
+    };
+    wideFloat words =
+        cast::BitwiseCast<wideFloat, double>(valueToUse.convertToDouble());
+    curTypeInst.push_back(words.word0);
+    curTypeInst.push_back(words.word1);
+  }
+
   finalizeTypeInstruction();
   finalizeTypeInstruction();
 
 
   // Remember this constant for future
   // Remember this constant for future

+ 357 - 0
tools/clang/lib/SPIRV/LiteralTypeVisitor.cpp

@@ -0,0 +1,357 @@
+//===--- LiteralTypeVisitor.cpp - Literal Type Visitor -----------*- C++ -*-==//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "LiteralTypeVisitor.h"
+#include "clang/SPIRV/AstTypeProbe.h"
+#include "clang/SPIRV/SpirvBuilder.h"
+
+namespace clang {
+namespace spirv {
+
+// -- SpirvReturn (OpReturnValue)
+// -- SpirvCompositeExtract
+// -- SpirvCompositeInsert
+// -- SpirvExtInst
+// -- SpirvImageOp
+// -- SpirvImageQuery
+// -- SpirvImageTexelPointer
+// -- SpirvSpecConstantBinaryOp
+// -- SpirvSpecConstantUnaryOp
+
+// SpirvConstantComposite
+// SpirvComposite
+
+bool LiteralTypeVisitor::canDeduceTypeFromLitType(QualType litType,
+                                                  QualType newType) {
+  if (litType == QualType() || newType == QualType() || litType == newType)
+    return false;
+  if (!isLitTypeOrVecOfLitType(litType))
+    return false;
+  if (isLitTypeOrVecOfLitType(newType))
+    return false;
+
+  if (litType->isFloatingType() && newType->isFloatingType())
+    return true;
+  if ((litType->isIntegerType() && !litType->isBooleanType()) &&
+      (newType->isIntegerType() && !newType->isBooleanType()))
+    return true;
+
+  {
+    QualType elemType1 = {};
+    uint32_t elemCount1 = 0;
+    QualType elemType2 = {};
+    uint32_t elemCount2 = 0;
+    if (isVectorType(litType, &elemType1, &elemCount1) &&
+        isVectorType(newType, &elemType2, &elemCount2))
+      return elemCount1 == elemCount2 &&
+             canDeduceTypeFromLitType(elemType1, elemType2);
+  }
+
+  return false;
+}
+
+void LiteralTypeVisitor::updateTypeForInstruction(SpirvInstruction *inst,
+                                                  QualType newType) {
+  if (!inst)
+    return;
+
+  // We may only update LitInt to Int type and LitFloat to Float type.
+  if (!canDeduceTypeFromLitType(inst->getAstResultType(), newType))
+    return;
+
+  // Since LiteralTypeVisitor is run before lowering the types, we can simply
+  // update the AST result-type of the instruction to the new type. In the case
+  // of the instruction being a constant instruction, since we do not have
+  // unique constants at this point, chaing the QualType of the constant
+  // instruction is safe.
+  inst->setAstResultType(newType);
+}
+
+bool LiteralTypeVisitor::visitInstruction(SpirvInstruction *instr) {
+  // Instructions that don't have custom visitors cannot help with deducing the
+  // real type from the literal type.
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvVariable *var) {
+  updateTypeForInstruction(var->getInitializer(), var->getAstResultType());
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvAtomic *inst) {
+  const auto resultType = inst->getAstResultType();
+  updateTypeForInstruction(inst->getValue(), resultType);
+  updateTypeForInstruction(inst->getComparator(), resultType);
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvUnaryOp *inst) {
+  // Do not try to make conclusions about types for bitwidth conversion
+  // operations.
+  // TODO: We can do more to deduce information in OpBitCast.
+  const auto opcode = inst->getopcode();
+  if (opcode == spv::Op::OpUConvert || opcode == spv::Op::OpSConvert ||
+      opcode == spv::Op::OpFConvert || opcode == spv::Op::OpBitcast) {
+    return true;
+  }
+
+  const auto resultType = inst->getAstResultType();
+  auto *arg = inst->getOperand();
+  const auto argType = arg->getAstResultType();
+
+  // OpNot, OpSNegate, and OpConvertXToY operations change the type, but may not
+  // change the bitwidth. So, for these operations, we can use the result type's
+  // bitwidth as a hint for the operand's bitwidth.
+  // --> get signedness and vector size (if any) from operand
+  // --> get bitwidth from result type
+  if (opcode == spv::Op::OpConvertFToU || opcode == spv::Op::OpConvertFToS ||
+      opcode == spv::Op::OpConvertSToF || opcode == spv::Op::OpConvertUToF ||
+      opcode == spv::Op::OpNot || opcode == spv::Op::OpSNegate) {
+    if (isLitTypeOrVecOfLitType(argType) &&
+        !isLitTypeOrVecOfLitType(resultType)) {
+      const uint32_t resultTypeBitwidth = getElementSpirvBitwidth(
+          astContext, resultType, spvOptions.enable16BitTypes);
+      const QualType newType =
+          getTypeWithCustomBitwidth(astContext, argType, resultTypeBitwidth);
+      updateTypeForInstruction(arg, newType);
+      return true;
+    }
+  }
+
+  // In all other cases, try to use the result type as a hint.
+  updateTypeForInstruction(arg, resultType);
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvBinaryOp *inst) {
+  const auto resultType = inst->getAstResultType();
+  const auto op = inst->getopcode();
+  auto *operand1 = inst->getOperand1();
+  auto *operand2 = inst->getOperand2();
+
+  // We should not modify operand2 type in these operations:
+  if (op == spv::Op::OpShiftRightLogical ||
+      op == spv::Op::OpShiftRightArithmetic ||
+      op == spv::Op::OpShiftLeftLogical) {
+    // Base (arg1) should have the same type as result type
+    updateTypeForInstruction(inst->getOperand1(), resultType);
+    return true;
+  }
+
+  // The following operations have a boolean return type, so we cannot deduce
+  // anything about the operand type from the result type. However, the two
+  // operands in these operations must have the same bitwidth.
+  if (op == spv::Op::OpIEqual || op == spv::Op::OpINotEqual ||
+      op == spv::Op::OpUGreaterThan || op == spv::Op::OpSGreaterThan ||
+      op == spv::Op::OpUGreaterThanEqual ||
+      op == spv::Op::OpSGreaterThanEqual || op == spv::Op::OpULessThan ||
+      op == spv::Op::OpSLessThan || op == spv::Op::OpULessThanEqual ||
+      op == spv::Op::OpSLessThanEqual || op == spv::Op::OpFOrdEqual ||
+      op == spv::Op::OpFUnordEqual || op == spv::Op::OpFOrdNotEqual ||
+      op == spv::Op::OpFUnordNotEqual || op == spv::Op::OpFOrdLessThan ||
+      op == spv::Op::OpFUnordLessThan || op == spv::Op::OpFOrdGreaterThan ||
+      op == spv::Op::OpFUnordGreaterThan ||
+      op == spv::Op::OpFOrdLessThanEqual ||
+      op == spv::Op::OpFUnordLessThanEqual ||
+      op == spv::Op::OpFOrdGreaterThanEqual ||
+      op == spv::Op::OpFUnordGreaterThanEqual) {
+    if (operand1->hasAstResultType() && operand2->hasAstResultType()) {
+      const auto operand1Type = operand1->getAstResultType();
+      const auto operand2Type = operand2->getAstResultType();
+      bool isLitOp1 = isLitTypeOrVecOfLitType(operand1Type);
+      bool isLitOp2 = isLitTypeOrVecOfLitType(operand2Type);
+
+      if (isLitOp1 && !isLitOp2) {
+        const uint32_t operand2Bitwidth = getElementSpirvBitwidth(
+            astContext, operand2Type, spvOptions.enable16BitTypes);
+        const QualType newType = getTypeWithCustomBitwidth(
+            astContext, operand1Type, operand2Bitwidth);
+        updateTypeForInstruction(operand1, newType);
+        return true;
+      }
+      if (isLitOp2 && !isLitOp1) {
+        const uint32_t operand1Bitwidth = getElementSpirvBitwidth(
+            astContext, operand1Type, spvOptions.enable16BitTypes);
+        const QualType newType = getTypeWithCustomBitwidth(
+            astContext, operand2Type, operand1Bitwidth);
+        updateTypeForInstruction(operand2, newType);
+        return true;
+      }
+    }
+  }
+
+  updateTypeForInstruction(operand1, resultType);
+  updateTypeForInstruction(operand2, resultType);
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvBitFieldInsert *inst) {
+  const auto resultType = inst->getAstResultType();
+  updateTypeForInstruction(inst->getBase(), resultType);
+  updateTypeForInstruction(inst->getInsert(), resultType);
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvBitFieldExtract *inst) {
+  const auto resultType = inst->getAstResultType();
+  updateTypeForInstruction(inst->getBase(), resultType);
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvSelect *inst) {
+  const auto resultType = inst->getAstResultType();
+  updateTypeForInstruction(inst->getTrueObject(), resultType);
+  updateTypeForInstruction(inst->getFalseObject(), resultType);
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvVectorShuffle *inst) {
+  const auto resultType = inst->getAstResultType();
+  if (inst->hasAstResultType() && !isLitTypeOrVecOfLitType(resultType)) {
+    auto *vec1 = inst->getVec1();
+    auto *vec2 = inst->getVec1();
+    assert(vec1 && vec2);
+    QualType resultElemType = {};
+    uint32_t resultElemCount = 0;
+    QualType vec1ElemType = {};
+    uint32_t vec1ElemCount = 0;
+    QualType vec2ElemType = {};
+    uint32_t vec2ElemCount = 0;
+    (void)isVectorType(resultType, &resultElemType, &resultElemCount);
+    (void)isVectorType(vec1->getAstResultType(), &vec1ElemType, &vec1ElemCount);
+    (void)isVectorType(vec2->getAstResultType(), &vec2ElemType, &vec2ElemCount);
+    if (isLitTypeOrVecOfLitType(vec1ElemType)) {
+      updateTypeForInstruction(
+          vec1, astContext.getExtVectorType(resultElemType, vec1ElemCount));
+    }
+    if (isLitTypeOrVecOfLitType(vec2ElemType)) {
+      updateTypeForInstruction(
+          vec2, astContext.getExtVectorType(resultElemType, vec2ElemCount));
+    }
+  }
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvNonUniformUnaryOp *inst) {
+  // Went through each non-uniform binary operation and made sure the following
+  // does not result in a wrong type deduction.
+  updateTypeForInstruction(inst->getArg(), inst->getAstResultType());
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvNonUniformBinaryOp *inst) {
+  // Went through each non-uniform unary operation and made sure the following
+  // does not result in a wrong type deduction.
+  updateTypeForInstruction(inst->getArg1(), inst->getAstResultType());
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvStore *inst) {
+  auto *object = inst->getObject();
+  auto *pointer = inst->getPointer();
+  if (pointer->hasAstResultType()) {
+    QualType type = pointer->getAstResultType();
+    if (const auto *ptrType = type->getAs<PointerType>())
+      type = ptrType->getPointeeType();
+    updateTypeForInstruction(object, type);
+  }
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvConstantComposite *inst) {
+  const auto resultType = inst->getAstResultType();
+  llvm::SmallVector<SpirvInstruction *, 4> constituents(
+      inst->getConstituents().begin(), inst->getConstituents().end());
+  updateTypeForCompositeMembers(resultType, constituents);
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvComposite *inst) {
+  const auto resultType = inst->getAstResultType();
+  updateTypeForCompositeMembers(resultType, inst->getConstituents());
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvCompositeExtract *inst) {
+  const auto resultType = inst->getAstResultType();
+  auto *base = inst->getComposite();
+  const auto baseType = base->getAstResultType();
+  if (isLitTypeOrVecOfLitType(baseType) &&
+      !isLitTypeOrVecOfLitType(resultType)) {
+    const uint32_t resultTypeBitwidth = getElementSpirvBitwidth(
+        astContext, resultType, spvOptions.enable16BitTypes);
+    const QualType newType =
+        getTypeWithCustomBitwidth(astContext, baseType, resultTypeBitwidth);
+    updateTypeForInstruction(base, newType);
+    return true;
+  }
+}
+
+bool LiteralTypeVisitor::updateTypeForCompositeMembers(
+    QualType compositeType, llvm::ArrayRef<SpirvInstruction *> constituents) {
+
+  if (compositeType == QualType())
+    return true;
+
+  // The constituents are the top level objects that create the result type.
+  // The result type may be one of the following:
+  // Vector, Array, Matrix, Struct
+
+  // TODO: This method is currently not recursive. We can use recursion if
+  // absolutely necessary.
+
+  { // Vector case
+    QualType elemType = {};
+    if (isVectorType(compositeType, &elemType)) {
+      for (auto *constituent : constituents)
+        updateTypeForInstruction(constituent, elemType);
+      return true;
+    }
+  }
+
+  { // Array case
+    if (const auto *arrType = dyn_cast<ConstantArrayType>(compositeType)) {
+      for (auto *constituent : constituents)
+        updateTypeForInstruction(constituent, arrType->getElementType());
+      return true;
+    }
+  }
+
+  { // Matrix case
+    QualType elemType = {};
+    if (isMxNMatrix(compositeType, &elemType)) {
+      for (auto *constituent : constituents) {
+        // Each constituent is a matrix column (a vector)
+        uint32_t colSize = 0;
+        if (isVectorType(constituent->getAstResultType(), nullptr, &colSize)) {
+          QualType newType = astContext.getExtVectorType(elemType, colSize);
+          updateTypeForInstruction(constituent, newType);
+        }
+      }
+      return true;
+    }
+  }
+
+  { // Struct case
+    if (const auto *structType = compositeType->getAs<RecordType>()) {
+      const auto *decl = structType->getDecl();
+      size_t i = 0;
+      for (const auto *field : decl->fields()) {
+        updateTypeForInstruction(constituents[i], field->getType());
+        ++i;
+      }
+      return true;
+    }
+  }
+
+  return true;
+}
+
+} // end namespace spirv
+} // namespace clang

+ 70 - 0
tools/clang/lib/SPIRV/LiteralTypeVisitor.h

@@ -0,0 +1,70 @@
+//===--- LiteralTypeVisitor.h - Literal Type Visitor -------------*- C++ -*-==//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_LIB_SPIRV_LITERALTYPEVISITOR_H
+#define LLVM_CLANG_LIB_SPIRV_LITERALTYPEVISITOR_H
+
+#include "clang/SPIRV/SPIRVContext.h"
+#include "clang/SPIRV/SpirvVisitor.h"
+
+namespace clang {
+namespace spirv {
+
+class SpirvBuilder;
+
+class LiteralTypeVisitor : public Visitor {
+public:
+  LiteralTypeVisitor(const ASTContext &ctx, SpirvContext &spvCtx,
+                     const SpirvCodeGenOptions &opts)
+      : Visitor(opts, spvCtx), astContext(ctx) {}
+
+  bool visit(SpirvVariable *);
+  bool visit(SpirvAtomic *);
+  bool visit(SpirvUnaryOp *);
+  bool visit(SpirvBinaryOp *);
+  bool visit(SpirvBitFieldInsert *);
+  bool visit(SpirvBitFieldExtract *);
+  bool visit(SpirvSelect *);
+  bool visit(SpirvVectorShuffle *);
+  bool visit(SpirvNonUniformUnaryOp *);
+  bool visit(SpirvNonUniformBinaryOp *);
+  bool visit(SpirvStore *);
+  bool visit(SpirvConstantComposite *);
+  bool visit(SpirvComposite *);
+  bool visit(SpirvCompositeExtract *);
+
+  /// The "sink" visit function for all instructions.
+  ///
+  /// By default, all other visit instructions redirect to this visit function.
+  /// So that you want override this visit function to handle all instructions,
+  /// regardless of their polymorphism.
+  bool visitInstruction(SpirvInstruction *instr);
+
+private:
+  /// Updates the result type of the given instruction to the new type.
+  void updateTypeForInstruction(SpirvInstruction *, QualType newType);
+
+  /// returns true if the given literal type can be deduced to the given
+  /// newType. In order for that to be true,
+  /// a) litType must be a literal type
+  /// b) litType and newType must be either scalar or vectors of the same size
+  /// c) they must have the same underlying type (both int or both float)
+  bool canDeduceTypeFromLitType(QualType litType, QualType newType);
+
+  bool updateTypeForCompositeMembers(
+      QualType compositeType, llvm::ArrayRef<SpirvInstruction *> constituents);
+
+private:
+  const ASTContext &astContext;
+};
+
+} // end namespace spirv
+} // end namespace clang
+
+#endif // LLVM_CLANG_LIB_SPIRV_LITERALTYPEVISITOR_H

+ 10 - 10
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -2449,7 +2449,7 @@ SpirvInstruction *SPIRVEmitter::doCastExpr(const CastExpr *expr) {
     //        `- <rhs>
     //        `- <rhs>
     // This FlatConversion does not affect CodeGen, so that we can ignore it.
     // This FlatConversion does not affect CodeGen, so that we can ignore it.
     else if (subExprType->isArrayType() &&
     else if (subExprType->isArrayType() &&
-             typeTranslator.isSameType(expr->getType(), subExprType)) {
+             isSameType(astContext, expr->getType(), subExprType)) {
       return doExpr(subExpr);
       return doExpr(subExpr);
     }
     }
     // We can have casts changing the shape but without affecting memory order,
     // We can have casts changing the shape but without affecting memory order,
@@ -6180,7 +6180,7 @@ SpirvInstruction *SPIRVEmitter::turnIntoElementPtr(
 SpirvInstruction *SPIRVEmitter::castToBool(SpirvInstruction *fromVal,
 SpirvInstruction *SPIRVEmitter::castToBool(SpirvInstruction *fromVal,
                                            QualType fromType,
                                            QualType fromType,
                                            QualType toBoolType) {
                                            QualType toBoolType) {
-  if (TypeTranslator::isSameScalarOrVecType(fromType, toBoolType))
+  if (isSameScalarOrVecType(fromType, toBoolType))
     return fromVal;
     return fromVal;
 
 
   { // Special case handling for converting to a matrix of booleans.
   { // Special case handling for converting to a matrix of booleans.
@@ -6210,7 +6210,7 @@ SpirvInstruction *SPIRVEmitter::castToBool(SpirvInstruction *fromVal,
 SpirvInstruction *SPIRVEmitter::castToInt(SpirvInstruction *fromVal,
 SpirvInstruction *SPIRVEmitter::castToInt(SpirvInstruction *fromVal,
                                           QualType fromType, QualType toIntType,
                                           QualType fromType, QualType toIntType,
                                           SourceLocation srcLoc) {
                                           SourceLocation srcLoc) {
-  if (TypeTranslator::isSameScalarOrVecType(fromType, toIntType))
+  if (isSameScalarOrVecType(fromType, toIntType))
     return fromVal;
     return fromVal;
 
 
   if (isBoolOrVecOfBoolType(fromType)) {
   if (isBoolOrVecOfBoolType(fromType)) {
@@ -6224,7 +6224,7 @@ SpirvInstruction *SPIRVEmitter::castToInt(SpirvInstruction *fromVal,
     QualType convertedType = {};
     QualType convertedType = {};
     fromVal = convertBitwidth(fromVal, fromType, toIntType, &convertedType);
     fromVal = convertBitwidth(fromVal, fromType, toIntType, &convertedType);
     // If bitwidth conversion was the only thing we needed to do, we're done.
     // If bitwidth conversion was the only thing we needed to do, we're done.
-    if (convertedType == toIntType)
+    if (isSameScalarOrVecType(convertedType, toIntType))
       return fromVal;
       return fromVal;
     return spvBuilder.createUnaryOp(spv::Op::OpBitcast, toIntType, fromVal);
     return spvBuilder.createUnaryOp(spv::Op::OpBitcast, toIntType, fromVal);
   }
   }
@@ -6318,7 +6318,7 @@ SpirvInstruction *SPIRVEmitter::castToFloat(SpirvInstruction *fromVal,
                                             QualType fromType,
                                             QualType fromType,
                                             QualType toFloatType,
                                             QualType toFloatType,
                                             SourceLocation srcLoc) {
                                             SourceLocation srcLoc) {
-  if (TypeTranslator::isSameScalarOrVecType(fromType, toFloatType))
+  if (isSameScalarOrVecType(fromType, toFloatType))
     return fromVal;
     return fromVal;
 
 
   if (isBoolOrVecOfBoolType(fromType)) {
   if (isBoolOrVecOfBoolType(fromType)) {
@@ -7816,7 +7816,7 @@ SpirvInstruction *SPIRVEmitter::processNonFpScalarTimesMatrix(
   uint32_t numRows = 0, numCols = 0;
   uint32_t numRows = 0, numCols = 0;
   const bool isMat = isMxNMatrix(matrixType, &elemType, &numRows, &numCols);
   const bool isMat = isMxNMatrix(matrixType, &elemType, &numRows, &numCols);
   assert(isMat);
   assert(isMat);
-  assert(typeTranslator.isSameType(scalarType, elemType));
+  assert(isSameType(astContext, scalarType, elemType));
   (void)isMat;
   (void)isMat;
 
 
   // We need to multiply the scalar by each vector of the matrix.
   // We need to multiply the scalar by each vector of the matrix.
@@ -7846,7 +7846,7 @@ SpirvInstruction *SPIRVEmitter::processNonFpVectorTimesMatrix(
   uint32_t vecSize = 0, numRows = 0, numCols = 0;
   uint32_t vecSize = 0, numRows = 0, numCols = 0;
   const bool isVec = isVectorType(vecType, &vecElemType, &vecSize);
   const bool isVec = isVectorType(vecType, &vecElemType, &vecSize);
   const bool isMat = isMxNMatrix(matType, &matElemType, &numRows, &numCols);
   const bool isMat = isMxNMatrix(matType, &matElemType, &numRows, &numCols);
-  assert(typeTranslator.isSameType(vecElemType, matElemType));
+  assert(isSameType(astContext, vecElemType, matElemType));
   assert(isVec);
   assert(isVec);
   assert(isMat);
   assert(isMat);
   assert(vecSize == numRows);
   assert(vecSize == numRows);
@@ -7880,7 +7880,7 @@ SpirvInstruction *SPIRVEmitter::processNonFpMatrixTimesVector(
   uint32_t vecSize = 0, numRows = 0, numCols = 0;
   uint32_t vecSize = 0, numRows = 0, numCols = 0;
   const bool isVec = isVectorType(vecType, &vecElemType, &vecSize);
   const bool isVec = isVectorType(vecType, &vecElemType, &vecSize);
   const bool isMat = isMxNMatrix(matType, &matElemType, &numRows, &numCols);
   const bool isMat = isMxNMatrix(matType, &matElemType, &numRows, &numCols);
-  assert(typeTranslator.isSameType(vecElemType, matElemType));
+  assert(isSameType(astContext, vecElemType, matElemType));
   assert(isVec);
   assert(isVec);
   assert(isMat);
   assert(isMat);
   assert(vecSize == numCols);
   assert(vecSize == numCols);
@@ -7912,7 +7912,7 @@ SpirvInstruction *SPIRVEmitter::processNonFpMatrixTimesMatrix(
       isMxNMatrix(lhsType, &lhsElemType, &lhsNumRows, &lhsNumCols);
       isMxNMatrix(lhsType, &lhsElemType, &lhsNumRows, &lhsNumCols);
   const bool rhsIsMat =
   const bool rhsIsMat =
       isMxNMatrix(rhsType, &rhsElemType, &rhsNumRows, &rhsNumCols);
       isMxNMatrix(rhsType, &rhsElemType, &rhsNumRows, &rhsNumCols);
-  assert(typeTranslator.isSameType(lhsElemType, rhsElemType));
+  assert(isSameType(astContext, lhsElemType, rhsElemType));
   assert(lhsIsMat && rhsIsMat);
   assert(lhsIsMat && rhsIsMat);
   assert(lhsNumCols == rhsNumRows);
   assert(lhsNumCols == rhsNumRows);
   (void)rhsIsMat;
   (void)rhsIsMat;
@@ -8290,7 +8290,7 @@ SPIRVEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
   const QualType argType = arg0->getType();
   const QualType argType = arg0->getType();
 
 
   // Method 3 return type may be the same as arg type, so it would be a no-op.
   // Method 3 return type may be the same as arg type, so it would be a no-op.
-  if (typeTranslator.isSameType(returnType, argType))
+  if (isSameType(astContext, returnType, argType))
     return doExpr(arg0);
     return doExpr(arg0);
 
 
   switch (numArgs) {
   switch (numArgs) {

+ 24 - 9
tools/clang/lib/SPIRV/SpirvBasicBlock.cpp

@@ -22,20 +22,35 @@ bool SpirvBasicBlock::hasTerminator() const {
 }
 }
 
 
 bool SpirvBasicBlock::invokeVisitor(Visitor *visitor,
 bool SpirvBasicBlock::invokeVisitor(Visitor *visitor,
-                                    llvm::ArrayRef<SpirvVariable *> vars) {
+                                    llvm::ArrayRef<SpirvVariable *> vars,
+                                    bool reverseOrder) {
   if (!visitor->visit(this, Visitor::Phase::Init))
   if (!visitor->visit(this, Visitor::Phase::Init))
     return false;
     return false;
 
 
-  // If a basic block is the first basic block of a function, it should include
-  // all the variables of the function.
-  if (!vars.empty())
-    for (auto *var : vars)
-      if (!var->invokeVisitor(visitor))
+  if (reverseOrder) {
+    for (auto inst = instructions.rbegin(); inst != instructions.rend();
+         ++inst) {
+      if (!(*inst)->invokeVisitor(visitor))
         return false;
         return false;
+    }
+    // If a basic block is the first basic block of a function, it should
+    // include all the variables of the function.
+    if (!vars.empty())
+      for (auto var = vars.rbegin(); var != vars.rend(); ++var)
+        if (!(*var)->invokeVisitor(visitor))
+          return false;
+  } else {
+    // If a basic block is the first basic block of a function, it should
+    // include all the variables of the function.
+    if (!vars.empty())
+      for (auto *var : vars)
+        if (!var->invokeVisitor(visitor))
+          return false;
 
 
-  for (auto *inst : instructions)
-    if (!inst->invokeVisitor(visitor))
-      return false;
+    for (auto *inst : instructions)
+      if (!inst->invokeVisitor(visitor))
+        return false;
+  }
 
 
   if (!visitor->visit(this, Visitor::Phase::Done))
   if (!visitor->visit(this, Visitor::Phase::Done))
     return false;
     return false;

+ 4 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -9,6 +9,7 @@
 
 
 #include "clang/SPIRV/SpirvBuilder.h"
 #include "clang/SPIRV/SpirvBuilder.h"
 #include "CapabilityVisitor.h"
 #include "CapabilityVisitor.h"
+#include "LiteralTypeVisitor.h"
 #include "TypeTranslator.h"
 #include "TypeTranslator.h"
 #include "clang/SPIRV/EmitVisitor.h"
 #include "clang/SPIRV/EmitVisitor.h"
 #include "clang/SPIRV/LowerTypeVisitor.h"
 #include "clang/SPIRV/LowerTypeVisitor.h"
@@ -1042,10 +1043,13 @@ SpirvConstant *SpirvBuilder::getConstantNull(QualType type) {
 
 
 std::vector<uint32_t> SpirvBuilder::takeModule() {
 std::vector<uint32_t> SpirvBuilder::takeModule() {
   // Run necessary visitor passes first
   // Run necessary visitor passes first
+  LiteralTypeVisitor literalTypeVisitor(astContext, context, spirvOptions);
   LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions);
   LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions);
   CapabilityVisitor capabilityVisitor(context, spirvOptions, *this);
   CapabilityVisitor capabilityVisitor(context, spirvOptions, *this);
   EmitVisitor emitVisitor(astContext, context, spirvOptions);
   EmitVisitor emitVisitor(astContext, context, spirvOptions);
 
 
+  module->invokeVisitor(&literalTypeVisitor, true);
+
   // Lower types
   // Lower types
   module->invokeVisitor(&lowerTypeVisitor);
   module->invokeVisitor(&lowerTypeVisitor);
 
 

+ 12 - 4
tools/clang/lib/SPIRV/SpirvFunction.cpp

@@ -21,7 +21,7 @@ SpirvFunction::SpirvFunction(QualType returnType, SpirvType *functionType,
       returnTypeId(0), fnType(functionType), fnTypeId(0),
       returnTypeId(0), fnType(functionType), fnTypeId(0),
       functionControl(control), functionLoc(loc), functionName(name) {}
       functionControl(control), functionLoc(loc), functionName(name) {}
 
 
-bool SpirvFunction::invokeVisitor(Visitor *visitor) {
+bool SpirvFunction::invokeVisitor(Visitor *visitor, bool reverseOrder) {
   if (!visitor->visit(this, Visitor::Phase::Init))
   if (!visitor->visit(this, Visitor::Phase::Init))
     return false;
     return false;
 
 
@@ -37,17 +37,25 @@ bool SpirvFunction::invokeVisitor(Visitor *visitor) {
     }).visit(basicBlocks.front());
     }).visit(basicBlocks.front());
   }
   }
 
 
+  if (reverseOrder)
+    std::reverse(orderedBlocks.begin(), orderedBlocks.end());
+
+  SpirvBasicBlock *firstBB =
+      orderedBlocks.empty()
+          ? nullptr
+          : reverseOrder ? orderedBlocks.back() : orderedBlocks[0];
+
   for (auto *bb : orderedBlocks) {
   for (auto *bb : orderedBlocks) {
     // The first basic block of the function should first visit the function
     // The first basic block of the function should first visit the function
     // variables.
     // variables.
-    if (bb == orderedBlocks[0]) {
-      if (!bb->invokeVisitor(visitor, variables))
+    if (bb == firstBB) {
+      if (!bb->invokeVisitor(visitor, variables, reverseOrder))
         return false;
         return false;
     }
     }
     // The rest of the basic blocks in the function do not need to visit
     // The rest of the basic blocks in the function do not need to visit
     // function variables.
     // function variables.
     else {
     else {
-      if (!bb->invokeVisitor(visitor))
+      if (!bb->invokeVisitor(visitor, {}, reverseOrder))
         return false;
         return false;
     }
     }
   }
   }

+ 8 - 0
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -726,6 +726,14 @@ SpirvUnaryOp::SpirvUnaryOp(spv::Op opcode, QualType resultType,
     : SpirvInstruction(IK_UnaryOp, opcode, resultType, resultId, loc),
     : SpirvInstruction(IK_UnaryOp, opcode, resultType, resultId, loc),
       operand(op) {}
       operand(op) {}
 
 
+bool SpirvUnaryOp::isConversionOp() const {
+  return opcode == spv::Op::OpConvertFToU || opcode == spv::Op::OpConvertFToS ||
+         opcode == spv::Op::OpConvertSToF || opcode == spv::Op::OpConvertUToF ||
+         opcode == spv::Op::OpUConvert || opcode == spv::Op::OpSConvert ||
+         opcode == spv::Op::OpFConvert || opcode == spv::Op::OpQuantizeToF16 ||
+         opcode == spv::Op::OpBitcast;
+}
+
 SpirvVectorShuffle::SpirvVectorShuffle(QualType resultType, uint32_t resultId,
 SpirvVectorShuffle::SpirvVectorShuffle(QualType resultType, uint32_t resultId,
                                        SourceLocation loc,
                                        SourceLocation loc,
                                        SpirvInstruction *vec1Inst,
                                        SpirvInstruction *vec1Inst,

+ 8 - 2
tools/clang/lib/SPIRV/SpirvModule.cpp

@@ -19,7 +19,13 @@ SpirvModule::SpirvModule()
       moduleProcesses({}), decorations({}), constants({}), variables({}),
       moduleProcesses({}), decorations({}), constants({}), variables({}),
       functions({}) {}
       functions({}) {}
 
 
-bool SpirvModule::invokeVisitor(Visitor *visitor) {
+bool SpirvModule::invokeVisitor(Visitor *visitor, bool reverseOrder) {
+  // Note: It is debatable whether reverse order of visiting the module should
+  // reverse everything in this method. For the time being, we just reverse the
+  // order of the function visitors, and keeping everything else the same.
+  // For example, it is not clear what the value would be of vising the last
+  // function first. We can update this methodology if needed.
+
   if (!visitor->visit(this, Visitor::Phase::Init))
   if (!visitor->visit(this, Visitor::Phase::Init))
     return false;
     return false;
 
 
@@ -66,7 +72,7 @@ bool SpirvModule::invokeVisitor(Visitor *visitor) {
       return false;
       return false;
 
 
   for (auto fn : functions)
   for (auto fn : functions)
-    if (!fn->invokeVisitor(visitor))
+    if (!fn->invokeVisitor(visitor, reverseOrder))
       return false;
       return false;
 
 
   if (!visitor->visit(this, Visitor::Phase::Done))
   if (!visitor->visit(this, Visitor::Phase::Done))

+ 1 - 1
tools/clang/test/CodeGenSPIRV/cast.vector.splat.hlsl

@@ -1,6 +1,7 @@
 // Run: %dxc -T vs_6_0 -E main
 // Run: %dxc -T vs_6_0 -E main
 
 
 // CHECK: [[v4f32c:%\d+]] = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
 // CHECK: [[v4f32c:%\d+]] = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
+// CHECK: [[v3f32c:%\d+]] = OpConstantComposite %v3float %float_2 %float_2 %float_2
 
 
 void main() {
 void main() {
 // CHECK-LABEL: %bb_entry = OpLabel
 // CHECK-LABEL: %bb_entry = OpLabel
@@ -9,7 +10,6 @@ void main() {
 // CHECK: OpStore %vf4 [[v4f32c]]
 // CHECK: OpStore %vf4 [[v4f32c]]
     float4 vf4 = 1;
     float4 vf4 = 1;
 
 
-// CHECK: [[v3f32c:%\d+]] = OpCompositeConstruct %v3float %float_2 %float_2 %float_2
 // CHECK-NEXT: OpStore %vf3 [[v3f32c]]
 // CHECK-NEXT: OpStore %vf3 [[v3f32c]]
     float3 vf3;
     float3 vf3;
     vf3 = float1(2);
     vf3 = float1(2);