Jelajahi Sumber

[spirv] Further updates to SPIRVEmitter for v2.

Ehsan 6 tahun lalu
induk
melakukan
7ebfcb09ad
2 mengubah file dengan 88 tambahan dan 93 penghapusan
  1. 83 89
      tools/clang/lib/SPIRV/SPIRVEmitter.cpp
  2. 5 4
      tools/clang/lib/SPIRV/SPIRVEmitter.h

+ 83 - 89
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -8512,81 +8512,80 @@ SPIRVEmitter::processIntrinsicFloatSign(const CallExpr *callExpr) {
   return castToInt(floatSign, arg->getType(), returnType, arg->getExprLoc());
   return castToInt(floatSign, arg->getType(), returnType, arg->getExprLoc());
 }
 }
 
 
-uint32_t SPIRVEmitter::processIntrinsicF16ToF32(const CallExpr *callExpr) {
+SpirvInstruction *
+SPIRVEmitter::processIntrinsicF16ToF32(const CallExpr *callExpr) {
   // f16tof32() takes in (vector of) uint and returns (vector of) float.
   // f16tof32() takes in (vector of) uint and returns (vector of) float.
   // The frontend should guarantee that by inserting implicit casts.
   // The frontend should guarantee that by inserting implicit casts.
-  const uint32_t glsl = theBuilder.getGLSLExtInstSet();
-  const uint32_t f32TypeId = theBuilder.getFloat32Type();
-  const uint32_t u32TypeId = theBuilder.getUint32Type();
-  const uint32_t v2f32TypeId = theBuilder.getVecType(f32TypeId, 2);
+  auto *glsl = spvBuilder.getGLSLExtInstSet();
+  const QualType f32Type = astContext.FloatTy;
+  const QualType u32Type = astContext.UnsignedIntTy;
+  const QualType v2f32Type = astContext.getExtVectorType(f32Type, 2);
 
 
   const auto *arg = callExpr->getArg(0);
   const auto *arg = callExpr->getArg(0);
-  const uint32_t argId = doExpr(arg);
+  auto *argId = doExpr(arg);
 
 
   uint32_t elemCount = {};
   uint32_t elemCount = {};
 
 
   if (isVectorType(arg->getType(), nullptr, &elemCount)) {
   if (isVectorType(arg->getType(), nullptr, &elemCount)) {
     // The input is a vector. We need to handle each element separately.
     // The input is a vector. We need to handle each element separately.
-    llvm::SmallVector<uint32_t, 4> elements;
+    llvm::SmallVector<SpirvInstruction *, 4> elements;
 
 
     for (uint32_t i = 0; i < elemCount; ++i) {
     for (uint32_t i = 0; i < elemCount; ++i) {
-      const uint32_t srcElem =
-          theBuilder.createCompositeExtract(u32TypeId, argId, {i});
-      const uint32_t convert = theBuilder.createExtInst(
-          v2f32TypeId, glsl, GLSLstd450::GLSLstd450UnpackHalf2x16, srcElem);
+      auto *srcElem = spvBuilder.createCompositeExtract(u32Type, argId, {i});
+      auto *convert = spvBuilder.createExtInst(
+          v2f32Type, glsl, GLSLstd450::GLSLstd450UnpackHalf2x16, srcElem);
       elements.push_back(
       elements.push_back(
-          theBuilder.createCompositeExtract(f32TypeId, convert, {0}));
+          spvBuilder.createCompositeExtract(f32Type, convert, {0}));
     }
     }
-    return theBuilder.createCompositeConstruct(
-        theBuilder.getVecType(f32TypeId, elemCount), elements);
+    return spvBuilder.createCompositeConstruct(
+        astContext.getExtVectorType(f32Type, elemCount), elements);
   }
   }
 
 
-  const uint32_t convert = theBuilder.createExtInst(
-      v2f32TypeId, glsl, GLSLstd450::GLSLstd450UnpackHalf2x16, argId);
+  auto *convert = spvBuilder.createExtInst(
+      v2f32Type, glsl, GLSLstd450::GLSLstd450UnpackHalf2x16, argId);
   // f16tof32() converts the float16 stored in the low-half of the uint to
   // f16tof32() converts the float16 stored in the low-half of the uint to
   // a float. So just need to return the first component.
   // a float. So just need to return the first component.
-  return theBuilder.createCompositeExtract(f32TypeId, convert, {0});
+  return spvBuilder.createCompositeExtract(f32Type, convert, {0});
 }
 }
 
 
