Sfoglia il codice sorgente

[SPIRV] Emit Offset decoration for PhysicalStorage structs (#5392)

When the SPV_EXT_physical_storage_buffer extensions is used, vulkan can
load raw addresses. For structs loaded through this mecanism, the
offsets must be explicit.

This commit fixes decoration emission by attaching the correct layout
when a struct is loaded using this extensions.
(Specifying a layout different than void forces explicit offsets).

Fixes #5327

---------

Signed-off-by: Nathan Gauër <[email protected]>
Nathan Gauër 2 anni fa
parent
commit
6289ba317e

+ 6 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -369,10 +369,13 @@ SpirvUnaryOp *SpirvBuilder::createUnaryOp(spv::Op op, QualType resultType,
                                           SpirvInstruction *operand,
                                           SourceLocation loc,
                                           SourceRange range) {
+  if (!operand)
+    return nullptr;
   assert(insertPoint && "null insert point");
   auto *instruction =
       new (context) SpirvUnaryOp(op, resultType, loc, operand, range);
   insertPoint->addInstruction(instruction);
+  instruction->setLayoutRule(operand->getLayoutRule());
   return instruction;
 }
 
@@ -380,8 +383,11 @@ SpirvUnaryOp *SpirvBuilder::createUnaryOp(spv::Op op,
                                           const SpirvType *resultType,
                                           SpirvInstruction *operand,
                                           SourceLocation loc) {
+  if (!operand)
+    return nullptr;
   assert(insertPoint && "null insert point");
   auto *instruction = new (context) SpirvUnaryOp(op, resultType, loc, operand);
+  instruction->setLayoutRule(operand->getLayoutRule());
   insertPoint->addInstruction(instruction);
   return instruction;
 }

+ 123 - 12
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -3067,6 +3067,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
 
     auto *value = castToInt(loadIfGLValue(subExpr), subExprType, toType,
                             subExpr->getLocStart(), range);
+    if (!value)
+      return nullptr;
+
     value->setRValue();
     return value;
   }
@@ -3083,6 +3086,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
 
     auto *value = castToFloat(loadIfGLValue(subExpr), subExprType, toType,
                               subExpr->getLocStart(), range);
+    if (!value)
+      return nullptr;
+
     value->setRValue();
     return value;
   }
@@ -3098,6 +3104,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
 
     auto *value = castToBool(loadIfGLValue(subExpr), subExprType, toType,
                              subExpr->getLocStart(), range);
+    if (!value)
+      return nullptr;
+
     value->setRValue();
     return value;
   }
@@ -3123,6 +3132,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
                                                   expr->getExprLoc(), range);
     }
 
+    if (!value)
+      return nullptr;
+
     value->setRValue();
     return value;
   }
@@ -3153,6 +3165,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
                                                    expr->getLocStart(), range);
     auto *mat = spvBuilder.createCompositeConstruct(toType, {subVec1, subVec2},
                                                     expr->getLocStart(), range);
+    if (!mat)
+      return nullptr;
+
     mat->setRValue();
     return mat;
   }
@@ -3176,6 +3191,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
       llvm::SmallVector<SpirvConstant *, 4> vectors(
           size_t(rowCount), cast<SpirvConstant>(vecSplat));
       auto *value = spvBuilder.getConstantComposite(toType, vectors);
+      if (!value)
+        return nullptr;
+
       value->setRValue();
       return value;
     } else {
@@ -3183,6 +3201,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
                                                        vecSplat);
       auto *value = spvBuilder.createCompositeConstruct(
           toType, vectors, expr->getLocEnd(), range);
+      if (!value)
+        return nullptr;
+
       value->setRValue();
       return value;
     }
@@ -3202,6 +3223,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
       if (isVectorType(srcType, nullptr, &srcVecSize) && isScalarType(toType)) {
         auto *val = spvBuilder.createCompositeExtract(
             toType, src, {0}, expr->getLocStart(), range);
+        if (!val)
+          return nullptr;
+
         val->setRValue();
         return val;
       }
