소스 검색

[spirv] Add literal type hint for OpDot.

Ehsan Nasiri 6 년 전
부모
커밋
232dbd9981
3개의 변경된 파일56개의 추가작업 그리고 38개의 파일을 삭제
  1. 46 36
      tools/clang/lib/SPIRV/LiteralTypeVisitor.cpp
  2. 5 2
      tools/clang/lib/SPIRV/LiteralTypeVisitor.h
  3. 5 0
      tools/clang/test/CodeGenSPIRV/intrinsics.dot.hlsl

+ 46 - 36
tools/clang/lib/SPIRV/LiteralTypeVisitor.cpp

@@ -64,8 +64,8 @@ bool LiteralTypeVisitor::canDeduceTypeFromLitType(QualType litType,
   return false;
   return false;
 }
 }
 
 
-void LiteralTypeVisitor::updateTypeForInstruction(SpirvInstruction *inst,
-                                                  QualType newType) {
+void LiteralTypeVisitor::tryToUpdateInstLitType(SpirvInstruction *inst,
+                                                QualType newType) {
   if (!inst)
   if (!inst)
     return;
     return;
 
 
@@ -88,14 +88,14 @@ bool LiteralTypeVisitor::visitInstruction(SpirvInstruction *instr) {
 }
 }
 
 
 bool LiteralTypeVisitor::visit(SpirvVariable *var) {
 bool LiteralTypeVisitor::visit(SpirvVariable *var) {
-  updateTypeForInstruction(var->getInitializer(), var->getAstResultType());
+  tryToUpdateInstLitType(var->getInitializer(), var->getAstResultType());
   return true;
   return true;
 }
 }
 
 
 bool LiteralTypeVisitor::visit(SpirvAtomic *inst) {
 bool LiteralTypeVisitor::visit(SpirvAtomic *inst) {
   const auto resultType = inst->getAstResultType();
   const auto resultType = inst->getAstResultType();
-  updateTypeForInstruction(inst->getValue(), resultType);
-  updateTypeForInstruction(inst->getComparator(), resultType);
+  tryToUpdateInstLitType(inst->getValue(), resultType);
+  tryToUpdateInstLitType(inst->getComparator(), resultType);
   return true;
   return true;
 }
 }
 
 
@@ -127,13 +127,13 @@ bool LiteralTypeVisitor::visit(SpirvUnaryOp *inst) {
           astContext, resultType, spvOptions.enable16BitTypes);
           astContext, resultType, spvOptions.enable16BitTypes);
       const QualType newType =
       const QualType newType =
           getTypeWithCustomBitwidth(astContext, argType, resultTypeBitwidth);
           getTypeWithCustomBitwidth(astContext, argType, resultTypeBitwidth);
-      updateTypeForInstruction(arg, newType);
+      tryToUpdateInstLitType(arg, newType);
       return true;
       return true;
     }
     }
   }
   }
 
 
   // In all other cases, try to use the result type as a hint.
   // In all other cases, try to use the result type as a hint.
-  updateTypeForInstruction(arg, resultType);
+  tryToUpdateInstLitType(arg, resultType);
   return true;
   return true;
 }
 }
 
 
@@ -148,9 +148,9 @@ bool LiteralTypeVisitor::visit(SpirvBinaryOp *inst) {
       op == spv::Op::OpShiftRightArithmetic ||
       op == spv::Op::OpShiftRightArithmetic ||
       op == spv::Op::OpShiftLeftLogical) {
       op == spv::Op::OpShiftLeftLogical) {
     // Base (arg1) should have the same type as result type
     // Base (arg1) should have the same type as result type
-    updateTypeForInstruction(inst->getOperand1(), resultType);
+    tryToUpdateInstLitType(inst->getOperand1(), resultType);
     // The shitf amount (arg2) cannot be a 64-bit type for a 32-bit base!
     // The shitf amount (arg2) cannot be a 64-bit type for a 32-bit base!
-    updateTypeForInstruction(inst->getOperand2(), resultType);
+    tryToUpdateInstLitType(inst->getOperand2(), resultType);
     return true;
     return true;
   }
   }
 
 
