convert_to_half_pass.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. // Copyright (c) 2019 The Khronos Group Inc.
  2. // Copyright (c) 2019 Valve Corporation
  3. // Copyright (c) 2019 LunarG Inc.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "convert_to_half_pass.h"
  17. #include "source/opt/ir_builder.h"
  18. namespace {
  19. // Indices of operands in SPIR-V instructions
  20. static const int kImageSampleDrefIdInIdx = 2;
  21. } // anonymous namespace
  22. namespace spvtools {
  23. namespace opt {
  24. bool ConvertToHalfPass::IsArithmetic(Instruction* inst) {
  25. return target_ops_core_.count(inst->opcode()) != 0 ||
  26. (inst->opcode() == SpvOpExtInst &&
  27. inst->GetSingleWordInOperand(0) ==
  28. context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
  29. target_ops_450_.count(inst->GetSingleWordInOperand(1)) != 0);
  30. }
  31. bool ConvertToHalfPass::IsFloat(Instruction* inst, uint32_t width) {
  32. uint32_t ty_id = inst->type_id();
  33. if (ty_id == 0) return false;
  34. return Pass::IsFloat(ty_id, width);
  35. }
  36. bool ConvertToHalfPass::IsDecoratedRelaxed(Instruction* inst) {
  37. uint32_t r_id = inst->result_id();
  38. for (auto r_inst : get_decoration_mgr()->GetDecorationsFor(r_id, false))
  39. if (r_inst->opcode() == SpvOpDecorate &&
  40. r_inst->GetSingleWordInOperand(1) == SpvDecorationRelaxedPrecision)
  41. return true;
  42. return false;
  43. }
  44. bool ConvertToHalfPass::IsRelaxed(uint32_t id) {
  45. return relaxed_ids_set_.count(id) > 0;
  46. }
  47. void ConvertToHalfPass::AddRelaxed(uint32_t id) { relaxed_ids_set_.insert(id); }
  48. analysis::Type* ConvertToHalfPass::FloatScalarType(uint32_t width) {
  49. analysis::Float float_ty(width);
  50. return context()->get_type_mgr()->GetRegisteredType(&float_ty);
  51. }
  52. analysis::Type* ConvertToHalfPass::FloatVectorType(uint32_t v_len,
  53. uint32_t width) {
  54. analysis::Type* reg_float_ty = FloatScalarType(width);
  55. analysis::Vector vec_ty(reg_float_ty, v_len);
  56. return context()->get_type_mgr()->GetRegisteredType(&vec_ty);
  57. }
  58. analysis::Type* ConvertToHalfPass::FloatMatrixType(uint32_t v_cnt,
  59. uint32_t vty_id,
  60. uint32_t width) {
  61. Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
  62. uint32_t v_len = vty_inst->GetSingleWordInOperand(1);
  63. analysis::Type* reg_vec_ty = FloatVectorType(v_len, width);
  64. analysis::Matrix mat_ty(reg_vec_ty, v_cnt);
  65. return context()->get_type_mgr()->GetRegisteredType(&mat_ty);
  66. }
  67. uint32_t ConvertToHalfPass::EquivFloatTypeId(uint32_t ty_id, uint32_t width) {
  68. analysis::Type* reg_equiv_ty;
  69. Instruction* ty_inst = get_def_use_mgr()->GetDef(ty_id);
  70. if (ty_inst->opcode() == SpvOpTypeMatrix)
  71. reg_equiv_ty = FloatMatrixType(ty_inst->GetSingleWordInOperand(1),
  72. ty_inst->GetSingleWordInOperand(0), width);
  73. else if (ty_inst->opcode() == SpvOpTypeVector)
  74. reg_equiv_ty = FloatVectorType(ty_inst->GetSingleWordInOperand(1), width);
  75. else // SpvOpTypeFloat
  76. reg_equiv_ty = FloatScalarType(width);
  77. return context()->get_type_mgr()->GetTypeInstruction(reg_equiv_ty);
  78. }
  79. void ConvertToHalfPass::GenConvert(uint32_t* val_idp, uint32_t width,
  80. Instruction* inst) {
  81. Instruction* val_inst = get_def_use_mgr()->GetDef(*val_idp);
  82. uint32_t ty_id = val_inst->type_id();
  83. uint32_t nty_id = EquivFloatTypeId(ty_id, width);
  84. if (nty_id == ty_id) return;
  85. Instruction* cvt_inst;
  86. InstructionBuilder builder(
  87. context(), inst,
  88. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  89. if (val_inst->opcode() == SpvOpUndef)
  90. cvt_inst = builder.AddNullaryOp(nty_id, SpvOpUndef);
  91. else
  92. cvt_inst = builder.AddUnaryOp(nty_id, SpvOpFConvert, *val_idp);
  93. *val_idp = cvt_inst->result_id();
  94. }
  95. bool ConvertToHalfPass::MatConvertCleanup(Instruction* inst) {
  96. if (inst->opcode() != SpvOpFConvert) return false;
  97. uint32_t mty_id = inst->type_id();
  98. Instruction* mty_inst = get_def_use_mgr()->GetDef(mty_id);
  99. if (mty_inst->opcode() != SpvOpTypeMatrix) return false;
  100. uint32_t vty_id = mty_inst->GetSingleWordInOperand(0);
  101. uint32_t v_cnt = mty_inst->GetSingleWordInOperand(1);
  102. Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
  103. uint32_t cty_id = vty_inst->GetSingleWordInOperand(0);
  104. Instruction* cty_inst = get_def_use_mgr()->GetDef(cty_id);
  105. InstructionBuilder builder(
  106. context(), inst,
  107. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  108. // Convert each component vector, combine them with OpCompositeConstruct
  109. // and replace original instruction.
  110. uint32_t orig_width = (cty_inst->GetSingleWordInOperand(0) == 16) ? 32 : 16;
  111. uint32_t orig_mat_id = inst->GetSingleWordInOperand(0);
  112. uint32_t orig_vty_id = EquivFloatTypeId(vty_id, orig_width);
  113. std::vector<Operand> opnds = {};
  114. for (uint32_t vidx = 0; vidx < v_cnt; ++vidx) {
  115. Instruction* ext_inst = builder.AddIdLiteralOp(
  116. orig_vty_id, SpvOpCompositeExtract, orig_mat_id, vidx);
  117. Instruction* cvt_inst =
  118. builder.AddUnaryOp(vty_id, SpvOpFConvert, ext_inst->result_id());
  119. opnds.push_back({SPV_OPERAND_TYPE_ID, {cvt_inst->result_id()}});
  120. }
  121. uint32_t mat_id = TakeNextId();
  122. std::unique_ptr<Instruction> mat_inst(new Instruction(
  123. context(), SpvOpCompositeConstruct, mty_id, mat_id, opnds));
  124. (void)builder.AddInstruction(std::move(mat_inst));
  125. context()->ReplaceAllUsesWith(inst->result_id(), mat_id);
  126. // Turn original instruction into copy so it is valid.
  127. inst->SetOpcode(SpvOpCopyObject);
  128. inst->SetResultType(EquivFloatTypeId(mty_id, orig_width));
  129. get_def_use_mgr()->AnalyzeInstUse(inst);
  130. return true;
  131. }
  132. bool ConvertToHalfPass::RemoveRelaxedDecoration(uint32_t id) {
  133. return context()->get_decoration_mgr()->RemoveDecorationsFrom(
  134. id, [](const Instruction& dec) {
  135. if (dec.opcode() == SpvOpDecorate &&
  136. dec.GetSingleWordInOperand(1u) == SpvDecorationRelaxedPrecision)
  137. return true;
  138. else
  139. return false;
  140. });
  141. }
  142. bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
  143. bool modified = false;
  144. // Convert all float32 based operands to float16 equivalent and change
  145. // instruction type to float16 equivalent.
  146. inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
  147. Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
  148. if (!IsFloat(op_inst, 32)) return;
  149. GenConvert(idp, 16, inst);
  150. modified = true;
  151. });
  152. if (IsFloat(inst, 32)) {
  153. inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
  154. converted_ids_.insert(inst->result_id());
  155. modified = true;
  156. }
  157. if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
  158. return modified;
  159. }
  160. bool ConvertToHalfPass::ProcessPhi(Instruction* inst) {
  161. // Add float16 converts of any float32 operands and change type
  162. // of phi to float16 equivalent. Operand converts need to be added to
  163. // preceeding blocks.
  164. uint32_t ocnt = 0;
  165. uint32_t* prev_idp;
  166. inst->ForEachInId([&ocnt, &prev_idp, this](uint32_t* idp) {
  167. if (ocnt % 2 == 0) {
  168. prev_idp = idp;
  169. } else {
  170. Instruction* val_inst = get_def_use_mgr()->GetDef(*prev_idp);
  171. if (IsFloat(val_inst, 32)) {
  172. BasicBlock* bp = context()->get_instr_block(*idp);
  173. auto insert_before = bp->tail();
  174. if (insert_before != bp->begin()) {
  175. --insert_before;
  176. if (insert_before->opcode() != SpvOpSelectionMerge &&
  177. insert_before->opcode() != SpvOpLoopMerge)
  178. ++insert_before;
  179. }
  180. GenConvert(prev_idp, 16, &*insert_before);
  181. }
  182. }
  183. ++ocnt;
  184. });
  185. inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
  186. get_def_use_mgr()->AnalyzeInstUse(inst);
  187. converted_ids_.insert(inst->result_id());
  188. return true;
  189. }
  190. bool ConvertToHalfPass::ProcessConvert(Instruction* inst) {
  191. // If float32 and relaxed, change to float16 convert
  192. if (IsFloat(inst, 32) && IsRelaxed(inst->result_id())) {
  193. inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
  194. get_def_use_mgr()->AnalyzeInstUse(inst);
  195. converted_ids_.insert(inst->result_id());
  196. }
  197. // If operand and result types are the same, change FConvert to CopyObject to
  198. // keep validator happy; simplification and DCE will clean it up
  199. // One way this can happen is if an FConvert generated during this pass
  200. // (likely by ProcessPhi) is later encountered here and its operand has been
  201. // changed to half.
  202. uint32_t val_id = inst->GetSingleWordInOperand(0);
  203. Instruction* val_inst = get_def_use_mgr()->GetDef(val_id);
  204. if (inst->type_id() == val_inst->type_id()) inst->SetOpcode(SpvOpCopyObject);
  205. return true; // modified
  206. }
  207. bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
  208. bool modified = false;
  209. // If image reference, only need to convert dref args back to float32
  210. if (dref_image_ops_.count(inst->opcode()) != 0) {
  211. uint32_t dref_id = inst->GetSingleWordInOperand(kImageSampleDrefIdInIdx);
  212. if (converted_ids_.count(dref_id) > 0) {
  213. GenConvert(&dref_id, 32, inst);
  214. inst->SetInOperand(kImageSampleDrefIdInIdx, {dref_id});
  215. get_def_use_mgr()->AnalyzeInstUse(inst);
  216. modified = true;
  217. }
  218. }
  219. return modified;
  220. }
  221. bool ConvertToHalfPass::ProcessDefault(Instruction* inst) {
  222. bool modified = false;
  223. // If non-relaxed instruction has changed operands, need to convert
  224. // them back to float32
  225. inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
  226. if (converted_ids_.count(*idp) == 0) return;
  227. uint32_t old_id = *idp;
  228. GenConvert(idp, 32, inst);
  229. if (*idp != old_id) modified = true;
  230. });
  231. if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
  232. return modified;
  233. }
  234. bool ConvertToHalfPass::GenHalfInst(Instruction* inst) {
  235. bool modified = false;
  236. // Remember id for later deletion of RelaxedPrecision decoration
  237. bool inst_relaxed = IsRelaxed(inst->result_id());
  238. if (IsArithmetic(inst) && inst_relaxed)
  239. modified = GenHalfArith(inst);
  240. else if (inst->opcode() == SpvOpPhi && inst_relaxed)
  241. modified = ProcessPhi(inst);
  242. else if (inst->opcode() == SpvOpFConvert)
  243. modified = ProcessConvert(inst);
  244. else if (image_ops_.count(inst->opcode()) != 0)
  245. modified = ProcessImageRef(inst);
  246. else
  247. modified = ProcessDefault(inst);
  248. return modified;
  249. }
  250. bool ConvertToHalfPass::CloseRelaxInst(Instruction* inst) {
  251. if (inst->result_id() == 0) return false;
  252. if (IsRelaxed(inst->result_id())) return false;
  253. if (!IsFloat(inst, 32)) return false;
  254. if (IsDecoratedRelaxed(inst)) {
  255. AddRelaxed(inst->result_id());
  256. return true;
  257. }
  258. if (closure_ops_.count(inst->opcode()) == 0) return false;
  259. // Can relax if all float operands are relaxed
  260. bool relax = true;
  261. inst->ForEachInId([&relax, this](uint32_t* idp) {
  262. Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
  263. if (!IsFloat(op_inst, 32)) return;
  264. if (!IsRelaxed(*idp)) relax = false;
  265. });
  266. if (relax) {
  267. AddRelaxed(inst->result_id());
  268. return true;
  269. }
  270. // Can relax if all uses are relaxed
  271. relax = true;
  272. get_def_use_mgr()->ForEachUser(inst, [&relax, this](Instruction* uinst) {
  273. if (uinst->result_id() == 0 || !IsFloat(uinst, 32) ||
  274. (!IsDecoratedRelaxed(uinst) && !IsRelaxed(uinst->result_id()))) {
  275. relax = false;
  276. return;
  277. }
  278. });
  279. if (relax) {
  280. AddRelaxed(inst->result_id());
  281. return true;
  282. }
  283. return false;
  284. }
  285. bool ConvertToHalfPass::ProcessFunction(Function* func) {
  286. // Do a closure of Relaxed on composite and phi instructions
  287. bool changed = true;
  288. while (changed) {
  289. changed = false;
  290. cfg()->ForEachBlockInReversePostOrder(
  291. func->entry().get(), [&changed, this](BasicBlock* bb) {
  292. for (auto ii = bb->begin(); ii != bb->end(); ++ii)
  293. changed |= CloseRelaxInst(&*ii);
  294. });
  295. }
  296. // Do convert of relaxed instructions to half precision
  297. bool modified = false;
  298. cfg()->ForEachBlockInReversePostOrder(
  299. func->entry().get(), [&modified, this](BasicBlock* bb) {
  300. for (auto ii = bb->begin(); ii != bb->end(); ++ii)
  301. modified |= GenHalfInst(&*ii);
  302. });
  303. // Replace invalid converts of matrix into equivalent vector extracts,
  304. // converts and finally a composite construct
  305. cfg()->ForEachBlockInReversePostOrder(
  306. func->entry().get(), [&modified, this](BasicBlock* bb) {
  307. for (auto ii = bb->begin(); ii != bb->end(); ++ii)
  308. modified |= MatConvertCleanup(&*ii);
  309. });
  310. return modified;
  311. }
  312. Pass::Status ConvertToHalfPass::ProcessImpl() {
  313. Pass::ProcessFunction pfn = [this](Function* fp) {
  314. return ProcessFunction(fp);
  315. };
  316. bool modified = context()->ProcessReachableCallTree(pfn);
  317. // If modified, make sure module has Float16 capability
  318. if (modified) context()->AddCapability(SpvCapabilityFloat16);
  319. // Remove all RelaxedPrecision decorations from instructions and globals
  320. for (auto c_id : relaxed_ids_set_) {
  321. modified |= RemoveRelaxedDecoration(c_id);
  322. }
  323. for (auto& val : get_module()->types_values()) {
  324. uint32_t v_id = val.result_id();
  325. if (v_id != 0) {
  326. modified |= RemoveRelaxedDecoration(v_id);
  327. }
  328. }
  329. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  330. }
  331. Pass::Status ConvertToHalfPass::Process() {
  332. Initialize();
  333. return ProcessImpl();
  334. }
  335. void ConvertToHalfPass::Initialize() {
  336. target_ops_core_ = {
  337. SpvOpVectorExtractDynamic,
  338. SpvOpVectorInsertDynamic,
  339. SpvOpVectorShuffle,
  340. SpvOpCompositeConstruct,
  341. SpvOpCompositeInsert,
  342. SpvOpCompositeExtract,
  343. SpvOpCopyObject,
  344. SpvOpTranspose,
  345. SpvOpConvertSToF,
  346. SpvOpConvertUToF,
  347. // SpvOpFConvert,
  348. // SpvOpQuantizeToF16,
  349. SpvOpFNegate,
  350. SpvOpFAdd,
  351. SpvOpFSub,
  352. SpvOpFMul,
  353. SpvOpFDiv,
  354. SpvOpFMod,
  355. SpvOpVectorTimesScalar,
  356. SpvOpMatrixTimesScalar,
  357. SpvOpVectorTimesMatrix,
  358. SpvOpMatrixTimesVector,
  359. SpvOpMatrixTimesMatrix,
  360. SpvOpOuterProduct,
  361. SpvOpDot,
  362. SpvOpSelect,
  363. SpvOpFOrdEqual,
  364. SpvOpFUnordEqual,
  365. SpvOpFOrdNotEqual,
  366. SpvOpFUnordNotEqual,
  367. SpvOpFOrdLessThan,
  368. SpvOpFUnordLessThan,
  369. SpvOpFOrdGreaterThan,
  370. SpvOpFUnordGreaterThan,
  371. SpvOpFOrdLessThanEqual,
  372. SpvOpFUnordLessThanEqual,
  373. SpvOpFOrdGreaterThanEqual,
  374. SpvOpFUnordGreaterThanEqual,
  375. };
  376. target_ops_450_ = {
  377. GLSLstd450Round, GLSLstd450RoundEven, GLSLstd450Trunc, GLSLstd450FAbs,
  378. GLSLstd450FSign, GLSLstd450Floor, GLSLstd450Ceil, GLSLstd450Fract,
  379. GLSLstd450Radians, GLSLstd450Degrees, GLSLstd450Sin, GLSLstd450Cos,
  380. GLSLstd450Tan, GLSLstd450Asin, GLSLstd450Acos, GLSLstd450Atan,
  381. GLSLstd450Sinh, GLSLstd450Cosh, GLSLstd450Tanh, GLSLstd450Asinh,
  382. GLSLstd450Acosh, GLSLstd450Atanh, GLSLstd450Atan2, GLSLstd450Pow,
  383. GLSLstd450Exp, GLSLstd450Log, GLSLstd450Exp2, GLSLstd450Log2,
  384. GLSLstd450Sqrt, GLSLstd450InverseSqrt, GLSLstd450Determinant,
  385. GLSLstd450MatrixInverse,
  386. // TODO(greg-lunarg): GLSLstd450ModfStruct,
  387. GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp, GLSLstd450FMix,
  388. GLSLstd450Step, GLSLstd450SmoothStep, GLSLstd450Fma,
  389. // TODO(greg-lunarg): GLSLstd450FrexpStruct,
  390. GLSLstd450Ldexp, GLSLstd450Length, GLSLstd450Distance, GLSLstd450Cross,
  391. GLSLstd450Normalize, GLSLstd450FaceForward, GLSLstd450Reflect,
  392. GLSLstd450Refract, GLSLstd450NMin, GLSLstd450NMax, GLSLstd450NClamp};
  393. image_ops_ = {SpvOpImageSampleImplicitLod,
  394. SpvOpImageSampleExplicitLod,
  395. SpvOpImageSampleDrefImplicitLod,
  396. SpvOpImageSampleDrefExplicitLod,
  397. SpvOpImageSampleProjImplicitLod,
  398. SpvOpImageSampleProjExplicitLod,
  399. SpvOpImageSampleProjDrefImplicitLod,
  400. SpvOpImageSampleProjDrefExplicitLod,
  401. SpvOpImageFetch,
  402. SpvOpImageGather,
  403. SpvOpImageDrefGather,
  404. SpvOpImageRead,
  405. SpvOpImageSparseSampleImplicitLod,
  406. SpvOpImageSparseSampleExplicitLod,
  407. SpvOpImageSparseSampleDrefImplicitLod,
  408. SpvOpImageSparseSampleDrefExplicitLod,
  409. SpvOpImageSparseSampleProjImplicitLod,
  410. SpvOpImageSparseSampleProjExplicitLod,
  411. SpvOpImageSparseSampleProjDrefImplicitLod,
  412. SpvOpImageSparseSampleProjDrefExplicitLod,
  413. SpvOpImageSparseFetch,
  414. SpvOpImageSparseGather,
  415. SpvOpImageSparseDrefGather,
  416. SpvOpImageSparseTexelsResident,
  417. SpvOpImageSparseRead};
  418. dref_image_ops_ = {
  419. SpvOpImageSampleDrefImplicitLod,
  420. SpvOpImageSampleDrefExplicitLod,
  421. SpvOpImageSampleProjDrefImplicitLod,
  422. SpvOpImageSampleProjDrefExplicitLod,
  423. SpvOpImageDrefGather,
  424. SpvOpImageSparseSampleDrefImplicitLod,
  425. SpvOpImageSparseSampleDrefExplicitLod,
  426. SpvOpImageSparseSampleProjDrefImplicitLod,
  427. SpvOpImageSparseSampleProjDrefExplicitLod,
  428. SpvOpImageSparseDrefGather,
  429. };
  430. closure_ops_ = {
  431. SpvOpVectorExtractDynamic,
  432. SpvOpVectorInsertDynamic,
  433. SpvOpVectorShuffle,
  434. SpvOpCompositeConstruct,
  435. SpvOpCompositeInsert,
  436. SpvOpCompositeExtract,
  437. SpvOpCopyObject,
  438. SpvOpTranspose,
  439. SpvOpPhi,
  440. };
  441. relaxed_ids_set_.clear();
  442. converted_ids_.clear();
  443. }
  444. } // namespace opt
  445. } // namespace spvtools