Browse Source

Fix matrix unary and binary operators and add a test. (#1782)

Various matrix operators had implementation issues and would either crash the compiler or produce incorrect results. This introduces a test invoking every unary and binary operator on matrices and fixes the misbehaving ones, with the exception of the increment and decrement operators (to be handled in a different change).
Tristan Labelle 6 years ago
parent
commit
23cf9fe2dd

+ 41 - 43
lib/HLSL/HLMatrixLowerPass.cpp

@@ -670,47 +670,53 @@ Instruction *HLMatrixLowerPass::TrivialMatUnOpToVec(CallInst *CI) {
   HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(CI));
   bool isFloat = ResultTy->getVectorElementType()->isFloatingPointTy();
 
-  auto GetVecConst = [&](Type *Ty, int v) -> Constant * {
-    Constant *val = isFloat ? ConstantFP::get(Ty->getScalarType(), v)
-                            : ConstantInt::get(Ty->getScalarType(), v);
-    std::vector<Constant *> vals(Ty->getVectorNumElements(), val);
-    return ConstantVector::get(vals);
-  };
-
-  Constant *one = GetVecConst(ResultTy, 1);
+  Constant *one = isFloat
+    ? ConstantFP::get(ResultTy->getVectorElementType(), 1)
+    : ConstantInt::get(ResultTy->getVectorElementType(), 1);
+  Constant *oneVec = ConstantVector::getSplat(ResultTy->getVectorNumElements(), one);
 
   Instruction *Result = nullptr;
   switch (opcode) {
+  case HLUnaryOpcode::Plus: {
+    // This is actually a no-op, but the structure of the code here requires
+    // that we create an instruction.
+    Constant *zero = Constant::getNullValue(ResultTy);
+    if (isFloat)
+      Result = BinaryOperator::CreateFAdd(tmp, zero);
+    else
+      Result = BinaryOperator::CreateAdd(tmp, zero);
+  } break;
   case HLUnaryOpcode::Minus: {
-    Constant *zero = GetVecConst(ResultTy, 0);
+    Constant *zero = Constant::getNullValue(ResultTy);
     if (isFloat)
       Result = BinaryOperator::CreateFSub(zero, tmp);
     else
       Result = BinaryOperator::CreateSub(zero, tmp);
   } break;
   case HLUnaryOpcode::LNot: {
-    Constant *zero = GetVecConst(ResultTy, 0);
+    Constant *zero = Constant::getNullValue(ResultTy);
     if (isFloat)
-      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_UNE, tmp, zero);
+      Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_UEQ, tmp, zero);
     else
-      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, zero);
+      Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, zero);
+  } break;
+  case HLUnaryOpcode::Not: {
+    Constant *allOneBits = Constant::getAllOnesValue(ResultTy);
+    Result = BinaryOperator::CreateXor(tmp, allOneBits);
   } break;
-  case HLUnaryOpcode::Not:
-    Result = BinaryOperator::CreateXor(tmp, tmp);
-    break;
   case HLUnaryOpcode::PostInc:
   case HLUnaryOpcode::PreInc:
     if (isFloat)
-      Result = BinaryOperator::CreateFAdd(tmp, one);
+      Result = BinaryOperator::CreateFAdd(tmp, oneVec);
     else
-      Result = BinaryOperator::CreateAdd(tmp, one);
+      Result = BinaryOperator::CreateAdd(tmp, oneVec);
     break;
   case HLUnaryOpcode::PostDec:
   case HLUnaryOpcode::PreDec:
     if (isFloat)
-      Result = BinaryOperator::CreateFSub(tmp, one);
+      Result = BinaryOperator::CreateFSub(tmp, oneVec);
     else
-      Result = BinaryOperator::CreateSub(tmp, one);
+      Result = BinaryOperator::CreateSub(tmp, oneVec);
     break;
   default:
     DXASSERT(0, "not implement");
@@ -842,33 +848,25 @@ Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
     break;
   case HLBinaryOpcode::LAnd:
   case HLBinaryOpcode::LOr: {
-    Constant *zero;
-    if (isFloat)
-      zero = llvm::ConstantFP::get(ResultTy->getVectorElementType(), 0);
-    else
-      zero = llvm::ConstantInt::get(ResultTy->getVectorElementType(), 0);
-
-    unsigned size = ResultTy->getVectorNumElements();
-    std::vector<Constant *> zeros(size, zero);
-    Value *vecZero = llvm::ConstantVector::get(zeros);
+    Value *vecZero = Constant::getNullValue(ResultTy);
     Instruction *cmpL;
     if (isFloat)
-      cmpL =
-          CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
+      cmpL = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, vecZero);
     else
