2
0
Эх сурвалжийг харах

[SPIRV] Emit Offset decoration for PhysicalStorage structs (#5392)

When the SPV_EXT_physical_storage_buffer extensions is used, vulkan can
load raw addresses. For structs loaded through this mecanism, the
offsets must be explicit.

This commit fixes decoration emission by attaching the correct layout
when a struct is loaded using this extensions.
(Specifying a layout different than void forces explicit offsets).

Fixes #5327

---------

Signed-off-by: Nathan Gauër <[email protected]>
Nathan Gauër 2 жил өмнө
parent
commit
6289ba317e

+ 6 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -369,10 +369,13 @@ SpirvUnaryOp *SpirvBuilder::createUnaryOp(spv::Op op, QualType resultType,
                                           SpirvInstruction *operand,
                                           SpirvInstruction *operand,
                                           SourceLocation loc,
                                           SourceLocation loc,
                                           SourceRange range) {
                                           SourceRange range) {
+  if (!operand)
+    return nullptr;
   assert(insertPoint && "null insert point");
   assert(insertPoint && "null insert point");
   auto *instruction =
   auto *instruction =
       new (context) SpirvUnaryOp(op, resultType, loc, operand, range);
       new (context) SpirvUnaryOp(op, resultType, loc, operand, range);
   insertPoint->addInstruction(instruction);
   insertPoint->addInstruction(instruction);
+  instruction->setLayoutRule(operand->getLayoutRule());
   return instruction;
   return instruction;
 }
 }
 
 
@@ -380,8 +383,11 @@ SpirvUnaryOp *SpirvBuilder::createUnaryOp(spv::Op op,
                                           const SpirvType *resultType,
                                           const SpirvType *resultType,
                                           SpirvInstruction *operand,
                                           SpirvInstruction *operand,
                                           SourceLocation loc) {
                                           SourceLocation loc) {
+  if (!operand)
+    return nullptr;
   assert(insertPoint && "null insert point");
   assert(insertPoint && "null insert point");
   auto *instruction = new (context) SpirvUnaryOp(op, resultType, loc, operand);
   auto *instruction = new (context) SpirvUnaryOp(op, resultType, loc, operand);
+  instruction->setLayoutRule(operand->getLayoutRule());
   insertPoint->addInstruction(instruction);
   insertPoint->addInstruction(instruction);
   return instruction;
   return instruction;
 }
 }

+ 123 - 12
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -3067,6 +3067,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
 
 
     auto *value = castToInt(loadIfGLValue(subExpr), subExprType, toType,
     auto *value = castToInt(loadIfGLValue(subExpr), subExprType, toType,
                             subExpr->getLocStart(), range);
                             subExpr->getLocStart(), range);
+    if (!value)
+      return nullptr;
+
     value->setRValue();
     value->setRValue();
     return value;
     return value;
   }
   }
@@ -3083,6 +3086,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
 
 
     auto *value = castToFloat(loadIfGLValue(subExpr), subExprType, toType,
     auto *value = castToFloat(loadIfGLValue(subExpr), subExprType, toType,
                               subExpr->getLocStart(), range);
                               subExpr->getLocStart(), range);
+    if (!value)
+      return nullptr;
+
     value->setRValue();
     value->setRValue();
     return value;
     return value;
   }
   }
@@ -3098,6 +3104,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
 
 
     auto *value = castToBool(loadIfGLValue(subExpr), subExprType, toType,
     auto *value = castToBool(loadIfGLValue(subExpr), subExprType, toType,
                              subExpr->getLocStart(), range);
                              subExpr->getLocStart(), range);
+    if (!value)
+      return nullptr;
+
     value->setRValue();
     value->setRValue();
     return value;
     return value;
   }
   }
@@ -3123,6 +3132,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
                                                   expr->getExprLoc(), range);
                                                   expr->getExprLoc(), range);
     }
     }
 
 
+    if (!value)
+      return nullptr;
+
     value->setRValue();
     value->setRValue();
     return value;
     return value;
   }
   }
@@ -3153,6 +3165,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
                                                    expr->getLocStart(), range);
                                                    expr->getLocStart(), range);
     auto *mat = spvBuilder.createCompositeConstruct(toType, {subVec1, subVec2},
     auto *mat = spvBuilder.createCompositeConstruct(toType, {subVec1, subVec2},
                                                     expr->getLocStart(), range);
                                                     expr->getLocStart(), range);
+    if (!mat)
+      return nullptr;
+
     mat->setRValue();
     mat->setRValue();
     return mat;
     return mat;
   }
   }
@@ -3176,6 +3191,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
       llvm::SmallVector<SpirvConstant *, 4> vectors(
       llvm::SmallVector<SpirvConstant *, 4> vectors(
           size_t(rowCount), cast<SpirvConstant>(vecSplat));
           size_t(rowCount), cast<SpirvConstant>(vecSplat));
       auto *value = spvBuilder.getConstantComposite(toType, vectors);
       auto *value = spvBuilder.getConstantComposite(toType, vectors);
+      if (!value)
+        return nullptr;
+
       value->setRValue();
       value->setRValue();
       return value;
       return value;
     } else {
     } else {
@@ -3183,6 +3201,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
                                                        vecSplat);
                                                        vecSplat);
       auto *value = spvBuilder.createCompositeConstruct(
       auto *value = spvBuilder.createCompositeConstruct(
           toType, vectors, expr->getLocEnd(), range);
           toType, vectors, expr->getLocEnd(), range);
