Selaa lähdekoodia

[spirv] Cleanups related to handling of storage classes (#2233)

* [spirv] Better handling of OpAccessChains.

* [spirv] Clean up usage of SpirvType in some GlPerVertex methods.

* [spirv] Remove createCompositeConstruct which uses SpirvType.

Also cleans up its usage in SpirvEmitter.cpp nicely.

* address code review comments.
Ehsan 6 vuotta sitten
vanhempi
commit
27538208c6

+ 3 - 8
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -114,10 +114,6 @@ public:
   createCompositeConstruct(QualType resultType,
                            llvm::ArrayRef<SpirvInstruction *> constituents,
                            SourceLocation loc);
-  SpirvCompositeConstruct *
-  createCompositeConstruct(const SpirvType *resultType,
-                           llvm::ArrayRef<SpirvInstruction *> constituents,
-                           SourceLocation loc);
 
   /// \brief Creates a composite extract instruction. The given composite is
   /// indexed using the given literal indexes to obtain the resulting element.
@@ -167,14 +163,13 @@ public:
   /// \brief Creates an access chain instruction to retrieve the element from
   /// the given base by walking through the given indexes. Returns the
   /// instruction pointer for the pointer to the element.
+  /// Note: The given 'resultType' should be the underlying value type, not the
+  /// pointer type. The type lowering pass automatically adds pointerness and
+  /// proper storage class (based on the access base) to the result type.
   SpirvAccessChain *
   createAccessChain(QualType resultType, SpirvInstruction *base,
                     llvm::ArrayRef<SpirvInstruction *> indexes,
                     SourceLocation loc);
-  SpirvAccessChain *
-  createAccessChain(const SpirvType *resultType, SpirvInstruction *base,
-                    llvm::ArrayRef<SpirvInstruction *> indexes,
-                    SourceLocation loc);
 
   /// \brief Creates a unary operation with the given SPIR-V opcode. Returns
   /// the instruction pointer for the result.

+ 6 - 16
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -529,11 +529,8 @@ SpirvInstruction *DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
 
       // Should only have VarDecls in a HLSLBufferDecl.
       QualType valueType = cast<VarDecl>(decl)->getType();
-      const auto *ptrType =
-          spvContext.getPointerType(valueType, info->instr->getStorageClass());
-
       return spvBuilder.createAccessChain(
-          ptrType, info->instr,
+          valueType, info->instr,
           {spvBuilder.getConstantInt(
               astContext.IntTy, llvm::APInt(32, info->indexInCTBuffer, true))},
           loc);
@@ -1971,10 +1968,8 @@ bool DeclResultIdMapper::createStageVars(
           hlsl::GetArraySize(type) != 4) {
         const auto tessFactorSize = hlsl::GetArraySize(type);
         for (uint32_t i = 0; i < tessFactorSize; ++i) {
-          const auto ptrType = spvContext.getPointerType(
-              spvContext.getFloatType(32), spv::StorageClass::Output);
           ptr = spvBuilder.createAccessChain(
-              ptrType, varInstr,
+              astContext.FloatTy, varInstr,
               {spvBuilder.getConstantInt(astContext.UnsignedIntTy,
                                          llvm::APInt(32, i))},
               thisSemantic.loc);
@@ -1994,9 +1989,7 @@ bool DeclResultIdMapper::createStageVars(
                // Some developers use float[1] instead of a scalar float.
                (!type->isArrayType() || hlsl::GetArraySize(type) == 1)) {
         ptr = spvBuilder.createAccessChain(
-            spvContext.getPointerType(spvContext.getFloatType(32),
-                                      spv::StorageClass::Output),
-            varInstr,
+            astContext.FloatTy, varInstr,
             spvBuilder.getConstantInt(astContext.UnsignedIntTy,
                                       llvm::APInt(32, 0)),
             thisSemantic.loc);
@@ -2008,10 +2001,8 @@ bool DeclResultIdMapper::createStageVars(
       // Special handling of SV_Coverage, which is an unit value. We need to
       // write it to the first element in the SampleMask builtin.
       else if (semanticKind == hlsl::Semantic::Kind::Coverage) {
-        const auto *ptrType =
-            spvContext.getPointerType(type, spv::StorageClass::Output);
         ptr = spvBuilder.createAccessChain(
-            ptrType, varInstr,
+            type, varInstr,
             spvBuilder.getConstantInt(astContext.UnsignedIntTy,
                                       llvm::APInt(32, 0)),
             thisSemantic.loc);
@@ -2027,9 +2018,8 @@ bool DeclResultIdMapper::createStageVars(
         const auto elementType =
             astContext.getAsArrayType(evalType)->getElementType();
         auto index = invocationId.getValue();
-        ptr = spvBuilder.createAccessChain(
-            spvContext.getPointerType(elementType, spv::StorageClass::Output),
-            varInstr, index, thisSemantic.loc);
+        ptr = spvBuilder.createAccessChain(elementType, varInstr, index,
+                                           thisSemantic.loc);
         ptr->setStorageClass(spv::StorageClass::Output);
         spvBuilder.createStore(ptr, *value, thisSemantic.loc);
       }

+ 26 - 32
tools/clang/lib/SPIRV/GlPerVertex.cpp

@@ -399,9 +399,7 @@ SpirvInstruction *GlPerVertex::readClipCullArrayAsType(
 
   // The ClipDistance/CullDistance is always an float array. We are accessing
   // it using pointers, which should be of pointer to float type.
-  const FloatType *f32Type = spvContext.getFloatType(32);
-  const SpirvPointerType *ptrType =
-      spvContext.getPointerType(f32Type, spv::StorageClass::Input);
+  const QualType f32Type = astContext.FloatTy;
 
   if (inArraySize == 0) {
     // The input builtin does not have extra arrayness. Only need one index
@@ -413,9 +411,9 @@ SpirvInstruction *GlPerVertex::readClipCullArrayAsType(
     if (isScalarType(asType)) {
       auto *spirvConstant = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
                                                       llvm::APInt(32, offset));
-      auto *ptr = spvBuilder.createAccessChain(ptrType, clipCullVar,
+      auto *ptr = spvBuilder.createAccessChain(f32Type, clipCullVar,
                                                {spirvConstant}, loc);
-      return spvBuilder.createLoad(astContext.FloatTy, ptr, loc);
+      return spvBuilder.createLoad(f32Type, ptr, loc);
     }
 
     if (isVectorType(asType, &elemType, &count)) {
@@ -426,12 +424,12 @@ SpirvInstruction *GlPerVertex::readClipCullArrayAsType(
         // Read elements sequentially from the float array
         auto *spirvConstant = spvBuilder.getConstantInt(
             astContext.UnsignedIntTy, llvm::APInt(32, offset + i));
-        auto *ptr = spvBuilder.createAccessChain(ptrType, clipCullVar,
+        auto *ptr = spvBuilder.createAccessChain(f32Type, clipCullVar,
                                                  {spirvConstant}, loc);
-        elements.push_back(spvBuilder.createLoad(astContext.FloatTy, ptr, loc));
+        elements.push_back(spvBuilder.createLoad(f32Type, ptr, loc));
       }
       return spvBuilder.createCompositeConstruct(
-          spvContext.getVectorType(f32Type, count), elements, loc);
+          astContext.getExtVectorType(f32Type, count), elements, loc);
     }
 
     llvm_unreachable("SV_ClipDistance/SV_CullDistance not float or vector of "
@@ -448,33 +446,32 @@ SpirvInstruction *GlPerVertex::readClipCullArrayAsType(
   llvm::SmallVector<SpirvInstruction *, 8> arrayElements;
   QualType elemType = {};
   uint32_t count = {};
-  const ArrayType *arrayType = nullptr;
+  QualType arrayType = {};
 
   if (isScalarType(asType)) {
-    arrayType = spvContext.getArrayType(f32Type, inArraySize,
-                                        /*ArrayStride*/ llvm::None);
+    arrayType = astContext.getConstantArrayType(
+        f32Type, llvm::APInt(32, inArraySize), clang::ArrayType::Normal, 0);
     for (uint32_t i = 0; i < inArraySize; ++i) {
       auto *ptr = spvBuilder.createAccessChain(
-          ptrType, clipCullVar,
+          f32Type, clipCullVar,
           {spvBuilder.getConstantInt(astContext.UnsignedIntTy,
                                      llvm::APInt(32, i)), // Block array index
            spvBuilder.getConstantInt(astContext.UnsignedIntTy,
                                      llvm::APInt(32, offset))},
           loc);
-      arrayElements.push_back(
-          spvBuilder.createLoad(astContext.FloatTy, ptr, loc));
+      arrayElements.push_back(spvBuilder.createLoad(f32Type, ptr, loc));
     }
   } else if (isVectorType(asType, &elemType, &count)) {
-    arrayType =
-        spvContext.getArrayType(spvContext.getVectorType(f32Type, count),
-                                inArraySize, /*ArrayStride*/ llvm::None);
+    arrayType = astContext.getConstantArrayType(
+        astContext.getExtVectorType(f32Type, count),
+        llvm::APInt(32, inArraySize), clang::ArrayType::Normal, 0);
 
     for (uint32_t i = 0; i < inArraySize; ++i) {
       // For each gl_PerVertex block, we need to read a vector from it.
       llvm::SmallVector<SpirvInstruction *, 4> vecElements;
       for (uint32_t j = 0; j < count; ++j) {
         auto *ptr = spvBuilder.createAccessChain(
-            ptrType, clipCullVar,
+            f32Type, clipCullVar,
             // Block array index
             {spvBuilder.getConstantInt(astContext.UnsignedIntTy,
                                        llvm::APInt(32, i)),
@@ -482,11 +479,10 @@ SpirvInstruction *GlPerVertex::readClipCullArrayAsType(
              spvBuilder.getConstantInt(astContext.UnsignedIntTy,
                                        llvm::APInt(32, offset + j))},
             loc);
-        vecElements.push_back(
-            spvBuilder.createLoad(astContext.FloatTy, ptr, loc));
+        vecElements.push_back(spvBuilder.createLoad(f32Type, ptr, loc));
       }
       arrayElements.push_back(spvBuilder.createCompositeConstruct(
-          spvContext.getVectorType(f32Type, count), vecElements, loc));
+          astContext.getExtVectorType(f32Type, count), vecElements, loc));
     }
   } else {
     llvm_unreachable("SV_ClipDistance/SV_CullDistance not float or vector of "
@@ -536,9 +532,7 @@ void GlPerVertex::writeClipCullArrayFromType(
 
   // The ClipDistance/CullDistance is always an float array. We are accessing
   // it using pointers, which should be of pointer to float type.
-  const FloatType *f32Type = spvContext.getFloatType(32);
-  const SpirvPointerType *ptrType =
-      spvContext.getPointerType(f32Type, spv::StorageClass::Output);
+  const QualType f32Type = astContext.FloatTy;
 
   if (outArraySize == 0) {
     // The output builtin does not have extra arrayness. Only need one index
@@ -551,7 +545,7 @@ void GlPerVertex::writeClipCullArrayFromType(
       auto *constant = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
                                                  llvm::APInt(32, offset));
       auto *ptr =
-          spvBuilder.createAccessChain(ptrType, clipCullVar, {constant}, loc);
+          spvBuilder.createAccessChain(f32Type, clipCullVar, {constant}, loc);
       spvBuilder.createStore(ptr, fromValue, loc);
       return;
     }
@@ -564,9 +558,9 @@ void GlPerVertex::writeClipCullArrayFromType(
         auto *constant = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
                                                    llvm::APInt(32, offset + i));
         auto *ptr =
-            spvBuilder.createAccessChain(ptrType, clipCullVar, {constant}, loc);
-        auto *subValue = spvBuilder.createCompositeExtract(astContext.FloatTy,
-                                                           fromValue, {i}, loc);
+            spvBuilder.createAccessChain(f32Type, clipCullVar, {constant}, loc);
+        auto *subValue =
+            spvBuilder.createCompositeExtract(f32Type, fromValue, {i}, loc);
         spvBuilder.createStore(ptr, subValue, loc);
       }
       return;
@@ -595,7 +589,7 @@ void GlPerVertex::writeClipCullArrayFromType(
 
   if (isScalarType(fromType)) {
     auto *ptr = spvBuilder.createAccessChain(
-        ptrType, clipCullVar,
+        f32Type, clipCullVar,
         {arrayIndex, spvBuilder.getConstantInt(astContext.UnsignedIntTy,
                                                llvm::APInt(32, offset))},
         loc);
@@ -607,7 +601,7 @@ void GlPerVertex::writeClipCullArrayFromType(
     // For each gl_PerVertex block, we need to write a vector into it.
     for (uint32_t i = 0; i < count; ++i) {
       auto *ptr = spvBuilder.createAccessChain(
-          ptrType, clipCullVar,
+          f32Type, clipCullVar,
           // Block array index
           {arrayIndex,
            // Write elements sequentially into the float array
@@ -615,8 +609,8 @@ void GlPerVertex::writeClipCullArrayFromType(
                                      llvm::APInt(32, offset + i))},
           loc);
 
-      auto *subValue = spvBuilder.createCompositeExtract(astContext.FloatTy,
-                                                         fromValue, {i}, loc);
+      auto *subValue =
+          spvBuilder.createCompositeExtract(f32Type, fromValue, {i}, loc);
       spvBuilder.createStore(ptr, subValue, loc);
     }
     return;

+ 9 - 0
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -91,6 +91,15 @@ bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
     instr->setResultType(pointerType);
     break;
   }
+  // Access chains must have a pointer type. The storage class for the pointer
+  // is the same as the storage class of the access base.
+  case spv::Op::OpAccessChain: {
+    const auto *pointerType = spvContext.getPointerType(
+        resultType,
+        cast<SpirvAccessChain>(instr)->getBase()->getStorageClass());
+    instr->setResultType(pointerType);
+    break;
+  }
   // OpImageTexelPointer's result type must be a pointer with image storage
   // class.
   case spv::Op::OpImageTexelPointer: {

+ 11 - 39
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -120,20 +120,6 @@ SpirvCompositeConstruct *SpirvBuilder::createCompositeConstruct(
   return instruction;
 }
 
-SpirvCompositeConstruct *SpirvBuilder::createCompositeConstruct(
-    const SpirvType *resultType,
-    llvm::ArrayRef<SpirvInstruction *> constituents, SourceLocation loc) {
-  assert(insertPoint && "null insert point");
-  auto *instruction =
-      new (context) SpirvCompositeConstruct(/*QualType*/ {}, loc, constituents);
-  instruction->setResultType(resultType);
-  if (!constituents.empty()) {
-    instruction->setLayoutRule(constituents[0]->getLayoutRule());
-  }
-  insertPoint->addInstruction(instruction);
-  return instruction;
-}
-
 SpirvCompositeExtract *SpirvBuilder::createCompositeExtract(
     QualType resultType, SpirvInstruction *composite,
     llvm::ArrayRef<uint32_t> indexes, SourceLocation loc) {
@@ -198,6 +184,17 @@ SpirvLoad *SpirvBuilder::createLoad(const SpirvType *resultType,
   auto *instruction = new (context) SpirvLoad(/*QualType*/ {}, loc, pointer);
   instruction->setResultType(resultType);
   instruction->setStorageClass(pointer->getStorageClass());
+  // Special case for legalization. We could have point-to-pointer types.
+  // For example:
+  //
+  // %var = OpVariable %_ptr_Private__ptr_Uniform_type_X Private
+  // %1 = OpLoad %_ptr_Uniform_type_X %var
+  //
+  // Loading from %var should result in Uniform storage class, not Private.
+  if (const auto *ptrType = dyn_cast<SpirvPointerType>(resultType)) {
+    instruction->setStorageClass(ptrType->getStorageClass());
+  }
+
   instruction->setLayoutRule(pointer->getLayoutRule());
   instruction->setNonUniform(pointer->isNonUniform());
   instruction->setRValue(true);
@@ -261,31 +258,6 @@ SpirvBuilder::createAccessChain(QualType resultType, SpirvInstruction *base,
   return instruction;
 }
 
-SpirvAccessChain *SpirvBuilder::createAccessChain(
-    const SpirvType *resultType, SpirvInstruction *base,
-    llvm::ArrayRef<SpirvInstruction *> indexes, SourceLocation loc) {
-  assert(insertPoint && "null insert point");
-  auto *instruction =
-      new (context) SpirvAccessChain(/*QualType*/ {}, loc, base, indexes);
-  instruction->setResultType(resultType);
-  instruction->setStorageClass(base->getStorageClass());
-  instruction->setLayoutRule(base->getLayoutRule());
-  bool isNonUniform = base->isNonUniform();
-  for (auto *index : indexes)
-    isNonUniform = isNonUniform || index->isNonUniform();
-  instruction->setNonUniform(isNonUniform);
-  instruction->setContainsAliasComponent(base->containsAliasComponent());
-
-  // If doing an access chain into a structured or byte address buffer, make
-  // sure the layout rule is sBufferLayoutRule.
-  if (base->hasAstResultType() &&
-      isAKindOfStructuredOrByteBuffer(base->getAstResultType()))
-    instruction->setLayoutRule(spirvOptions.sBufferLayoutRule);
-
-  insertPoint->addInstruction(instruction);
-  return instruction;
-}
-
 SpirvUnaryOp *SpirvBuilder::createUnaryOp(spv::Op op, QualType resultType,
                                           SpirvInstruction *operand,
                                           SourceLocation loc) {

+ 40 - 62
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -831,6 +831,8 @@ SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr) {
 
 SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr,
                                               SpirvInstruction *info) {
+  const auto exprType = expr->getType();
+
   // Do nothing if this is already rvalue
   if (!info || info->isRValue())
     return info;
@@ -839,7 +841,7 @@ SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr,
   // If true, we are likely to copy it as a whole. To assist per-element
   // copying, avoid the load here and return the pointer directly.
   // TODO: consider moving this hack into SPIRV-Tools as a transformation.
-  if (isOpaqueArrayType(expr->getType()))
+  if (isOpaqueArrayType(exprType))
     return info;
 
   // Check whether we are trying to load an externally visible structured/byte
@@ -872,8 +874,7 @@ SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr,
         declIdMapper.getCTBufferPushConstantType(declContext), info,
         expr->getExprLoc());
   } else {
-    loadedInstr =
-        spvBuilder.createLoad(expr->getType(), info, expr->getExprLoc());
+    loadedInstr = spvBuilder.createLoad(exprType, info, expr->getExprLoc());
   }
   assert(loadedInstr);
 
@@ -884,40 +885,29 @@ SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr,
   {
     uint32_t vecSize = 1, numRows = 0, numCols = 0;
     if (info->getLayoutRule() != SpirvLayoutRule::Void &&
-        isBoolOrVecMatOfBoolType(expr->getType())) {
-      const auto exprType = expr->getType();
+        isBoolOrVecMatOfBoolType(exprType)) {
       QualType uintType = astContext.UnsignedIntTy;
-      QualType boolType = astContext.BoolTy;
       if (isScalarType(exprType) || isVectorType(exprType, nullptr, &vecSize)) {
         const auto fromType =
             vecSize == 1 ? uintType
                          : astContext.getExtVectorType(uintType, vecSize);
-        const auto toType =
-            vecSize == 1 ? boolType
-                         : astContext.getExtVectorType(boolType, vecSize);
         loadedInstr =
-            castToBool(loadedInstr, fromType, toType, expr->getLocStart());
+            castToBool(loadedInstr, fromType, exprType, expr->getLocStart());
       } else {
         const bool isMat = isMxNMatrix(exprType, nullptr, &numRows, &numCols);
         assert(isMat);
         (void)isMat;
-        const auto uintRowQualType =
-            astContext.getExtVectorType(uintType, numCols);
-        const auto boolRowQualType =
-            astContext.getExtVectorType(boolType, numCols);
-        const SpirvType *resultType = spvContext.getMatrixType(
-            spvContext.getVectorType(spvContext.getBoolType(), numCols),
-            numRows);
-
-        llvm::SmallVector<SpirvInstruction *, 4> rows;
-        for (uint32_t i = 0; i < numRows; ++i) {
-          auto *row = spvBuilder.createCompositeExtract(
-              uintRowQualType, loadedInstr, {i}, expr->getLocStart());
-          rows.push_back(castToBool(row, uintRowQualType, boolRowQualType,
-                                    expr->getLocStart()));
-        }
-        loadedInstr = spvBuilder.createCompositeConstruct(resultType, rows,
-                                                          expr->getExprLoc());
+        const clang::Type *type = exprType.getCanonicalType().getTypePtr();
+        const RecordType *RT = cast<RecordType>(type);
+        const ClassTemplateSpecializationDecl *templateSpecDecl =
+            cast<ClassTemplateSpecializationDecl>(RT->getDecl());
+        ClassTemplateDecl *templateDecl =
+            templateSpecDecl->getSpecializedTemplate();
+        const auto fromType = getHLSLMatrixType(
+            astContext, theCompilerInstance.getSema(), templateDecl,
+            astContext.UnsignedIntTy, numRows, numCols);
+        loadedInstr =
+            castToBool(loadedInstr, fromType, exprType, expr->getLocStart());
       }
       // Now that it is converted to Bool, it has no layout rule.
       // This result-id should be evaluated as bool from here on out.
@@ -2825,10 +2815,9 @@ SpirvInstruction *SpirvEmitter::processRWByteAddressBufferAtomicMethods(
   auto *address = spvBuilder.createBinaryOp(
       spv::Op::OpShiftRightLogical, astContext.UnsignedIntTy, offset,
       spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 2)));
-  auto *ptr = spvBuilder.createAccessChain(
-      spvContext.getPointerType(astContext.UnsignedIntTy,
-                                objectInfo->getStorageClass()),
-      objectInfo, {zero, address}, object->getLocStart());
+  auto *ptr =
+      spvBuilder.createAccessChain(astContext.UnsignedIntTy, objectInfo,
+                                   {zero, address}, object->getLocStart());
 
   const bool isCompareExchange =
       opcode == hlsl::IntrinsicOp::MOP_InterlockedCompareExchange;
@@ -3317,8 +3306,6 @@ SpirvInstruction *SpirvEmitter::processByteAddressBufferLoadStore(
   // the address.
   auto *constUint0 =
       spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
-  const auto *ptrType = spvContext.getPointerType(
-      astContext.UnsignedIntTy, objectInfo->getStorageClass());
   if (doStore) {
     auto *values = doExpr(expr->getArg(1));
     auto *curStoreAddress = address;
@@ -3341,14 +3328,15 @@ SpirvInstruction *SpirvEmitter::processByteAddressBufferLoadStore(
 
       // Store the word to the right address at the output.
       auto *storePtr = spvBuilder.createAccessChain(
-          ptrType, objectInfo, {constUint0, curStoreAddress},
+          astContext.UnsignedIntTy, objectInfo, {constUint0, curStoreAddress},
           object->getLocStart());
       spvBuilder.createStore(storePtr, curValue,
                              expr->getCallee()->getExprLoc());
     }
   } else {
     auto *loadPtr = spvBuilder.createAccessChain(
-        ptrType, objectInfo, {constUint0, address}, object->getLocStart());
+        astContext.UnsignedIntTy, objectInfo, {constUint0, address},
+        object->getLocStart());
     result = spvBuilder.createLoad(astContext.UnsignedIntTy, loadPtr,
                                    expr->getCallee()->getExprLoc());
     if (numWords > 1) {
@@ -3362,9 +3350,9 @@ SpirvInstruction *SpirvEmitter::processByteAddressBufferLoadStore(
         auto *newAddress =
             spvBuilder.createBinaryOp(spv::Op::OpIAdd, addressType, address,
                                       offset, expr->getCallee()->getExprLoc());
-        loadPtr = spvBuilder.createAccessChain(ptrType, objectInfo,
-                                               {constUint0, newAddress},
-                                               object->getLocStart());
+        loadPtr = spvBuilder.createAccessChain(
+            astContext.UnsignedIntTy, objectInfo, {constUint0, newAddress},
+            object->getLocStart());
         values.push_back(
             spvBuilder.createLoad(astContext.UnsignedIntTy, loadPtr,
                                   expr->getCallee()->getExprLoc()));
@@ -3429,10 +3417,9 @@ SpirvEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
     return nullptr;
   }
 
-  const auto *counterPtrType =
-      spvContext.getPointerType(astContext.IntTy, spv::StorageClass::Uniform);
   auto *counterPtr = spvBuilder.createAccessChain(
-      counterPtrType, counterPair->get(spvBuilder, spvContext), {zero}, srcLoc);
+      astContext.IntTy, counterPair->get(spvBuilder, spvContext), {zero},
+      srcLoc);
 
   SpirvInstruction *index = nullptr;
   if (isInc) {
@@ -3724,8 +3711,6 @@ SpirvEmitter::emitGetSamplePosition(SpirvInstruction *sampleCount,
   //   }
 
   const auto v2f32Type = astContext.getExtVectorType(astContext.FloatTy, 2);
-  const auto *ptrType =
-      spvContext.getPointerType(v2f32Type, spv::StorageClass::Function);
 
   // Creates a SPIR-V function scope variable of type float2[len].
   const auto createArray = [this, v2f32Type, loc](const Float2 *ptr,
@@ -3786,7 +3771,8 @@ SpirvEmitter::emitGetSamplePosition(SpirvInstruction *sampleCount,
   //     position = pos2[index];
   //   }
   spvBuilder.setInsertPoint(then2BB);
-  auto *ac = spvBuilder.createAccessChain(ptrType, pos2Arr, {sampleIndex}, loc);
+  auto *ac =
+      spvBuilder.createAccessChain(v2f32Type, pos2Arr, {sampleIndex}, loc);
   spvBuilder.createStore(resultVar, spvBuilder.createLoad(v2f32Type, ac, loc),
                          loc);
   spvBuilder.createBranch(merge2BB, /*SourceLocation*/ {});
@@ -3806,7 +3792,7 @@ SpirvEmitter::emitGetSamplePosition(SpirvInstruction *sampleCount,
   //     position = pos4[index];
   //   }
   spvBuilder.setInsertPoint(then4BB);
-  ac = spvBuilder.createAccessChain(ptrType, pos4Arr, {sampleIndex}, loc);
+  ac = spvBuilder.createAccessChain(v2f32Type, pos4Arr, {sampleIndex}, loc);
   spvBuilder.createStore(resultVar, spvBuilder.createLoad(v2f32Type, ac, loc),
                          loc);
   spvBuilder.createBranch(merge4BB, /*SourceLocation*/ {});
@@ -3826,7 +3812,7 @@ SpirvEmitter::emitGetSamplePosition(SpirvInstruction *sampleCount,
   //     position = pos8[index];
   //   }
   spvBuilder.setInsertPoint(then8BB);
-  ac = spvBuilder.createAccessChain(ptrType, pos8Arr, {sampleIndex}, loc);
+  ac = spvBuilder.createAccessChain(v2f32Type, pos8Arr, {sampleIndex}, loc);
   spvBuilder.createStore(resultVar, spvBuilder.createLoad(v2f32Type, ac, loc),
                          loc);
   spvBuilder.createBranch(merge8BB, /*SourceLocation*/ {});
@@ -3846,7 +3832,7 @@ SpirvEmitter::emitGetSamplePosition(SpirvInstruction *sampleCount,
   //     position = pos16[index];
   //   }
   spvBuilder.setInsertPoint(then16BB);
-  ac = spvBuilder.createAccessChain(ptrType, pos16Arr, {sampleIndex}, loc);
+  ac = spvBuilder.createAccessChain(v2f32Type, pos16Arr, {sampleIndex}, loc);
   spvBuilder.createStore(resultVar, spvBuilder.createLoad(v2f32Type, ac, loc),
                          loc);
   spvBuilder.createBranch(merge16BB, /*SourceLocation*/ {});
@@ -4599,8 +4585,7 @@ SpirvEmitter::doExtMatrixElementExpr(const ExtMatrixElementExpr *expr) {
         assert(!baseInfo->isRValue());
         // Load the element via access chain
         elem = spvBuilder.createAccessChain(
-            spvContext.getPointerType(elemType, baseInfo->getStorageClass()),
-            baseInfo, indexInstructions, baseExpr->getLocStart());
+            elemType, baseInfo, indexInstructions, baseExpr->getLocStart());
       } else {
         // The matrix is of size 1x1. No need to use access chain, base should
         // be the source pointer.
@@ -4678,9 +4663,8 @@ SpirvEmitter::doHLSLVectorElementExpr(const HLSLVectorElementExpr *expr) {
       auto *index = spvBuilder.getConstantInt(
           astContext.IntTy, llvm::APInt(32, accessor.Swz0, true));
       // We need a lvalue here. Do not try to load.
-      return spvBuilder.createAccessChain(
-          spvContext.getPointerType(type, baseInfo->getStorageClass()),
-          baseInfo, {index}, baseExpr->getLocStart());
+      return spvBuilder.createAccessChain(type, baseInfo, {index},
+                                          baseExpr->getLocStart());
     } else { // E.g., (v + w).x;
       // The original base vector may not be a rvalue. Need to load it if
       // it is lvalue since ImplicitCastExpr (LValueToRValue) will be missing
@@ -5091,10 +5075,8 @@ void SpirvEmitter::storeValue(SpirvInstruction *lhsPtr,
     // Do separate load of each element via access chain
     llvm::SmallVector<SpirvInstruction *, 8> elements;
     for (uint32_t i = 0; i < arraySize; ++i) {
-      const auto *subRhsPtrType =
-          spvContext.getPointerType(elemType, rhsVal->getStorageClass());
       auto *subRhsPtr = spvBuilder.createAccessChain(
-          subRhsPtrType, rhsVal,
+          elemType, rhsVal,
           {spvBuilder.getConstantInt(astContext.IntTy,
                                      llvm::APInt(32, i, true))},
           loc);
@@ -5965,8 +5947,7 @@ SpirvEmitter::tryToAssignToMatrixElements(const Expr *lhs,
       assert(!base->isRValue());
       // Load the element via access chain
       lhsElemPtr = spvBuilder.createAccessChain(
-          spvContext.getPointerType(elemType, lhsElemPtr->getStorageClass()),
-          lhsElemPtr, indexInstructions, lhs->getLocStart());
+          elemType, lhsElemPtr, indexInstructions, lhs->getLocStart());
     }
 
     spvBuilder.createStore(lhsElemPtr, rhsElem, lhs->getLocStart());
@@ -6267,9 +6248,7 @@ SpirvInstruction *SpirvEmitter::turnIntoElementPtr(
     accessChainBase = var;
   }
 
-  base = spvBuilder.createAccessChain(
-      spvContext.getPointerType(elemType, accessChainBase->getStorageClass()),
-      accessChainBase, indices, loc);
+  base = spvBuilder.createAccessChain(elemType, accessChainBase, indices, loc);
 
   // Okay, this part seems weird, but it is intended:
   // If the base is originally a rvalue, the whole AST involving the base
@@ -10284,8 +10263,7 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
     hullMainOutputPatch = spvBuilder.addFnVar(
         hullMainRetType, /*SourceLocation*/ {}, "temp.var.hullMainRetVal");
     auto *tempLocation = spvBuilder.createAccessChain(
-        spvContext.getPointerType(retType, spv::StorageClass::Function),
-        hullMainOutputPatch, {outputControlPointId},
+        retType, hullMainOutputPatch, {outputControlPointId},
         hullMainFuncDecl->getLocation());
     spvBuilder.createStore(tempLocation, retVal,
                            hullMainFuncDecl->getLocation());