@@ -3212,6 +3236,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
           indexes.push_back(i);
         auto *val = spvBuilder.createVectorShuffle(toType, src, src, indexes,
                                                    expr->getLocStart(), range);
+        if (!val)
+          return nullptr;
+
         val->setRValue();
         return val;
       }
@@ -3251,6 +3278,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
       val = spvBuilder.createCompositeConstruct(toType, extractedVecs,
                                                 expr->getExprLoc(), range);
     }
+    if (!val)
+      return nullptr;
+
     val->setRValue();
     return val;
   }
@@ -3286,6 +3316,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
         spvBuilder.createCompositeExtract(vec2Type, mat, {1}, srcLoc, range);
     auto *vec = spvBuilder.createVectorShuffle(toType, row0, row1, {0, 1, 2, 3},
                                                srcLoc, range);
+    if (!vec)
+      return nullptr;
+
     vec->setRValue();
     return vec;
   }
@@ -3631,6 +3664,8 @@ SpirvInstruction *SpirvEmitter::doShortCircuitedConditionalOperator(
   // From now on, emit instructions into the merge block.
   spvBuilder.setInsertPoint(mergeBB);
   SpirvInstruction *result = spvBuilder.createLoad(type, tempVar, loc, range);
+  if (!result)
+    return nullptr;
   result->setRValue();
   return result;
 }
@@ -3684,6 +3719,9 @@ SpirvInstruction *SpirvEmitter::doConditional(const Expr *expr,
       }
       auto *result =
           spvBuilder.createCompositeConstruct(type, rows, loc, range);
+      if (!result)
+        return nullptr;
+
       result->setRValue();
       return result;
     }
@@ -3708,6 +3746,9 @@ SpirvInstruction *SpirvEmitter::doConditional(const Expr *expr,
 
     auto *value = spvBuilder.createSelect(type, condition, trueBranch,
                                           falseBranch, loc, range);
+    if (!value)
+      return nullptr;
+
     value->setRValue();
     return value;
   }
@@ -3746,6 +3787,9 @@ SpirvInstruction *SpirvEmitter::doConditional(const Expr *expr,
   // From now on, emit instructions into the merge block.
   spvBuilder.setInsertPoint(mergeBB);
   auto *result = spvBuilder.createLoad(type, tempVar, expr->getLocEnd(), range);
+  if (!result)
+    return nullptr;
+
   result->setRValue();
   return result;
 }
@@ -4315,6 +4359,9 @@ SpirvInstruction *SpirvEmitter::processBufferTextureLoad(
                       : elemType;
     retVal = castToBool(retVal, toType, sampledType, loc);
   }
+  if (!retVal)
+    return nullptr;
+
   retVal->setRValue();
   return retVal;
 }
@@ -4438,6 +4485,9 @@ SpirvInstruction *SpirvEmitter::processByteAddressBufferLoadStore(
           astContext.getExtVectorType(addressType, numWords);
       result = spvBuilder.createCompositeConstruct(resultType, values,
                                                    expr->getLocStart(), range);
+      if (!result)
+        return nullptr;
+
       result->setRValue();
     }
   }
@@ -5789,6 +5839,8 @@ SpirvEmitter::doExtMatrixElementExpr(const ExtMatrixElementExpr *expr) {
                   : astContext.getExtVectorType(astContext.BoolTy, size);
     value = castToBool(value, fromType, toType, expr->getLocStart());
   }
+  if (!value)
+    return nullptr;
   value->setRValue();
   return value;
 }
@@ -5862,6 +5914,9 @@ SpirvEmitter::doHLSLVectorElementExpr(const HLSLVectorElementExpr *expr,
     llvm::SmallVector<SpirvInstruction *, 4> components(accessorSize, info);
     info = spvBuilder.createCompositeConstruct(type, components,
                                                expr->getLocStart(), range);
+    if (!info)
+      return nullptr;
+
     info->setRValue();
     return info;
   }