+      if (!value)
+        return nullptr;
+
       value->setRValue();
       value->setRValue();
       return value;
       return value;
     }
     }
@@ -3202,6 +3223,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
       if (isVectorType(srcType, nullptr, &srcVecSize) && isScalarType(toType)) {
       if (isVectorType(srcType, nullptr, &srcVecSize) && isScalarType(toType)) {
         auto *val = spvBuilder.createCompositeExtract(
         auto *val = spvBuilder.createCompositeExtract(
             toType, src, {0}, expr->getLocStart(), range);
             toType, src, {0}, expr->getLocStart(), range);
+        if (!val)
+          return nullptr;
+
         val->setRValue();
         val->setRValue();
         return val;
         return val;
       }
       }
@@ -3212,6 +3236,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
           indexes.push_back(i);
           indexes.push_back(i);
         auto *val = spvBuilder.createVectorShuffle(toType, src, src, indexes,
         auto *val = spvBuilder.createVectorShuffle(toType, src, src, indexes,
                                                    expr->getLocStart(), range);
                                                    expr->getLocStart(), range);
+        if (!val)
+          return nullptr;
+
         val->setRValue();
         val->setRValue();
         return val;
         return val;
       }
       }
@@ -3251,6 +3278,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
       val = spvBuilder.createCompositeConstruct(toType, extractedVecs,
       val = spvBuilder.createCompositeConstruct(toType, extractedVecs,
                                                 expr->getExprLoc(), range);
                                                 expr->getExprLoc(), range);
     }
     }
+    if (!val)
+      return nullptr;
+
     val->setRValue();
     val->setRValue();
     return val;
     return val;
   }
   }
@@ -3286,6 +3316,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr,
         spvBuilder.createCompositeExtract(vec2Type, mat, {1}, srcLoc, range);
         spvBuilder.createCompositeExtract(vec2Type, mat, {1}, srcLoc, range);
     auto *vec = spvBuilder.createVectorShuffle(toType, row0, row1, {0, 1, 2, 3},
     auto *vec = spvBuilder.createVectorShuffle(toType, row0, row1, {0, 1, 2, 3},
                                                srcLoc, range);
                                                srcLoc, range);
+    if (!vec)
+      return nullptr;
+
     vec->setRValue();
     vec->setRValue();
     return vec;
     return vec;
   }
   }
@@ -3631,6 +3664,8 @@ SpirvInstruction *SpirvEmitter::doShortCircuitedConditionalOperator(
   // From now on, emit instructions into the merge block.
   // From now on, emit instructions into the merge block.
   spvBuilder.setInsertPoint(mergeBB);
   spvBuilder.setInsertPoint(mergeBB);
   SpirvInstruction *result = spvBuilder.createLoad(type, tempVar, loc, range);
   SpirvInstruction *result = spvBuilder.createLoad(type, tempVar, loc, range);
+  if (!result)
+    return nullptr;
   result->setRValue();
   result->setRValue();
   return result;
   return result;
 }
 }
@@ -3684,6 +3719,9 @@ SpirvInstruction *SpirvEmitter::doConditional(const Expr *expr,
       }
       }
       auto *result =
       auto *result =
           spvBuilder.createCompositeConstruct(type, rows, loc, range);
           spvBuilder.createCompositeConstruct(type, rows, loc, range);
+      if (!result)
+        return nullptr;
+
       result->setRValue();
       result->setRValue();
       return result;
       return result;
     }
     }
@@ -3708,6 +3746,9 @@ SpirvInstruction *SpirvEmitter::doConditional(const Expr *expr,
 
 
     auto *value = spvBuilder.createSelect(type, condition, trueBranch,
     auto *value = spvBuilder.createSelect(type, condition, trueBranch,
                                           falseBranch, loc, range);
                                           falseBranch, loc, range);
+    if (!value)
+      return nullptr;
+
     value->setRValue();
     value->setRValue();
     return value;
     return value;
   }
   }
@@ -3746,6 +3787,9 @@ SpirvInstruction *SpirvEmitter::doConditional(const Expr *expr,
   // From now on, emit instructions into the merge block.
   // From now on, emit instructions into the merge block.
   spvBuilder.setInsertPoint(mergeBB);
   spvBuilder.setInsertPoint(mergeBB);
   auto *result = spvBuilder.createLoad(type, tempVar, expr->getLocEnd(), range);
   auto *result = spvBuilder.createLoad(type, tempVar, expr->getLocEnd(), range);
+  if (!result)
+    return nullptr;
+
   result->setRValue();
   result->setRValue();
   return result;
   return result;
 }
 }
@@ -4315,6 +4359,9 @@ SpirvInstruction *SpirvEmitter::processBufferTextureLoad(
                       : elemType;
                       : elemType;
     retVal = castToBool(retVal, toType, sampledType, loc);
     retVal = castToBool(retVal, toType, sampledType, loc);
   }
   }
+  if (!retVal)
+    return nullptr;
+
   retVal->setRValue();
   retVal->setRValue();
   return retVal;
   return retVal;
 }
 }
@@ -4438,6 +4485,9 @@ SpirvInstruction *SpirvEmitter::processByteAddressBufferLoadStore(
           astContext.getExtVectorType(addressType, numWords);
           astContext.getExtVectorType(addressType, numWords);
       result = spvBuilder.createCompositeConstruct(resultType, values,
       result = spvBuilder.createCompositeConstruct(resultType, values,
                                                    expr->getLocStart(), range);
                                                    expr->getLocStart(), range);
+      if (!result)
+        return nullptr;
+
       result->setRValue();
       result->setRValue();
     }
     }
   }
   }
