InitListHandler.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. //===------- InitListHandler.cpp - Initializer List Handler -----*- C++ -*-===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements an initalizer list handler that takes in an
  10. // InitListExpr and emits the corresponding SPIR-V instructions for it.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "InitListHandler.h"
  14. #include "llvm/ADT/SmallVector.h"
  15. namespace clang {
  16. namespace spirv {
  17. InitListHandler::InitListHandler(SPIRVEmitter &emitter)
  18. : theEmitter(emitter), theBuilder(emitter.getModuleBuilder()),
  19. typeTranslator(emitter.getTypeTranslator()),
  20. diags(emitter.getDiagnosticsEngine()) {}
  21. uint32_t InitListHandler::process(const InitListExpr *expr) {
  22. initializers.clear();
  23. scalars.clear();
  24. flatten(expr);
  25. const uint32_t init = createInitForType(expr->getType());
  26. /// We should have consumed all initializers and scalars extracted from them.
  27. assert(initializers.empty());
  28. assert(scalars.empty());
  29. return init;
  30. }
  31. void InitListHandler::flatten(const InitListExpr *expr) {
  32. const auto numInits = expr->getNumInits();
  33. for (uint32_t i = 0; i < numInits; ++i) {
  34. const Expr *init = expr->getInit(i);
  35. if (const auto *subInitList = dyn_cast<InitListExpr>(init)) {
  36. flatten(subInitList);
  37. } else if (const auto *subInitList = dyn_cast<InitListExpr>(
  38. // Ignore constructor casts which are no-ops
  39. // For cases like: <type>(<initializer-list>)
  40. init->IgnoreParenNoopCasts(theEmitter.getASTContext()))) {
  41. flatten(subInitList);
  42. } else {
  43. initializers.push_back(init);
  44. }
  45. }
  46. }
  47. void InitListHandler::decompose(const Expr *expr) {
  48. const QualType type = expr->getType();
  49. assert(!type->isBuiltinType()); // Cannot decompose builtin types
  50. if (hlsl::IsHLSLVecType(type)) {
  51. const uint32_t vec = theEmitter.loadIfGLValue(expr);
  52. const QualType elemType = hlsl::GetHLSLVecElementType(type);
  53. const auto size = hlsl::GetHLSLVecSize(type);
  54. if (size == 1) {
  55. // Decomposing of size-1 vector just results in the vector itself.
  56. scalars.emplace_back(vec, elemType);
  57. } else {
  58. const uint32_t elemTypeId = typeTranslator.translateType(elemType);
  59. for (uint32_t i = 0; i < size; ++i) {
  60. const uint32_t element =
  61. theBuilder.createCompositeExtract(elemTypeId, vec, {i});
  62. scalars.emplace_back(element, elemType);
  63. }
  64. }
  65. } else {
  66. emitError("decomposing type %0 in initializer list unimplemented") << type;
  67. }
  68. }
  69. uint32_t InitListHandler::createInitForType(QualType type) {
  70. type = type.getCanonicalType();
  71. if (type->isBuiltinType())
  72. return createInitForBuiltinType(type);
  73. if (hlsl::IsHLSLVecType(type))
  74. return createInitForVectorType(hlsl::GetHLSLVecElementType(type),
  75. hlsl::GetHLSLVecSize(type));
  76. if (hlsl::IsHLSLMatType(type)) {
  77. uint32_t rowCount = 0, colCount = 0;
  78. hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
  79. const QualType elemType = hlsl::GetHLSLMatElementType(type);
  80. return createInitForMatrixType(elemType, rowCount, colCount);
  81. }
  82. emitError("unimplemented initializer for type '%0'") << type;
  83. return 0;
  84. }
  85. uint32_t InitListHandler::createInitForBuiltinType(QualType type) {
  86. assert(type->isBuiltinType());
  87. if (!scalars.empty()) {
  88. const auto init = scalars.front();
  89. scalars.pop_front();
  90. return theEmitter.castToType(init.first, init.second, type);
  91. }
  92. const Expr *init = initializers.front();
  93. initializers.pop_front();
  94. if (!init->getType()->isBuiltinType()) {
  95. decompose(init);
  96. return createInitForBuiltinType(type);
  97. }
  98. const uint32_t value = theEmitter.loadIfGLValue(init);
  99. return theEmitter.castToType(value, init->getType(), type);
  100. }
  101. uint32_t InitListHandler::createInitForVectorType(QualType elemType,
  102. uint32_t count) {
  103. // If we don't have leftover scalars, we can try to see if there is a vector
  104. // of the same size in the original initializer list so that we can use it
  105. // directly. For all other cases, we need to construct a new vector as the
  106. // initializer.
  107. if (scalars.empty()) {
  108. const Expr *init = initializers.front();
  109. if (hlsl::IsHLSLVecType(init->getType()) &&
  110. hlsl::GetHLSLVecSize(init->getType()) == count) {
  111. initializers.pop_front();
  112. /// HLSL vector types are parameterized templates and we cannot
  113. /// construct them. So we construct an ExtVectorType here instead.
  114. /// This is unfortunate since it means we need to handle ExtVectorType
  115. /// in all type casting methods in SPIRVEmitter.
  116. const auto toVecType =
  117. theEmitter.getASTContext().getExtVectorType(elemType, count);
  118. return theEmitter.castToType(theEmitter.loadIfGLValue(init),
  119. init->getType(), toVecType);
  120. }
  121. }
  122. if (count == 1)
  123. return createInitForBuiltinType(elemType);
  124. llvm::SmallVector<uint32_t, 4> elements;
  125. for (uint32_t i = 0; i < count; ++i) {
  126. // All elements are scalars, which should already be casted to the correct
  127. // type if necessary.
  128. elements.push_back(createInitForBuiltinType(elemType));
  129. }
  130. const uint32_t elemTypeId = typeTranslator.translateType(elemType);
  131. const uint32_t vecType = theBuilder.getVecType(elemTypeId, count);
  132. // TODO: use OpConstantComposite when all components are constants
  133. return theBuilder.createCompositeConstruct(vecType, elements);
  134. }
  135. uint32_t InitListHandler::createInitForMatrixType(QualType elemType,
  136. uint32_t rowCount,
  137. uint32_t colCount) {
  138. // Same as the vector case, first try to see if we already have a matrix at
  139. // the beginning of the initializer queue.
  140. if (scalars.empty()) {
  141. const Expr *init = initializers.front();
  142. if (hlsl::IsHLSLMatType(init->getType())) {
  143. uint32_t initRowCount = 0, initColCount = 0;
  144. hlsl::GetHLSLMatRowColCount(init->getType(), initRowCount, initColCount);
  145. if (rowCount == initRowCount && colCount == initColCount) {
  146. initializers.pop_front();
  147. // TODO: We only support FP matrices now. Do type cast here after
  148. // adding more matrix types.
  149. return theEmitter.loadIfGLValue(init);
  150. }
  151. }
  152. }
  153. if (rowCount == 1)
  154. return createInitForVectorType(elemType, colCount);
  155. if (colCount == 1)
  156. return createInitForVectorType(elemType, rowCount);
  157. llvm::SmallVector<uint32_t, 4> vectors;
  158. for (uint32_t i = 0; i < rowCount; ++i) {
  159. // All elements are vectors, which should already be casted to the correct
  160. // type if necessary.
  161. vectors.push_back(createInitForVectorType(elemType, colCount));
  162. }
  163. const uint32_t elemTypeId = typeTranslator.translateType(elemType);
  164. const uint32_t vecType = theBuilder.getVecType(elemTypeId, colCount);
  165. const uint32_t matType = theBuilder.getMatType(vecType, rowCount);
  166. // TODO: use OpConstantComposite when all components are constants
  167. return theBuilder.createCompositeConstruct(matType, vectors);
  168. }
  169. } // end namespace spirv
  170. } // end namespace clang