-uint32_t SPIRVEmitter::processIntrinsicF32ToF16(const CallExpr *callExpr) {
+SpirvInstruction *
+SPIRVEmitter::processIntrinsicF32ToF16(const CallExpr *callExpr) {
   // f32tof16() takes in (vector of) float and returns (vector of) uint.
   // f32tof16() takes in (vector of) float and returns (vector of) uint.
   // The frontend should guarantee that by inserting implicit casts.
   // The frontend should guarantee that by inserting implicit casts.
-  const uint32_t glsl = theBuilder.getGLSLExtInstSet();
-  const uint32_t f32TypeId = theBuilder.getFloat32Type();
-  const uint32_t u32TypeId = theBuilder.getUint32Type();
-  const uint32_t v2f32TypeId = theBuilder.getVecType(f32TypeId, 2);
-  const uint32_t zero = theBuilder.getConstantFloat32(0);
+  auto *glsl = spvBuilder.getGLSLExtInstSet();
+  const QualType f32Type = astContext.FloatTy;
+  const QualType u32Type = astContext.UnsignedIntTy;
+  const QualType v2f32Type = astContext.getExtVectorType(f32Type, 2);
+  auto *zero = spvContext.getConstantFloat32(0);
 
 
   const auto *arg = callExpr->getArg(0);
   const auto *arg = callExpr->getArg(0);
-  const uint32_t argId = doExpr(arg);
+  auto *argId = doExpr(arg);
   uint32_t elemCount = {};
   uint32_t elemCount = {};
 
 
   if (isVectorType(arg->getType(), nullptr, &elemCount)) {
   if (isVectorType(arg->getType(), nullptr, &elemCount)) {
     // The input is a vector. We need to handle each element separately.
     // The input is a vector. We need to handle each element separately.
-    llvm::SmallVector<uint32_t, 4> elements;
+    llvm::SmallVector<SpirvInstruction *, 4> elements;
 
 
     for (uint32_t i = 0; i < elemCount; ++i) {
     for (uint32_t i = 0; i < elemCount; ++i) {
-      const uint32_t srcElem =
-          theBuilder.createCompositeExtract(f32TypeId, argId, {i});
-      const uint32_t srcVec =
-          theBuilder.createCompositeConstruct(v2f32TypeId, {srcElem, zero});
+      auto *srcElem = spvBuilder.createCompositeExtract(f32Type, argId, {i});
+      auto *srcVec =
+          spvBuilder.createCompositeConstruct(v2f32Type, {srcElem, zero});
 
 
-      elements.push_back(theBuilder.createExtInst(
-          u32TypeId, glsl, GLSLstd450::GLSLstd450PackHalf2x16, srcVec));
+      elements.push_back(spvBuilder.createExtInst(
+          u32Type, glsl, GLSLstd450::GLSLstd450PackHalf2x16, srcVec));
     }
     }
-    return theBuilder.createCompositeConstruct(
-        theBuilder.getVecType(u32TypeId, elemCount), elements);
+    return spvBuilder.createCompositeConstruct(
+        astContext.getExtVectorType(u32Type, elemCount), elements);
   }
   }
 
 
   // f16tof32() stores the float into the low-half of the uint. So we need
   // f16tof32() stores the float into the low-half of the uint. So we need
   // to supply another zero to take the other half.
   // to supply another zero to take the other half.
-  const uint32_t srcVec =
-      theBuilder.createCompositeConstruct(v2f32TypeId, {argId, zero});
-  return theBuilder.createExtInst(u32TypeId, glsl,
+  auto *srcVec = spvBuilder.createCompositeConstruct(v2f32Type, {argId, zero});
+  return spvBuilder.createExtInst(u32Type, glsl,
                                   GLSLstd450::GLSLstd450PackHalf2x16, srcVec);
                                   GLSLstd450::GLSLstd450PackHalf2x16, srcVec);
 }
 }
 
 