@@ -6004,7 +6059,8 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
                                         SpirvInstruction *lhsVec) {
         auto *val = spvBuilder.createBinaryOp(spvOp, vecType, lhsVec, one,
                                               expr->getOperatorLoc(), range);
-        val->setRValue();
+        if (val)
+          val->setRValue();
         return val;
       };
       incValue = processEachVectorInMatrix(subExpr, originValue, actOnEachVec,
@@ -6017,6 +6073,8 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
     // If this is a RWBuffer/RWTexture assignment, OpImageWrite will be used.
     // Otherwise, store using OpStore.
     if (tryToAssignToRWBufferRWTexture(subExpr, incValue, range)) {
+      if (!incValue)
+        return nullptr;
       incValue->setRValue();
       subValue = incValue;
     } else {
@@ -6027,14 +6085,19 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
     // increment/decrement returns a rvalue.
     if (isPre) {
       return subValue;
-    } else {
-      originValue->setRValue();
-      return originValue;
     }
+
+    if (!originValue)
+      return nullptr;
+    originValue->setRValue();
+    return originValue;
   }
   case UO_Not: {
     subValue = spvBuilder.createUnaryOp(spv::Op::OpNot, subType, subValue,
                                         expr->getOperatorLoc(), range);
+    if (!subValue)
+      return nullptr;
+
     subValue->setRValue();
     return subValue;
   }
@@ -6044,6 +6107,9 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
     subValue =
         spvBuilder.createUnaryOp(spv::Op::OpLogicalNot, subType, subValue,
                                  expr->getOperatorLoc(), range);
+    if (!subValue)
+      return nullptr;
+
     subValue->setRValue();
     return subValue;
   }
@@ -6069,6 +6135,9 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
     } else {
       subValue = spvBuilder.createUnaryOp(spvOp, subType, subValue,
                                           expr->getOperatorLoc(), range);
+      if (!subValue)
+        return nullptr;
+
       subValue->setRValue();
       return subValue;
     }
@@ -6569,6 +6638,9 @@ SpirvInstruction *SpirvEmitter::processBinaryOp(
     SpirvInstruction *result =
         castToType(tempVar, astContext.BoolTy, resultType, loc, sourceRange);
     result = spvBuilder.createLoad(resultType, tempVar, loc, sourceRange);
+    if (!result)
+      return nullptr;
+
     result->setRValue();
     return result;
   }
@@ -6653,6 +6725,9 @@ SpirvInstruction *SpirvEmitter::processBinaryOp(
               rhsValConstant->isSpecConstant()) {
             auto *val = spvBuilder.createSpecConstantBinaryOp(
                 spvOp, resultType, lhsVal, rhsVal, loc);
+            if (!val)
+              return nullptr;
+
             val->setRValue();
             return val;
           }
@@ -6675,6 +6750,8 @@ SpirvInstruction *SpirvEmitter::processBinaryOp(
                                       sourceRange);
     }
 
+    if (!val)
+      return nullptr;
     val->setRValue();
 
     // Propagate RelaxedPrecision
@@ -6929,6 +7006,8 @@ SpirvInstruction *SpirvEmitter::createVectorSplat(const Expr *scalarExpr,
   // Try to evaluate the element as constant first. If successful, then we
   // can generate constant instructions for this vector splat.
   if ((scalarVal = tryToEvaluateAsConst(scalarExpr))) {
+    if (!scalarVal)
+      return nullptr;
     scalarVal->setRValue();
   } else {
     scalarVal = loadIfGLValue(scalarExpr, range);
@@ -6949,12 +7028,16 @@ SpirvInstruction *SpirvEmitter::createVectorSplat(const Expr *scalarExpr,
     llvm::SmallVector<SpirvConstant *, 4> elements(size_t(size), constVal);
     const bool isSpecConst = constVal->getopcode() == spv::Op::OpSpecConstant;
     auto *value = spvBuilder.getConstantComposite(vecType, elements, isSpecConst);
+    if (!value)
+      return nullptr;
     value->setRValue();
     return value;
   } else {
     llvm::SmallVector<SpirvInstruction *, 4> elements(size_t(size), scalarVal);
     auto *value = spvBuilder.createCompositeConstruct(
         vecType, elements, scalarExpr->getLocStart(), range);
+    if (!value)
+      return nullptr;
     value->setRValue();
     return value;
   }
@@ -7611,6 +7694,8 @@ SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
 
   // Construct the result matrix
   auto *val = spvBuilder.createCompositeConstruct(matType, vectors, loc, range);
+  if (!val)
+    return nullptr;
   val->setRValue();
   return val;
 }
@@ -7723,7 +7808,8 @@ SpirvEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
                                                        rhs->getLocStart());
       auto *val =
           spvBuilder.createBinaryOp(spvOp, vecType, lhsVec, rhsVec, loc, range);