-      cmpL = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
+      cmpL = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, vecZero);
     Builder.Insert(cmpL);
 
     Instruction *cmpR;
     if (isFloat)
       cmpR =
-          CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
+          CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, vecZero);
     else
-      cmpR = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
+      cmpR = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, vecZero);
     Builder.Insert(cmpR);
+
     // How to map l, r back? Need check opcode
     if (opcode == HLBinaryOpcode::LOr)
-      Result = BinaryOperator::CreateAnd(cmpL, cmpR);
+      Result = BinaryOperator::CreateOr(cmpL, cmpR);
     else
       Result = BinaryOperator::CreateAnd(cmpL, cmpR);
     break;
@@ -996,23 +994,23 @@ void HLMatrixLowerPass::TrivialMatUnOpReplace(Value *matVal,
   HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(matUseInst));
   Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
   switch (opcode) {
-  case HLUnaryOpcode::Not:
-    // Not is xor now
-    vecUseInst->setOperand(0, vecVal);
-    vecUseInst->setOperand(1, vecVal);
-    break;
-  case HLUnaryOpcode::LNot:
+  case HLUnaryOpcode::Plus: // add(x, 0)
+    // Ideally we'd get completely rid of the instruction for +mat,
+    // but matToVecMap needs to point to some instruction.
+  case HLUnaryOpcode::Not: // xor(x, -1)
+  case HLUnaryOpcode::LNot: // cmpeq(x, 0)
   case HLUnaryOpcode::PostInc:
   case HLUnaryOpcode::PreInc:
   case HLUnaryOpcode::PostDec:
   case HLUnaryOpcode::PreDec:
     vecUseInst->setOperand(0, vecVal);
     break;
+  case HLUnaryOpcode::Minus: // sub(0, x)
+    vecUseInst->setOperand(1, vecVal);
+    break;
   case HLUnaryOpcode::Invalid:
-  case HLUnaryOpcode::Plus:
-  case HLUnaryOpcode::Minus:
   case HLUnaryOpcode::NumOfUO:
-    // No VecInst replacements for these.
+    DXASSERT(false, "Unexpected HL unary opcode.");
     break;
   }
 }

+ 131 - 0
tools/clang/test/CodeGenHLSL/quick-test/matrix_operators.hlsl