@@ -182,7 +182,7 @@ bool LiteralTypeVisitor::visit(SpirvBinaryOp *inst) {
             astContext, operand2Type, spvOptions.enable16BitTypes);
             astContext, operand2Type, spvOptions.enable16BitTypes);
         const QualType newType = getTypeWithCustomBitwidth(
         const QualType newType = getTypeWithCustomBitwidth(
             astContext, operand1Type, operand2Bitwidth);
             astContext, operand1Type, operand2Bitwidth);
-        updateTypeForInstruction(operand1, newType);
+        tryToUpdateInstLitType(operand1, newType);
         return true;
         return true;
       }
       }
       if (isLitOp2 && !isLitOp1) {
       if (isLitOp2 && !isLitOp1) {
@@ -190,34 +190,44 @@ bool LiteralTypeVisitor::visit(SpirvBinaryOp *inst) {
             astContext, operand1Type, spvOptions.enable16BitTypes);
             astContext, operand1Type, spvOptions.enable16BitTypes);
         const QualType newType = getTypeWithCustomBitwidth(
         const QualType newType = getTypeWithCustomBitwidth(
             astContext, operand2Type, operand1Bitwidth);
             astContext, operand2Type, operand1Bitwidth);
-        updateTypeForInstruction(operand2, newType);
+        tryToUpdateInstLitType(operand2, newType);
         return true;
         return true;
       }
       }
     }
     }
   }
   }
 
 
-  updateTypeForInstruction(operand1, resultType);
-  updateTypeForInstruction(operand2, resultType);
+  // The result type of dot product is scalar but operands should be vector of
+  // the same type.
+  if (op == spv::Op::OpDot) {
+    tryToUpdateInstLitType(inst->getOperand1(),
+                           inst->getOperand2()->getAstResultType());
+    tryToUpdateInstLitType(inst->getOperand2(),
+                           inst->getOperand1()->getAstResultType());
+    return true;
+  }
+
+  tryToUpdateInstLitType(operand1, resultType);
+  tryToUpdateInstLitType(operand2, resultType);
   return true;
   return true;
 }
 }
 
 
 bool LiteralTypeVisitor::visit(SpirvBitFieldInsert *inst) {
 bool LiteralTypeVisitor::visit(SpirvBitFieldInsert *inst) {
   const auto resultType = inst->getAstResultType();
   const auto resultType = inst->getAstResultType();
-  updateTypeForInstruction(inst->getBase(), resultType);
-  updateTypeForInstruction(inst->getInsert(), resultType);
+  tryToUpdateInstLitType(inst->getBase(), resultType);
+  tryToUpdateInstLitType(inst->getInsert(), resultType);
   return true;
   return true;
 }
 }
 
 
 bool LiteralTypeVisitor::visit(SpirvBitFieldExtract *inst) {
 bool LiteralTypeVisitor::visit(SpirvBitFieldExtract *inst) {
   const auto resultType = inst->getAstResultType();
   const auto resultType = inst->getAstResultType();
-  updateTypeForInstruction(inst->getBase(), resultType);
+  tryToUpdateInstLitType(inst->getBase(), resultType);
   return true;
   return true;
 }
 }
 
 
 bool LiteralTypeVisitor::visit(SpirvSelect *inst) {
 bool LiteralTypeVisitor::visit(SpirvSelect *inst) {
   const auto resultType = inst->getAstResultType();
   const auto resultType = inst->getAstResultType();
-  updateTypeForInstruction(inst->getTrueObject(), resultType);
-  updateTypeForInstruction(inst->getFalseObject(), resultType);
+  tryToUpdateInstLitType(inst->getTrueObject(), resultType);
+  tryToUpdateInstLitType(inst->getFalseObject(), resultType);
   return true;
   return true;
 }
 }
 
 