@@ -5789,6 +5839,8 @@ SpirvEmitter::doExtMatrixElementExpr(const ExtMatrixElementExpr *expr) {
                   : astContext.getExtVectorType(astContext.BoolTy, size);
                   : astContext.getExtVectorType(astContext.BoolTy, size);
     value = castToBool(value, fromType, toType, expr->getLocStart());
     value = castToBool(value, fromType, toType, expr->getLocStart());
   }
   }
+  if (!value)
+    return nullptr;
   value->setRValue();
   value->setRValue();
   return value;
   return value;
 }
 }
@@ -5862,6 +5914,9 @@ SpirvEmitter::doHLSLVectorElementExpr(const HLSLVectorElementExpr *expr,
     llvm::SmallVector<SpirvInstruction *, 4> components(accessorSize, info);
     llvm::SmallVector<SpirvInstruction *, 4> components(accessorSize, info);
     info = spvBuilder.createCompositeConstruct(type, components,
     info = spvBuilder.createCompositeConstruct(type, components,
                                                expr->getLocStart(), range);
                                                expr->getLocStart(), range);
+    if (!info)
+      return nullptr;
+
     info->setRValue();
     info->setRValue();
     return info;
     return info;
   }
   }
@@ -6004,7 +6059,8 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
                                         SpirvInstruction *lhsVec) {
                                         SpirvInstruction *lhsVec) {
         auto *val = spvBuilder.createBinaryOp(spvOp, vecType, lhsVec, one,
         auto *val = spvBuilder.createBinaryOp(spvOp, vecType, lhsVec, one,
                                               expr->getOperatorLoc(), range);
                                               expr->getOperatorLoc(), range);
-        val->setRValue();
+        if (val)
+          val->setRValue();
         return val;
         return val;
       };
       };
       incValue = processEachVectorInMatrix(subExpr, originValue, actOnEachVec,
       incValue = processEachVectorInMatrix(subExpr, originValue, actOnEachVec,
@@ -6017,6 +6073,8 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
     // If this is a RWBuffer/RWTexture assignment, OpImageWrite will be used.
     // If this is a RWBuffer/RWTexture assignment, OpImageWrite will be used.
     // Otherwise, store using OpStore.
     // Otherwise, store using OpStore.
     if (tryToAssignToRWBufferRWTexture(subExpr, incValue, range)) {
     if (tryToAssignToRWBufferRWTexture(subExpr, incValue, range)) {
+      if (!incValue)
+        return nullptr;
       incValue->setRValue();
       incValue->setRValue();
       subValue = incValue;
       subValue = incValue;
     } else {
     } else {
@@ -6027,14 +6085,19 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
     // increment/decrement returns a rvalue.
     // increment/decrement returns a rvalue.
     if (isPre) {
     if (isPre) {
       return subValue;
       return subValue;
-    } else {
-      originValue->setRValue();
-      return originValue;
     }
     }
+
+    if (!originValue)
+      return nullptr;
+    originValue->setRValue();
+    return originValue;
   }
   }
   case UO_Not: {
   case UO_Not: {
     subValue = spvBuilder.createUnaryOp(spv::Op::OpNot, subType, subValue,
     subValue = spvBuilder.createUnaryOp(spv::Op::OpNot, subType, subValue,
                                         expr->getOperatorLoc(), range);
                                         expr->getOperatorLoc(), range);
+    if (!subValue)
+      return nullptr;
+
     subValue->setRValue();
     subValue->setRValue();
     return subValue;
     return subValue;
   }
   }
@@ -6044,6 +6107,9 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
     subValue =
     subValue =
         spvBuilder.createUnaryOp(spv::Op::OpLogicalNot, subType, subValue,
         spvBuilder.createUnaryOp(spv::Op::OpLogicalNot, subType, subValue,
                                  expr->getOperatorLoc(), range);
                                  expr->getOperatorLoc(), range);
+    if (!subValue)
+      return nullptr;
+
     subValue->setRValue();
     subValue->setRValue();
     return subValue;
     return subValue;
   }
   }
@@ -6069,6 +6135,9 @@ SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
     } else {
     } else {
       subValue = spvBuilder.createUnaryOp(spvOp, subType, subValue,
       subValue = spvBuilder.createUnaryOp(spvOp, subType, subValue,
                                           expr->getOperatorLoc(), range);
                                           expr->getOperatorLoc(), range);
+      if (!subValue)
+        return nullptr;
+
       subValue->setRValue();
       subValue->setRValue();
       return subValue;
       return subValue;
     }
     }
@@ -6569,6 +6638,9 @@ SpirvInstruction *SpirvEmitter::processBinaryOp(
     SpirvInstruction *result =
     SpirvInstruction *result =
         castToType(tempVar, astContext.BoolTy, resultType, loc, sourceRange);
         castToType(tempVar, astContext.BoolTy, resultType, loc, sourceRange);
     result = spvBuilder.createLoad(resultType, tempVar, loc, sourceRange);
     result = spvBuilder.createLoad(resultType, tempVar, loc, sourceRange);
+    if (!result)
+      return nullptr;
+
     result->setRValue();
     result->setRValue();
     return result;
     return result;
   }
   }