@@ -0,0 +1,131 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// Tests the implementation of unary and binary matrix operators
+
+// Workaround for AppendStructuredBuffer<matrix>.Append bug
+#define Append(buf, val) buf[buf.IncrementCounter()] = (val)
+
+RWStructuredBuffer<int1x1> output_i;
+RWStructuredBuffer<uint1x1> output_u;
+RWStructuredBuffer<float1x1> output_f;
+RWStructuredBuffer<bool1x1> output_b;
+
+void main()
+{
+    int1x1 i1 = int1x1(1);
+    int1x1 i2 = int1x1(2);
+    int1x1 i3 = int1x1(3);
+    int1x1 im1 = int1x1(-1);
+    int1x1 im3 = int1x1(-3);
+    uint1x1 u1 = uint1x1(1);
+    uint1x1 u2 = uint1x1(2);
+    uint1x1 u3 = uint1x1(3);
+    uint1x1 um1 = uint1x1((uint)(-1));
+    float1x1 fm0_5 = float1x1(-0.5);
+    float1x1 f0_5 = float1x1(0.5);
+    float1x1 f1 = float1x1(1);
+    float1x1 f1_5 = float1x1(1.5);
+    float1x1 f2 = float1x1(2);
+
+    // Unary operators, except pre/post inc/dec
+    // CHECK: i32 3, i32 undef
+    Append(output_i, +i3); // Plus
+    // CHECK: i32 -3, i32 undef
+    Append(output_i, -i3); // Minus
+    // CHECK: i32 -4, i32 undef
+    Append(output_i, ~i3); // Not
+    // CHECK: i32 0, i32 undef
+    Append(output_b, !i3); // LNot
+    
+    // CHECK: float 5.000000e-01, float undef
+    Append(output_f, +f0_5); // Plus
+    // CHECK: float -5.000000e-01, float undef
+    Append(output_f, -f0_5); // Minus
+    // CHECK: i32 0, i32 undef
+    Append(output_b, !f0_5); // LNot
+
+    // Binary operators
+    // CHECK: i32 6, i32 undef
+    Append(output_i, i3 * i2); // Mul
+    // CHECK: i32 -1, i32 undef
+    Append(output_i, im3 / i2); // Div
+    // CHECK: i32 -1, i32 undef
+    Append(output_i, im3 % i2); // Rem
+    // CHECK: i32 3, i32 undef
+    Append(output_i, i1 + i2); // Add
+    // CHECK: i32 2, i32 undef
+    Append(output_i, i3 - i1); // Sub
+
+    // CHECK: float 1.000000e+00, float undef
+    Append(output_f, f0_5 * f2); // Mul
+    // CHECK: float 2.000000e+00, float undef
+    Append(output_f, f1 / f0_5); // Div
+    // CHECK: float 5.000000e-01, float undef
+    Append(output_f, f2 % f1_5); // Rem
+    // CHECK: float 2.000000e+00, float undef
+    Append(output_f, f0_5 + f1_5); // Add
+    // CHECK: float -1.000000e+00, float undef
+    Append(output_f, f0_5 - f1_5); // Sub
+
+    // CHECK: i32 6, i32 undef
+    Append(output_i, i3 << i1); // Shl
+    // CHECK: i32 -1, i32 undef
+    Append(output_i, im1 >> i1); // Shr
+    // CHECK: i32 2, i32 undef
+    Append(output_i, i3 & i2); // And
+    // CHECK: i32 2, i32 undef
+    Append(output_i, i3 ^ i1); // Xor
+    // CHECK: i32 3, i32 undef
+    Append(output_i, i2 | i1); // Or
+
+    // CHECK: i32 1, i32 undef
+    Append(output_b, i3 && i2); // LAnd
+    // CHECK: i32 1, i32 undef
+    Append(output_b, i3 || i2); // LOr
+    
+    // CHECK: i32 1, i32 undef
+    Append(output_b, f0_5 && f1_5); // LAnd
+    // CHECK: i32 1, i32 undef
+    Append(output_b, f0_5 || f1_5); // LOr
+
+    // CHECK: i32 6, i32 undef
+    Append(output_u, u3 * u2); // UDiv
+    // CHECK: i32 1, i32 undef
+    Append(output_u, u3 % u2); // URem
+    // CHECK: i32 2147483647, i32 undef
+    Append(output_u, um1 >> u1); // UShr
+
+    // CHECK: i32 1, i32 undef
+    Append(output_b, im1 < i1); // LT
+    // CHECK: i32 0, i32 undef
+    Append(output_b, im1 > i1); // GT
+    // CHECK: i32 1, i32 undef
+    Append(output_b, im1 <= i1); // LE
+    // CHECK: i32 0, i32 undef
+    Append(output_b, im1 >= i1); // GE
+    // CHECK: i32 0, i32 undef
+    Append(output_b, im1 == i1); // EQ
+    // CHECK: i32 1, i32 undef
+    Append(output_b, im1 != i1); // NE
+    // CHECK: i32 0, i32 undef
+    Append(output_b, um1 < u1); // ULT
+    // CHECK: i32 1, i32 undef
+    Append(output_b, um1 > u1); // UGT
+    // CHECK: i32 0, i32 undef
+    Append(output_b, um1 <= u1); // ULE
+    // CHECK: i32 1, i32 undef
+    Append(output_b, um1 >= u1); // UGE
+    
+    // CHECK: i32 1, i32 undef
+    Append(output_b, fm0_5 < f1_5); // LT
+    // CHECK: i32 0, i32 undef
+    Append(output_b, fm0_5 > f1_5); // GT
+    // CHECK: i32 1, i32 undef
+    Append(output_b, fm0_5 <= f1_5); // LE
+    // CHECK: i32 0, i32 undef
+    Append(output_b, fm0_5 >= f1_5); // GE
+    // CHECK: i32 0, i32 undef
+    Append(output_b, fm0_5 == f1_5); // EQ
+    // CHECK: i32 1, i32 undef
+    Append(output_b, fm0_5 != f1_5); // NE
+}