@@ -237,11 +247,11 @@ bool LiteralTypeVisitor::visit(SpirvVectorShuffle *inst) {
     (void)isVectorType(vec1->getAstResultType(), &vec1ElemType, &vec1ElemCount);
     (void)isVectorType(vec1->getAstResultType(), &vec1ElemType, &vec1ElemCount);
     (void)isVectorType(vec2->getAstResultType(), &vec2ElemType, &vec2ElemCount);
     (void)isVectorType(vec2->getAstResultType(), &vec2ElemType, &vec2ElemCount);
     if (isLitTypeOrVecOfLitType(vec1ElemType)) {
     if (isLitTypeOrVecOfLitType(vec1ElemType)) {
-      updateTypeForInstruction(
+      tryToUpdateInstLitType(
           vec1, astContext.getExtVectorType(resultElemType, vec1ElemCount));
           vec1, astContext.getExtVectorType(resultElemType, vec1ElemCount));
     }
     }
     if (isLitTypeOrVecOfLitType(vec2ElemType)) {
     if (isLitTypeOrVecOfLitType(vec2ElemType)) {
-      updateTypeForInstruction(
+      tryToUpdateInstLitType(
           vec2, astContext.getExtVectorType(resultElemType, vec2ElemCount));
           vec2, astContext.getExtVectorType(resultElemType, vec2ElemCount));
     }
     }
   }
   }
@@ -251,14 +261,14 @@ bool LiteralTypeVisitor::visit(SpirvVectorShuffle *inst) {
 bool LiteralTypeVisitor::visit(SpirvNonUniformUnaryOp *inst) {
 bool LiteralTypeVisitor::visit(SpirvNonUniformUnaryOp *inst) {
   // Went through each non-uniform binary operation and made sure the following
   // Went through each non-uniform binary operation and made sure the following
   // does not result in a wrong type deduction.
   // does not result in a wrong type deduction.
-  updateTypeForInstruction(inst->getArg(), inst->getAstResultType());
+  tryToUpdateInstLitType(inst->getArg(), inst->getAstResultType());
   return true;
   return true;
 }
 }
 
 
 bool LiteralTypeVisitor::visit(SpirvNonUniformBinaryOp *inst) {
 bool LiteralTypeVisitor::visit(SpirvNonUniformBinaryOp *inst) {
   // Went through each non-uniform unary operation and made sure the following
   // Went through each non-uniform unary operation and made sure the following
   // does not result in a wrong type deduction.
   // does not result in a wrong type deduction.
-  updateTypeForInstruction(inst->getArg1(), inst->getAstResultType());
+  tryToUpdateInstLitType(inst->getArg1(), inst->getAstResultType());
   return true;
   return true;
 }
 }
 
 
@@ -269,11 +279,11 @@ bool LiteralTypeVisitor::visit(SpirvStore *inst) {
     QualType type = pointer->getAstResultType();
     QualType type = pointer->getAstResultType();
     if (const auto *ptrType = type->getAs<PointerType>())
     if (const auto *ptrType = type->getAs<PointerType>())
       type = ptrType->getPointeeType();
       type = ptrType->getPointeeType();
-    updateTypeForInstruction(object, type);
+    tryToUpdateInstLitType(object, type);
   } else if (pointer->hasResultType()) {
   } else if (pointer->hasResultType()) {
     if (auto *ptrType = dyn_cast<HybridPointerType>(pointer->getResultType())) {
     if (auto *ptrType = dyn_cast<HybridPointerType>(pointer->getResultType())) {
       QualType type = ptrType->getPointeeType();
       QualType type = ptrType->getPointeeType();
-      updateTypeForInstruction(object, type);
+      tryToUpdateInstLitType(object, type);
     }
     }
   }
   }
   return true;
   return true;
