inst_debug_printf_pass.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. // Copyright (c) 2020 The Khronos Group Inc.
  2. // Copyright (c) 2020 Valve Corporation
  3. // Copyright (c) 2020 LunarG Inc.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "inst_debug_printf_pass.h"
  17. #include "source/util/string_utils.h"
  18. #include "spirv/unified1/NonSemanticDebugPrintf.h"
  19. namespace spvtools {
  20. namespace opt {
  21. void InstDebugPrintfPass::GenOutputValues(Instruction* val_inst,
  22. std::vector<uint32_t>* val_ids,
  23. InstructionBuilder* builder) {
  24. uint32_t val_ty_id = val_inst->type_id();
  25. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  26. analysis::Type* val_ty = type_mgr->GetType(val_ty_id);
  27. switch (val_ty->kind()) {
  28. case analysis::Type::kVector: {
  29. analysis::Vector* v_ty = val_ty->AsVector();
  30. const analysis::Type* c_ty = v_ty->element_type();
  31. uint32_t c_ty_id = type_mgr->GetId(c_ty);
  32. for (uint32_t c = 0; c < v_ty->element_count(); ++c) {
  33. Instruction* c_inst = builder->AddIdLiteralOp(
  34. c_ty_id, spv::Op::OpCompositeExtract, val_inst->result_id(), c);
  35. GenOutputValues(c_inst, val_ids, builder);
  36. }
  37. return;
  38. }
  39. case analysis::Type::kBool: {
  40. // Select between uint32 zero or one
  41. uint32_t zero_id = builder->GetUintConstantId(0);
  42. uint32_t one_id = builder->GetUintConstantId(1);
  43. Instruction* sel_inst =
  44. builder->AddTernaryOp(GetUintId(), spv::Op::OpSelect,
  45. val_inst->result_id(), one_id, zero_id);
  46. val_ids->push_back(sel_inst->result_id());
  47. return;
  48. }
  49. case analysis::Type::kFloat: {
  50. analysis::Float* f_ty = val_ty->AsFloat();
  51. switch (f_ty->width()) {
  52. case 16: {
  53. // Convert float16 to float32 and recurse
  54. Instruction* f32_inst = builder->AddUnaryOp(
  55. GetFloatId(), spv::Op::OpFConvert, val_inst->result_id());
  56. GenOutputValues(f32_inst, val_ids, builder);
  57. return;
  58. }
  59. case 64: {
  60. // Bitcast float64 to uint64 and recurse
  61. Instruction* ui64_inst = builder->AddUnaryOp(
  62. GetUint64Id(), spv::Op::OpBitcast, val_inst->result_id());
  63. GenOutputValues(ui64_inst, val_ids, builder);
  64. return;
  65. }
  66. case 32: {
  67. // Bitcase float32 to uint32
  68. Instruction* bc_inst = builder->AddUnaryOp(
  69. GetUintId(), spv::Op::OpBitcast, val_inst->result_id());
  70. val_ids->push_back(bc_inst->result_id());
  71. return;
  72. }
  73. default:
  74. assert(false && "unsupported float width");
  75. return;
  76. }
  77. }
  78. case analysis::Type::kInteger: {
  79. analysis::Integer* i_ty = val_ty->AsInteger();
  80. switch (i_ty->width()) {
  81. case 64: {
  82. Instruction* ui64_inst = val_inst;
  83. if (i_ty->IsSigned()) {
  84. // Bitcast sint64 to uint64
  85. ui64_inst = builder->AddUnaryOp(GetUint64Id(), spv::Op::OpBitcast,
  86. val_inst->result_id());
  87. }
  88. // Break uint64 into 2x uint32
  89. Instruction* lo_ui64_inst = builder->AddUnaryOp(
  90. GetUintId(), spv::Op::OpUConvert, ui64_inst->result_id());
  91. Instruction* rshift_ui64_inst = builder->AddBinaryOp(
  92. GetUint64Id(), spv::Op::OpShiftRightLogical,
  93. ui64_inst->result_id(), builder->GetUintConstantId(32));
  94. Instruction* hi_ui64_inst = builder->AddUnaryOp(
  95. GetUintId(), spv::Op::OpUConvert, rshift_ui64_inst->result_id());
  96. val_ids->push_back(lo_ui64_inst->result_id());
  97. val_ids->push_back(hi_ui64_inst->result_id());
  98. return;
  99. }
  100. case 8: {
  101. Instruction* ui8_inst = val_inst;
  102. if (i_ty->IsSigned()) {
  103. // Bitcast sint8 to uint8
  104. ui8_inst = builder->AddUnaryOp(GetUint8Id(), spv::Op::OpBitcast,
  105. val_inst->result_id());
  106. }
  107. // Convert uint8 to uint32
  108. Instruction* ui32_inst = builder->AddUnaryOp(
  109. GetUintId(), spv::Op::OpUConvert, ui8_inst->result_id());
  110. val_ids->push_back(ui32_inst->result_id());
  111. return;
  112. }
  113. case 32: {
  114. Instruction* ui32_inst = val_inst;
  115. if (i_ty->IsSigned()) {
  116. // Bitcast sint32 to uint32
  117. ui32_inst = builder->AddUnaryOp(GetUintId(), spv::Op::OpBitcast,
  118. val_inst->result_id());
  119. }
  120. // uint32 needs no further processing
  121. val_ids->push_back(ui32_inst->result_id());
  122. return;
  123. }
  124. default:
  125. // TODO(greg-lunarg): Support non-32-bit int
  126. assert(false && "unsupported int width");
  127. return;
  128. }
  129. }
  130. default:
  131. assert(false && "unsupported type");
  132. return;
  133. }
  134. }
  135. void InstDebugPrintfPass::GenOutputCode(
  136. Instruction* printf_inst, uint32_t stage_idx,
  137. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  138. BasicBlock* back_blk_ptr = &*new_blocks->back();
  139. InstructionBuilder builder(
  140. context(), back_blk_ptr,
  141. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  142. // Gen debug printf record validation-specific values. The format string
  143. // will have its id written. Vectors will need to be broken down into
  144. // component values. float16 will need to be converted to float32. Pointer
  145. // and uint64 will need to be converted to two uint32 values. float32 will
  146. // need to be bitcast to uint32. int32 will need to be bitcast to uint32.
  147. std::vector<uint32_t> val_ids;
  148. bool is_first_operand = false;
  149. printf_inst->ForEachInId(
  150. [&is_first_operand, &val_ids, &builder, this](const uint32_t* iid) {
  151. // skip set operand
  152. if (!is_first_operand) {
  153. is_first_operand = true;
  154. return;
  155. }
  156. Instruction* opnd_inst = get_def_use_mgr()->GetDef(*iid);
  157. if (opnd_inst->opcode() == spv::Op::OpString) {
  158. uint32_t string_id_id = builder.GetUintConstantId(*iid);
  159. val_ids.push_back(string_id_id);
  160. } else {
  161. GenOutputValues(opnd_inst, &val_ids, &builder);
  162. }
  163. });
  164. GenDebugStreamWrite(uid2offset_[printf_inst->unique_id()], stage_idx, val_ids,
  165. &builder);
  166. context()->KillInst(printf_inst);
  167. }
  168. void InstDebugPrintfPass::GenDebugPrintfCode(
  169. BasicBlock::iterator ref_inst_itr,
  170. UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
  171. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  172. // If not DebugPrintf OpExtInst, return.
  173. Instruction* printf_inst = &*ref_inst_itr;
  174. if (printf_inst->opcode() != spv::Op::OpExtInst) return;
  175. if (printf_inst->GetSingleWordInOperand(0) != ext_inst_printf_id_) return;
  176. if (printf_inst->GetSingleWordInOperand(1) !=
  177. NonSemanticDebugPrintfDebugPrintf)
  178. return;
  179. // Initialize DefUse manager before dismantling module
  180. (void)get_def_use_mgr();
  181. // Move original block's preceding instructions into first new block
  182. std::unique_ptr<BasicBlock> new_blk_ptr;
  183. MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
  184. new_blocks->push_back(std::move(new_blk_ptr));
  185. // Generate instructions to output printf args to printf buffer
  186. GenOutputCode(printf_inst, stage_idx, new_blocks);
  187. // Caller expects at least two blocks with last block containing remaining
  188. // code, so end block after instrumentation, create remainder block, and
  189. // branch to it
  190. uint32_t rem_blk_id = TakeNextId();
  191. std::unique_ptr<Instruction> rem_label(NewLabel(rem_blk_id));
  192. BasicBlock* back_blk_ptr = &*new_blocks->back();
  193. InstructionBuilder builder(
  194. context(), back_blk_ptr,
  195. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  196. (void)builder.AddBranch(rem_blk_id);
  197. // Gen remainder block
  198. new_blk_ptr.reset(new BasicBlock(std::move(rem_label)));
  199. builder.SetInsertPoint(&*new_blk_ptr);
  200. // Move original block's remaining code into remainder block and add
  201. // to new blocks
  202. MovePostludeCode(ref_block_itr, &*new_blk_ptr);
  203. new_blocks->push_back(std::move(new_blk_ptr));
  204. }
  205. void InstDebugPrintfPass::InitializeInstDebugPrintf() {
  206. // Initialize base class
  207. InitializeInstrument();
  208. }
  209. Pass::Status InstDebugPrintfPass::ProcessImpl() {
  210. // Perform printf instrumentation on each entry point function in module
  211. InstProcessFunction pfn =
  212. [this](BasicBlock::iterator ref_inst_itr,
  213. UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
  214. std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
  215. return GenDebugPrintfCode(ref_inst_itr, ref_block_itr, stage_idx,
  216. new_blocks);
  217. };
  218. (void)InstProcessEntryPointCallTree(pfn);
  219. // Remove DebugPrintf OpExtInstImport instruction
  220. Instruction* ext_inst_import_inst =
  221. get_def_use_mgr()->GetDef(ext_inst_printf_id_);
  222. context()->KillInst(ext_inst_import_inst);
  223. // If no remaining non-semantic instruction sets, remove non-semantic debug
  224. // info extension from module and feature manager
  225. bool non_sem_set_seen = false;
  226. for (auto c_itr = context()->module()->ext_inst_import_begin();
  227. c_itr != context()->module()->ext_inst_import_end(); ++c_itr) {
  228. const std::string set_name = c_itr->GetInOperand(0).AsString();
  229. if (spvtools::utils::starts_with(set_name, "NonSemantic.")) {
  230. non_sem_set_seen = true;
  231. break;
  232. }
  233. }
  234. if (!non_sem_set_seen) {
  235. for (auto c_itr = context()->module()->extension_begin();
  236. c_itr != context()->module()->extension_end(); ++c_itr) {
  237. const std::string ext_name = c_itr->GetInOperand(0).AsString();
  238. if (ext_name == "SPV_KHR_non_semantic_info") {
  239. context()->KillInst(&*c_itr);
  240. break;
  241. }
  242. }
  243. context()->get_feature_mgr()->RemoveExtension(kSPV_KHR_non_semantic_info);
  244. }
  245. return Status::SuccessWithChange;
  246. }
  247. Pass::Status InstDebugPrintfPass::Process() {
  248. ext_inst_printf_id_ =
  249. get_module()->GetExtInstImportId("NonSemantic.DebugPrintf");
  250. if (ext_inst_printf_id_ == 0) return Status::SuccessWithoutChange;
  251. InitializeInstDebugPrintf();
  252. return ProcessImpl();
  253. }
  254. } // namespace opt
  255. } // namespace spvtools