-uint32_t SPIRVEmitter::processIntrinsicUsingSpirvInst(
+SpirvInstruction *SPIRVEmitter::processIntrinsicUsingSpirvInst(
     const CallExpr *callExpr, spv::Op opcode, bool actPerRowForMatrices) {
     const CallExpr *callExpr, spv::Op opcode, bool actPerRowForMatrices) {
   // Certain opcodes are only allowed in pixel shader
   // Certain opcodes are only allowed in pixel shader
   if (!shaderModel.IsPS())
   if (!shaderModel.IsPS())
@@ -8608,44 +8607,44 @@ uint32_t SPIRVEmitter::processIntrinsicUsingSpirvInst(
       break;
       break;
     }
     }
 
 
-  const uint32_t returnType = typeTranslator.translateType(callExpr->getType());
+  const QualType returnType = callExpr->getType();
   if (callExpr->getNumArgs() == 1u) {
   if (callExpr->getNumArgs() == 1u) {
     const Expr *arg = callExpr->getArg(0);
     const Expr *arg = callExpr->getArg(0);
-    const uint32_t argId = doExpr(arg);
+    auto *argId = doExpr(arg);
 
 
     // If the instruction does not operate on matrices, we can perform the
     // If the instruction does not operate on matrices, we can perform the
     // instruction on each vector of the matrix.
     // instruction on each vector of the matrix.
     if (actPerRowForMatrices && isMxNMatrix(arg->getType())) {
     if (actPerRowForMatrices && isMxNMatrix(arg->getType())) {
       const auto actOnEachVec = [this, opcode](uint32_t /*index*/,
       const auto actOnEachVec = [this, opcode](uint32_t /*index*/,
-                                               uint32_t vecType,
-                                               uint32_t curRowId) {
-        return theBuilder.createUnaryOp(opcode, vecType, curRowId);
+                                               QualType vecType,
+                                               SpirvInstruction *curRow) {
+        return spvBuilder.createUnaryOp(opcode, vecType, curRow);
       };
       };
       return processEachVectorInMatrix(arg, argId, actOnEachVec);
       return processEachVectorInMatrix(arg, argId, actOnEachVec);
     }
     }
-    return theBuilder.createUnaryOp(opcode, returnType, argId);
+    return spvBuilder.createUnaryOp(opcode, returnType, argId);
   } else if (callExpr->getNumArgs() == 2u) {
   } else if (callExpr->getNumArgs() == 2u) {
     const Expr *arg0 = callExpr->getArg(0);
     const Expr *arg0 = callExpr->getArg(0);
-    const uint32_t arg0Id = doExpr(arg0);
-    const uint32_t arg1Id = doExpr(callExpr->getArg(1));
+    auto *arg0Id = doExpr(arg0);
+    auto *arg1Id = doExpr(callExpr->getArg(1));
     // If the instruction does not operate on matrices, we can perform the
     // If the instruction does not operate on matrices, we can perform the
     // instruction on each vector of the matrix.
     // instruction on each vector of the matrix.
     if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
     if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
-      const auto actOnEachVec = [this, opcode, arg1Id](uint32_t index,
-                                                       uint32_t vecType,
-                                                       uint32_t arg0RowId) {
-        const uint32_t arg1RowId =
-            theBuilder.createCompositeExtract(vecType, arg1Id, {index});
-        return theBuilder.createBinaryOp(opcode, vecType, arg0RowId, arg1RowId);
+      const auto actOnEachVec = [this, opcode,
+                                 arg1Id](uint32_t index, QualType vecType,
+                                         SpirvInstruction *arg0Row) {
+        auto *arg1Row =
+            spvBuilder.createCompositeExtract(vecType, arg1Id, {index});
+        return spvBuilder.createBinaryOp(opcode, vecType, arg0Row, arg1Row);
       };
       };
       return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec);
       return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec);
     }
     }
-    return theBuilder.createBinaryOp(opcode, returnType, arg0Id, arg1Id);
+    return spvBuilder.createBinaryOp(opcode, returnType, arg0Id, arg1Id);
   }
   }
 
 
   emitError("unsupported %0 intrinsic function", callExpr->getExprLoc())
   emitError("unsupported %0 intrinsic function", callExpr->getExprLoc())
       << cast<DeclRefExpr>(callExpr->getCallee())->getNameInfo().getAsString();
       << cast<DeclRefExpr>(callExpr->getCallee())->getNameInfo().getAsString();