@@ -303,7 +313,7 @@ bool LiteralTypeVisitor::visit(SpirvCompositeExtract *inst) {
         astContext, resultType, spvOptions.enable16BitTypes);
         astContext, resultType, spvOptions.enable16BitTypes);
     const QualType newType =
     const QualType newType =
         getTypeWithCustomBitwidth(astContext, baseType, resultTypeBitwidth);
         getTypeWithCustomBitwidth(astContext, baseType, resultTypeBitwidth);
-    updateTypeForInstruction(base, newType);
+    tryToUpdateInstLitType(base, newType);
   }
   }
 
 
   return true;
   return true;
@@ -326,7 +336,7 @@ bool LiteralTypeVisitor::updateTypeForCompositeMembers(
     QualType elemType = {};
     QualType elemType = {};
     if (isVectorType(compositeType, &elemType)) {
     if (isVectorType(compositeType, &elemType)) {
       for (auto *constituent : constituents)
       for (auto *constituent : constituents)
-        updateTypeForInstruction(constituent, elemType);
+        tryToUpdateInstLitType(constituent, elemType);
       return true;
       return true;
     }
     }
   }
   }
@@ -334,7 +344,7 @@ bool LiteralTypeVisitor::updateTypeForCompositeMembers(
   { // Array case
   { // Array case
     if (const auto *arrType = dyn_cast<ConstantArrayType>(compositeType)) {
     if (const auto *arrType = dyn_cast<ConstantArrayType>(compositeType)) {
       for (auto *constituent : constituents)
       for (auto *constituent : constituents)
-        updateTypeForInstruction(constituent, arrType->getElementType());
+        tryToUpdateInstLitType(constituent, arrType->getElementType());
       return true;
       return true;
     }
     }
   }
   }
@@ -347,7 +357,7 @@ bool LiteralTypeVisitor::updateTypeForCompositeMembers(
         uint32_t colSize = 0;
         uint32_t colSize = 0;
         if (isVectorType(constituent->getAstResultType(), nullptr, &colSize)) {
         if (isVectorType(constituent->getAstResultType(), nullptr, &colSize)) {
           QualType newType = astContext.getExtVectorType(elemType, colSize);
           QualType newType = astContext.getExtVectorType(elemType, colSize);
-          updateTypeForInstruction(constituent, newType);
+          tryToUpdateInstLitType(constituent, newType);
         }
         }
       }
       }
       return true;
       return true;