@@ -6653,6 +6725,9 @@ SpirvInstruction *SpirvEmitter::processBinaryOp(
               rhsValConstant->isSpecConstant()) {
               rhsValConstant->isSpecConstant()) {
             auto *val = spvBuilder.createSpecConstantBinaryOp(
             auto *val = spvBuilder.createSpecConstantBinaryOp(
                 spvOp, resultType, lhsVal, rhsVal, loc);
                 spvOp, resultType, lhsVal, rhsVal, loc);
+            if (!val)
+              return nullptr;
+
             val->setRValue();
             val->setRValue();
             return val;
             return val;
           }
           }
@@ -6675,6 +6750,8 @@ SpirvInstruction *SpirvEmitter::processBinaryOp(
                                       sourceRange);
                                       sourceRange);
     }
     }
 
 
+    if (!val)
+      return nullptr;
     val->setRValue();
     val->setRValue();
 
 
     // Propagate RelaxedPrecision
     // Propagate RelaxedPrecision
@@ -6929,6 +7006,8 @@ SpirvInstruction *SpirvEmitter::createVectorSplat(const Expr *scalarExpr,
   // Try to evaluate the element as constant first. If successful, then we
   // Try to evaluate the element as constant first. If successful, then we
   // can generate constant instructions for this vector splat.
   // can generate constant instructions for this vector splat.
   if ((scalarVal = tryToEvaluateAsConst(scalarExpr))) {
   if ((scalarVal = tryToEvaluateAsConst(scalarExpr))) {
+    if (!scalarVal)
+      return nullptr;
     scalarVal->setRValue();
     scalarVal->setRValue();
   } else {
   } else {
     scalarVal = loadIfGLValue(scalarExpr, range);
     scalarVal = loadIfGLValue(scalarExpr, range);
@@ -6949,12 +7028,16 @@ SpirvInstruction *SpirvEmitter::createVectorSplat(const Expr *scalarExpr,
     llvm::SmallVector<SpirvConstant *, 4> elements(size_t(size), constVal);
     llvm::SmallVector<SpirvConstant *, 4> elements(size_t(size), constVal);
     const bool isSpecConst = constVal->getopcode() == spv::Op::OpSpecConstant;
     const bool isSpecConst = constVal->getopcode() == spv::Op::OpSpecConstant;
     auto *value = spvBuilder.getConstantComposite(vecType, elements, isSpecConst);
     auto *value = spvBuilder.getConstantComposite(vecType, elements, isSpecConst);
+    if (!value)
+      return nullptr;
     value->setRValue();
     value->setRValue();
     return value;
     return value;
   } else {
   } else {
     llvm::SmallVector<SpirvInstruction *, 4> elements(size_t(size), scalarVal);
     llvm::SmallVector<SpirvInstruction *, 4> elements(size_t(size), scalarVal);
     auto *value = spvBuilder.createCompositeConstruct(
     auto *value = spvBuilder.createCompositeConstruct(
         vecType, elements, scalarExpr->getLocStart(), range);
         vecType, elements, scalarExpr->getLocStart(), range);
+    if (!value)
+      return nullptr;
     value->setRValue();
     value->setRValue();
     return value;
     return value;
   }
   }
@@ -7611,6 +7694,8 @@ SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
 
 
   // Construct the result matrix
   // Construct the result matrix
   auto *val = spvBuilder.createCompositeConstruct(matType, vectors, loc, range);
   auto *val = spvBuilder.createCompositeConstruct(matType, vectors, loc, range);
+  if (!val)
+    return nullptr;
   val->setRValue();
   val->setRValue();
   return val;
   return val;
 }
 }
@@ -7723,7 +7808,8 @@ SpirvEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
                                                        rhs->getLocStart());
                                                        rhs->getLocStart());
       auto *val =
       auto *val =
           spvBuilder.createBinaryOp(spvOp, vecType, lhsVec, rhsVec, loc, range);
           spvBuilder.createBinaryOp(spvOp, vecType, lhsVec, rhsVec, loc, range);
