Просмотр исходного кода

Fixed InterlockedCompareExchange failing validation (#2284)

The code handled instruction addrspacecasts but not constant expression ones. Refactored into a common method between two code paths.
Tristan Labelle 6 лет назад
Родитель
Сommit
9799235a6a

+ 13 - 11
lib/HLSL/HLOperationLower.cpp

@@ -4034,18 +4034,22 @@ void TranslateSharedMemAtomicBinOp(CallInst *CI, IntrinsicOp IOP, Value *addr) {
         CI->getArgOperand(HLOperandIndex::kInterlockedOriginalValueOpIndex));
         CI->getArgOperand(HLOperandIndex::kInterlockedOriginalValueOpIndex));
 }
 }
 
 
+static Value* SkipAddrSpaceCast(Value* Ptr) {
+  if (AddrSpaceCastInst *CastInst = dyn_cast<AddrSpaceCastInst>(Ptr))
+    return CastInst->getOperand(0);
+  else if (ConstantExpr *ConstExpr = dyn_cast<ConstantExpr>(Ptr)) {
+    if (ConstExpr->getOpcode() == Instruction::AddrSpaceCast) {
+      return ConstExpr->getOperand(0);
+    }
+  }
+  return Ptr;
+}
+
 Value *TranslateIopAtomicBinaryOperation(CallInst *CI, IntrinsicOp IOP,
 Value *TranslateIopAtomicBinaryOperation(CallInst *CI, IntrinsicOp IOP,
                                          DXIL::OpCode opcode,
                                          DXIL::OpCode opcode,
                                          HLOperationLowerHelper &helper,  HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
                                          HLOperationLowerHelper &helper,  HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
   Value *addr = CI->getArgOperand(HLOperandIndex::kInterlockedDestOpIndex);
   Value *addr = CI->getArgOperand(HLOperandIndex::kInterlockedDestOpIndex);
-  // Get the original addr from cast.
-  if (CastInst *castInst = dyn_cast<CastInst>(addr))
-    addr = castInst->getOperand(0);
-  else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(addr)) {
-    if (CE->getOpcode() == Instruction::AddrSpaceCast) {
-      addr = CE->getOperand(0);
-    }
-  }
+  addr = SkipAddrSpaceCast(addr);
 
 
   unsigned addressSpace = addr->getType()->getPointerAddressSpace();
   unsigned addressSpace = addr->getType()->getPointerAddressSpace();
   if (addressSpace == DXIL::kTGSMAddrSpace)
   if (addressSpace == DXIL::kTGSMAddrSpace)
@@ -4081,9 +4085,7 @@ Value *TranslateIopAtomicCmpXChg(CallInst *CI, IntrinsicOp IOP,
                                  DXIL::OpCode opcode,
                                  DXIL::OpCode opcode,
                                  HLOperationLowerHelper &helper,  HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
                                  HLOperationLowerHelper &helper,  HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
   Value *addr = CI->getArgOperand(HLOperandIndex::kInterlockedDestOpIndex);
   Value *addr = CI->getArgOperand(HLOperandIndex::kInterlockedDestOpIndex);
-  // Get the original addr from cast.
-  if (CastInst *castInst = dyn_cast<CastInst>(addr))
-    addr = castInst->getOperand(0);
+  addr = SkipAddrSpaceCast(addr);
 
 
   unsigned addressSpace = addr->getType()->getPointerAddressSpace();
   unsigned addressSpace = addr->getType()->getPointerAddressSpace();
   if (addressSpace == DXIL::kTGSMAddrSpace)
   if (addressSpace == DXIL::kTGSMAddrSpace)

+ 30 - 0
tools/clang/test/CodeGenHLSL/batch/expressions/intrinsics/Interlocked_groupshared.hlsl

@@ -0,0 +1,30 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+groupshared int sharedInt[1];
+groupshared uint sharedUInt[1];
+
+[numthreads(8, 8, 1)]
+void main()
+{
+  int v;
+  // CHECK: cmpxchg
+  InterlockedCompareExchange(sharedInt[0], -1, -2, v);
+  // CHECK: atomicrmw xchg
+  InterlockedExchange(sharedInt[0], -1, v);
+  // CHECK: atomicrmw add
+  InterlockedAdd(sharedInt[0], -1);
+  // CHECK: atomicrmw and
+  InterlockedAnd(sharedInt[0], -1);
+  // CHECK: atomicrmw or
+  InterlockedOr(sharedInt[0], -1);
+  // CHECK: atomicrmw xor
+  InterlockedXor(sharedInt[0], -1);
+  // CHECK: atomicrmw max
+  InterlockedMax(sharedInt[0], -1); // -1 to workaround GitHub #2283
+  // CHECK: atomicrmw min
+  InterlockedMin(sharedInt[0], -1); // -1 to workaround GitHub #2283
+  // CHECK: atomicrmw umax
+  InterlockedMax(sharedUInt[0], (uint)1);
+  // CHECK: atomicrmw umin
+  InterlockedMin(sharedUInt[0], (uint)1);
+}