@@ -359,7 +369,7 @@ bool LiteralTypeVisitor::updateTypeForCompositeMembers(
       const auto *decl = structType->getDecl();
       const auto *decl = structType->getDecl();
       size_t i = 0;
       size_t i = 0;
       for (const auto *field : decl->fields()) {
       for (const auto *field : decl->fields()) {
-        updateTypeForInstruction(constituents[i], field->getType());
+        tryToUpdateInstLitType(constituents[i], field->getType());
         ++i;
         ++i;
       }
       }
       return true;
       return true;
@@ -373,7 +383,7 @@ bool LiteralTypeVisitor::visit(SpirvAccessChain *inst) {
   for (auto *index : inst->getIndexes()) {
   for (auto *index : inst->getIndexes()) {
     if (auto *constInt = dyn_cast<SpirvConstantInteger>(index)) {
     if (auto *constInt = dyn_cast<SpirvConstantInteger>(index)) {
       if (!isLiteralLargerThan32Bits(constInt)) {
       if (!isLiteralLargerThan32Bits(constInt)) {
-        updateTypeForInstruction(
+        tryToUpdateInstLitType(
             constInt, constInt->getAstResultType()->isSignedIntegerType()
             constInt, constInt->getAstResultType()->isSignedIntegerType()
                           ? astContext.IntTy
                           ? astContext.IntTy
                           : astContext.UnsignedIntTy);
                           : astContext.UnsignedIntTy);
@@ -390,22 +400,22 @@ bool LiteralTypeVisitor::visit(SpirvExtInst *inst) {
   // OpExtInst %float %glsl_set Pow %float_2 %float_12
   // OpExtInst %float %glsl_set Pow %float_2 %float_12
   const auto resultType = inst->getAstResultType();
   const auto resultType = inst->getAstResultType();
   for (auto *operand : inst->getOperands())
   for (auto *operand : inst->getOperands())
-    updateTypeForInstruction(operand, resultType);
+    tryToUpdateInstLitType(operand, resultType);
   return true;
   return true;
 }
 }
 
 
 bool LiteralTypeVisitor::visit(SpirvReturn *inst) {
 bool LiteralTypeVisitor::visit(SpirvReturn *inst) {
   if (inst->hasReturnValue()) {
   if (inst->hasReturnValue()) {
-    updateTypeForInstruction(inst->getReturnValue(), curFnAstReturnType);
+    tryToUpdateInstLitType(inst->getReturnValue(), curFnAstReturnType);
   }
   }
   return true;
   return true;
 }
 }
 
 
 bool LiteralTypeVisitor::visit(SpirvCompositeInsert *inst) {
 bool LiteralTypeVisitor::visit(SpirvCompositeInsert *inst) {
   const auto resultType = inst->getAstResultType();
   const auto resultType = inst->getAstResultType();
-  updateTypeForInstruction(inst->getComposite(), resultType);
-  updateTypeForInstruction(inst->getObject(),
-                           getElementType(astContext, resultType));
+  tryToUpdateInstLitType(inst->getComposite(), resultType);
+  tryToUpdateInstLitType(inst->getObject(),
+                         getElementType(astContext, resultType));
   return true;
   return true;
 }
 }
 
 
@@ -413,7 +423,7 @@ bool LiteralTypeVisitor::visit(SpirvImageOp *inst) {
   if (inst->isImageWrite() && inst->hasAstResultType()) {
   if (inst->isImageWrite() && inst->hasAstResultType()) {
     const auto sampledType =
     const auto sampledType =
         hlsl::GetHLSLResourceResultType(inst->getAstResultType());
         hlsl::GetHLSLResourceResultType(inst->getAstResultType());
-    updateTypeForInstruction(inst->getTexelToWrite(), sampledType);
+    tryToUpdateInstLitType(inst->getTexelToWrite(), sampledType);
   }
   }
   return true;
   return true;
 }
 }

+ 5 - 2
tools/clang/lib/SPIRV/LiteralTypeVisitor.h

@@ -62,8 +62,11 @@ public:
   bool visitInstruction(SpirvInstruction *instr);
   bool visitInstruction(SpirvInstruction *instr);
 
 
 private:
 private:
-  /// Updates the result type of the given instruction to the new type.
-  void updateTypeForInstruction(SpirvInstruction *, QualType newType);
+  /// If the given instruction's return type is a literal type and the given
+  /// 'newType' is not a literal type, and they are of the same kind (both
+  /// integer or both float), updates the instruction's result type to newType.
+  /// Does nothing otherwise.
+  void tryToUpdateInstLitType(SpirvInstruction *, QualType newType);
 
 
   /// returns true if the given literal type can be deduced to the given
   /// returns true if the given literal type can be deduced to the given
   /// newType. In order for that to be true,
   /// newType. In order for that to be true,

+ 5 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.dot.hlsl

@@ -159,4 +159,9 @@ void main() {
     uint4 uj, uk;
     uint4 uj, uk;
     uint ul;
     uint ul;
     ul = dot(uj, uk);
     ul = dot(uj, uk);
+
+    // CHECK:      OpCompositeConstruct %v3float %float_1 %float_1 %float_1
+    // CHECK-NEXT: OpDot %float
+    float3 f3;
+    float dotProductByLiteral = dot(f3, float3(1.0.xxx));
 }
 }