Bladeren bron

Support non-int32 types with SV_***ID semantics (#2286)

It seemed overkill to request backends to support a @dx.op.threadId.i16 intrinsic so I'm always calling the i32 version and then truncating or zextending to the final type.
Tristan Labelle 6 jaren geleden
bovenliggende
commit
a0c95cd98b

+ 14 - 7
lib/HLSL/HLSignatureLower.cpp

@@ -1159,21 +1159,23 @@ void HLSignatureLower::GenerateDxilCSInputs() {
     }
 
     Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
-    Type *Ty = arg.getType();
-    if (Ty->isPointerTy())
-      Ty = Ty->getPointerElementType();
-    Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty->getScalarType());
+    Type *NumTy = arg.getType();
+    DXASSERT(!NumTy->isPointerTy(), "Unexpected byref value for CS SV_***ID semantic.");
+    DXASSERT(NumTy->getScalarType()->isIntegerTy(), "Unexpected non-integer value for CS SV_***ID semantic.");
+
+    // Always use the i32 overload of those intrinsics, and then cast as needed
+    Function *dxilFunc = hlslOP->GetOpFunc(opcode, Builder.getInt32Ty());
     Value *newArg = nullptr;
     if (opcode == OP::OpCode::FlattenedThreadIdInGroup) {
       newArg = Builder.CreateCall(dxilFunc, {OpArg});
     } else {
       unsigned vecSize = 1;
-      if (Ty->isVectorTy())
-        vecSize = Ty->getVectorNumElements();
+      if (NumTy->isVectorTy())
+        vecSize = NumTy->getVectorNumElements();
 
       newArg = Builder.CreateCall(dxilFunc, {OpArg, hlslOP->GetU32Const(0)});
       if (vecSize > 1) {
-        Value *result = UndefValue::get(Ty);
+        Value *result = UndefValue::get(VectorType::get(Builder.getInt32Ty(), vecSize));
         result = Builder.CreateInsertElement(result, newArg, (uint64_t)0);
 
         for (unsigned i = 1; i < vecSize; i++) {
@@ -1184,6 +1186,11 @@ void HLSignatureLower::GenerateDxilCSInputs() {
         newArg = result;
       }
     }
+
+    // If the argument is of non-i32 type, convert here
+    if (newArg->getType() != NumTy)
+      newArg = Builder.CreateZExtOrTrunc(newArg, NumTy);
+
     if (newArg->getType() != arg.getType()) {
       DXASSERT_NOMSG(arg.getType()->isPointerTy());
       for (User *U : arg.users()) {

+ 20 - 0
tools/clang/test/CodeGenHLSL/batch/declarations/functions/entrypoints/semantics/cs_sv_id_int16.hlsl

@@ -0,0 +1,20 @@
+// RUN: %dxc -E main -T cs_6_2 -enable-16bit-types %s | FileCheck %s
+
+// Check that compute shader SV_***ID parameters can have 16-bit integer types.
+// Regression test for GitHub issue #2268
+
+RWStructuredBuffer<uint3> buf;
+
+[numthreads(1, 1, 1)]
+void main(uint16_t3 tid : SV_DispatchThreadID)
+{
+    // CHECK: call i32 @dx.op.threadId.i32(i32 93, i32 0)
+    // CHECK: call i32 @dx.op.threadId.i32(i32 93, i32 1)
+    // CHECK: call i32 @dx.op.threadId.i32(i32 93, i32 2)
+    // Truncation honors uint16_t type
+    // CHECK: and i32 {{.*}}, 65535
+    // CHECK: and i32 {{.*}}, 65535
+    // CHECK: and i32 {{.*}}, 65535
+    // CHECK: call void @dx.op.rawBufferStore.i32
+    buf[0] = tid;
+}

+ 23 - 0
tools/clang/test/CodeGenHLSL/batch/declarations/functions/entrypoints/semantics/cs_sv_id_min12int.hlsl

@@ -0,0 +1,23 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// Check that compute shader SV_***ID parameters can have min-integer types.
+// Regression test for GitHub issue #2268
+
+RWStructuredBuffer<uint3> buf;
+
+[numthreads(1, 1, 1)]
+void main(min12int3 tid : SV_DispatchThreadID)
+{
+    // CHECK: call i32 @dx.op.threadId.i32(i32 93, i32 0)
+    // CHECK: call i32 @dx.op.threadId.i32(i32 93, i32 1)
+    // CHECK: call i32 @dx.op.threadId.i32(i32 93, i32 2)
+    // Truncation honors uint16_t type
+    // CHECK: shl i32 %{{.*}}, 16
+    // CHECK: ashr exact i32 %{{.*}}, 16
+    // CHECK: shl i32 %{{.*}}, 16
+    // CHECK: ashr exact i32 %{{.*}}, 16
+    // CHECK: shl i32 %{{.*}}, 16
+    // CHECK: ashr exact i32 %{{.*}}, 16
+    // CHECK: call void @dx.op.bufferStore.i32
+    buf[0] = tid;
+}

+ 6 - 0
tools/clang/test/CodeGenHLSL/crashes/float_cs_sv_semantic.hlsl

@@ -0,0 +1,6 @@
+// RUN: %dxc -E main -T cs_6_2 %s | FileCheck %s
+
+// Repro of GitHub #2299
+
+[numthreads(1, 1, 1)]
+void main(float3 tid : SV_DispatchThreadID) {}