Browse Source

Sundry code cleanups (#3073)

refactor usage of resourceparams to just use RP

create common way of extracting res properties

make getoverloadtype static

so it can be used to determine version
requirements

Disallow literals where not included

Previously, whether an intrinsic parameter type included an entry for
literals or not, a literal was permitted to be cast to it.

By moving up the explicit per-param check to before this cast takes
place, these implicit casts won't be added
Greg Roth 5 years ago
parent
commit
efc14a6c40

+ 1 - 1
include/dxc/DXIL/DxilOperations.h

@@ -47,7 +47,6 @@ public:
   llvm::Function *GetOpFunc(OpCode OpCode, llvm::Type *pOverloadType);
   const llvm::SmallMapVector<llvm::Type *, llvm::Function *, 8> &GetOpFuncList(OpCode OpCode) const;
   void RemoveFunction(llvm::Function *F);
-  llvm::Type *GetOverloadType(OpCode OpCode, llvm::Function *F);
   llvm::LLVMContext &GetCtx() { return m_Ctx; }
   llvm::Type *GetHandleType() const;
   llvm::Type *GetResourcePropertiesType() const;
@@ -87,6 +86,7 @@ public:
   llvm::Constant *GetFloatConst(float v);
   llvm::Constant *GetDoubleConst(double v);
 
+  static llvm::Type *GetOverloadType(OpCode OpCode, llvm::Function *F);
   static OpCode GetDxilOpFuncCallInst(const llvm::Instruction *I);
   static const char *GetOpCodeName(OpCode OpCode);
   static const char *GetAtomicOpName(DXIL::AtomicBinOpCode OpCode);

+ 6 - 5
lib/DXIL/DxilOperations.cpp

@@ -1430,6 +1430,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   DXASSERT(F, "not work on nullptr");
   Type *Ty = F->getReturnType();
   FunctionType *FT = F->getFunctionType();
+  LLVMContext &Ctx = F->getContext();
 /* <py::lines('OPCODE-OLOAD-TYPES')>hctdb_instrhelp.get_funcs_oload_type()</py>*/
   switch (opCode) {            // return     OpCode
   // OPCODE-OLOAD-TYPES:BEGIN
@@ -1521,7 +1522,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   case OpCode::RayQuery_CommitProceduralPrimitiveHit:
   case OpCode::CreateHandleFromHeap:
   case OpCode::AnnotateHandle:
-    return Type::getVoidTy(m_Ctx);
+    return Type::getVoidTy(Ctx);
   case OpCode::CheckAccessFullyMapped:
   case OpCode::AtomicBinOp:
   case OpCode::AtomicCompareExchange:
@@ -1559,7 +1560,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   case OpCode::GeometryIndex:
   case OpCode::RayQuery_CandidateInstanceContributionToHitGroupIndex:
   case OpCode::RayQuery_CommittedInstanceContributionToHitGroupIndex:
-    return IntegerType::get(m_Ctx, 32);
+    return IntegerType::get(Ctx, 32);
   case OpCode::CalculateLOD:
   case OpCode::DomainLocation:
   case OpCode::WorldRayOrigin:
@@ -1585,15 +1586,15 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   case OpCode::RayQuery_CandidateObjectRayDirection:
   case OpCode::RayQuery_CommittedObjectRayOrigin:
   case OpCode::RayQuery_CommittedObjectRayDirection:
-    return Type::getFloatTy(m_Ctx);
+    return Type::getFloatTy(Ctx);
   case OpCode::MakeDouble:
   case OpCode::SplitDouble:
-    return Type::getDoubleTy(m_Ctx);
+    return Type::getDoubleTy(Ctx);
   case OpCode::RayQuery_Proceed:
   case OpCode::RayQuery_CandidateProceduralPrimitiveNonOpaque:
   case OpCode::RayQuery_CandidateTriangleFrontFace:
   case OpCode::RayQuery_CommittedTriangleFrontFace:
-    return IntegerType::get(m_Ctx, 1);
+    return IntegerType::get(Ctx, 1);
   case OpCode::CBufferLoadLegacy:
   case OpCode::Sample:
   case OpCode::SampleBias:

+ 56 - 64
lib/DXIL/DxilShaderFlags.cpp

@@ -215,6 +215,54 @@ static CallInst *FindCallToCreateHandle(Value *handleType) {
   return CI;
 }
 
+DxilResourceProperties GetResourcePropertyFromHandleCall(const hlsl::DxilModule *M, CallInst *handleCall) {
+
+  DxilResourceProperties RP = {};
+  RP.Class = DXIL::ResourceClass::Invalid;
+  RP.Kind = DXIL::ResourceKind::Invalid;
+
+  ConstantInt *HandleOpCodeConst = cast<ConstantInt>(
+      handleCall->getArgOperand(DXIL::OperandIndex::kOpcodeIdx));
+  DXIL::OpCode handleOp = static_cast<DXIL::OpCode>(HandleOpCodeConst->getLimitedValue());
+  if (handleOp == DXIL::OpCode::CreateHandle) {
+    if (ConstantInt *resClassArg =
+      dyn_cast<ConstantInt>(handleCall->getArgOperand(
+        DXIL::OperandIndex::kCreateHandleResClassOpIdx))) {
+      DXIL::ResourceClass resClass = static_cast<DXIL::ResourceClass>(
+        resClassArg->getLimitedValue());
+      ConstantInt *rangeID = GetArbitraryConstantRangeID(handleCall);
+      if (rangeID) {
+        DxilResource resource;
+        if (resClass == DXIL::ResourceClass::UAV)
+          resource = M->GetUAV(rangeID->getLimitedValue());
+        else if (resClass == DXIL::ResourceClass::SRV)
+          resource = M->GetSRV(rangeID->getLimitedValue());
+        RP = resource_helper::loadFromResourceBase(&resource);
+      }
+    }
+  }
+  else if (handleOp == DXIL::OpCode::CreateHandleForLib) {
+    // If library handle, find DxilResource by checking the name
+    if (LoadInst *LI = dyn_cast<LoadInst>(handleCall->getArgOperand(
+            DXIL::OperandIndex::kCreateHandleForLibResOpIdx))) {
+      Value *resType = LI->getOperand(0);
+      for (auto &&res : M->GetUAVs()) {
+        if (res->GetGlobalSymbol() == resType) {
+          RP = resource_helper::loadFromResourceBase(res.get());
+        }
+      }
+    }
+  } else if (handleOp == DXIL::OpCode::AnnotateHandle) {
+    DxilInst_AnnotateHandle annotateHandle(cast<Instruction>(handleCall));
+    Type *ResPropTy = M->GetOP()->GetResourcePropertiesType();
+
+    RP = resource_helper::loadFromAnnotateHandle(annotateHandle, ResPropTy, *M->GetShaderModel());
+  }
+
+  return RP;
+}
+
+
 ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F,
                                            const hlsl::DxilModule *M) {
   ShaderFlags flag;
@@ -315,71 +363,15 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F,
           CallInst *handleCall = FindCallToCreateHandle(resHandle);
           // Check if this is a library handle or general create handle
           if (handleCall) {
-            ConstantInt *HandleOpCodeConst = cast<ConstantInt>(
-                handleCall->getArgOperand(DXIL::OperandIndex::kOpcodeIdx));
-            DXIL::OpCode handleOp = static_cast<DXIL::OpCode>(HandleOpCodeConst->getLimitedValue());
-            if (handleOp == DXIL::OpCode::CreateHandle) {
-              if (ConstantInt *resClassArg =
-                dyn_cast<ConstantInt>(handleCall->getArgOperand(
-                  DXIL::OperandIndex::kCreateHandleResClassOpIdx))) {
-                DXIL::ResourceClass resClass = static_cast<DXIL::ResourceClass>(
-                  resClassArg->getLimitedValue());
-                if (resClass == DXIL::ResourceClass::UAV) {
-                  // Validator 1.0 assumes that all uav load is multi component load.
-                  if (hasMulticomponentUAVLoadsBackCompat) {
-                    hasMulticomponentUAVLoads = true;
-                    continue;
-                  }
-                  else {
-                    ConstantInt *rangeID = GetArbitraryConstantRangeID(handleCall);
-                    if (rangeID) {
-                      DxilResource resource = M->GetUAV(rangeID->getLimitedValue());
-                      if ((resource.IsTypedBuffer() ||
-                        resource.IsAnyTexture()) &&
-                        !dxilutil::IsResourceSingleComponent(resource.GetRetType())) {
-                        hasMulticomponentUAVLoads = true;
-                      }
-                    }
-                  }
-                }
-              }
-              else {
-                DXASSERT(false, "Resource class must be constant.");
-              }
-            }
-            else if (handleOp == DXIL::OpCode::CreateHandleForLib) {
-              // If library handle, find DxilResource by checking the name
-              if (LoadInst *LI = dyn_cast<LoadInst>(handleCall->getArgOperand(
-                      DXIL::OperandIndex::
-                          kCreateHandleForLibResOpIdx))) {
-                Value *resType = LI->getOperand(0);
-                for (auto &&res : M->GetUAVs()) {
-                  if (res->GetGlobalSymbol() == resType) {
-                    if ((res->IsTypedBuffer() || res->IsAnyTexture()) &&
-                        !dxilutil::IsResourceSingleComponent(res->GetRetType())) {
-                      hasMulticomponentUAVLoads = true;
-                    }
-                  }
-                }
-              }
-            } else if (handleOp == DXIL::OpCode::AnnotateHandle) {
-              DxilInst_AnnotateHandle annotateHandle(handleCall);
-              Type *ResPropTy = M->GetOP()->GetResourcePropertiesType();
-
-              DxilResourceProperties RP =
-                  resource_helper::loadFromAnnotateHandle(
-                      annotateHandle, ResPropTy, *M->GetShaderModel());
-              if (RP.Class == DXIL::ResourceClass::UAV) {
-                // Validator 1.0 assumes that all uav load is multi component
-                // load.
-                if (hasMulticomponentUAVLoadsBackCompat) {
+            DxilResourceProperties RP = GetResourcePropertyFromHandleCall(M, handleCall);
+            if (RP.Class == DXIL::ResourceClass::UAV) {
+              // Validator 1.0 assumes that all uav load is multi component load.
+              if (hasMulticomponentUAVLoadsBackCompat) {
+                hasMulticomponentUAVLoads = true;
+                continue;
+              } else {
+                if (DXIL::IsTyped(RP.Kind) && !RP.Typed.SingleComponent)
                   hasMulticomponentUAVLoads = true;
-                  continue;
-                } else {
-                  if (DXIL::IsTyped(RP.Kind) &&
-                      !RP.Typed.SingleComponent)
-                    hasMulticomponentUAVLoads = true;
-                }
               }
             }
           }

+ 2 - 2
lib/HLSL/DxilTranslateRawBuffer.cpp

@@ -52,7 +52,7 @@ public:
         if (hlslOP->GetOpCodeClass(func, opClass)) {
           if (opClass == DXIL::OpCodeClass::RawBufferLoad) {
             Type *ETy =
-                hlslOP->GetOverloadType(DXIL::OpCode::RawBufferLoad, func);
+                OP::GetOverloadType(DXIL::OpCode::RawBufferLoad, func);
 
             bool is64 =
                 ETy->isDoubleTy() || ETy == Type::getInt64Ty(ETy->getContext());
@@ -62,7 +62,7 @@ public:
             }
           } else if (opClass == DXIL::OpCodeClass::RawBufferStore) {
             Type *ETy =
-                hlslOP->GetOverloadType(DXIL::OpCode::RawBufferStore, func);
+                OP::GetOverloadType(DXIL::OpCode::RawBufferStore, func);
 
             bool is64 =
                 ETy->isDoubleTy() || ETy == Type::getInt64Ty(ETy->getContext());

+ 5 - 7
lib/HLSL/DxilValidation.cpp

@@ -2214,10 +2214,9 @@ static void ValidateResourceDxilOp(CallInst *CI, DXIL::OpCode opcode,
     }
   } break;
   case DXIL::OpCode::RawBufferLoad: {
-    hlsl::OP *hlslOP = ValCtx.DxilMod.GetOP();
     if (!ValCtx.DxilMod.GetShaderModel()->IsSM63Plus()) {
-      Type *Ty = hlslOP->GetOverloadType(DXIL::OpCode::RawBufferLoad,
-                                         CI->getCalledFunction());
+      Type *Ty = OP::GetOverloadType(DXIL::OpCode::RawBufferLoad,
+                                 CI->getCalledFunction());
       if (ValCtx.DL.getTypeAllocSizeInBits(Ty) > 32) {
         ValCtx.EmitInstrError(CI, ValidationRule::Sm64bitRawBufferLoadStore);
       }
@@ -2263,10 +2262,9 @@ static void ValidateResourceDxilOp(CallInst *CI, DXIL::OpCode opcode,
     }
   } break;
   case DXIL::OpCode::RawBufferStore: {
-    hlsl::OP *hlslOP = ValCtx.DxilMod.GetOP();
     if (!ValCtx.DxilMod.GetShaderModel()->IsSM63Plus()) {
-      Type *Ty = hlslOP->GetOverloadType(DXIL::OpCode::RawBufferStore,
-                                         CI->getCalledFunction());
+      Type *Ty = OP::GetOverloadType(DXIL::OpCode::RawBufferStore,
+                                 CI->getCalledFunction());
       if (ValCtx.DL.getTypeAllocSizeInBits(Ty) > 32) {
         ValCtx.EmitInstrError(CI, ValidationRule::Sm64bitRawBufferLoadStore);
       }
@@ -2560,7 +2558,7 @@ static void ValidateExternalFunction(Function *F, ValidationContext &ValCtx) {
       dxilFunc = hlslOP->GetOpFunc(dxilOpcode, voidTy);
     }
     else {
-      Type *Ty = hlslOP->GetOverloadType(dxilOpcode, CI->getCalledFunction());
+      Type *Ty = OP::GetOverloadType(dxilOpcode, CI->getCalledFunction());
       try {
         if (!hlslOP->IsOverloadLegal(dxilOpcode, Ty)) {
           ValCtx.EmitInstrError(CI, ValidationRule::InstrOload);

+ 1 - 1
lib/HLSL/HLOperationLowerExtension.cpp

@@ -622,7 +622,7 @@ Value *ExtensionLowering::Dxil(CallInst *CI) {
     return nullptr;
 
   // Find the dxil function based on the overload type.
-  Type *overloadTy = m_hlslOp.GetOverloadType(dxilOpcode, CI->getCalledFunction());
+  Type *overloadTy = OP::GetOverloadType(dxilOpcode, CI->getCalledFunction());
   Function *F = m_hlslOp.GetOpFunc(dxilOpcode, overloadTy->getScalarType());
 
   // Update the opcode in the original call so we can just copy it below.

+ 16 - 12
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -1000,6 +1000,7 @@ static const ArBasicKind g_DoubleCT[] =
 static const ArBasicKind g_DoubleOnlyCT[] =
 {
   AR_BASIC_FLOAT64,
+  AR_BASIC_LITERAL_FLOAT,
   AR_BASIC_NOCAST,
   AR_BASIC_UNKNOWN
 };
@@ -5493,6 +5494,21 @@ bool HLSLExternalSource::MatchArguments(
       continue;
     }
 
+    // Verify TypeInfoEltKind can be cast to something legal for this param
+    if (AR_BASIC_UNKNOWN != TypeInfoEltKind) {
+      for (const ArBasicKind *pCT = g_LegalIntrinsicCompTypes[pIntrinsicArg->uLegalComponentTypes];
+           AR_BASIC_UNKNOWN != *pCT; pCT++) {
+        if (TypeInfoEltKind == *pCT)
+          break;
+        else if ((TypeInfoEltKind == AR_BASIC_LITERAL_INT && *pCT == AR_BASIC_LITERAL_FLOAT) ||
+                 (TypeInfoEltKind == AR_BASIC_LITERAL_FLOAT && *pCT == AR_BASIC_LITERAL_INT))
+          break;
+        else if (*pCT == AR_BASIC_NOCAST) {
+          badArgIdx = std::min(badArgIdx, iArg);
+        }
+      }
+    }
+
     if (TypeInfoEltKind == AR_BASIC_LITERAL_INT ||
         TypeInfoEltKind == AR_BASIC_LITERAL_FLOAT) {
       bool affectRetType =
@@ -5555,18 +5571,6 @@ bool HLSLExternalSource::MatchArguments(
       pIntrinsicArg->uComponentTypeId < MaxIntrinsicArgs,
       "otherwise intrinsic table was modified and MaxIntrinsicArgs was not updated (or uComponentTypeId is out of bounds)");
 
-    // Verify TypeInfoEltKind can be cast to something legal for this param
-    if (AR_BASIC_UNKNOWN != TypeInfoEltKind) {
-      for (const ArBasicKind *pCT = g_LegalIntrinsicCompTypes[pIntrinsicArg->uLegalComponentTypes];
-           AR_BASIC_UNKNOWN != *pCT; pCT++) {
-        if (TypeInfoEltKind == *pCT)
-          break;
-        else if (*pCT == AR_BASIC_NOCAST) {
-          badArgIdx = std::min(badArgIdx, iArg);
-        }
-      }
-    }
-
     // Merge ComponentTypes
     if (AR_BASIC_UNKNOWN == ComponentType[pIntrinsicArg->uComponentTypeId]) {
       ComponentType[pIntrinsicArg->uComponentTypeId] = TypeInfoEltKind;

+ 10 - 10
utils/hct/hctdb_instrhelp.py

@@ -497,17 +497,17 @@ class db_oload_gen:
             assert len(instr.oload_types)==1, "overload no elt_ty %s" % (instr.name)
             ty = instr.oload_types[0]
             type_code_texts = {
-            "d": "Type::getDoubleTy(m_Ctx)",
-            "f": "Type::getFloatTy(m_Ctx)",
+            "d": "Type::getDoubleTy(Ctx)",
+            "f": "Type::getFloatTy(Ctx)",
             "h": "Type::getHalfTy",
-            "1": "IntegerType::get(m_Ctx, 1)",
-			"8": "IntegerType::get(m_Ctx, 8)",
-            "w": "IntegerType::get(m_Ctx, 16)",
-            "i": "IntegerType::get(m_Ctx, 32)",
-            "l": "IntegerType::get(m_Ctx, 64)",
-            "v": "Type::getVoidTy(m_Ctx)",
-            "u": "Type::getInt32PtrTy(m_Ctx)",
-            "o": "Type::getInt32PtrTy(m_Ctx)",
+            "1": "IntegerType::get(Ctx, 1)",
+            "8": "IntegerType::get(Ctx, 8)",
+            "w": "IntegerType::get(Ctx, 16)",
+            "i": "IntegerType::get(Ctx, 32)",
+            "l": "IntegerType::get(Ctx, 64)",
+            "v": "Type::getVoidTy(Ctx)",
+            "u": "Type::getInt32PtrTy(Ctx)",
+            "o": "Type::getInt32PtrTy(Ctx)",
             }
             assert ty in type_code_texts, "llvm type %s is unknown" % (ty)
             ty_code = type_code_texts[ty]