-  return 0;
+  return nullptr;
 }
 }
 
 
 SpirvInstruction *SPIRVEmitter::processIntrinsicUsingGLSLInst(
 SpirvInstruction *SPIRVEmitter::processIntrinsicUsingGLSLInst(
@@ -8719,21 +8718,21 @@ SpirvInstruction *SPIRVEmitter::processIntrinsicUsingGLSLInst(
   return nullptr;
   return nullptr;
 }
 }
 
 
-uint32_t SPIRVEmitter::processIntrinsicLog10(const CallExpr *callExpr) {
+SpirvInstruction *
+SPIRVEmitter::processIntrinsicLog10(const CallExpr *callExpr) {
   // Since there is no log10 instruction in SPIR-V, we can use:
   // Since there is no log10 instruction in SPIR-V, we can use:
   // log10(x) = log2(x) * ( 1 / log2(10) )
   // log10(x) = log2(x) * ( 1 / log2(10) )
   // 1 / log2(10) = 0.30103
   // 1 / log2(10) = 0.30103
-  const auto scale = theBuilder.getConstantFloat32(0.30103f);
-  const auto log2 =
+  auto *scale = spvContext.getConstantFloat32(0.30103f);
+  auto *log2 =
       processIntrinsicUsingGLSLInst(callExpr, GLSLstd450::GLSLstd450Log2, true);
       processIntrinsicUsingGLSLInst(callExpr, GLSLstd450::GLSLstd450Log2, true);
   const auto returnType = callExpr->getType();
   const auto returnType = callExpr->getType();
-  const auto returnTypeId = typeTranslator.translateType(returnType);
   spv::Op scaleOp = isScalarType(returnType)
   spv::Op scaleOp = isScalarType(returnType)
                         ? spv::Op::OpFMul
                         ? spv::Op::OpFMul
                         : isVectorType(returnType)
                         : isVectorType(returnType)
                               ? spv::Op::OpVectorTimesScalar
                               ? spv::Op::OpVectorTimesScalar
                               : spv::Op::OpMatrixTimesScalar;
                               : spv::Op::OpMatrixTimesScalar;
-  return theBuilder.createBinaryOp(scaleOp, returnTypeId, log2, scale);
+  return spvBuilder.createBinaryOp(scaleOp, returnType, log2, scale);
 }
 }
 
 
 SpirvConstant *SPIRVEmitter::getValueZero(QualType type) {
 SpirvConstant *SPIRVEmitter::getValueZero(QualType type) {
@@ -8874,19 +8873,15 @@ SpirvConstant *SPIRVEmitter::getMaskForBitwidthValue(QualType type) {
   if (isScalarType(type, &elemType) || isVectorType(type, &elemType, &count)) {
   if (isScalarType(type, &elemType) || isVectorType(type, &elemType, &count)) {
     const auto bitwidth = typeTranslator.getElementSpirvBitwidth(elemType);
     const auto bitwidth = typeTranslator.getElementSpirvBitwidth(elemType);
     SpirvConstant *mask = nullptr;
     SpirvConstant *mask = nullptr;
-    const SpirvType *elemType nullptr;
     switch (bitwidth) {
     switch (bitwidth) {
     case 16:
     case 16:
       mask = spvContext.getConstantUint16(bitwidth - 1);
       mask = spvContext.getConstantUint16(bitwidth - 1);
-      elemType = spvContext.getUIntType(16);
       break;
       break;
     case 32:
     case 32:
       mask = spvContext.getConstantUint32(bitwidth - 1);
       mask = spvContext.getConstantUint32(bitwidth - 1);
-      elemType = spvContext.getUIntType(32);
       break;
       break;
     case 64:
     case 64:
       mask = spvContext.getConstantUint64(bitwidth - 1);
       mask = spvContext.getConstantUint64(bitwidth - 1);
-      elemType = spvContext.getUIntType(64);
       break;
       break;
     default:
     default:
       assert(false && "this method only supports 16-, 32-, and 64-bit types");
       assert(false && "this method only supports 16-, 32-, and 64-bit types");
@@ -8895,9 +8890,9 @@ SpirvConstant *SPIRVEmitter::getMaskForBitwidthValue(QualType type) {
     if (count == 1)
     if (count == 1)
       return mask;
       return mask;
 
 
-    const SpirvType *resultType = spvContext.getVectorType(elemType, count);
+    const QualType resultType = astContext.getExtVectorType(elemType, count);
     llvm::SmallVector<SpirvConstant *, 4> elements(size_t(count), mask);
     llvm::SmallVector<SpirvConstant *, 4> elements(size_t(count), mask);
-    return spvBuilder.getConstantComposite(resultType, elements);
+    return spvContext.getConstantComposite(resultType, elements);
   }
   }
 
 
   assert(false && "this method only supports scalars and vectors");
   assert(false && "this method only supports scalars and vectors");
@@ -8920,7 +8915,6 @@ SpirvConstant *SPIRVEmitter::translateAPValue(const APValue &value,
   } else if (targetType->isFloatingType()) {
   } else if (targetType->isFloatingType()) {
     result = translateAPFloat(value.getFloat(), targetType);
     result = translateAPFloat(value.getFloat(), targetType);
   } else if (hlsl::IsHLSLVecType(targetType)) {
   } else if (hlsl::IsHLSLVecType(targetType)) {
-    const uint32_t vecType = typeTranslator.translateType(targetType);
     const QualType elemType = hlsl::GetHLSLVecElementType(targetType);
     const QualType elemType = hlsl::GetHLSLVecElementType(targetType);
     const auto numElements = value.getVectorLength();
     const auto numElements = value.getVectorLength();
     // Special case for vectors of size 1. SPIR-V doesn't support this vector
     // Special case for vectors of size 1. SPIR-V doesn't support this vector
@@ -8932,7 +8926,7 @@ SpirvConstant *SPIRVEmitter::translateAPValue(const APValue &value,
       for (uint32_t i = 0; i < numElements; ++i) {
       for (uint32_t i = 0; i < numElements; ++i) {
         elements.push_back(translateAPValue(value.getVectorElt(i), elemType));
         elements.push_back(translateAPValue(value.getVectorElt(i), elemType));
       }
       }
-      result = spvContext.getConstantComposite(vecType, elements);
+      result = spvContext.getConstantComposite(targetType, elements);
     }
     }
   }
   }
 
 
@@ -8976,7 +8970,7 @@ SpirvConstant *SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
                   "inforamtion",
                   "inforamtion",
                   {})
                   {})
             << std::to_string(intValue.getSExtValue());
             << std::to_string(intValue.getSExtValue());
-        return 0;
+        return nullptr;
       }
       }
       return spvContext.getConstantInt32(
       return spvContext.getConstantInt32(
           static_cast<int32_t>(intValue.getSExtValue()), isSpecConstantMode);
           static_cast<int32_t>(intValue.getSExtValue()), isSpecConstantMode);
@@ -8986,7 +8980,7 @@ SpirvConstant *SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
                   "inforamtion",
                   "inforamtion",
                   {})
                   {})
             << std::to_string(intValue.getZExtValue());
             << std::to_string(intValue.getZExtValue());