-      val->setRValue();
+      if (val)
+        val->setRValue();
       return val;
       return val;
     };
     };
     return processEachVectorInMatrix(lhs, lhsVal, actOnEachVec,
     return processEachVectorInMatrix(lhs, lhsVal, actOnEachVec,
@@ -13525,6 +13611,8 @@ SpirvInstruction *SpirvEmitter::createSpirvIntrInstExt(
 
 
   SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt(
   SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt(
       op, retType, spvArgs, extensions, instSet, capbilities, loc);
       op, retType, spvArgs, extensions, instSet, capbilities, loc);
+  if (!retVal)
+    return nullptr;
 
 
   // TODO: Revisit this r-value setting when handling vk::ext_result_id<T> ?
   // TODO: Revisit this r-value setting when handling vk::ext_result_id<T> ?
   retVal->setRValue();
   retVal->setRValue();
@@ -13643,6 +13731,8 @@ SpirvInstruction *SpirvEmitter::processRawBufferLoad(const CallExpr *callExpr) {
   SpirvInstruction *load =
   SpirvInstruction *load =
       loadDataFromRawAddress(address, bufferType, alignment, loc);
       loadDataFromRawAddress(address, bufferType, alignment, loc);
   auto *loadAsBool = castToBool(load, bufferType, boolType, loc);
   auto *loadAsBool = castToBool(load, bufferType, boolType, loc);
+  if (!loadAsBool)
+    return nullptr;
   loadAsBool->setRValue();
   loadAsBool->setRValue();
   return loadAsBool;
   return loadAsBool;
 }
 }
@@ -13670,23 +13760,39 @@ SpirvEmitter::loadDataFromRawAddress(SpirvInstruction *addressInUInt64,
   return loadInst;
   return loadInst;
 }
 }
 
 
-SpirvInstruction *SpirvEmitter::storeDataToRawAddress(
-    SpirvInstruction *addressInUInt64, SpirvInstruction *value,
-    QualType bufferType, uint32_t alignment, SourceLocation loc) {
-
+SpirvInstruction *
+SpirvEmitter::storeDataToRawAddress(SpirvInstruction *addressInUInt64,
+                                    SpirvInstruction *value,
+                                    QualType bufferType, uint32_t alignment,
+                                    SourceLocation loc, SourceRange range) {
   // Summary:
   // Summary:
   //   %address = OpBitcast %ptrTobufferType %addressInUInt64
   //   %address = OpBitcast %ptrTobufferType %addressInUInt64
   //   %storeInst = OpStore %address %value alignment %alignment
   //   %storeInst = OpStore %address %value alignment %alignment
+  if (!value || !addressInUInt64)
+    return nullptr;
 
 
   const HybridPointerType *bufferPtrType =
   const HybridPointerType *bufferPtrType =
       spvBuilder.getPhysicalStorageBufferType(bufferType);
       spvBuilder.getPhysicalStorageBufferType(bufferType);
 
 
   SpirvUnaryOp *address = spvBuilder.createUnaryOp(
   SpirvUnaryOp *address = spvBuilder.createUnaryOp(
       spv::Op::OpBitcast, bufferPtrType, addressInUInt64, loc);
       spv::Op::OpBitcast, bufferPtrType, addressInUInt64, loc);
+  if (!address)
+    return nullptr;
   address->setStorageClass(spv::StorageClass::PhysicalStorageBuffer);
   address->setStorageClass(spv::StorageClass::PhysicalStorageBuffer);
 
 
-  SpirvStore *storeInst = spvBuilder.createStore(address, value, loc);
+  // If the source value has a different layout, it is not safe to directly
+  // store it. It needs to be component-wise reconstructed to the new layout.
+  SpirvInstruction *source = value;
+  if (value->getStorageClass() != address->getStorageClass()) {
+    source = reconstructValue(value, bufferType, address->getLayoutRule(), loc,
+                              range);
+  }
+  if (!source)
+    return nullptr;
+
+  SpirvStore *storeInst = spvBuilder.createStore(address, source, loc);
   storeInst->setAlignment(alignment);
   storeInst->setAlignment(alignment);
+  storeInst->setStorageClass(spv::StorageClass::PhysicalStorageBuffer);
   return nullptr;
   return nullptr;
 }
 }
 
 
@@ -13698,10 +13804,14 @@ SpirvEmitter::processRawBufferStore(const CallExpr *callExpr) {
 
 
   SpirvInstruction *address = doExpr(callExpr->getArg(0));
   SpirvInstruction *address = doExpr(callExpr->getArg(0));
   SpirvInstruction *value = doExpr(callExpr->getArg(1));
   SpirvInstruction *value = doExpr(callExpr->getArg(1));
+  if (!address || !value)
+    return nullptr;
+
   QualType bufferType = value->getAstResultType();
   QualType bufferType = value->getAstResultType();
   clang::SourceLocation loc = callExpr->getExprLoc();
   clang::SourceLocation loc = callExpr->getExprLoc();
   if (!isBoolOrVecMatOfBoolType(bufferType)) {
   if (!isBoolOrVecMatOfBoolType(bufferType)) {
-    return storeDataToRawAddress(address, value, bufferType, alignment, loc);
+    return storeDataToRawAddress(address, value, bufferType, alignment, loc,
+                                 callExpr->getLocStart());
   }
   }
 
 
   // If callExpr is `vk::RawBufferLoad<bool>(..)`, we have to load 'uint' and
   // If callExpr is `vk::RawBufferLoad<bool>(..)`, we have to load 'uint' and
@@ -13717,7 +13827,8 @@ SpirvEmitter::processRawBufferStore(const CallExpr *callExpr) {
   QualType boolType = bufferType;
   QualType boolType = bufferType;
   bufferType = getUintTypeForBool(astContext, theCompilerInstance, boolType);
   bufferType = getUintTypeForBool(astContext, theCompilerInstance, boolType);
   auto *storeAsInt = castToInt(value, boolType, bufferType, loc);
   auto *storeAsInt = castToInt(value, boolType, bufferType, loc);
-  return storeDataToRawAddress(address, storeAsInt, bufferType, alignment, loc);
+  return storeDataToRawAddress(address, storeAsInt, bufferType, alignment, loc,
+                               callExpr->getLocStart());
 }
 }
 
 
 SpirvInstruction *
 SpirvInstruction *

+ 2 - 1
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -656,7 +656,8 @@ private:
                                           SpirvInstruction *value,
                                           SpirvInstruction *value,
                                           QualType bufferType,
                                           QualType bufferType,
                                           uint32_t alignment,
                                           uint32_t alignment,
-                                          SourceLocation loc);
+                                          SourceLocation loc,
+                                          SourceRange range);
 
 
   /// Returns the alignment of `vk::RawBufferLoad()`.
   /// Returns the alignment of `vk::RawBufferLoad()`.
   uint32_t getAlignmentForRawBufferLoad(const CallExpr *callExpr);
   uint32_t getAlignmentForRawBufferLoad(const CallExpr *callExpr);

