2
0
Эх сурвалжийг харах

[spirv] Fix handling of float1 in vector scaling and matrix scaling.

Ehsan Nasiri 7 жил өмнө
parent
commit
a79eeb2962

+ 11 - 6
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -5570,11 +5570,13 @@ SpirvEvalInfo
 SPIRVEmitter::tryToGenFloatVectorScale(const BinaryOperator *expr) {
   const QualType type = expr->getType();
   const SourceRange range = expr->getSourceRange();
+  QualType elemType = {};
 
   // We can only translate floatN * float into OpVectorTimesScalar.
-  // So the result type must be floatN.
-  if (!hlsl::IsHLSLVecType(type) ||
-      !hlsl::GetHLSLVecElementType(type)->isFloatingType())
+  // So the result type must be floatN. Note that float1 is not a valid vector
+  // in SPIR-V.
+  if (!(TypeTranslator::isVectorType(type, &elemType) &&
+        elemType->isFloatingType()))
     return 0;
 
   const Expr *lhs = expr->getLHS();
@@ -5627,10 +5629,13 @@ SPIRVEmitter::tryToGenFloatMatrixScale(const BinaryOperator *expr) {
   const QualType type = expr->getType();
   const SourceRange range = expr->getSourceRange();
 
-  // We can only translate floatMxN * float into OpMatrixTimesScalar.
-  // So the result type must be floatMxN.
+  // We translate 'floatMxN * float' into OpMatrixTimesScalar.
+  // We translate 'floatMx1 * float' and 'float1xN * float' using
+  // OpVectorTimesScalar.
+  // So the result type can be floatMxN, floatMx1, or float1xN.
   if (!hlsl::IsHLSLMatType(type) ||
-      !hlsl::GetHLSLMatElementType(type)->isFloatingType())
+      !hlsl::GetHLSLMatElementType(type)->isFloatingType() ||
+      TypeTranslator::is1x1Matrix(type))
     return 0;
 
   const Expr *lhs = expr->getLHS();

+ 1 - 5
tools/clang/test/CodeGenSPIRV/binary-op.arith-assign.mixed.form.hlsl

@@ -50,13 +50,9 @@ void main() {
     i *= s;
 
     // Use OpVectorTimesScalar for float1xN * float
-    // Sadly, the AST is constructed differently for 'float1xN *= float' cases.
-    // So we are not able generate an OpVectorTimesScalar here.
-    // TODO: Minor issue. Fix this later maybe.
 // CHECK-NEXT: [[s6:%\d+]] = OpLoad %float %s
-// CHECK-NEXT: [[cc1:%\d+]] = OpCompositeConstruct %v3float [[s6]] [[s6]] [[s6]]
 // CHECK-NEXT: [[k0:%\d+]] = OpLoad %v3float %k
-// CHECK-NEXT: [[mul10:%\d+]] = OpFMul %v3float [[k0]] [[cc1]]
+// CHECK-NEXT: [[mul10:%\d+]] = OpVectorTimesScalar %v3float [[k0]] [[s6]]
 // CHECK-NEXT: OpStore %k [[mul10]]
     k *= s;