Browse Source

[spirv] Support assigning to objects of composite types (#598)

For these objects, if the lhs and rhs are of different storage
class, we may need to do the assignment recursively at the
non-composite level since OpStore requires that the pointer's
type operand must be the same as the type of object.
Lei Zhang 8 years ago
parent
commit
f4e379db85

+ 75 - 29
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -94,7 +94,7 @@ uint32_t DeclResultIdMapper::getDeclResultId(const NamedDecl *decl) {
 uint32_t DeclResultIdMapper::createFnParam(uint32_t paramType,
                                            const ParmVarDecl *param) {
   const uint32_t id = theBuilder.addFnParam(paramType, param->getName());
-  astDecls[param] = {id, spv::StorageClass::Function, -1};
+  astDecls[param] = {id, spv::StorageClass::Function};
 
   return id;
 }
@@ -102,7 +102,7 @@ uint32_t DeclResultIdMapper::createFnParam(uint32_t paramType,
 uint32_t DeclResultIdMapper::createFnVar(uint32_t varType, const VarDecl *var,
                                          llvm::Optional<uint32_t> init) {
   const uint32_t id = theBuilder.addFnVar(varType, var->getName(), init);
-  astDecls[var] = {id, spv::StorageClass::Function, -1};
+  astDecls[var] = {id, spv::StorageClass::Function};
 
   return id;
 }
@@ -111,7 +111,7 @@ uint32_t DeclResultIdMapper::createFileVar(uint32_t varType, const VarDecl *var,
                                            llvm::Optional<uint32_t> init) {
   const uint32_t id = theBuilder.addModuleVar(
       varType, spv::StorageClass::Private, var->getName(), init);
-  astDecls[var] = {id, spv::StorageClass::Private, -1};
+  astDecls[var] = {id, spv::StorageClass::Private};
 
   return id;
 }
@@ -136,7 +136,7 @@ uint32_t DeclResultIdMapper::createExternVar(const VarDecl *var) {
   const auto varType = typeTranslator.translateType(var->getType(), rule);
   const uint32_t id = theBuilder.addModuleVar(varType, storageClass,
                                               var->getName(), llvm::None);
-  astDecls[var] = {id, storageClass, -1};
+  astDecls[var] = {id, storageClass, rule};
   resourceVars.emplace_back(id, getResourceBinding(var),
                             var->getAttr<VKBindingAttr>());
 
@@ -198,7 +198,9 @@ uint32_t DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
   int index = 0;
   for (const auto *subDecl : decl->decls()) {
     const auto *varDecl = cast<VarDecl>(subDecl);
-    astDecls[varDecl] = {bufferVar, spv::StorageClass::Uniform, index++};
+    // TODO: std140 rules may not suit tbuffers.
+    astDecls[varDecl] = {bufferVar, spv::StorageClass::Uniform,
+                         LayoutRule::GLSLStd140, index++};
   }
   resourceVars.emplace_back(bufferVar, getResourceBinding(decl),
                             decl->getAttr<VKBindingAttr>());
@@ -219,7 +221,9 @@ uint32_t DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
       recordType->getDecl(), structName, decl->getName());
 
   // We register the VarDecl here.
-  astDecls[decl] = {bufferVar, spv::StorageClass::Uniform, -1};
+  // TODO: std140 rules may not suit tbuffers.
+  astDecls[decl] = {bufferVar, spv::StorageClass::Uniform,
+                    LayoutRule::GLSLStd140};
   resourceVars.emplace_back(bufferVar, getResourceBinding(context),
                             decl->getAttr<VKBindingAttr>());
 
@@ -231,24 +235,48 @@ uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
     return info->resultId;
 
   const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
-  astDecls[fn] = {id, spv::StorageClass::Function, -1};
+  astDecls[fn] = {id, spv::StorageClass::Function};
 
   return id;
 }
 
 namespace {
-/// A class for resolving the storage class of a given Decl or Expr.
-class StorageClassResolver : public RecursiveASTVisitor<StorageClassResolver> {
+/// A class for resolving the storage info (storage class and memory layout) of
+/// a given Decl or Expr.
+class StorageInfoResolver : public RecursiveASTVisitor<StorageInfoResolver> {
 public:
-  explicit StorageClassResolver(const DeclResultIdMapper &mapper)
-      : declIdMapper(mapper), storageClass(spv::StorageClass::Max) {}
+  explicit StorageInfoResolver(const DeclResultIdMapper &mapper)
+      : declIdMapper(mapper), baseDecl(nullptr) {}
+
+  bool TraverseMemberExpr(MemberExpr *expr) {
+    // For MemberExpr, the storage info should follow the base.
+    return TraverseStmt(expr->getBase());
+  }
+
+  bool TravereArraySubscriptExpr(ArraySubscriptExpr *expr) {
+    // For ArraySubscriptExpr, the storage info should follow the array object.
+    return TraverseStmt(expr->getBase());
+  }
+
+  bool TraverseCXXOperatorCallExpr(CXXOperatorCallExpr *expr) {
+    // For operator[], the storage info should follow the object.
+    if (expr->getOperator() == OverloadedOperatorKind::OO_Subscript)
+      return TraverseStmt(expr->getArg(0));
+
+    // TODO: the following may not be correct for all other operator calls.
+    for (uint32_t i = 0; i < expr->getNumArgs(); ++i)
+      if (!TraverseStmt(expr->getArg(i)))
+        return false;
+
+    return true;
+  }
 
   bool TraverseCXXMemberCallExpr(CXXMemberCallExpr *expr) {
-    // For method calls, the storage class should follow the object.
+    // For method calls, the storage info should follow the object.
     return TraverseStmt(expr->getImplicitObjectArgument());
   }
 
-  // For querying the storage class of a remapped decl
+  // For querying the storage info of a remapped decl
 
   // Semantics may be attached to FunctionDecl, ParmVarDecl, and FieldDecl.
   // We create stage variables for them and we may need to query the storage
@@ -257,7 +285,7 @@ public:
   bool VisitFieldDecl(FieldDecl *decl) { return processDecl(decl); }
   bool VisitParmVarDecl(ParmVarDecl *decl) { return processDecl(decl); }
 
-  // For querying the storage class of a normal decl
+  // For querying the storage info of a normal decl
 
   // Normal decls should be referred in expressions.
   bool VisitDeclRefExpr(DeclRefExpr *expr) {
@@ -269,39 +297,57 @@ public:
     if (isa<CXXMethodDecl>(decl))
       return true;
 
-    const auto *info = declIdMapper.getDeclSpirvInfo(decl);
-    assert(info);
-    if (storageClass == spv::StorageClass::Max) {
-      storageClass = info->storageClass;
+    if (!baseDecl) {
+      baseDecl = decl;
       return true;
     }
 
-    // Two decls with different storage classes are referenced in this
-    // expression. We should not visit such expression using this class.
-    assert(storageClass == info->storageClass);
+    // Two different decls referenced in the expression: this expression stands
+    // for a derived temporary object, for which case we use the Function
+    // storage class and no layout rules.
+    // Turn baseDecl to nullptr so we return the proper info and stop further
+    // traversing.
+    baseDecl = nullptr;
     return false;
   }
 
-  spv::StorageClass get() const { return storageClass; }
+  spv::StorageClass getStorageClass() const {
+    if (const auto *info = declIdMapper.getDeclSpirvInfo(baseDecl))
+      return info->storageClass;
+
+    // No Decl referenced. This is probably a temporary expression.
+    return spv::StorageClass::Function;
+  }
+
+  LayoutRule getLayoutRule() const {
+    if (const auto *info = declIdMapper.getDeclSpirvInfo(baseDecl))
+      return info->layoutRule;
+
+    // No Decl referenced. This is probably a temporary expression.
+    return LayoutRule::Void;
+  }
 
 private:
   const DeclResultIdMapper &declIdMapper;
-  spv::StorageClass storageClass;
+  const NamedDecl *baseDecl;
 };
 } // namespace
 
 spv::StorageClass
-DeclResultIdMapper::resolveStorageClass(const Expr *expr) const {
-  auto resolver = StorageClassResolver(*this);
+DeclResultIdMapper::resolveStorageInfo(const Expr *expr,
+                                       LayoutRule *rule) const {
+  auto resolver = StorageInfoResolver(*this);
   resolver.TraverseStmt(const_cast<Expr *>(expr));
-  return resolver.get();
+  if (rule)
+    *rule = resolver.getLayoutRule();
+  return resolver.getStorageClass();
 }
 
 spv::StorageClass
 DeclResultIdMapper::resolveStorageClass(const Decl *decl) const {
-  auto resolver = StorageClassResolver(*this);
+  auto resolver = StorageInfoResolver(*this);
   resolver.TraverseDecl(const_cast<Decl *>(decl));
-  return resolver.get();
+  return resolver.getStorageClass();
 }
 
 std::vector<uint32_t> DeclResultIdMapper::collectStageVars() const {
@@ -810,4 +856,4 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
 }
 
 } // end namespace spirv
-} // end namespace clang
+} // end namespace clang

+ 9 - 4
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -177,9 +177,11 @@ public:
   struct DeclSpirvInfo {
     uint32_t resultId;
     spv::StorageClass storageClass;
+    /// Layout rule for this decl.
+    LayoutRule layoutRule = LayoutRule::Void;
     /// Value >= 0 means that this decl is a VarDecl inside a cbuffer/tbuffer
     /// and this is the index; value < 0 means this is just a standalone decl.
-    int indexInCTBuffer;
+    int indexInCTBuffer = -1;
   };
 
   /// \brief Returns the SPIR-V information for the given decl.
@@ -196,9 +198,12 @@ public:
   /// returns a newly assigned <result-id> for it.
   uint32_t getOrRegisterFnResultId(const FunctionDecl *fn);
 
-  /// Returns the storage class for the given expression. The expression is
-  /// expected to be an lvalue. Otherwise this method may panic.
-  spv::StorageClass resolveStorageClass(const Expr *expr) const;
+  /// Returns the storage class for the given expression. If rule is not
+  /// nullptr, also writes the layout rule into it.
+  /// The expression is expected to be an lvalue. Otherwise this method may
+  /// panic.
+  spv::StorageClass resolveStorageInfo(const Expr *expr,
+                                       LayoutRule *rule = nullptr) const;
   spv::StorageClass resolveStorageClass(const Decl *decl) const;
 
   /// \brief Returns all defined stage (builtin/input/ouput) variables in this

+ 105 - 29
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -371,10 +371,8 @@ uint32_t SPIRVEmitter::doExpr(const Expr *expr) {
 
 uint32_t SPIRVEmitter::loadIfGLValue(const Expr *expr) {
   const uint32_t result = doExpr(expr);
-  if (expr->isGLValue()) {
-    const uint32_t baseTyId = typeTranslator.translateType(expr->getType());
-    return theBuilder.createLoad(baseTyId, result);
-  }
+  if (expr->isGLValue())
+    return theBuilder.createLoad(getType(expr), result);
 
   return result;
 }
@@ -516,7 +514,10 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
           toInitGloalVars.push_back(decl);
         }
       } else {
-        theBuilder.createStore(varId, doExpr(decl->getInit()));
+        LayoutRule rhsLayout = LayoutRule::Void;
+        (void)declIdMapper.resolveStorageInfo(decl->getInit(), &rhsLayout);
+        storeValue(varId, loadIfGLValue(decl->getInit()), decl->getType(),
+                   spv::StorageClass::Function, LayoutRule::Void, rhsLayout);
       }
     }
   } else {
@@ -1040,9 +1041,10 @@ uint32_t SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
   llvm::SmallVector<uint32_t, 4> indices;
   const auto *base = collectArrayStructIndices(expr, &indices);
 
-  const uint32_t ptrType =
-      theBuilder.getPointerType(typeTranslator.translateType(expr->getType()),
-                                declIdMapper.resolveStorageClass(base));
+  LayoutRule rule = LayoutRule::Void;
+  const auto sc = declIdMapper.resolveStorageInfo(base, &rule);
+  const uint32_t ptrType = theBuilder.getPointerType(
+      typeTranslator.translateType(expr->getType(), rule), sc);
 
   return theBuilder.createAccessChain(ptrType, doExpr(base), indices);
 }
