Browse Source

Translate AddUint64. (#47)

Xiang Li 8 years ago
parent
commit
ac228c5aa4

+ 1 - 0
include/dxc/HLSL/DxilOperations.h

@@ -37,6 +37,7 @@ public:
   OP(llvm::LLVMContext &Ctx, llvm::Module *pModule);
 
   llvm::Function *GetOpFunc(OpCode OpCode, llvm::Type *pOverloadType);
+  llvm::Type *GetOverloadType(OpCode OpCode, llvm::Function *F);
   llvm::LLVMContext &GetCtx() { return m_Ctx; }
   llvm::Type *GetHandleType() const;
   llvm::Type *GetDimensionsType() const;

+ 105 - 0
lib/HLSL/DxilOperations.cpp

@@ -688,6 +688,111 @@ Function *OP::GetOpFunc(OpCode OpCode, Type *pOverloadType) {
   return F;
 }
 
+llvm::Type *OP::GetOverloadType(OpCode OpCode, llvm::Function *F) {
+  DXASSERT(F, "not work on nullptr");
+  Type *Ty = F->getReturnType();
+  FunctionType *FT = F->getFunctionType();
+/* <py::lines('OPCODE-OLOAD-TYPES')>hctdb_instrhelp.get_funcs_oload_type()</py>*/
+  switch (OpCode) {            // return     OpCode
+  // OPCODE-OLOAD-TYPES:BEGIN
+  case OpCode::IsNaN:
+  case OpCode::IsInf:
+  case OpCode::IsFinite:
+  case OpCode::IsNormal:
+  case OpCode::Countbits:
+  case OpCode::FirstbitLo:
+  case OpCode::FirstbitHi:
+  case OpCode::FirstbitSHi:
+  case OpCode::IMul:
+  case OpCode::UMul:
+  case OpCode::UDiv:
+  case OpCode::IAddc:
+  case OpCode::UAddc:
+  case OpCode::ISubc:
+  case OpCode::USubc:
+  case OpCode::WaveActiveAllEqual:
+    return FT->getParamType(1);
+  case OpCode::TempRegStore:
+    return FT->getParamType(2);
+  case OpCode::MinPrecXRegStore:
+  case OpCode::StoreOutput:
+  case OpCode::BufferStore:
+  case OpCode::StorePatchConstant:
+    return FT->getParamType(4);
+  case OpCode::TextureStore:
+    return FT->getParamType(5);
+  case OpCode::MakeDouble:
+  case OpCode::SplitDouble:
+    return Type::getDoubleTy(m_Ctx);
+  case OpCode::CheckAccessFullyMapped:
+  case OpCode::AtomicBinOp:
+  case OpCode::AtomicCompareExchange:
+  case OpCode::SampleIndex:
+  case OpCode::Coverage:
+  case OpCode::InnerCoverage:
+  case OpCode::ThreadId:
+  case OpCode::GroupId:
+  case OpCode::ThreadIdInGroup:
+  case OpCode::FlattenedThreadIdInGroup:
+  case OpCode::GSInstanceID:
+  case OpCode::OutputControlPointID:
+  case OpCode::PrimitiveID:
+    return IntegerType::get(m_Ctx, 32);
+  case OpCode::CalculateLOD:
+  case OpCode::DomainLocation:
+    return Type::getFloatTy(m_Ctx);
+  case OpCode::CreateHandle:
+  case OpCode::BufferUpdateCounter:
+  case OpCode::GetDimensions:
+  case OpCode::Texture2DMSGetSamplePosition:
+  case OpCode::RenderTargetGetSamplePosition:
+  case OpCode::RenderTargetGetSampleCount:
+  case OpCode::Barrier:
+  case OpCode::Discard:
+  case OpCode::EmitStream:
+  case OpCode::CutStream:
+  case OpCode::EmitThenCutStream:
+  case OpCode::CycleCounterLegacy:
+  case OpCode::WaveIsFirstLane:
+  case OpCode::WaveGetLaneIndex:
+  case OpCode::WaveGetLaneCount:
+  case OpCode::WaveAnyTrue:
+  case OpCode::WaveAllTrue:
+  case OpCode::WaveActiveBallot:
+  case OpCode::BitcastI16toF16:
+  case OpCode::BitcastF16toI16:
+  case OpCode::BitcastI32toF32:
+  case OpCode::BitcastF32toI32:
+  case OpCode::BitcastI64toF64:
+  case OpCode::BitcastF64toI64:
+  case OpCode::LegacyF32ToF16:
+  case OpCode::LegacyF16ToF32:
+  case OpCode::LegacyDoubleToFloat:
+  case OpCode::LegacyDoubleToSInt32:
+  case OpCode::LegacyDoubleToUInt32:
+  case OpCode::WaveAllBitCount:
+  case OpCode::WavePrefixBitCount:
+    return Type::getVoidTy(m_Ctx);
+  case OpCode::CBufferLoadLegacy:
+  case OpCode::Sample:
+  case OpCode::SampleBias:
+  case OpCode::SampleLevel:
+  case OpCode::SampleGrad:
+  case OpCode::SampleCmp:
+  case OpCode::SampleCmpLevelZero:
+  case OpCode::TextureLoad:
+  case OpCode::BufferLoad:
+  case OpCode::TextureGather:
+  case OpCode::TextureGatherCmp:
+  {
+    StructType *ST = cast<StructType>(Ty);
+    return ST->getElementType(0);
+  }
+  // OPCODE-OLOAD-TYPES:END
+  default: return Ty;
+  }
+}
+
 Type *OP::GetHandleType() const {
   return m_pHandleType;
 }

+ 1 - 71
lib/HLSL/DxilValidation.cpp

@@ -1774,76 +1774,6 @@ static void ValidateDxilOperationCallInProfile(CallInst *CI,
 
 }
 
-static Type *GetOverloadTyForDxilOperation(CallInst *CI, DXIL::OpCode opcode) {
-  Type *Ty = CI->getType();
-  if (Ty->isVoidTy()) {
-    switch (opcode) {
-    case DXIL::OpCode::StoreOutput:
-    case DXIL::OpCode::StorePatchConstant:
-      if (CI->getNumArgOperands() < DXIL::OperandIndex::kStoreOutputValOpIdx) {
-        // Will emit error later when cannot find valid dxil function.
-        return CI->getType();
-      }
-      return CI->getArgOperand(DXIL::OperandIndex::kStoreOutputValOpIdx)->getType();
-    case DXIL::OpCode::BufferStore:
-      return CI->getArgOperand(DXIL::OperandIndex::kBufferStoreVal0OpIdx)->getType();
-    case DXIL::OpCode::TextureStore:
-      return CI->getArgOperand(DXIL::OperandIndex::kTextureStoreVal0OpIdx)->getType();
-    default:
-      return Ty;
-    }
-  } else if (Ty->isAggregateType()) {
-    switch (opcode) {
-    case DXIL::OpCode::CreateHandle:
-    case DXIL::OpCode::GetDimensions:
-    case DXIL::OpCode::Texture2DMSGetSamplePosition:
-    case DXIL::OpCode::RenderTargetGetSamplePosition:
-      return Type::getVoidTy(CI->getContext());
-    case DXIL::OpCode::CBufferLoadLegacy:
-    case DXIL::OpCode::BufferLoad:
-    case DXIL::OpCode::TextureLoad:
-    case DXIL::OpCode::Sample:
-    case DXIL::OpCode::SampleBias:
-    case DXIL::OpCode::SampleCmp:
-    case DXIL::OpCode::SampleCmpLevelZero:
-    case DXIL::OpCode::SampleGrad:
-    case DXIL::OpCode::SampleLevel:
-    case DXIL::OpCode::TextureGather:
-    case DXIL::OpCode::TextureGatherCmp:
-    {
-      StructType *ST = cast<StructType>(CI->getType());
-      return ST->getElementType(0);
-    }
-    case DXIL::OpCode::SplitDouble:
-      return CI->getArgOperand(DXIL::OperandIndex::kUnarySrc0OpIdx)->getType();
-    default:
-      return Ty;
-    }
-  } else
-    switch (opcode) {
-    case DXIL::OpCode::BufferUpdateCounter:
-    case DXIL::OpCode::RenderTargetGetSampleCount:
-      return Type::getVoidTy(CI->getContext());
-    case DXIL::OpCode::CheckAccessFullyMapped:
-      return Type::getInt32Ty(CI->getContext());
-    case DXIL::OpCode::IsFinite:
-    case DXIL::OpCode::IsInf:
-    case DXIL::OpCode::IsNaN:
-    case DXIL::OpCode::IsNormal:
-      return CI->getArgOperand(DXIL::OperandIndex::kUnarySrc0OpIdx)->getType();
-    case DXIL::OpCode::WaveActiveAllEqual:
-      // TODO: build this whole function from hctdb.py
-      return CI->getArgOperand(1)->getType()->getScalarType();
-    case DXIL::OpCode::Countbits:
-    case DXIL::OpCode::FirstbitLo:
-    case DXIL::OpCode::FirstbitHi:
-    case DXIL::OpCode::FirstbitSHi:
-      return CI->getArgOperand(DXIL::OperandIndex::kUnarySrc0OpIdx)->getType()->getScalarType();
-    default:
-      return Ty;
-    }
-}
-
 static bool IsDxilFunction(llvm::Function *F) {
   unsigned argSize = F->getArgumentList().size();
   if (argSize < 1) {
@@ -1896,7 +1826,7 @@ static void ValidateExternalFunction(Function *F, ValidationContext &ValCtx) {
       dxilFunc = hlslOP->GetOpFunc(dxilOpcode, voidTy);
     }
     else {
-      Type *Ty = GetOverloadTyForDxilOperation(CI, dxilOpcode);
+      Type *Ty = hlslOP->GetOverloadType(dxilOpcode, CI->getCalledFunction());
       try {
         if (!hlslOP->IsOverloadLegal(dxilOpcode, Ty)) {
           ValCtx.EmitInstrError(CI, ValidationRule::InstrOload);

+ 49 - 1
lib/HLSL/HLOperationLower.cpp

@@ -354,6 +354,54 @@ Value *TranslateD3DColorToUByte4(CallInst *CI, IntrinsicOp IOP,
   return Builder.CreateBitCast(byte4, CI->getType());
 }
 
+Value *TranslateAddUint64(CallInst *CI, IntrinsicOp IOP,
+                                 OP::OpCode opcode,
+                                 HLOperationLowerHelper &helper,  HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
+  hlsl::OP *hlslOP = &helper.hlslOP;
+  IRBuilder<> Builder(CI);
+  Value *val = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
+  Type *Ty = val->getType();
+  VectorType *VT = dyn_cast<VectorType>(Ty);
+  if (!VT) {
+    CI->getContext().emitError(
+        CI, "AddUint64 can only be applied to uint2 and uint4 operands");
+    return UndefValue::get(Ty);
+  }
+
+  unsigned size = VT->getNumElements();
+  if (size != 2 && size != 4) {
+    CI->getContext().emitError(
+        CI, "AddUint64 can only be applied to uint2 and uint4 operands");
+    return UndefValue::get(Ty);
+  }
+  Value *op0 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
+  Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
+
+  Value *RetVal = UndefValue::get(Ty);
+
+  Function *AddC = hlslOP->GetOpFunc(DXIL::OpCode::UAddc, helper.i32Ty);
+  Value *opArg = Builder.getInt32(static_cast<unsigned>(DXIL::OpCode::UAddc));
+  for (unsigned i=0; i<size; i+=2) {
+    Value *low0 = Builder.CreateExtractElement(op0, i);
+    Value *low1 = Builder.CreateExtractElement(op1, i);
+    Value *lowWithC = Builder.CreateCall(AddC, { opArg, low0, low1});
+    Value *low = Builder.CreateExtractValue(lowWithC, 0);
+    RetVal = Builder.CreateInsertElement(RetVal, low, i);
+
+    Value *carry = Builder.CreateExtractValue(lowWithC, 1);
+    // Ext i1 to i32
+    carry = Builder.CreateZExt(carry, helper.i32Ty);
+
+    Value *hi0 = Builder.CreateExtractElement(op0, i+1);
+    Value *hi1 = Builder.CreateExtractElement(op1, i+1);
+    Value *hi = Builder.CreateAdd(hi0, hi1);
+    hi = Builder.CreateAdd(hi, carry);
+    RetVal = Builder.CreateInsertElement(RetVal, hi, i+1);
+  }
+  return RetVal;
+}
+
+
 CallInst *ValidateLoadInput(Value *V) {
   // Must be load input.
   CallInst *CI = cast<CallInst>(V);
@@ -3637,7 +3685,7 @@ Value *StreamOutputLower(CallInst *CI, IntrinsicOp IOP, DXIL::OpCode opcode,
 }
 
 IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] = {
-    {IntrinsicOp::IOP_AddUint64,  EmptyLower,  DXIL::OpCode::NumOpCodes},
+    {IntrinsicOp::IOP_AddUint64,  TranslateAddUint64,  DXIL::OpCode::UAddc},
     {IntrinsicOp::IOP_AllMemoryBarrier, TrivialBarrier, DXIL::OpCode::Barrier},
     {IntrinsicOp::IOP_AllMemoryBarrierWithGroupSync, TrivialBarrier, DXIL::OpCode::Barrier},
     {IntrinsicOp::IOP_CheckAccessFullyMapped, TrivialUnaryOperation, DXIL::OpCode::CheckAccessFullyMapped},

+ 10 - 0
tools/clang/test/CodeGenHLSL/AddUint64.hlsl

@@ -0,0 +1,10 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: binaryWithCarry
+// CHECK: binaryWithCarry
+
+float4 main(uint4 a : A, uint4 b :B) : SV_TARGET {
+  uint2 c2 = AddUint64(a.xy, b.xy);
+  uint4 c4 = AddUint64(a, b);
+  return c2.xxyy + c4;
+}

+ 9 - 0
tools/clang/test/CodeGenHLSL/AddUint64Odd.hlsl

@@ -0,0 +1,9 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: AddUint64 can only be applied to uint2 and uint4 operands
+
+float4 main(uint4 a : A, uint4 b :B) : SV_TARGET {
+  uint c = AddUint64(a.x, b.x);
+  uint3 c3 = AddUint64(a.xyz, b.xyz);
+  return c + c3.xyzz;
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -294,6 +294,7 @@ public:
 
   TEST_METHOD(CodeGenAbs1)
   TEST_METHOD(CodeGenAbs2)
+  TEST_METHOD(CodeGenAddUint64)
   TEST_METHOD(CodeGenArrayArg)
   TEST_METHOD(CodeGenArrayOfStruct)
   TEST_METHOD(CodeGenAsUint)
@@ -1694,6 +1695,10 @@ TEST_F(CompilerTest, CodeGenAbs2) {
   CodeGenTest(L"..\\CodeGenHLSL\\abs2.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenAddUint64) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\AddUint64.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenArrayArg){
   CodeGenTest(L"..\\CodeGenHLSL\\arrayArg.hlsl");
 }

+ 5 - 0
tools/clang/unittests/HLSL/ValidationTest.cpp

@@ -125,6 +125,7 @@ public:
   TEST_METHOD(I8Type)
   TEST_METHOD(EmptyStructInBuffer)
   TEST_METHOD(BigStructInBuffer)
+  TEST_METHOD(AddUint64Odd)
 
   TEST_METHOD(ClipCullMaxComponents)
   TEST_METHOD(ClipCullMaxRows)
@@ -1299,6 +1300,10 @@ TEST_F(ValidationTest, BigStructInBuffer) {
   TestCheck(L"..\\CodeGenHLSL\\BigStructInBuffer.hlsl");
 }
 
+TEST_F(ValidationTest, AddUint64Odd) {
+  TestCheck(L"..\\CodeGenHLSL\\AddUint64Odd.hlsl");
+}
+
 TEST_F(ValidationTest, WhenWaveAffectsGradientThenFail) {
   TestCheck(L"val-wave-failures-ps.hlsl");
 }

+ 102 - 1
utils/hct/hctdb_instrhelp.py

@@ -394,6 +394,102 @@ class db_oload_gen:
             line = line + "break;"
             print(line)
     
+    def print_opfunc_oload_type(self):
+        # Print the function for OP::GetOverloadType
+        elt_ty = "$o"
+        res_ret_ty = "$r"
+        cb_ret_ty = "$cb"
+
+        last_category = None
+
+        index_dict = {}
+        single_dict = {}
+        struct_list = []
+
+        for instr in self.instrs:
+            ret_ty = instr.ops[0].llvm_type
+            # Skip case return type is overload type
+            if (ret_ty == elt_ty):
+                continue
+
+            if ret_ty == res_ret_ty:
+                struct_list.append(instr.name)
+                continue
+
+            if ret_ty == cb_ret_ty:
+                struct_list.append(instr.name)
+                continue
+
+            in_param_ty = False
+            # Try to find elt_ty in parameter types.
+            for index, op in enumerate(instr.ops):
+                # Skip return type.
+                if (op.pos == 0):
+                    continue
+                # Skip dxil opcode.
+                if (op.pos == 1):
+                    continue
+
+                op_type = op.llvm_type
+                if (op_type == elt_ty):
+                    # Skip return op
+                    index = index - 1
+                    if index not in index_dict:
+                        index_dict[index] = [instr.name]
+                    else:
+                        index_dict[index].append(instr.name)
+                    in_param_ty = True
+                    break
+
+            if in_param_ty:
+                continue
+
+            # No overload, just return the single oload_type.
+            assert len(instr.oload_types)==1, "overload no elt_ty %s" % (instr.name)
+            ty = instr.oload_types[0]
+            type_code_texts = {
+            "d": "Type::getDoubleTy(m_Ctx)",
+            "f": "Type::getFloatTy(m_Ctx)",
+            "h": "Type::getHalfTy",
+            "1": "IntegerType::get(m_Ctx, 1)",
+			"8": "IntegerType::get(m_Ctx, 8)",
+            "w": "IntegerType::get(m_Ctx, 16)",
+            "i": "IntegerType::get(m_Ctx, 32)",
+            "l": "IntegerType::get(m_Ctx, 64)",
+            "v": "Type::getVoidTy(m_Ctx)",
+            }
+            assert ty in type_code_texts, "llvm type %s is unknown" % (ty)
+            ty_code = type_code_texts[ty]
+
+            if ty_code not in single_dict:
+                single_dict[ty_code] = [instr.name]
+            else:
+                single_dict[ty_code].append(instr.name)
+
+        for index, opcodes in index_dict.items():
+            line = ""
+            for opcode in opcodes:
+                line = line + "case OpCode::{name}".format(name = opcode + ":\n")
+
+            line = line + "  return FT->getParamType(" + str(index) + ");"
+            print(line)
+
+        for code, opcodes in single_dict.items():
+            line = ""
+            for opcode in opcodes:
+                line = line + "case OpCode::{name}".format(name = opcode + ":\n")
+            line = line + "  return " + code + ";"
+            print(line)
+
+        line = ""
+        for opcode in struct_list:
+            line = line + "case OpCode::{name}".format(name = opcode + ":\n")
+        line = line + "{\n"
+        line = line + "  StructType *ST = cast<StructType>(Ty);\n"
+        line = line + "  return ST->getElementType(0);\n"
+        line = line + "}"
+        print(line)
+
 
 class db_valfns_gen:
     "A generator of validation functions."
@@ -594,12 +690,17 @@ def get_oloads_props():
     db = get_db_dxil()
     gen = db_oload_gen(db)
     return run_with_stdout(lambda: gen.print_opfunc_props())
-        
+
 def get_oloads_funcs():
     db = get_db_dxil()
     gen = db_oload_gen(db)
     return run_with_stdout(lambda: gen.print_opfunc_table())
 
+def get_funcs_oload_type():
+    db = get_db_dxil()
+    gen = db_oload_gen(db)
+    return run_with_stdout(lambda: gen.print_opfunc_oload_type())
+
 def get_enum_decl(name, **kwargs):
     db = get_db_dxil()
     gen = db_enumhelp_gen(db)