+ 88 - 17
tools/clang/test/CodeGenSPIRV/intrinsics.vkrawbufferload.bitfield.hlsl

@@ -1,8 +1,26 @@
-// RUN: %dxc -T ps_6_0 -E main -HV 2021
+// RUN: %dxc -T cs_6_0 -E main -HV 2021
 
 
 // CHECK: OpCapability PhysicalStorageBufferAddresses
 // CHECK: OpCapability PhysicalStorageBufferAddresses
 // CHECK: OpExtension "SPV_KHR_physical_storage_buffer"
 // CHECK: OpExtension "SPV_KHR_physical_storage_buffer"
 // CHECK: OpMemoryModel PhysicalStorageBuffer64 GLSL450
 // CHECK: OpMemoryModel PhysicalStorageBuffer64 GLSL450
+// CHECK-NOT: OpMemberDecorate %S 0 Offset 0
+// CHECK-NOT: OpMemberDecorate %S 1 Offset 4
+// CHECK-NOT: OpMemberDecorate %S 2 Offset 8
+// CHECK-NOT: OpMemberDecorate %S 3 Offset 12
+// CHECK: OpMemberDecorate %S_0 0 Offset 0
+// CHECK: OpMemberDecorate %S_0 1 Offset 4
+// CHECK: OpMemberDecorate %S_0 2 Offset 8
+// CHECK-NOT: OpMemberDecorate %S_0 3 Offset 12
+
+// CHECK: %S = OpTypeStruct %uint %uint %uint
+// CHECK: %_ptr_Function_S = OpTypePointer Function %S
+// CHECK: %S_0 = OpTypeStruct %uint %uint %uint
+// CHECK: %_ptr_PhysicalStorageBuffer_S_0 = OpTypePointer PhysicalStorageBuffer %S_0
+
+// CHECK: %temp_var_S = OpVariable %_ptr_Function_S Function
+// CHECK: %temp_var_S_0 = OpVariable %_ptr_Function_S Function
+// CHECK: %temp_var_S_1 = OpVariable %_ptr_Function_S Function
+// CHECK: %temp_var_S_2 = OpVariable %_ptr_Function_S Function
 
 
 struct S {
 struct S {
   uint f1;
   uint f1;
@@ -13,21 +31,74 @@ struct S {
 
 
 uint64_t Address;
 uint64_t Address;
 
 
-// CHECK: [[type_S:%\w+]] = OpTypeStruct %uint %uint %uint
-// CHECK: [[ptr_f_S:%\w+]] = OpTypePointer Function [[type_S]]
-// CHECK: [[ptr_p_S:%\w+]] = OpTypePointer PhysicalStorageBuffer [[type_S]]
-
-void main() : B {
-// CHECK: [[tmp_S:%\w+]] = OpVariable [[ptr_f_S]] Function
-// CHECK: [[value:%\d+]] = OpAccessChain %_ptr_Uniform_ulong %_Globals %int_0
-// CHECK: [[value:%\d+]] = OpLoad %ulong [[value]]
-// CHECK: [[value:%\d+]] = OpBitcast [[ptr_p_S]] [[value]]
-// CHECK: [[value:%\d+]] = OpLoad [[type_S]] [[value]] Aligned 4
-// CHECK: OpStore [[tmp_S]] [[value]]
-// CHECK: [[value:%\d+]] = OpAccessChain %_ptr_Function_uint [[tmp_S]] %int_1
-// CHECK: [[value:%\d+]] = OpLoad %uint [[value]]
-// CHECK: [[value:%\d+]] = OpBitFieldUExtract %uint [[value]] %uint_1 %uint_1
-// CHECK: OpStore %tmp [[value]]
-  uint tmp = vk::RawBufferLoad<S>(Address).f3;
+[numthreads(1, 1, 1)]
+void main(uint3 tid : SV_DispatchThreadID) {
+
+
+  {
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Uniform_ulong %_Globals %int_0
+    // CHECK: [[tmp:%\d+]] = OpLoad %ulong [[tmp]]
+    // CHECK: [[ptr:%\d+]] = OpBitcast %_ptr_PhysicalStorageBuffer_S_0 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpLoad %S_0 [[ptr]] Aligned 4
+    // CHECK: [[member0:%\d+]] = OpCompositeExtract %uint [[tmp]] 0
+    // CHECK: [[member1:%\d+]] = OpCompositeExtract %uint [[tmp]] 1
+    // CHECK: [[member2:%\d+]] = OpCompositeExtract %uint [[tmp]] 2
+    // CHECK: [[tmp:%\d+]] = OpCompositeConstruct %S [[member0]] [[member1]] [[member2]]
+    // CHECK: OpStore %temp_var_S [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint %temp_var_S %int_0
+    // CHECK: [[tmp:%\d+]] = OpLoad %uint [[tmp]]
+    // CHECK: OpStore %tmp1 [[tmp]]
+    uint tmp1 = vk::RawBufferLoad<S>(Address).f1;
+  }
+
+  {
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Uniform_ulong %_Globals %int_0
+    // CHECK: [[tmp:%\d+]] = OpLoad %ulong [[tmp]]
+    // CHECK: [[ptr:%\d+]] = OpBitcast %_ptr_PhysicalStorageBuffer_S_0 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpLoad %S_0 [[ptr]] Aligned 4
+    // CHECK: [[member0:%\d+]] = OpCompositeExtract %uint [[tmp]] 0
+    // CHECK: [[member1:%\d+]] = OpCompositeExtract %uint [[tmp]] 1
+    // CHECK: [[member2:%\d+]] = OpCompositeExtract %uint [[tmp]] 2
+    // CHECK: [[tmp:%\d+]] = OpCompositeConstruct %S [[member0]] [[member1]] [[member2]]
+    // CHECK: OpStore %temp_var_S_0 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint %temp_var_S_0 %int_1
+    // CHECK: [[tmp:%\d+]] = OpLoad %uint [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpBitFieldUExtract %uint [[tmp]] %uint_0 %uint_1
+    // CHECK: OpStore %tmp2 [[tmp]]
+    uint tmp2 = vk::RawBufferLoad<S>(Address).f2;
+  }
+
+  {
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Uniform_ulong %_Globals %int_0
+    // CHECK: [[tmp:%\d+]] = OpLoad %ulong [[tmp]]
+    // CHECK: [[ptr:%\d+]] = OpBitcast %_ptr_PhysicalStorageBuffer_S_0 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpLoad %S_0 [[ptr]] Aligned 4
+    // CHECK: [[member0:%\d+]] = OpCompositeExtract %uint [[tmp]] 0
+    // CHECK: [[member1:%\d+]] = OpCompositeExtract %uint [[tmp]] 1
+    // CHECK: [[member2:%\d+]] = OpCompositeExtract %uint [[tmp]] 2
+    // CHECK: [[tmp:%\d+]] = OpCompositeConstruct %S [[member0]] [[member1]] [[member2]]
+    // CHECK: OpStore %temp_var_S_1 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint %temp_var_S_1 %int_1
+    // CHECK: [[tmp:%\d+]] = OpLoad %uint [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpBitFieldUExtract %uint [[tmp]] %uint_1 %uint_1
+    // CHECK: OpStore %tmp3 [[tmp]]
+    uint tmp3 = vk::RawBufferLoad<S>(Address).f3;
+  }
+
+  {
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Uniform_ulong %_Globals %int_0
+    // CHECK: [[tmp:%\d+]] = OpLoad %ulong [[tmp]]
+    // CHECK: [[ptr:%\d+]] = OpBitcast %_ptr_PhysicalStorageBuffer_S_0 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpLoad %S_0 [[ptr]] Aligned 4
+    // CHECK: [[member0:%\d+]] = OpCompositeExtract %uint [[tmp]] 0
+    // CHECK: [[member1:%\d+]] = OpCompositeExtract %uint [[tmp]] 1
+    // CHECK: [[member2:%\d+]] = OpCompositeExtract %uint [[tmp]] 2
+    // CHECK: [[tmp:%\d+]] = OpCompositeConstruct %S [[member0]] [[member1]] [[member2]]
+    // CHECK: OpStore %temp_var_S_2 [[tmp]]
+    // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint %temp_var_S_2 %int_2
+    // CHECK: [[tmp:%\d+]] = OpLoad %uint [[tmp]]
+    // CHECK: OpStore %tmp4 [[tmp]]
+    uint tmp4 = vk::RawBufferLoad<S>(Address).f4;
+  }
 }
 }
 
 

+ 56 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.vkrawbufferstore.bitfields.hlsl

@@ -0,0 +1,56 @@
+// RUN: %dxc -T cs_6_0 -E main -HV 2021
+
+// CHECK: OpCapability PhysicalStorageBufferAddresses
+// CHECK: OpExtension "SPV_KHR_physical_storage_buffer"
+// CHECK: OpMemoryModel PhysicalStorageBuffer64 GLSL450
+// CHECK-NOT: OpMemberDecorate %S 0 Offset 0
+// CHECK-NOT: OpMemberDecorate %S 1 Offset 4
+// CHECK-NOT: OpMemberDecorate %S 2 Offset 8
+// CHECK-NOT: OpMemberDecorate %S 3 Offset 12
+// CHECK: OpMemberDecorate %S_0 0 Offset 0
+// CHECK: OpMemberDecorate %S_0 1 Offset 4
+// CHECK: OpMemberDecorate %S_0 2 Offset 8
+// CHECK-NOT: OpMemberDecorate %S_0 3 Offset 12
+
+// CHECK: %S = OpTypeStruct %uint %uint %uint
+// CHECK: %_ptr_Function_S = OpTypePointer Function %S
+// CHECK: %S_0 = OpTypeStruct %uint %uint %uint
+// CHECK: %_ptr_PhysicalStorageBuffer_S_0 = OpTypePointer PhysicalStorageBuffer %S_0
+
+
+struct S {
+  uint f1;
+  uint f2 : 1;
+  uint f3 : 3;
+  uint f4;
+};
+
+uint64_t Address;
+
+[numthreads(1, 1, 1)]
+void main(uint3 tid : SV_DispatchThreadID) {
+  // CHECK: %tmp = OpVariable %_ptr_Function_S Function
+  S tmp;
+
+  // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint %tmp %int_0
+  // CHECK: OpStore [[tmp]] %uint_2
+  tmp.f1 = 2;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint %tmp %int_1
+  // CHECK: [[tmp:%\d+]] = OpLoad %uint [[ptr]]
+  // CHECK: [[tmp:%\d+]] = OpBitFieldInsert %uint [[tmp]] %uint_1 %uint_0 %uint_1
+  // CHECK: OpStore [[ptr]] [[tmp]]
+  tmp.f2 = 1;
+
+  // CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_uint %tmp %int_1
+  // CHECK: [[tmp:%\d+]] = OpLoad %uint [[ptr]]
+  // CHECK: [[tmp:%\d+]] = OpBitFieldInsert %uint [[tmp]] %uint_0 %uint_1 %uint_3
+  // CHECK: OpStore [[ptr]] [[tmp]]
+  tmp.f3 = 0;
+
+  // CHECK: [[tmp:%\d+]] = OpAccessChain %_ptr_Function_uint %tmp %int_2
+  // CHECK: OpStore [[tmp]] %uint_3
+  tmp.f4 = 3;
+  vk::RawBufferStore<S>(Address, tmp);
+}
+

+ 19 - 2
tools/clang/test/CodeGenSPIRV/intrinsics.vkrawbufferstore.hlsl

@@ -3,6 +3,18 @@
 // CHECK: OpCapability PhysicalStorageBufferAddresses
 // CHECK: OpCapability PhysicalStorageBufferAddresses
 // CHECK: OpExtension "SPV_KHR_physical_storage_buffer"
 // CHECK: OpExtension "SPV_KHR_physical_storage_buffer"
 // CHECK: OpMemoryModel PhysicalStorageBuffer64 GLSL450
 // CHECK: OpMemoryModel PhysicalStorageBuffer64 GLSL450
+// CHECK-NOT: OpMemberDecorate %XYZW 0 Offset 0
+// CHECK-NOT: OpMemberDecorate %XYZW 1 Offset 4
+// CHECK-NOT: OpMemberDecorate %XYZW 2 Offset 8
+// CHECK-NOT: OpMemberDecorate %XYZW 3 Offset 12
+// CHECK: OpMemberDecorate %XYZW_0 0 Offset 0
+// CHECK: OpMemberDecorate %XYZW_0 1 Offset 4
+// CHECK: OpMemberDecorate %XYZW_0 2 Offset 8
+// CHECK: OpMemberDecorate %XYZW_0 3 Offset 12
+// CHECK: %XYZW = OpTypeStruct %int %int %int %int
+// CHECK: %_ptr_Function_XYZW = OpTypePointer Function %XYZW
+// CHECK: %XYZW_0 = OpTypeStruct %int %int %int %int
+// CHECK: %_ptr_PhysicalStorageBuffer_XYZW_0 = OpTypePointer PhysicalStorageBuffer %XYZW_0
 
 
 struct XYZW {
 struct XYZW {
   int x;
   int x;
@@ -50,8 +62,13 @@ void main(uint3 tid : SV_DispatchThreadID) {
 
 
   // CHECK:      [[addr:%\d+]] = OpLoad %ulong
   // CHECK:      [[addr:%\d+]] = OpLoad %ulong
   // CHECK-NEXT: [[xyzwval:%\d+]] = OpLoad %XYZW %xyzw
   // CHECK-NEXT: [[xyzwval:%\d+]] = OpLoad %XYZW %xyzw
-  // CHECK-NEXT: [[buf:%\d+]] = OpBitcast %_ptr_PhysicalStorageBuffer_XYZW [[addr]]
-  // CHECK-NEXT: OpStore [[buf]] [[xyzwval]] Aligned 4
+  // CHECK-NEXT: [[buf:%\d+]] = OpBitcast %_ptr_PhysicalStorageBuffer_XYZW_0 [[addr]]
+  // CHECK-NEXT: [[member1:%\d+]] = OpCompositeExtract %int [[xyzwval]] 0
+  // CHECK-NEXT: [[member2:%\d+]] = OpCompositeExtract %int [[xyzwval]] 1
+  // CHECK-NEXT: [[member3:%\d+]] = OpCompositeExtract %int [[xyzwval]] 2
+  // CHECK-NEXT: [[member4:%\d+]] = OpCompositeExtract %int [[xyzwval]] 3
+  // CHECK-NEXT: [[p_xyzwval:%\d+]] = OpCompositeConstruct %XYZW_0 [[member1]] [[member2]] [[member3]] [[member4]]
+  // CHECK-NEXT: OpStore [[buf]] [[p_xyzwval]] Aligned 4
   XYZW xyzw;
   XYZW xyzw;
   xyzw.x = 78;
   xyzw.x = 78;
   xyzw.y = 65;
   xyzw.y = 65;

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -1461,6 +1461,9 @@ TEST_F(FileTest, IntrinsicsVkRawBufferLoadBitfield) {
 TEST_F(FileTest, IntrinsicsVkRawBufferStore) {
 TEST_F(FileTest, IntrinsicsVkRawBufferStore) {
   runFileTest("intrinsics.vkrawbufferstore.hlsl");
   runFileTest("intrinsics.vkrawbufferstore.hlsl");
 }
 }
+TEST_F(FileTest, IntrinsicsVkRawBufferStoreBitfields) {
+  runFileTest("intrinsics.vkrawbufferstore.bitfields.hlsl");
+}
 // Intrinsics added in SM 6.6
 // Intrinsics added in SM 6.6
 TEST_F(FileTest, IntrinsicsSM66PackU8S8) {
 TEST_F(FileTest, IntrinsicsSM66PackU8S8) {
   runFileTest("intrinsics.sm6_6.pack_s8u8.hlsl");
   runFileTest("intrinsics.sm6_6.pack_s8u8.hlsl");