瀏覽代碼

[spirv] Simplify the logic in InitListHandler::decompose. (#2090)

Ehsan 6 年之前
父節點
當前提交
456d770192
共有 2 個文件被更改,包括 24 次插入37 次删除
  1. 24 36
      tools/clang/lib/SPIRV/InitListHandler.cpp
  2. 0 1
      tools/clang/lib/SPIRV/InitListHandler.h

+ 24 - 36
tools/clang/lib/SPIRV/InitListHandler.cpp

@@ -84,46 +84,34 @@ void InitListHandler::flatten(const InitListExpr *expr) {
 void InitListHandler::decompose(SpirvInstruction *inst) {
 void InitListHandler::decompose(SpirvInstruction *inst) {
   const QualType type = inst->getAstResultType();
   const QualType type = inst->getAstResultType();
 
 
-  QualType elemType;
-  uint32_t elemCount;
-  if (isVectorType(type, &elemType, &elemCount)) {
-    decomposeVector(inst, elemType, elemCount);
-  } else if (hlsl::IsHLSLMatType(type)) {
-    elemType = hlsl::GetHLSLMatElementType(type);
-
-    uint32_t rowCount = 0, colCount = 0;
-    hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
-
-    if (rowCount == 1 || colCount == 1) {
-      // This also handles the scalar case
-      decomposeVector(inst, elemType, rowCount == 1 ? colCount : rowCount);
-    } else {
-      for (uint32_t i = 0; i < rowCount; ++i)
-        for (uint32_t j = 0; j < colCount; ++j) {
-          auto *element =
-              spvBuilder.createCompositeExtract(elemType, inst, {i, j});
-          scalars.emplace_back(element, elemType);
-        }
-    }
-  } else if (isScalarType(type, &elemType)) {
+  QualType elemType = {};
+  uint32_t elemCount = 0, rowCount = 0, colCount = 0;
+
+  // Scalar cases, including vec1 and mat1x1.
+  if (isScalarType(type, &elemType)) {
     scalars.emplace_back(inst, elemType);
     scalars.emplace_back(inst, elemType);
-  } else {
-    llvm_unreachable(
-        "decompose() should only handle scalar or vector or matrix types");
   }
   }
-}
-
-void InitListHandler::decomposeVector(SpirvInstruction *vec, QualType elemType,
-                                      uint32_t size) {
-  if (size == 1) {
-    // Decomposing of size-1 vector just results in the vector itself.
-    scalars.emplace_back(vec, elemType);
-  } else {
-    for (uint32_t i = 0; i < size; ++i) {
-      auto *element = spvBuilder.createCompositeExtract(elemType, vec, {i});
+  // Vector cases, including mat1xN and matNx1 where N > 1.
+  else if (isVectorType(type, &elemType, &elemCount)) {
+    for (uint32_t i = 0; i < elemCount; ++i) {
+      auto *element = spvBuilder.createCompositeExtract(elemType, inst, {i});
       scalars.emplace_back(element, elemType);
       scalars.emplace_back(element, elemType);
     }
     }
   }
   }
+  // MxN matrix cases, where M > 1 and N > 1.
+  else if (isMxNMatrix(type, &elemType, &rowCount, &colCount)) {
+    for (uint32_t i = 0; i < rowCount; ++i)
+      for (uint32_t j = 0; j < colCount; ++j) {
+        auto *element =
+            spvBuilder.createCompositeExtract(elemType, inst, {i, j});
+        scalars.emplace_back(element, elemType);
+      }
+  }
+  // The decompose method only supports scalar, vector, and matrix types.
+  else {
+    llvm_unreachable(
+        "decompose() should only handle scalar or vector or matrix types");
+  }
 }
 }
 
 
 bool InitListHandler::tryToSplitStruct() {
 bool InitListHandler::tryToSplitStruct() {
@@ -199,7 +187,7 @@ SpirvInstruction *InitListHandler::createInitForType(QualType type,
   if (type->isBuiltinType())
   if (type->isBuiltinType())
     return createInitForBuiltinType(type, srcLoc);
     return createInitForBuiltinType(type, srcLoc);
 
 
-  QualType elemType;
+  QualType elemType = {};
   uint32_t elemCount = 0;
   uint32_t elemCount = 0;
   if (isVectorType(type, &elemType, &elemCount))
   if (isVectorType(type, &elemType, &elemCount))
     return createInitForVectorType(elemType, elemCount, srcLoc);
     return createInitForVectorType(elemType, elemCount, srcLoc);

+ 0 - 1
tools/clang/lib/SPIRV/InitListHandler.h

@@ -112,7 +112,6 @@ private:
   /// Decomposes the given SpirvInstruction and puts all elements into the end
   /// Decomposes the given SpirvInstruction and puts all elements into the end
   /// of the scalars queue.
   /// of the scalars queue.
   void decompose(SpirvInstruction *inst);
   void decompose(SpirvInstruction *inst);
-  void decomposeVector(SpirvInstruction *vec, QualType elemType, uint32_t size);
 
 
   /// If the next initializer is a struct, replaces it with OpCompositeExtract
   /// If the next initializer is a struct, replaces it with OpCompositeExtract
   /// its members and returns true. Otherwise, does nothing and return false.
   /// its members and returns true. Otherwise, does nothing and return false.