convert_to_half_pass.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. // Copyright (c) 2019 The Khronos Group Inc.
  2. // Copyright (c) 2019 Valve Corporation
  3. // Copyright (c) 2019 LunarG Inc.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "convert_to_half_pass.h"
  17. #include "source/opt/ir_builder.h"
  18. namespace {
  19. // Indices of operands in SPIR-V instructions
  20. static const int kImageSampleDrefIdInIdx = 2;
  21. } // anonymous namespace
  22. namespace spvtools {
  23. namespace opt {
  24. bool ConvertToHalfPass::IsArithmetic(Instruction* inst) {
  25. return target_ops_core_.count(inst->opcode()) != 0 ||
  26. (inst->opcode() == SpvOpExtInst &&
  27. inst->GetSingleWordInOperand(0) ==
  28. context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
  29. target_ops_450_.count(inst->GetSingleWordInOperand(1)) != 0);
  30. }
  31. bool ConvertToHalfPass::IsFloat(Instruction* inst, uint32_t width) {
  32. uint32_t ty_id = inst->type_id();
  33. if (ty_id == 0) return false;
  34. return Pass::IsFloat(ty_id, width);
  35. }
  36. bool ConvertToHalfPass::IsRelaxed(Instruction* inst) {
  37. uint32_t r_id = inst->result_id();
  38. for (auto r_inst : get_decoration_mgr()->GetDecorationsFor(r_id, false))
  39. if (r_inst->opcode() == SpvOpDecorate &&
  40. r_inst->GetSingleWordInOperand(1) == SpvDecorationRelaxedPrecision)
  41. return true;
  42. return false;
  43. }
  44. analysis::Type* ConvertToHalfPass::FloatScalarType(uint32_t width) {
  45. analysis::Float float_ty(width);
  46. return context()->get_type_mgr()->GetRegisteredType(&float_ty);
  47. }
  48. analysis::Type* ConvertToHalfPass::FloatVectorType(uint32_t v_len,
  49. uint32_t width) {
  50. analysis::Type* reg_float_ty = FloatScalarType(width);
  51. analysis::Vector vec_ty(reg_float_ty, v_len);
  52. return context()->get_type_mgr()->GetRegisteredType(&vec_ty);
  53. }
  54. analysis::Type* ConvertToHalfPass::FloatMatrixType(uint32_t v_cnt,
  55. uint32_t vty_id,
  56. uint32_t width) {
  57. Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
  58. uint32_t v_len = vty_inst->GetSingleWordInOperand(1);
  59. analysis::Type* reg_vec_ty = FloatVectorType(v_len, width);
  60. analysis::Matrix mat_ty(reg_vec_ty, v_cnt);
  61. return context()->get_type_mgr()->GetRegisteredType(&mat_ty);
  62. }
  63. uint32_t ConvertToHalfPass::EquivFloatTypeId(uint32_t ty_id, uint32_t width) {
  64. analysis::Type* reg_equiv_ty;
  65. Instruction* ty_inst = get_def_use_mgr()->GetDef(ty_id);
  66. if (ty_inst->opcode() == SpvOpTypeMatrix)
  67. reg_equiv_ty = FloatMatrixType(ty_inst->GetSingleWordInOperand(1),
  68. ty_inst->GetSingleWordInOperand(0), width);
  69. else if (ty_inst->opcode() == SpvOpTypeVector)
  70. reg_equiv_ty = FloatVectorType(ty_inst->GetSingleWordInOperand(1), width);
  71. else // SpvOpTypeFloat
  72. reg_equiv_ty = FloatScalarType(width);
  73. return context()->get_type_mgr()->GetTypeInstruction(reg_equiv_ty);
  74. }
  75. void ConvertToHalfPass::GenConvert(uint32_t* val_idp, uint32_t width,
  76. InstructionBuilder* builder) {
  77. Instruction* val_inst = get_def_use_mgr()->GetDef(*val_idp);
  78. uint32_t ty_id = val_inst->type_id();
  79. uint32_t nty_id = EquivFloatTypeId(ty_id, width);
  80. if (nty_id == ty_id) return;
  81. Instruction* cvt_inst;
  82. if (val_inst->opcode() == SpvOpUndef)
  83. cvt_inst = builder->AddNullaryOp(nty_id, SpvOpUndef);
  84. else
  85. cvt_inst = builder->AddUnaryOp(nty_id, SpvOpFConvert, *val_idp);
  86. *val_idp = cvt_inst->result_id();
  87. }
  88. bool ConvertToHalfPass::MatConvertCleanup(Instruction* inst) {
  89. if (inst->opcode() != SpvOpFConvert) return false;
  90. uint32_t mty_id = inst->type_id();
  91. Instruction* mty_inst = get_def_use_mgr()->GetDef(mty_id);
  92. if (mty_inst->opcode() != SpvOpTypeMatrix) return false;
  93. uint32_t vty_id = mty_inst->GetSingleWordInOperand(0);
  94. uint32_t v_cnt = mty_inst->GetSingleWordInOperand(1);
  95. Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
  96. uint32_t cty_id = vty_inst->GetSingleWordInOperand(0);
  97. Instruction* cty_inst = get_def_use_mgr()->GetDef(cty_id);
  98. InstructionBuilder builder(
  99. context(), inst,
  100. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  101. // Convert each component vector, combine them with OpCompositeConstruct
  102. // and replace original instruction.
  103. uint32_t orig_width = (cty_inst->GetSingleWordInOperand(0) == 16) ? 32 : 16;
  104. uint32_t orig_mat_id = inst->GetSingleWordInOperand(0);
  105. uint32_t orig_vty_id = EquivFloatTypeId(vty_id, orig_width);
  106. std::vector<Operand> opnds = {};
  107. for (uint32_t vidx = 0; vidx < v_cnt; ++vidx) {
  108. Instruction* ext_inst = builder.AddIdLiteralOp(
  109. orig_vty_id, SpvOpCompositeExtract, orig_mat_id, vidx);
  110. Instruction* cvt_inst =
  111. builder.AddUnaryOp(vty_id, SpvOpFConvert, ext_inst->result_id());
  112. opnds.push_back({SPV_OPERAND_TYPE_ID, {cvt_inst->result_id()}});
  113. }
  114. uint32_t mat_id = TakeNextId();
  115. std::unique_ptr<Instruction> mat_inst(new Instruction(
  116. context(), SpvOpCompositeConstruct, mty_id, mat_id, opnds));
  117. (void)builder.AddInstruction(std::move(mat_inst));
  118. context()->ReplaceAllUsesWith(inst->result_id(), mat_id);
  119. // Turn original instruction into copy so it is valid.
  120. inst->SetOpcode(SpvOpCopyObject);
  121. inst->SetResultType(EquivFloatTypeId(mty_id, orig_width));
  122. get_def_use_mgr()->AnalyzeInstUse(inst);
  123. return true;
  124. }
  125. void ConvertToHalfPass::RemoveRelaxedDecoration(uint32_t id) {
  126. context()->get_decoration_mgr()->RemoveDecorationsFrom(
  127. id, [](const Instruction& dec) {
  128. if (dec.opcode() == SpvOpDecorate &&
  129. dec.GetSingleWordInOperand(1u) == SpvDecorationRelaxedPrecision)
  130. return true;
  131. else
  132. return false;
  133. });
  134. }
  135. bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
  136. bool modified = false;
  137. // Convert all float32 based operands to float16 equivalent and change
  138. // instruction type to float16 equivalent.
  139. InstructionBuilder builder(
  140. context(), inst,
  141. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  142. inst->ForEachInId([&builder, &modified, this](uint32_t* idp) {
  143. Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
  144. if (!IsFloat(op_inst, 32)) return;
  145. GenConvert(idp, 16, &builder);
  146. modified = true;
  147. });
  148. if (IsFloat(inst, 32)) {
  149. inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
  150. modified = true;
  151. }
  152. if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
  153. return modified;
  154. }
  155. bool ConvertToHalfPass::ProcessPhi(Instruction* inst) {
  156. // Skip if not float32
  157. if (!IsFloat(inst, 32)) return false;
  158. // Skip if no relaxed operands.
  159. bool relaxed_found = false;
  160. uint32_t ocnt = 0;
  161. inst->ForEachInId([&ocnt, &relaxed_found, this](uint32_t* idp) {
  162. if (ocnt % 2 == 0) {
  163. Instruction* val_inst = get_def_use_mgr()->GetDef(*idp);
  164. if (IsRelaxed(val_inst)) relaxed_found = true;
  165. }
  166. ++ocnt;
  167. });
  168. if (!relaxed_found) return false;
  169. // Add float16 converts of any float32 operands and change type
  170. // of phi to float16 equivalent. Operand converts need to be added to
  171. // preceeding blocks.
  172. ocnt = 0;
  173. uint32_t* prev_idp;
  174. inst->ForEachInId([&ocnt, &prev_idp, this](uint32_t* idp) {
  175. if (ocnt % 2 == 0) {
  176. prev_idp = idp;
  177. } else {
  178. Instruction* val_inst = get_def_use_mgr()->GetDef(*prev_idp);
  179. if (IsFloat(val_inst, 32)) {
  180. BasicBlock* bp = context()->get_instr_block(*idp);
  181. auto insert_before = bp->tail();
  182. if (insert_before != bp->begin()) {
  183. --insert_before;
  184. if (insert_before->opcode() != SpvOpSelectionMerge &&
  185. insert_before->opcode() != SpvOpLoopMerge)
  186. ++insert_before;
  187. }
  188. InstructionBuilder builder(context(), &*insert_before,
  189. IRContext::kAnalysisDefUse |
  190. IRContext::kAnalysisInstrToBlockMapping);
  191. GenConvert(prev_idp, 16, &builder);
  192. }
  193. }
  194. ++ocnt;
  195. });
  196. inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
  197. get_def_use_mgr()->AnalyzeInstUse(inst);
  198. return true;
  199. }
  200. bool ConvertToHalfPass::ProcessExtract(Instruction* inst) {
  201. bool modified = false;
  202. uint32_t comp_id = inst->GetSingleWordInOperand(0);
  203. Instruction* comp_inst = get_def_use_mgr()->GetDef(comp_id);
  204. // If extract is relaxed float32 based type and the composite is a relaxed
  205. // float32 based type, convert it to float16 equivalent. This is slightly
  206. // aggressive and pushes any likely conversion to apply to the whole
  207. // composite rather than apply to each extracted component later. This
  208. // can be a win if the platform can convert the entire composite in the same
  209. // time as one component. It risks converting components that may not be
  210. // used, although empirical data on a large set of real-world shaders seems
  211. // to suggest this is not common and the composite convert is the best choice.
  212. if (IsFloat(inst, 32) && IsRelaxed(inst) && IsFloat(comp_inst, 32) &&
  213. IsRelaxed(comp_inst)) {
  214. InstructionBuilder builder(
  215. context(), inst,
  216. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  217. GenConvert(&comp_id, 16, &builder);
  218. inst->SetInOperand(0, {comp_id});
  219. comp_inst = get_def_use_mgr()->GetDef(comp_id);
  220. modified = true;
  221. }
  222. // If the composite is a float16 based type, make sure the type of the
  223. // extract agrees.
  224. if (IsFloat(comp_inst, 16) && !IsFloat(inst, 16)) {
  225. inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
  226. modified = true;
  227. }
  228. if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
  229. return modified;
  230. }
  231. bool ConvertToHalfPass::ProcessConvert(Instruction* inst) {
  232. // If float32 and relaxed, change to float16 convert
  233. if (IsFloat(inst, 32) && IsRelaxed(inst)) {
  234. inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
  235. get_def_use_mgr()->AnalyzeInstUse(inst);
  236. }
  237. // If operand and result types are the same, replace result with operand
  238. // and change convert to copy to keep validator happy; DCE will clean it up
  239. uint32_t val_id = inst->GetSingleWordInOperand(0);
  240. Instruction* val_inst = get_def_use_mgr()->GetDef(val_id);
  241. if (inst->type_id() == val_inst->type_id()) {
  242. context()->ReplaceAllUsesWith(inst->result_id(), val_id);
  243. inst->SetOpcode(SpvOpCopyObject);
  244. }
  245. return true; // modified
  246. }
  247. bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
  248. bool modified = false;
  249. // If image reference, only need to convert dref args back to float32
  250. if (dref_image_ops_.count(inst->opcode()) != 0) {
  251. uint32_t dref_id = inst->GetSingleWordInOperand(kImageSampleDrefIdInIdx);
  252. Instruction* dref_inst = get_def_use_mgr()->GetDef(dref_id);
  253. if (IsFloat(dref_inst, 16) && IsRelaxed(dref_inst)) {
  254. InstructionBuilder builder(
  255. context(), inst,
  256. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  257. GenConvert(&dref_id, 32, &builder);
  258. inst->SetInOperand(kImageSampleDrefIdInIdx, {dref_id});
  259. get_def_use_mgr()->AnalyzeInstUse(inst);
  260. modified = true;
  261. }
  262. }
  263. return modified;
  264. }
  265. bool ConvertToHalfPass::ProcessDefault(Instruction* inst) {
  266. bool modified = false;
  267. // If non-relaxed instruction has changed operands, need to convert
  268. // them back to float32
  269. InstructionBuilder builder(
  270. context(), inst,
  271. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  272. inst->ForEachInId([&builder, &modified, this](uint32_t* idp) {
  273. Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
  274. if (!IsFloat(op_inst, 16)) return;
  275. if (!IsRelaxed(op_inst)) return;
  276. uint32_t old_id = *idp;
  277. GenConvert(idp, 32, &builder);
  278. if (*idp != old_id) modified = true;
  279. });
  280. if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
  281. return modified;
  282. }
  283. bool ConvertToHalfPass::GenHalfCode(Instruction* inst) {
  284. bool modified = false;
  285. // Remember id for later deletion of RelaxedPrecision decoration
  286. bool inst_relaxed = IsRelaxed(inst);
  287. if (inst_relaxed) relaxed_ids_.push_back(inst->result_id());
  288. if (IsArithmetic(inst) && inst_relaxed)
  289. modified = GenHalfArith(inst);
  290. else if (inst->opcode() == SpvOpPhi)
  291. modified = ProcessPhi(inst);
  292. else if (inst->opcode() == SpvOpCompositeExtract)
  293. modified = ProcessExtract(inst);
  294. else if (inst->opcode() == SpvOpFConvert)
  295. modified = ProcessConvert(inst);
  296. else if (image_ops_.count(inst->opcode()) != 0)
  297. modified = ProcessImageRef(inst);
  298. else
  299. modified = ProcessDefault(inst);
  300. return modified;
  301. }
  302. bool ConvertToHalfPass::ProcessFunction(Function* func) {
  303. bool modified = false;
  304. cfg()->ForEachBlockInReversePostOrder(
  305. func->entry().get(), [&modified, this](BasicBlock* bb) {
  306. for (auto ii = bb->begin(); ii != bb->end(); ++ii)
  307. modified |= GenHalfCode(&*ii);
  308. });
  309. cfg()->ForEachBlockInReversePostOrder(
  310. func->entry().get(), [&modified, this](BasicBlock* bb) {
  311. for (auto ii = bb->begin(); ii != bb->end(); ++ii)
  312. modified |= MatConvertCleanup(&*ii);
  313. });
  314. return modified;
  315. }
  316. Pass::Status ConvertToHalfPass::ProcessImpl() {
  317. Pass::ProcessFunction pfn = [this](Function* fp) {
  318. return ProcessFunction(fp);
  319. };
  320. bool modified = context()->ProcessEntryPointCallTree(pfn);
  321. // If modified, make sure module has Float16 capability
  322. if (modified) context()->AddCapability(SpvCapabilityFloat16);
  323. // Remove all RelaxedPrecision decorations from instructions and globals
  324. for (auto c_id : relaxed_ids_) RemoveRelaxedDecoration(c_id);
  325. for (auto& val : get_module()->types_values()) {
  326. uint32_t v_id = val.result_id();
  327. if (v_id != 0) RemoveRelaxedDecoration(v_id);
  328. }
  329. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  330. }
  331. Pass::Status ConvertToHalfPass::Process() {
  332. Initialize();
  333. return ProcessImpl();
  334. }
  335. void ConvertToHalfPass::Initialize() {
  336. target_ops_core_ = {
  337. SpvOpVectorExtractDynamic,
  338. SpvOpVectorInsertDynamic,
  339. SpvOpVectorShuffle,
  340. SpvOpCompositeConstruct,
  341. SpvOpCompositeInsert,
  342. SpvOpCopyObject,
  343. SpvOpTranspose,
  344. SpvOpConvertSToF,
  345. SpvOpConvertUToF,
  346. // SpvOpFConvert,
  347. // SpvOpQuantizeToF16,
  348. SpvOpFNegate,
  349. SpvOpFAdd,
  350. SpvOpFSub,
  351. SpvOpFMul,
  352. SpvOpFDiv,
  353. SpvOpFMod,
  354. SpvOpVectorTimesScalar,
  355. SpvOpMatrixTimesScalar,
  356. SpvOpVectorTimesMatrix,
  357. SpvOpMatrixTimesVector,
  358. SpvOpMatrixTimesMatrix,
  359. SpvOpOuterProduct,
  360. SpvOpDot,
  361. SpvOpSelect,
  362. SpvOpFOrdEqual,
  363. SpvOpFUnordEqual,
  364. SpvOpFOrdNotEqual,
  365. SpvOpFUnordNotEqual,
  366. SpvOpFOrdLessThan,
  367. SpvOpFUnordLessThan,
  368. SpvOpFOrdGreaterThan,
  369. SpvOpFUnordGreaterThan,
  370. SpvOpFOrdLessThanEqual,
  371. SpvOpFUnordLessThanEqual,
  372. SpvOpFOrdGreaterThanEqual,
  373. SpvOpFUnordGreaterThanEqual,
  374. };
  375. target_ops_450_ = {
  376. GLSLstd450Round, GLSLstd450RoundEven, GLSLstd450Trunc, GLSLstd450FAbs,
  377. GLSLstd450FSign, GLSLstd450Floor, GLSLstd450Ceil, GLSLstd450Fract,
  378. GLSLstd450Radians, GLSLstd450Degrees, GLSLstd450Sin, GLSLstd450Cos,
  379. GLSLstd450Tan, GLSLstd450Asin, GLSLstd450Acos, GLSLstd450Atan,
  380. GLSLstd450Sinh, GLSLstd450Cosh, GLSLstd450Tanh, GLSLstd450Asinh,
  381. GLSLstd450Acosh, GLSLstd450Atanh, GLSLstd450Atan2, GLSLstd450Pow,
  382. GLSLstd450Exp, GLSLstd450Log, GLSLstd450Exp2, GLSLstd450Log2,
  383. GLSLstd450Sqrt, GLSLstd450InverseSqrt, GLSLstd450Determinant,
  384. GLSLstd450MatrixInverse,
  385. // TODO(greg-lunarg): GLSLstd450ModfStruct,
  386. GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp, GLSLstd450FMix,
  387. GLSLstd450Step, GLSLstd450SmoothStep, GLSLstd450Fma,
  388. // TODO(greg-lunarg): GLSLstd450FrexpStruct,
  389. GLSLstd450Ldexp, GLSLstd450Length, GLSLstd450Distance, GLSLstd450Cross,
  390. GLSLstd450Normalize, GLSLstd450FaceForward, GLSLstd450Reflect,
  391. GLSLstd450Refract, GLSLstd450NMin, GLSLstd450NMax, GLSLstd450NClamp};
  392. image_ops_ = {SpvOpImageSampleImplicitLod,
  393. SpvOpImageSampleExplicitLod,
  394. SpvOpImageSampleDrefImplicitLod,
  395. SpvOpImageSampleDrefExplicitLod,
  396. SpvOpImageSampleProjImplicitLod,
  397. SpvOpImageSampleProjExplicitLod,
  398. SpvOpImageSampleProjDrefImplicitLod,
  399. SpvOpImageSampleProjDrefExplicitLod,
  400. SpvOpImageFetch,
  401. SpvOpImageGather,
  402. SpvOpImageDrefGather,
  403. SpvOpImageRead,
  404. SpvOpImageSparseSampleImplicitLod,
  405. SpvOpImageSparseSampleExplicitLod,
  406. SpvOpImageSparseSampleDrefImplicitLod,
  407. SpvOpImageSparseSampleDrefExplicitLod,
  408. SpvOpImageSparseSampleProjImplicitLod,
  409. SpvOpImageSparseSampleProjExplicitLod,
  410. SpvOpImageSparseSampleProjDrefImplicitLod,
  411. SpvOpImageSparseSampleProjDrefExplicitLod,
  412. SpvOpImageSparseFetch,
  413. SpvOpImageSparseGather,
  414. SpvOpImageSparseDrefGather,
  415. SpvOpImageSparseTexelsResident,
  416. SpvOpImageSparseRead};
  417. dref_image_ops_ = {
  418. SpvOpImageSampleDrefImplicitLod,
  419. SpvOpImageSampleDrefExplicitLod,
  420. SpvOpImageSampleProjDrefImplicitLod,
  421. SpvOpImageSampleProjDrefExplicitLod,
  422. SpvOpImageDrefGather,
  423. SpvOpImageSparseSampleDrefImplicitLod,
  424. SpvOpImageSparseSampleDrefExplicitLod,
  425. SpvOpImageSparseSampleProjDrefImplicitLod,
  426. SpvOpImageSparseSampleProjDrefExplicitLod,
  427. SpvOpImageSparseDrefGather,
  428. };
  429. relaxed_ids_.clear();
  430. }
  431. } // namespace opt
  432. } // namespace spvtools