-      val->setRValue();
+      if (val)
+        val->setRValue();
       return val;
     };
     return processEachVectorInMatrix(lhs, lhsVal, actOnEachVec,
@@ -13525,6 +13611,8 @@ SpirvInstruction *SpirvEmitter::createSpirvIntrInstExt(
 
   SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt(
       op, retType, spvArgs, extensions, instSet, capbilities, loc);
+  if (!retVal)
+    return nullptr;
 
   // TODO: Revisit this r-value setting when handling vk::ext_result_id<T> ?
   retVal->setRValue();
@@ -13643,6 +13731,8 @@ SpirvInstruction *SpirvEmitter::processRawBufferLoad(const CallExpr *callExpr) {
   SpirvInstruction *load =
       loadDataFromRawAddress(address, bufferType, alignment, loc);
   auto *loadAsBool = castToBool(load, bufferType, boolType, loc);
+  if (!loadAsBool)
+    return nullptr;
   loadAsBool->setRValue();
   return loadAsBool;
 }
@@ -13670,23 +13760,39 @@ SpirvEmitter::loadDataFromRawAddress(SpirvInstruction *addressInUInt64,
   return loadInst;
 }
 
-SpirvInstruction *SpirvEmitter::storeDataToRawAddress(
-    SpirvInstruction *addressInUInt64, SpirvInstruction *value,
-    QualType bufferType, uint32_t alignment, SourceLocation loc) {
-
+SpirvInstruction *
+SpirvEmitter::storeDataToRawAddress(SpirvInstruction *addressInUInt64,
+                                    SpirvInstruction *value,
+                                    QualType bufferType, uint32_t alignment,
+                                    SourceLocation loc, SourceRange range) {
   // Summary:
   //   %address = OpBitcast %ptrTobufferType %addressInUInt64
   //   %storeInst = OpStore %address %value alignment %alignment
+  if (!value || !addressInUInt64)
+    return nullptr;
 
   const HybridPointerType *bufferPtrType =
       spvBuilder.getPhysicalStorageBufferType(bufferType);
 
   SpirvUnaryOp *address = spvBuilder.createUnaryOp(
       spv::Op::OpBitcast, bufferPtrType, addressInUInt64, loc);
+  if (!address)
+    return nullptr;
   address->setStorageClass(spv::StorageClass::PhysicalStorageBuffer);
 
-  SpirvStore *storeInst = spvBuilder.createStore(address, value, loc);
+  // If the source value has a different layout, it is not safe to directly
+  // store it. It needs to be component-wise reconstructed to the new layout.
+  SpirvInstruction *source = value;
+  if (value->getStorageClass() != address->getStorageClass()) {
+    source = reconstructValue(value, bufferType, address->getLayoutRule(), loc,
+                              range);
+  }
+  if (!source)
+    return nullptr;
+
+  SpirvStore *storeInst = spvBuilder.createStore(address, source, loc);
   storeInst->setAlignment(alignment);
+  storeInst->setStorageClass(spv::StorageClass::PhysicalStorageBuffer);
   return nullptr;
 }
 
@@ -13698,10 +13804,14 @@ SpirvEmitter::processRawBufferStore(const CallExpr *callExpr) {
 
   SpirvInstruction *address = doExpr(callExpr->getArg(0));
   SpirvInstruction *value = doExpr(callExpr->getArg(1));
+  if (!address || !value)
+    return nullptr;
+
   QualType bufferType = value->getAstResultType();
   clang::SourceLocation loc = callExpr->getExprLoc();
   if (!isBoolOrVecMatOfBoolType(bufferType)) {
-    return storeDataToRawAddress(address, value, bufferType, alignment, loc);
+    return storeDataToRawAddress(address, value, bufferType, alignment, loc,
+                                 callExpr->getLocStart());
   }
 
   // If callExpr is `vk::RawBufferLoad<bool>(..)`, we have to load 'uint' and
@@ -13717,7 +13827,8 @@ SpirvEmitter::processRawBufferStore(const CallExpr *callExpr) {
   QualType boolType = bufferType;
   bufferType = getUintTypeForBool(astContext, theCompilerInstance, boolType);
   auto *storeAsInt = castToInt(value, boolType, bufferType, loc);
-  return storeDataToRawAddress(address, storeAsInt, bufferType, alignment, loc);
+  return storeDataToRawAddress(address, storeAsInt, bufferType, alignment, loc,
+                               callExpr->getLocStart());
 }
 
 SpirvInstruction *

+ 2 - 1
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -656,7 +656,8 @@ private:
                                           SpirvInstruction *value,
                                           QualType bufferType,
                                           uint32_t alignment,
-                                          SourceLocation loc);
+                                          SourceLocation loc,
+                                          SourceRange range);
 
   /// Returns the alignment of `vk::RawBufferLoad()`.
   uint32_t getAlignmentForRawBufferLoad(const CallExpr *callExpr);

+ 88 - 17
tools/clang/test/CodeGenSPIRV/intrinsics.vkrawbufferload.bitfield.hlsl

@@ -1,8 +1,26 @@
-// RUN: %dxc -T ps_6_0 -E main -HV 2021
+// RUN: %dxc -T cs_6_0 -E main -HV 2021
 
 // CHECK: OpCapability PhysicalStorageBufferAddresses
 // CHECK: OpExtension "SPV_KHR_physical_storage_buffer"
 // CHECK: OpMemoryModel PhysicalStorageBuffer64 GLSL450
+// CHECK-NOT: OpMemberDecorate %S 0 Offset 0
+// CHECK-NOT: OpMemberDecorate %S 1 Offset 4
+// CHECK-NOT: OpMemberDecorate %S 2 Offset 8
+// CHECK-NOT: OpMemberDecorate %S 3 Offset 12
+// CHECK: OpMemberDecorate %S_0 0 Offset 0
+// CHECK: OpMemberDecorate %S_0 1 Offset 4
+// CHECK: OpMemberDecorate %S_0 2 Offset 8
+// CHECK-NOT: OpMemberDecorate %S_0 3 Offset 12
+
+// CHECK: %S = OpTypeStruct %uint %uint %uint
+// CHECK: %_ptr_Function_S = OpTypePointer Function %S
+// CHECK: %S_0 = OpTypeStruct %uint %uint %uint
+// CHECK: %_ptr_PhysicalStorageBuffer_S_0 = OpTypePointer PhysicalStorageBuffer %S_0
+
+// CHECK: %temp_var_S = OpVariable %_ptr_Function_S Function
+// CHECK: %temp_var_S_0 = OpVariable %_ptr_Function_S Function
+// CHECK: %temp_var_S_1 = OpVariable %_ptr_Function_S Function
+// CHECK: %temp_var_S_2 = OpVariable %_ptr_Function_S Function
 
 struct S {
   uint f1;
@@ -13,21 +31,74 @@ struct S {
 
 uint64_t Address;
 
-// CHECK: [[type_S:%\w+]] = OpTypeStruct %uint %uint %uint
-// CHECK: [[ptr_f_S:%\w+]] = OpTypePointer Function [[type_S]]
-// CHECK: [[ptr_p_S:%\w+]] = OpTypePointer PhysicalStorageBuffer [[type_S]]
-
-void main() : B {
-// CHECK: [[tmp_S:%\w+]] = OpVariable [[ptr_f_S]] Function
-// CHECK: [[value:%\d+]] = OpAccessChain %_ptr_Uniform_ulong %_Globals %int_0
-// CHECK: [[value:%\d+]] = OpLoad %ulong [[value]]
-// CHECK: [[value:%\d+]] = OpBitcast [[ptr_p_S]] [[value]]
-// CHECK: [[value:%\d+]] = OpLoad [[type_S]] [[value]] Aligned 4
-// CHECK: OpStore [[tmp_S]] [[value]]
-// CHECK: [[value:%\d+]] = OpAccessChain %_ptr_Function_uint [[tmp_S]] %int_1
-// CHECK: [[value:%\d+]] = OpLoad %uint [[value]]
-// CHECK: [[value:%\d+]] = OpBitFieldUExtract %uint [[value]] %uint_1 %uint_1
-// CHECK: OpStore %tmp [[value]]
-  uint tmp = vk::RawBufferLoad<S>(Address).f3;
+[numthreads(1, 1, 1)]
+void main(uint3 tid : SV_DispatchThreadID) {
+
+
+  {
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Uniform_ulong %_Globals %int_0
+    // CHECK: [[tmp:%\d+]] = OpLoad %ulong [[tmp]]
+    // CHECK: [[ptr:%\d+]] = OpBitcast %_ptr_PhysicalStorageBuffer_S_0 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpLoad %S_0 [[ptr]] Aligned 4
+    // CHECK: [[member0:%\d+]] = OpCompositeExtract %uint [[tmp]] 0
+    // CHECK: [[member1:%\d+]] = OpCompositeExtract %uint [[tmp]] 1
+    // CHECK: [[member2:%\d+]] = OpCompositeExtract %uint [[tmp]] 2
+    // CHECK: [[tmp:%\d+]] = OpCompositeConstruct %S [[member0]] [[member1]] [[member2]]
+    // CHECK: OpStore %temp_var_S [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint %temp_var_S %int_0
+    // CHECK: [[tmp:%\d+]] = OpLoad %uint [[tmp]]
+    // CHECK: OpStore %tmp1 [[tmp]]
+    uint tmp1 = vk::RawBufferLoad<S>(Address).f1;
+  }
+
+  {
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Uniform_ulong %_Globals %int_0
+    // CHECK: [[tmp:%\d+]] = OpLoad %ulong [[tmp]]
+    // CHECK: [[ptr:%\d+]] = OpBitcast %_ptr_PhysicalStorageBuffer_S_0 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpLoad %S_0 [[ptr]] Aligned 4
+    // CHECK: [[member0:%\d+]] = OpCompositeExtract %uint [[tmp]] 0
+    // CHECK: [[member1:%\d+]] = OpCompositeExtract %uint [[tmp]] 1
+    // CHECK: [[member2:%\d+]] = OpCompositeExtract %uint [[tmp]] 2
+    // CHECK: [[tmp:%\d+]] = OpCompositeConstruct %S [[member0]] [[member1]] [[member2]]
+    // CHECK: OpStore %temp_var_S_0 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint %temp_var_S_0 %int_1
+    // CHECK: [[tmp:%\d+]] = OpLoad %uint [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpBitFieldUExtract %uint [[tmp]] %uint_0 %uint_1
+    // CHECK: OpStore %tmp2 [[tmp]]
+    uint tmp2 = vk::RawBufferLoad<S>(Address).f2;
+  }
+
+  {
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Uniform_ulong %_Globals %int_0
+    // CHECK: [[tmp:%\d+]] = OpLoad %ulong [[tmp]]
+    // CHECK: [[ptr:%\d+]] = OpBitcast %_ptr_PhysicalStorageBuffer_S_0 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpLoad %S_0 [[ptr]] Aligned 4
+    // CHECK: [[member0:%\d+]] = OpCompositeExtract %uint [[tmp]] 0
+    // CHECK: [[member1:%\d+]] = OpCompositeExtract %uint [[tmp]] 1
+    // CHECK: [[member2:%\d+]] = OpCompositeExtract %uint [[tmp]] 2
+    // CHECK: [[tmp:%\d+]] = OpCompositeConstruct %S [[member0]] [[member1]] [[member2]]
+    // CHECK: OpStore %temp_var_S_1 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint %temp_var_S_1 %int_1
+    // CHECK: [[tmp:%\d+]] = OpLoad %uint [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpBitFieldUExtract %uint [[tmp]] %uint_1 %uint_1
+    // CHECK: OpStore %tmp3 [[tmp]]
+    uint tmp3 = vk::RawBufferLoad<S>(Address).f3;
+  }
+
+  {
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Uniform_ulong %_Globals %int_0
+    // CHECK: [[tmp:%\d+]] = OpLoad %ulong [[tmp]]
+    // CHECK: [[ptr:%\d+]] = OpBitcast %_ptr_PhysicalStorageBuffer_S_0 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpLoad %S_0 [[ptr]] Aligned 4
+    // CHECK: [[member0:%\d+]] = OpCompositeExtract %uint [[tmp]] 0
+    // CHECK: [[member1:%\d+]] = OpCompositeExtract %uint [[tmp]] 1
+    // CHECK: [[member2:%\d+]] = OpCompositeExtract %uint [[tmp]] 2
+    // CHECK: [[tmp:%\d+]] = OpCompositeConstruct %S [[member0]] [[member1]] [[member2]]
+    // CHECK: OpStore %temp_var_S_2 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint %temp_var_S_2 %int_2
+    // CHECK: [[tmp:%\d+]] = OpLoad %uint [[tmp]]
+    // CHECK: OpStore %tmp4 [[tmp]]
+    uint tmp4 = vk::RawBufferLoad<S>(Address).f4;
+  }
 }
 

+ 56 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.vkrawbufferstore.bitfields.hlsl

@@ -0,0 +1,56 @@
+// RUN: %dxc -T cs_6_0 -E main -HV 2021
+
+// CHECK: OpCapability PhysicalStorageBufferAddresses
+// CHECK: OpExtension "SPV_KHR_physical_storage_buffer"
+// CHECK: OpMemoryModel PhysicalStorageBuffer64 GLSL450
+// CHECK-NOT: OpMemberDecorate %S 0 Offset 0
+// CHECK-NOT: OpMemberDecorate %S 1 Offset 4
+// CHECK-NOT: OpMemberDecorate %S 2 Offset 8
+// CHECK-NOT: OpMemberDecorate %S 3 Offset 12
+// CHECK: OpMemberDecorate %S_0 0 Offset 0
+// CHECK: OpMemberDecorate %S_0 1 Offset 4
+// CHECK: OpMemberDecorate %S_0 2 Offset 8
+// CHECK-NOT: OpMemberDecorate %S_0 3 Offset 12
+
+// CHECK: %S = OpTypeStruct %uint %uint %uint
+// CHECK: %_ptr_Function_S = OpTypePointer Function %S
+// CHECK: %S_0 = OpTypeStruct %uint %uint %uint
+// CHECK: %_ptr_PhysicalStorageBuffer_S_0 = OpTypePointer PhysicalStorageBuffer %S_0
+
+
+struct S {
+  uint f1;
+  uint f2 : 1;
+  uint f3 : 3;
+  uint f4;
+};
+
+uint64_t Address;
+
+[numthreads(1, 1, 1)]
+void main(uint3 tid : SV_DispatchThreadID) {
+  // CHECK: %tmp = OpVariable %_ptr_Function_S Function
+  S tmp;
+
+  // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint %tmp %int_0
+  // CHECK: OpStore [[tmp]] %uint_2
+  tmp.f1 = 2;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint %tmp %int_1
+  // CHECK: [[tmp:%\d+]] = OpLoad %uint [[ptr]]
+  // CHECK: [[tmp:%\d+]] = OpBitFieldInsert %uint [[tmp]] %uint_1 %uint_0 %uint_1
+  // CHECK: OpStore [[ptr]] [[tmp]]
+  tmp.f2 = 1;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint %tmp %int_1
+  // CHECK: [[tmp:%\d+]] = OpLoad %uint [[ptr]]
+  // CHECK: [[tmp:%\d+]] = OpBitFieldInsert %uint [[tmp]] %uint_0 %uint_1 %uint_3
+  // CHECK: OpStore [[ptr]] [[tmp]]
+  tmp.f3 = 0;
+
+  // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint %tmp %int_2
+  // CHECK: OpStore [[tmp]] %uint_3
+  tmp.f4 = 3;
+  vk::RawBufferStore<S>(Address, tmp);
+}
+

+ 19 - 2
tools/clang/test/CodeGenSPIRV/intrinsics.vkrawbufferstore.hlsl

@@ -3,6 +3,18 @@
 // CHECK: OpCapability PhysicalStorageBufferAddresses
 // CHECK: OpExtension "SPV_KHR_physical_storage_buffer"
 // CHECK: OpMemoryModel PhysicalStorageBuffer64 GLSL450
+// CHECK-NOT: OpMemberDecorate %XYZW 0 Offset 0
+// CHECK-NOT: OpMemberDecorate %XYZW 1 Offset 4
+// CHECK-NOT: OpMemberDecorate %XYZW 2 Offset 8
+// CHECK-NOT: OpMemberDecorate %XYZW 3 Offset 12
+// CHECK: OpMemberDecorate %XYZW_0 0 Offset 0
+// CHECK: OpMemberDecorate %XYZW_0 1 Offset 4
+// CHECK: OpMemberDecorate %XYZW_0 2 Offset 8
+// CHECK: OpMemberDecorate %XYZW_0 3 Offset 12
+// CHECK: %XYZW = OpTypeStruct %int %int %int %int
+// CHECK: %_ptr_Function_XYZW = OpTypePointer Function %XYZW
+// CHECK: %XYZW_0 = OpTypeStruct %int %int %int %int
+// CHECK: %_ptr_PhysicalStorageBuffer_XYZW_0 = OpTypePointer PhysicalStorageBuffer %XYZW_0
 
 struct XYZW {
   int x;
@@ -50,8 +62,13 @@ void main(uint3 tid : SV_DispatchThreadID) {
 
   // CHECK:      [[addr:%\d+]] = OpLoad %ulong
   // CHECK-NEXT: [[xyzwval:%\d+]] = OpLoad %XYZW %xyzw
-  // CHECK-NEXT: [[buf:%\d+]] = OpBitcast %_ptr_PhysicalStorageBuffer_XYZW [[addr]]
-  // CHECK-NEXT: OpStore [[buf]] [[xyzwval]] Aligned 4
+  // CHECK-NEXT: [[buf:%\d+]] = OpBitcast %_ptr_PhysicalStorageBuffer_XYZW_0 [[addr]]
+  // CHECK-NEXT: [[member1:%\d+]] = OpCompositeExtract %int [[xyzwval]] 0
+  // CHECK-NEXT: [[member2:%\d+]] = OpCompositeExtract %int [[xyzwval]] 1
+  // CHECK-NEXT: [[member3:%\d+]] = OpCompositeExtract %int [[xyzwval]] 2
+  // CHECK-NEXT: [[member4:%\d+]] = OpCompositeExtract %int [[xyzwval]] 3
+  // CHECK-NEXT: [[p_xyzwval:%\d+]] = OpCompositeConstruct %XYZW_0 [[member1]] [[member2]] [[member3]] [[member4]]
+  // CHECK-NEXT: OpStore [[buf]] [[p_xyzwval]] Aligned 4
   XYZW xyzw;
   xyzw.x = 78;
   xyzw.y = 65;

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -1461,6 +1461,9 @@ TEST_F(FileTest, IntrinsicsVkRawBufferLoadBitfield) {
 TEST_F(FileTest, IntrinsicsVkRawBufferStore) {
   runFileTest("intrinsics.vkrawbufferstore.hlsl");
 }
+TEST_F(FileTest, IntrinsicsVkRawBufferStoreBitfields) {
+  runFileTest("intrinsics.vkrawbufferstore.bitfields.hlsl");
+}
 // Intrinsics added in SM 6.6
 TEST_F(FileTest, IntrinsicsSM66PackU8S8) {
   runFileTest("intrinsics.sm6_6.pack_s8u8.hlsl");