-        return 0;
+        return nullptr;
       }
       }
       return spvContext.getConstantUint32(
       return spvContext.getConstantUint32(
           static_cast<uint32_t>(intValue.getZExtValue()), isSpecConstantMode);
           static_cast<uint32_t>(intValue.getZExtValue()), isSpecConstantMode);
@@ -9003,7 +8997,7 @@ SpirvConstant *SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
   emitError("APInt for target bitwidth %0 unimplemented", {})
   emitError("APInt for target bitwidth %0 unimplemented", {})
       << astContext.getIntWidth(targetType);
       << astContext.getIntWidth(targetType);
 
 
-  return 0;
+  return nullptr;
 }
 }
 
 
 bool SPIRVEmitter::isLiteralLargerThan32Bits(const Expr *expr) {
 bool SPIRVEmitter::isLiteralLargerThan32Bits(const Expr *expr) {
@@ -9106,7 +9100,7 @@ SpirvConstant *SPIRVEmitter::translateAPFloat(llvm::APFloat floatValue,
           << std::to_string(valueBitwidth == 32
           << std::to_string(valueBitwidth == 32
                                 ? originalValue.convertToFloat()
                                 ? originalValue.convertToFloat()
                                 : originalValue.convertToDouble());
                                 : originalValue.convertToDouble());
-      return 0;
+      return nullptr;
     }
     }
   }
   }
 
 
