Browse Source

[spirv] Create utility function to de-duplicate code (#876)

Lei Zhang 7 years ago
parent
commit
3076ec85de
1 changed files with 32 additions and 45 deletions
  1. 32 45
      tools/clang/lib/SPIRV/SPIRVEmitter.cpp

+ 32 - 45
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -136,16 +136,10 @@ bool isFloatOrVecMatOfFloatType(QualType type) {
           hlsl::GetHLSLMatElementType(type)->isFloatingType());
 }
 
-bool isSpirvMatrixOp(spv::Op opcode) {
-  switch (opcode) {
-  case spv::Op::OpMatrixTimesMatrix:
-  case spv::Op::OpMatrixTimesVector:
-  case spv::Op::OpMatrixTimesScalar:
-    return true;
-  default:
-    break;
-  }
-  return false;
+inline bool isSpirvMatrixOp(spv::Op opcode) {
+  return opcode == spv::Op::OpMatrixTimesMatrix ||
+         opcode == spv::Op::OpMatrixTimesVector ||
+         opcode == spv::Op::OpMatrixTimesScalar;
 }
 
 /// If expr is a (RW)StructuredBuffer.Load(), returns the object and writes
@@ -275,6 +269,19 @@ inline bool evaluatesToConstZero(const Expr *expr, ASTContext &astContext) {
   return false;
 }
 
+/// Creates an access chain to index into the given SPIR-V evaluation result
+/// and overwrites and returns the new SPIR-V evaluation result.
+inline SpirvEvalInfo &
+turnIntoElementPtr(SpirvEvalInfo &info, QualType elemType,
+                   const llvm::SmallVector<uint32_t, 4> &indices,
+                   ModuleBuilder &builder, TypeTranslator &translator) {
+  assert(!info.isRValue());
+  const uint32_t ptrType = builder.getPointerType(
+      translator.translateType(elemType, info.getLayoutRule()),
+      info.getStorageClass());
+  return info.setResultId(builder.createAccessChain(ptrType, info, indices));
+}
+
 } // namespace
 
 SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
@@ -1353,15 +1360,11 @@ void SPIRVEmitter::doSwitchStmt(const SwitchStmt *switchStmt,
 SpirvEvalInfo
 SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
   llvm::SmallVector<uint32_t, 4> indices;
-  const auto *base = collectArrayStructIndices(expr, &indices);
-  auto info = doExpr(base);
+  auto info = doExpr(collectArrayStructIndices(expr, &indices));
 
   if (!indices.empty()) {
-    assert(!info.isRValue());
-    const uint32_t ptrType = theBuilder.getPointerType(
-        typeTranslator.translateType(expr->getType(), info.getLayoutRule()),
-        info.getStorageClass());
-    info.setResultId(theBuilder.createAccessChain(ptrType, info, indices));
+    (void)turnIntoElementPtr(info, expr->getType(), indices, theBuilder,
+                             typeTranslator);
   }
 
   return info;
@@ -2481,15 +2484,12 @@ SPIRVEmitter::processStructuredBufferLoad(const CXXMemberCallExpr *expr) {
 
   const QualType structType =
       hlsl::GetHLSLResourceResultType(buffer->getType());
-  const uint32_t ptrType = theBuilder.getPointerType(
-      typeTranslator.translateType(structType, info.getLayoutRule()),
-      info.getStorageClass());
 
   const uint32_t zero = theBuilder.getConstantInt32(0);
   const uint32_t index = doExpr(expr->getArg(0));
 
-  info.setResultId(theBuilder.createAccessChain(ptrType, info, {zero, index}));
-  return info;
+  return turnIntoElementPtr(info, structType, {zero, index}, theBuilder,
+                            typeTranslator);
 }
 
 uint32_t SPIRVEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
@@ -2539,21 +2539,16 @@ SPIRVEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) {
   auto bufferInfo = declIdMapper.getDeclResultId(buffer);
 
   const auto bufferElemTy = hlsl::GetHLSLResourceResultType(object->getType());
-  const uint32_t bufferElemType =
-      typeTranslator.translateType(bufferElemTy, bufferInfo.getLayoutRule());
-  // Get the pointer inside the {Append|Consume}StructuredBuffer
-  const uint32_t bufferElemPtrType =
-      theBuilder.getPointerType(bufferElemType, bufferInfo.getStorageClass());
-  const uint32_t bufferElemPtr = theBuilder.createAccessChain(
-      bufferElemPtrType, bufferInfo, {zero, index});
+
+  (void)turnIntoElementPtr(bufferInfo, bufferElemTy, {zero, index}, theBuilder,
+                           typeTranslator);
 
   if (isAppend) {
     // Write out the value
-    bufferInfo.setResultId(bufferElemPtr);
     storeValue(bufferInfo, doExpr(expr->getArg(0)), bufferElemTy);
     return 0;
   } else {
-    return bufferInfo.setResultId(bufferElemPtr);
+    return bufferInfo;
   }
 }
 
@@ -2617,7 +2612,7 @@ void SPIRVEmitter::handleOptionalOffsetInMethodCall(
     else
       *varOffset = doExpr(expr->getArg(index));
   }
-};
+}
 
 SpirvEvalInfo
 SPIRVEmitter::processIntrinsicMemberCall(const CXXMemberCallExpr *expr,
@@ -2999,12 +2994,8 @@ SPIRVEmitter::doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr) {
     base = tempVar;
   }
 
-  const uint32_t ptrType = theBuilder.getPointerType(
-      typeTranslator.translateType(expr->getType(), base.getLayoutRule()),
-      base.getStorageClass());
-  base.setResultId(theBuilder.createAccessChain(ptrType, base, indices));
-
-  return base;
+  return turnIntoElementPtr(base, expr->getType(), indices, theBuilder,
+                            typeTranslator);
 }
 
 SpirvEvalInfo
@@ -3157,15 +3148,11 @@ SpirvEvalInfo SPIRVEmitter::doInitListExpr(const InitListExpr *expr) {
 
 SpirvEvalInfo SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {
   llvm::SmallVector<uint32_t, 4> indices;
-
-  const Expr *base = collectArrayStructIndices(expr, &indices);
-  auto info = doExpr(base);
+  auto info = doExpr(collectArrayStructIndices(expr, &indices));
 
   if (!indices.empty()) {
-    const uint32_t ptrType = theBuilder.getPointerType(
-        typeTranslator.translateType(expr->getType(), info.getLayoutRule()),
-        info.getStorageClass());
-    info.setResultId(theBuilder.createAccessChain(ptrType, info, indices));
+    (void)turnIntoElementPtr(info, expr->getType(), indices, theBuilder,
+                             typeTranslator);
   }
 
   return info;