@@ -1052,8 +1054,12 @@ uint32_t SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
 
   // Handle assignment first since we need to evaluate rhs before lhs.
   // For other binary operations, we need to evaluate lhs before rhs.
-  if (opcode == BO_Assign)
-    return processAssignment(expr->getLHS(), doExpr(expr->getRHS()), false);
+  if (opcode == BO_Assign) {
+    LayoutRule rhsLayout = LayoutRule::Void;
+    (void)declIdMapper.resolveStorageInfo(expr->getRHS(), &rhsLayout);
+    return processAssignment(expr->getLHS(), loadIfGLValue(expr->getRHS()),
+                             false, /*lhsPtr*/ 0, rhsLayout);
+  }
 
   // Try to optimize floatMxN * float and floatN * float case
   if (opcode == BO_Mul) {
@@ -1165,8 +1171,7 @@ uint32_t SPIRVEmitter::doCastExpr(const CastExpr *expr) {
 
     // Using lvalue as rvalue means we need to OpLoad the contents from
     // the parameter/variable first.
-    const uint32_t resultType = typeTranslator.translateType(toType);
-    return theBuilder.createLoad(resultType, fromValue);
+    return theBuilder.createLoad(getType(expr), fromValue);
   }
   case CastKind::CK_NoOp:
     return doExpr(subExpr);
@@ -1433,7 +1438,7 @@ uint32_t SPIRVEmitter::processByteAddressBufferLoadStore(
   // the address.
   const uint32_t uintTypeId = theBuilder.getUint32Type();
   const uint32_t ptrType = theBuilder.getPointerType(
-      uintTypeId, declIdMapper.resolveStorageClass(object));
+      uintTypeId, declIdMapper.resolveStorageInfo(object));
   const uint32_t constUint0 = theBuilder.getConstantUint32(0);
 
   if (doStore) {
@@ -1495,7 +1500,7 @@ SPIRVEmitter::processStructuredBufferLoad(const CXXMemberCallExpr *expr) {
       hlsl::GetHLSLResourceResultType(buffer->getType());
   const uint32_t ptrType = theBuilder.getPointerType(
       typeTranslator.translateType(structType, LayoutRule::GLSLStd430),
-      declIdMapper.resolveStorageClass(buffer));
+      declIdMapper.resolveStorageInfo(buffer));
 
   const uint32_t zero = theBuilder.getConstantInt32(0);
   const uint32_t index = doExpr(expr->getArg(0));
@@ -1724,9 +1729,10 @@ uint32_t SPIRVEmitter::doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr) {
     base = tempVar;
   }
 
-  const uint32_t ptrType =
-      theBuilder.getPointerType(typeTranslator.translateType(expr->getType()),
-                                declIdMapper.resolveStorageClass(baseExpr));
+  LayoutRule rule = LayoutRule::Void;
+  const auto sc = declIdMapper.resolveStorageInfo(baseExpr, &rule);
+  const uint32_t ptrType = theBuilder.getPointerType(
+      typeTranslator.translateType(expr->getType(), rule), sc);
 
   return theBuilder.createAccessChain(ptrType, base, indices);
 }