@@ -9124,7 +9118,7 @@ SpirvConstant *SPIRVEmitter::translateAPFloat(llvm::APFloat floatValue,
   }
   }
   emitError("APFloat for target bitwidth %0 unimplemented", {})
   emitError("APFloat for target bitwidth %0 unimplemented", {})
       << targetBitwidth;
       << targetBitwidth;
-  return 0;
+  return nullptr;
 }
 }
 
 
 SpirvConstant *SPIRVEmitter::tryToEvaluateAsConst(const Expr *expr) {
 SpirvConstant *SPIRVEmitter::tryToEvaluateAsConst(const Expr *expr) {
@@ -9769,8 +9763,8 @@ bool SPIRVEmitter::allSwitchCasesAreIntegerLiterals(const Stmt *root) {
 }
 }
 
 
 void SPIRVEmitter::discoverAllCaseStmtInSwitchStmt(
 void SPIRVEmitter::discoverAllCaseStmtInSwitchStmt(
-    const Stmt *root, uint32_t *defaultBB,
-    std::vector<std::pair<uint32_t, uint32_t>> *targets) {
+    const Stmt *root, SpirvBasicBlock **defaultBB,
+    std::vector<std::pair<uint32_t, SpirvBasicBlock *>> *targets) {
   if (!root)
   if (!root)
     return;
     return;
 
 
@@ -9812,8 +9806,8 @@ void SPIRVEmitter::discoverAllCaseStmtInSwitchStmt(
     caseLabel = "switch." + std::string(value < 0 ? "n" : "") +
     caseLabel = "switch." + std::string(value < 0 ? "n" : "") +
                 llvm::itostr(std::abs(value));
                 llvm::itostr(std::abs(value));
   }
   }
-  const uint32_t caseBB = theBuilder.createBasicBlock(caseLabel);
-  theBuilder.addSuccessor(caseBB);
+  auto *caseBB = spvBuilder.createBasicBlock(caseLabel);
+  spvBuilder.addSuccessor(caseBB);
   stmtBasicBlock[root] = caseBB;
   stmtBasicBlock[root] = caseBB;
 
 
   // Add all cases to the 'targets' vector.
   // Add all cases to the 'targets' vector.
@@ -9860,15 +9854,15 @@ void SPIRVEmitter::processCaseStmtOrDefaultStmt(const Stmt *stmt) {
   auto *defaultStmt = dyn_cast<DefaultStmt>(stmt);
   auto *defaultStmt = dyn_cast<DefaultStmt>(stmt);
   assert(caseStmt || defaultStmt);
   assert(caseStmt || defaultStmt);
 
 
-  uint32_t caseBB = stmtBasicBlock[stmt];
-  if (!theBuilder.isCurrentBasicBlockTerminated()) {
+  auto *caseBB = stmtBasicBlock[stmt];
+  if (!spvBuilder.isCurrentBasicBlockTerminated()) {
     // We are about to handle the case passed in as parameter. If the current
     // We are about to handle the case passed in as parameter. If the current
     // basic block is not terminated, it means the previous case is a fall
     // basic block is not terminated, it means the previous case is a fall
     // through case. We need to link it to the case to be processed.
     // through case. We need to link it to the case to be processed.
-    theBuilder.createBranch(caseBB);
-    theBuilder.addSuccessor(caseBB);
+    spvBuilder.createBranch(caseBB);
+    spvBuilder.addSuccessor(caseBB);
   }
   }
-  theBuilder.setInsertPoint(caseBB);
+  spvBuilder.setInsertPoint(caseBB);
   doStmt(caseStmt ? caseStmt->getSubStmt() : defaultStmt->getSubStmt());
   doStmt(caseStmt ? caseStmt->getSubStmt() : defaultStmt->getSubStmt());
 }
 }
 
 
@@ -9880,30 +9874,30 @@ void SPIRVEmitter::processSwitchStmtUsingSpirvOpSwitch(
   if (const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt())
   if (const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt())
     doDeclStmt(condVarDeclStmt);
     doDeclStmt(condVarDeclStmt);
 
 
-  const uint32_t selector = doExpr(switchStmt->getCond());
+  auto *selector = doExpr(switchStmt->getCond());
 
 
   // We need a merge block regardless of the number of switch cases.
   // We need a merge block regardless of the number of switch cases.
   // Since OpSwitch always requires a default label, if the switch statement
   // Since OpSwitch always requires a default label, if the switch statement
   // does not have a default branch, we use the merge block as the default
   // does not have a default branch, we use the merge block as the default
   // target.
   // target.
-  const uint32_t mergeBB = theBuilder.createBasicBlock("switch.merge");
-  theBuilder.setMergeTarget(mergeBB);
+  auto *mergeBB = spvBuilder.createBasicBlock("switch.merge");
+  spvBuilder.setMergeTarget(mergeBB);
   breakStack.push(mergeBB);
   breakStack.push(mergeBB);
-  uint32_t defaultBB = mergeBB;
+  auto *defaultBB = mergeBB;
 
 
   // (literal, labelId) pairs to pass to the OpSwitch instruction.
   // (literal, labelId) pairs to pass to the OpSwitch instruction.
-  std::vector<std::pair<uint32_t, uint32_t>> targets;
+  std::vector<std::pair<uint32_t, SpirvBasicBlock *>> targets;
   discoverAllCaseStmtInSwitchStmt(switchStmt->getBody(), &defaultBB, &targets);
   discoverAllCaseStmtInSwitchStmt(switchStmt->getBody(), &defaultBB, &targets);
 
 
   // Create the OpSelectionMerge and OpSwitch.
   // Create the OpSelectionMerge and OpSwitch.
-  theBuilder.createSwitch(mergeBB, selector, defaultBB, targets);
+  spvBuilder.createSwitch(mergeBB, selector, defaultBB, targets);
 
 
   // Handle the switch body.
   // Handle the switch body.
   doStmt(switchStmt->getBody());
   doStmt(switchStmt->getBody());
 
 
-  if (!theBuilder.isCurrentBasicBlockTerminated())
-    theBuilder.createBranch(mergeBB);
-  theBuilder.setInsertPoint(mergeBB);
+  if (!spvBuilder.isCurrentBasicBlockTerminated())
+    spvBuilder.createBranch(mergeBB);
+  spvBuilder.setInsertPoint(mergeBB);
   breakStack.pop();
   breakStack.pop();
 }
 }
 
 