@@ -1765,7 +1771,7 @@ SPIRVEmitter::doExtMatrixElementExpr(const ExtMatrixElementExpr *expr) {
         indices[i] = theBuilder.getConstantInt32(indices[i]);
 
       const uint32_t ptrType = theBuilder.getPointerType(
-          elemType, declIdMapper.resolveStorageClass(baseExpr));
+          elemType, declIdMapper.resolveStorageInfo(baseExpr));
       if (!indices.empty()) {
         // Load the element via access chain
         elem = theBuilder.createAccessChain(ptrType, base, indices);
@@ -1824,7 +1830,7 @@ SPIRVEmitter::doHLSLVectorElementExpr(const HLSLVectorElementExpr *expr) {
     // v.xyyz to turn a lvalue v into rvalue.
     if (expr->getBase()->isGLValue()) { // E.g., v.x;
       const uint32_t ptrType = theBuilder.getPointerType(
-          type, declIdMapper.resolveStorageClass(baseExpr));
+          type, declIdMapper.resolveStorageInfo(baseExpr));
       const uint32_t index = theBuilder.getConstantInt32(accessor.Swz0);
       // We need a lvalue here. Do not try to load.
       return theBuilder.createAccessChain(ptrType, doExpr(baseExpr), {index});
@@ -1875,9 +1881,10 @@ uint32_t SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {
   llvm::SmallVector<uint32_t, 4> indices;
   const Expr *base = collectArrayStructIndices(expr, &indices);
 
-  const uint32_t ptrType =
-      theBuilder.getPointerType(typeTranslator.translateType(expr->getType()),
-                                declIdMapper.resolveStorageClass(base));
+  LayoutRule rule = LayoutRule::Void;
+  const auto sc = declIdMapper.resolveStorageInfo(base, &rule);
+  const uint32_t ptrType = theBuilder.getPointerType(
+      typeTranslator.translateType(expr->getType(), rule), sc);
 
   return theBuilder.createAccessChain(ptrType, doExpr(base), indices);
 }
@@ -1944,6 +1951,12 @@ uint32_t SPIRVEmitter::doUnaryOperator(const UnaryOperator *expr) {
   return 0;
 }
 
+uint32_t SPIRVEmitter::getType(const Expr *expr) {
+  LayoutRule rule = LayoutRule::Void;
+  (void)declIdMapper.resolveStorageInfo(expr, &rule);
+  return typeTranslator.translateType(expr->getType(), rule);
+}
+
 spv::Op SPIRVEmitter::translateOp(BinaryOperator::Opcode op, QualType type) {
   const bool isSintType = isSintOrVecMatOfSintType(type);
   const bool isUintType = isUintOrVecMatOfUintType(type);
@@ -2066,8 +2079,9 @@ case BO_##kind : {                                                             \
 }
 
 uint32_t SPIRVEmitter::processAssignment(const Expr *lhs, const uint32_t rhs,
-                                         bool isCompoundAssignment,
-                                         uint32_t lhsPtr) {
+                                         const bool isCompoundAssignment,
+                                         uint32_t lhsPtr,
+                                         LayoutRule rhsLayout) {
   // Assigning to vector swizzling should be handled differently.
   if (const uint32_t result = tryToAssignToVectorElements(lhs, rhs)) {
     return result;
@@ -2082,15 +2096,77 @@ uint32_t SPIRVEmitter::processAssignment(const Expr *lhs, const uint32_t rhs,
   }
 
   // Normal assignment procedure
-  if (lhsPtr == 0)
+  if (!lhsPtr)
     lhsPtr = doExpr(lhs);
 
-  theBuilder.createStore(lhsPtr, rhs);
+  LayoutRule lhsLayout = LayoutRule::Void;
+  const auto lhsSc = declIdMapper.resolveStorageInfo(lhs, &lhsLayout);
+  storeValue(lhsPtr, rhs, lhs->getType(), lhsSc, lhsLayout, rhsLayout);
+
   // Plain assignment returns a rvalue, while compound assignment returns
   // lvalue.
   return isCompoundAssignment ? lhsPtr : rhs;
 }
 
+void SPIRVEmitter::storeValue(const uint32_t lhsPtr, const uint32_t rhsVal,
+                              const QualType valType,
+                              const spv::StorageClass lhsSc,
+                              const LayoutRule lhsLayout,
+                              const LayoutRule rhsLayout) {
+  // If lhs and rhs has the same memory layout, we should be safe to load
+  // from rhs and directly store into lhs and avoid decomposing rhs.
+  // TODO: is this optimization always correct?
+  if (lhsLayout == rhsLayout || typeTranslator.isScalarType(valType) ||
+      typeTranslator.isVectorType(valType) ||
+      typeTranslator.isMxNMatrix(valType)) {
+    theBuilder.createStore(lhsPtr, rhsVal);
+  } else if (const auto *recordType = valType->getAs<RecordType>()) {
+    uint32_t index = 0;
+    for (const auto *decl : recordType->getDecl()->decls()) {
+      // Implicit generated struct declarations should be ignored.
+      if (isa<CXXRecordDecl>(decl) && decl->isImplicit())
+        continue;
+
+      const auto *field = cast<FieldDecl>(decl);
+      assert(field);
+
+      const auto subRhsValType =
+          typeTranslator.translateType(field->getType(), rhsLayout);
+      const auto subRhsVal =
+          theBuilder.createCompositeExtract(subRhsValType, rhsVal, {index});
+      const auto subLhsPtrType = theBuilder.getPointerType(
+          typeTranslator.translateType(field->getType(), lhsLayout), lhsSc);
+      const auto subLhsPtr = theBuilder.createAccessChain(
+          subLhsPtrType, lhsPtr, {theBuilder.getConstantUint32(index)});
+
+      storeValue(subLhsPtr, subRhsVal, field->getType(), lhsSc, lhsLayout,
+                 rhsLayout);
+      ++index;
+    }
+  } else if (const auto *arrayType =
+                 astContext.getAsConstantArrayType(valType)) {
+    const auto elemType = arrayType->getElementType();
+    // TODO: handle extra large array size?
+    const auto size =
+        static_cast<uint32_t>(arrayType->getSize().getZExtValue());
+
+    for (uint32_t i = 0; i < size; ++i) {
+      const auto subRhsValType =
+          typeTranslator.translateType(elemType, rhsLayout);
+      const auto subRhsVal =
+          theBuilder.createCompositeExtract(subRhsValType, rhsVal, {i});
+      const auto subLhsPtrType = theBuilder.getPointerType(
+          typeTranslator.translateType(elemType, lhsLayout), lhsSc);
+      const auto subLhsPtr = theBuilder.createAccessChain(
+          subLhsPtrType, lhsPtr, {theBuilder.getConstantUint32(i)});
+
+      storeValue(subLhsPtr, subRhsVal, elemType, lhsSc, lhsLayout, rhsLayout);
+    }
+  } else {
+    emitError("storing value of type %0 unimplemented") << valType;
+  }
+}
+
 uint32_t SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
                                        const BinaryOperatorKind opcode,
                                        const uint32_t resultType,
@@ -2526,8 +2602,8 @@ uint32_t SPIRVEmitter::tryToAssignToVectorElements(const Expr *lhs,
 }
 
 uint32_t SPIRVEmitter::tryToAssignToRWBuffer(const Expr *lhs, uint32_t rhs) {
-  const Expr* baseExpr = nullptr;
-  const Expr* indexExpr = nullptr;
+  const Expr *baseExpr = nullptr;
+  const Expr *indexExpr = nullptr;
   if (isBufferIndexing(dyn_cast<CXXOperatorCallExpr>(lhs), &baseExpr,
                        &indexExpr)) {
     const uint32_t locId = doExpr(indexExpr);
@@ -2581,7 +2657,7 @@ uint32_t SPIRVEmitter::tryToAssignToMatrixElements(const Expr *lhs,
       rhsElem = theBuilder.createCompositeExtract(elemTypeId, rhs, {i});
 
     const uint32_t ptrType = theBuilder.getPointerType(
-        elemTypeId, declIdMapper.resolveStorageClass(baseMat));
+        elemTypeId, declIdMapper.resolveStorageInfo(baseMat));
 
     // If the lhs is actually a matrix of size 1x1, we don't need the access
     // chain. base is already the dest pointer.

+ 16 - 5
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -101,6 +101,11 @@ private:
   uint32_t doUnaryOperator(const UnaryOperator *expr);
 
 private:
+  /// Returns the proper type for the given expr. This method is used to please
+  /// expressions derived from global resource variables, when we must construct
+  /// the type with correct layout decorations.
+  uint32_t getType(const Expr *expr);
+
   /// Translates the given frontend binary operator into its SPIR-V equivalent
   /// taking consideration of the operand type.
   spv::Op translateOp(BinaryOperator::Opcode op, QualType type);
@@ -108,8 +113,15 @@ private:
   /// Generates the necessary instructions for assigning rhs to lhs. If lhsPtr
   /// is not zero, it will be used as the pointer from lhs instead of evaluating
   /// lhs again.
-  uint32_t processAssignment(const Expr *lhs, const uint32_t rhs,
-                             bool isCompoundAssignment, uint32_t lhsPtr = 0);
+  uint32_t processAssignment(const Expr *lhs, uint32_t rhs,
+                             bool isCompoundAssignment, uint32_t lhsPtr = 0,
+                             LayoutRule rhsLayout = LayoutRule::Void);
+
+  /// Generates SPIR-V instructions to store rhsVal into lhsPtr. This will be
+  /// recursive if valType is a composite type.
+  void storeValue(uint32_t lhsPtr, uint32_t rhsVal, QualType valType,
+                  spv::StorageClass lhsSc, LayoutRule lhsLayout,
+                  LayoutRule rhsLayout);
 
   /// Generates the necessary instructions for conducting the given binary
   /// operation on lhs and rhs. If lhsResultId is not nullptr, the evaluated
@@ -117,10 +129,9 @@ private:
   /// mandateGenOpcode is not spv::Op::Max, it will used as the SPIR-V opcode
   /// instead of deducing from Clang frontend opcode.
   uint32_t processBinaryOp(const Expr *lhs, const Expr *rhs,
-                           const BinaryOperatorKind opcode,
-                           const uint32_t resultType,
+                           BinaryOperatorKind opcode, uint32_t resultType,
                            uint32_t *lhsResultId = nullptr,
-                           const spv::Op mandateGenOpcode = spv::Op::Max);
+                           spv::Op mandateGenOpcode = spv::Op::Max);
 
   /// Generates SPIR-V instructions to initialize the given variable once.
   void initOnce(std::string varName, uint32_t varPtr, const Expr *varInit);

+ 101 - 0
tools/clang/test/CodeGenSPIRV/binary-op.assign.composite.hlsl

@@ -0,0 +1,101 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct SubBuffer {
+    float    a[1];
+    float2   b[1];
+    float2x3 c[1];
+};
+
+struct BufferType {
+    float     a;
+    float3    b;
+    float3x2  c;
+    SubBuffer d[1];
+};
+
+RWStructuredBuffer<BufferType> sbuf;  // %BufferType                     & %SubBuffer
+    ConstantBuffer<BufferType> cbuf;  // %type_ConstantBuffer_BufferType & %SubBuffer_0
+
+void main(uint index: A) {
+    // Same storage class
+
+// CHECK:      [[sbuf0:%\d+]] = OpAccessChain %_ptr_Uniform_BufferType %sbuf %int_0 %uint_0
+// CHECK-NEXT: [[val:%\d+]] = OpLoad %BufferType [[sbuf0]]
+// CHECK-NEXT: [[sbuf8:%\d+]] = OpAccessChain %_ptr_Uniform_BufferType %sbuf %int_0 %uint_8
+// CHECK-NEXT: OpStore [[sbuf8]] [[val]]
+    sbuf[8] = sbuf[0];
+
+    // Different storage class
+
+
+// CHECK-NEXT: [[lbuf:%\d+]] = OpLoad %BufferType_0 %lbuf
+// CHECK-NEXT: [[sbuf5:%\d+]] = OpAccessChain %_ptr_Uniform_BufferType %sbuf %int_0 %uint_5
+
+    // sbuf[5].a <- lbuf.a
+// CHECK-NEXT: [[val:%\d+]] = OpCompositeExtract %float [[lbuf]] 0
+// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_float [[sbuf5]] %uint_0
+// CHECK-NEXT: OpStore [[ptr]] [[val]]
+
+    // sbuf[5].b <- lbuf.b
+// CHECK-NEXT: [[val:%\d+]] = OpCompositeExtract %v3float [[lbuf]] 1
+// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_v3float [[sbuf5]] %uint_1
+// CHECK-NEXT: OpStore [[ptr]] [[val]]
+
+    // sbuf[5].c <- lbuf.c
+// CHECK-NEXT: [[val:%\d+]] = OpCompositeExtract %mat3v2float [[lbuf]] 2
+// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_mat3v2float [[sbuf5]] %uint_2
+// CHECK-NEXT: OpStore [[ptr]] [[val]]
+
+// CHECK-NEXT: [[lbuf_d:%\d+]] = OpCompositeExtract %_arr_SubBuffer_1_uint_1 [[lbuf]] 3
+// CHECK-NEXT: [[sbuf_d:%\d+]] = OpAccessChain %_ptr_Uniform__arr_SubBuffer_uint_1 [[sbuf5]] %uint_3
+// CHECK-NEXT: [[lbuf_d0:%\d+]] = OpCompositeExtract %SubBuffer_1 [[lbuf_d]] 0
+// CHECK-NEXT: [[sbuf_d0:%\d+]] = OpAccessChain %_ptr_Uniform_SubBuffer [[sbuf_d]] %uint_0
+
+    // sbuf[5].d[0].a[0] <- lbuf.a[0]
+// CHECK-NEXT: [[lbuf_d0_a:%\d+]] = OpCompositeExtract %_arr_float_uint_1_1 [[lbuf_d0]] 0
+// CHECK-NEXT: [[sbuf_d0_a:%\d+]] = OpAccessChain %_ptr_Uniform__arr_float_uint_1 [[sbuf_d0]] %uint_0
+// CHECK-NEXT: [[lbuf_d0_a0:%\d+]] = OpCompositeExtract %float [[lbuf_d0_a]] 0
+// CHECK-NEXT: [[sbuf_d0_a0:%\d+]] = OpAccessChain %_ptr_Uniform_float [[sbuf_d0_a]] %uint_0
+// CHECK-NEXT: OpStore [[sbuf_d0_a0]] [[lbuf_d0_a0]]
+
+    // sbuf[5].d[0].b[0] <- lbuf.b[0]
+// CHECK-NEXT: [[lbuf_d0_b:%\d+]] = OpCompositeExtract %_arr_v2float_uint_1_1 [[lbuf_d0]] 1
+// CHECK-NEXT: [[sbuf_d0_b:%\d+]] = OpAccessChain %_ptr_Uniform__arr_v2float_uint_1 [[sbuf_d0]] %uint_1
+// CHECK-NEXT: [[lbuf_d0_b0:%\d+]] = OpCompositeExtract %v2float [[lbuf_d0_b]] 0
+// CHECK-NEXT: [[sbuf_d0_b0:%\d+]] = OpAccessChain %_ptr_Uniform_v2float [[sbuf_d0_b]] %uint_0
+// CHECK-NEXT: OpStore [[sbuf_d0_b0]] [[lbuf_d0_b0]]
+
+    // sbuf[5].d[0].c[0] <- lbuf.c[0]
+// CHECK-NEXT: [[lbuf_d0_c:%\d+]] = OpCompositeExtract %_arr_mat2v3float_uint_1_1 [[lbuf_d0]] 2
+// CHECK-NEXT: [[sbuf_d0_c:%\d+]] = OpAccessChain %_ptr_Uniform__arr_mat2v3float_uint_1 [[sbuf_d0]] %uint_2
+// CHECK-NEXT: [[lbuf_d0_c0:%\d+]] = OpCompositeExtract %mat2v3float [[lbuf_d0_c]] 0
+// CHECK-NEXT: [[sbuf_d0_c0:%\d+]] = OpAccessChain %_ptr_Uniform_mat2v3float [[sbuf_d0_c]] %uint_0
+// CHECK-NEXT: OpStore [[sbuf_d0_c0]] [[lbuf_d0_c0]]
+    BufferType lbuf;                  // %BufferType_0                   & %SubBuffer_1
+    sbuf[5]  = lbuf;             // %BufferType <- %BufferType_0
+
+// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_SubBuffer_0 %cbuf %int_3 %int_0
+// CHECK-NEXT: [[cbuf_d0:%\d+]] = OpLoad %SubBuffer_0 [[ptr]]
+
+    // sub.a[0] <- cbuf.d[0].a[0]
+// CHECK-NEXT: [[cbuf_d0_a:%\d+]] = OpCompositeExtract %_arr_float_uint_1_0 [[cbuf_d0]] 0
+// CHECK-NEXT: [[sub_a:%\d+]] = OpAccessChain %_ptr_Function__arr_float_uint_1_1 %sub %uint_0
+// CHECK-NEXT: [[cbuf_d0_a0:%\d+]] = OpCompositeExtract %float [[cbuf_d0_a]] 0
+// CHECK-NEXT: [[sub_a0:%\d+]] = OpAccessChain %_ptr_Function_float [[sub_a]] %uint_0
+// CHECK-NEXT: OpStore [[sub_a0]] [[cbuf_d0_a0]]
+
+    // sub.b[0] <- cbuf.d[0].b[0]
+// CHECK-NEXT: [[cbuf_d0_b:%\d+]] = OpCompositeExtract %_arr_v2float_uint_1_0 [[cbuf_d0]] 1
+// CHECK-NEXT: [[sub_b:%\d+]] = OpAccessChain %_ptr_Function__arr_v2float_uint_1_1 %sub %uint_1
+// CHECK-NEXT: [[cbuf_d0_b0:%\d+]] = OpCompositeExtract %v2float [[cbuf_d0_b]] 0
+// CHECK-NEXT: [[sub_b0:%\d+]] = OpAccessChain %_ptr_Function_v2float [[sub_b]] %uint_0
+// CHECK-NEXT: OpStore [[sub_b0]] [[cbuf_d0_b0]]
+
+    // sub.c[0] <- cbuf.d[0].c[0]
+// CHECK-NEXT: [[cbuf_d0_c:%\d+]] = OpCompositeExtract %_arr_mat2v3float_uint_1_0 [[cbuf_d0]] 2
+// CHECK-NEXT: [[sub_c:%\d+]] = OpAccessChain %_ptr_Function__arr_mat2v3float_uint_1_1 %sub %uint_2
+// CHECK-NEXT: [[cbuf_d0_c0:%\d+]] = OpCompositeExtract %mat2v3float [[cbuf_d0_c]] 0
+// CHECK-NEXT: [[sub_c0:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float [[sub_c]] %uint_0
+// CHECK-NEXT: OpStore [[sub_c0]] [[cbuf_d0_c0]]
+    SubBuffer sub = cbuf.d[0];        // %SubBuffer_1 <- %SubBuffer_0
+}

+ 12 - 14
tools/clang/test/CodeGenSPIRV/method.structured-buffer.load.hlsl

@@ -11,22 +11,20 @@ RWStructuredBuffer<SBuffer> mySBuffer2;
 float4 main(int index: A) : SV_Target {
     // b1 and b2's type does not need layout decorations. So it's a different
     // SBuffer definition.
-// XXXXX-NOT:  OpMemberDecorate %SBuffer_0 0 Offset 0
-// XXXXX:      %_ptr_Function_SBuffer_0 = OpTypePointer Function %SBuffer_0
+// CHECK-NOT:  OpMemberDecorate %SBuffer_0 0 Offset 0
+// CHECK:      %_ptr_Function_SBuffer_0 = OpTypePointer Function %SBuffer_0
 
-// XXXXX:      %b1 = OpVariable %_ptr_Function_SBuffer_0 Function
-// XXXXX-NEXT: %b2 = OpVariable %_ptr_Function_SBuffer_0 Function
+// CHECK:      %b1 = OpVariable %_ptr_Function_SBuffer_0 Function
+// CHECK-NEXT: %b2 = OpVariable %_ptr_Function_SBuffer_0 Function
 
-// TODO: wrong codegen right now: missing load the value from sb1 & sb2
-// TODO: need to make sure we have %SBuffer (not %SBuffer_0) as the loaded type
-// XXXXX:      [[index:%\d+]] = OpLoad %int %index
-// XXXXX:      [[sb1:%\d+]] = OpAccessChain %_ptr_Uniform_SBuffer %mySBuffer1 %int_0 [[index]]
-// XXXXX:      {{%\d+}} = OpLoad %SBuffer [[sb1]]
-// XXXXX:      [[sb2:%\d+]] = OpAccessChain %_ptr_Uniform_SBuffer %mySBuffer2 %int_0 %int_0
-// XXXXX:      {{%\d+}} = OpLoad %SBuffer [[sb2]]
-    //SBuffer b1 = mySBuffer1.Load(index);
-    //SBuffer b2;
-    //b2 = mySBuffer2.Load(0);
+// CHECK:      [[index:%\d+]] = OpLoad %int %index
+// CHECK:      [[sb1:%\d+]] = OpAccessChain %_ptr_Uniform_SBuffer %mySBuffer1 %int_0 [[index]]
+// CHECK:      {{%\d+}} = OpLoad %SBuffer [[sb1]]
+// CHECK:      [[sb2:%\d+]] = OpAccessChain %_ptr_Uniform_SBuffer %mySBuffer2 %int_0 %int_0
+// CHECK:      {{%\d+}} = OpLoad %SBuffer [[sb2]]
+    SBuffer b1 = mySBuffer1.Load(index);
+    SBuffer b2;
+    b2 = mySBuffer2.Load(0);
 
 // CHECK:      [[f1:%\d+]] = OpAccessChain %_ptr_Uniform_v4float %mySBuffer1 %int_0 %int_5 %int_0
 // CHECK-NEXT: [[x:%\d+]] = OpAccessChain %_ptr_Uniform_float [[f1]] %int_0

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -112,6 +112,9 @@ TEST_F(FileTest, UnaryOpLogicalNot) {
 
 // For assignments
 TEST_F(FileTest, BinaryOpAssign) { runFileTest("binary-op.assign.hlsl"); }
+TEST_F(FileTest, BinaryOpAssignComposite) {
+  runFileTest("binary-op.assign.composite.hlsl");
+}
 
 // For arithmetic binary operators
 TEST_F(FileTest, BinaryOpScalarArithmetic) {