+ 5 - 4
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -680,8 +680,8 @@ private:
   /// label for the default basic block through the defaultBB parameter. This
   /// label for the default basic block through the defaultBB parameter. This
   /// method panics if it finds a case value that is not an integer literal.
   /// method panics if it finds a case value that is not an integer literal.
   void discoverAllCaseStmtInSwitchStmt(
   void discoverAllCaseStmtInSwitchStmt(
-      const Stmt *root, uint32_t *defaultBB,
-      std::vector<std::pair<uint32_t, uint32_t>> *targets);
+      const Stmt *root, SpirvBasicBlock **defaultBB,
+      std::vector<std::pair<uint32_t, SpirvBasicBlock *>> *targets);
 
 
   /// Flattens structured AST of the given switch statement into a vector of AST
   /// Flattens structured AST of the given switch statement into a vector of AST
   /// nodes and stores into flatSwitch.
   /// nodes and stores into flatSwitch.
@@ -786,8 +786,9 @@ private:
 
 
   /// \brief Handles .Gather{|Cmp}{Red|Green|Blue|Alpha}() calls on texture
   /// \brief Handles .Gather{|Cmp}{Red|Green|Blue|Alpha}() calls on texture
   /// types.
   /// types.
-  SpirvInstruction *processTextureGatherRGBACmpRGBA(const CXXMemberCallExpr *expr,
-                                           bool isCmp, uint32_t component);
+  SpirvInstruction *
+  processTextureGatherRGBACmpRGBA(const CXXMemberCallExpr *expr, bool isCmp,
+                                  uint32_t component);
 
 
   /// \brief Handles .GatherCmp() calls on texture types.
   /// \brief Handles .GatherCmp() calls on texture types.
   SpirvInstruction *processTextureGatherCmp(const CXXMemberCallExpr *expr);
   SpirvInstruction *processTextureGatherCmp(const CXXMemberCallExpr *expr);