Browse Source

Fixed pointer arithmetic stride issues

Brian Fiete 4 years ago
parent
commit
e60bbdf64f

+ 6 - 6
IDEHelper/Backend/BeMCContext.cpp

@@ -16813,12 +16813,12 @@ void BeMCContext::Generate(BeFunction* function)
 								int relScale = 1;
 								if (mcIdx1.IsImmediate())
 								{
-									mcRelOffset = BeMCOperand::FromImmediate(mcIdx1.mImmediate * arrayType->mElementType->mSize);
+									mcRelOffset = BeMCOperand::FromImmediate(mcIdx1.mImmediate * arrayType->mElementType->GetStride());
 								}
 								else
 								{
 									mcRelOffset = mcIdx1;
-									relScale = arrayType->mElementType->mSize;
+									relScale = arrayType->mElementType->GetStride();
 								}
 
 								result = AllocRelativeVirtualReg(elementPtrType, result, mcRelOffset, relScale);
@@ -16847,13 +16847,13 @@ void BeMCContext::Generate(BeFunction* function)
 							{
 								auto arrayType = (BeSizedArrayType*)ptrType->mElementType;
 								elementType = arrayType->mElementType;
-								byteOffset = mcIdx1.mImmediate * elementType->mSize;
+								byteOffset = mcIdx1.mImmediate * elementType->GetStride();
 							}
 							else if (ptrType->mElementType->mTypeCode == BeTypeCode_Vector)
 							{
 								auto arrayType = (BeVectorType*)ptrType->mElementType;
 								elementType = arrayType->mElementType;
-								byteOffset = mcIdx1.mImmediate * elementType->mSize;
+								byteOffset = mcIdx1.mImmediate * elementType->GetStride();
 							}
 							else
 							{
@@ -16876,12 +16876,12 @@ void BeMCContext::Generate(BeFunction* function)
 						int relScale = 1;
 						if (mcIdx0.IsImmediate())
 						{
-							mcRelOffset = BeMCOperand::FromImmediate(mcIdx0.mImmediate * ptrType->mElementType->mSize);
+							mcRelOffset = BeMCOperand::FromImmediate(mcIdx0.mImmediate * ptrType->mElementType->GetStride());
 						}
 						else
 						{
 							mcRelOffset = mcIdx0;
-							relScale = ptrType->mElementType->mSize;
+							relScale = ptrType->mElementType->GetStride();
 						}
 
 						result = AllocRelativeVirtualReg(ptrType, result, mcRelOffset, relScale);

+ 43 - 0
IDEHelper/Compiler/BfIRCodeGen.cpp

@@ -599,6 +599,14 @@ BfIRTypeEntry& BfIRCodeGen::GetTypeEntry(int typeId)
 	return typeEntry;
 }
 
+BfIRTypeEntry* BfIRCodeGen::GetTypeEntry(llvm::Type* type)
+{
+	int typeId = 0;
+	if (!mTypeToTypeIdMap.TryGetValue(type, &typeId))
+		return NULL;
+	return &GetTypeEntry(typeId);
+}
+
 void BfIRCodeGen::SetResult(int id, llvm::Value* value)
 {
 	BfIRCodeGenEntry entry;
@@ -1383,6 +1391,24 @@ llvm::Type* BfIRCodeGen::GetSizeAlignedType(BfIRTypeEntry* typeEntry)
 	return typeEntry->mLLVMType;
 }
 
+llvm::Value* BfIRCodeGen::GetAlignedPtr(llvm::Value* val)
+{
+	if (auto ptrType = llvm::dyn_cast<llvm::PointerType>(val->getType()))
+	{
+		auto elemType = ptrType->getElementType();
+		auto typeEntry = GetTypeEntry(elemType);
+		if (typeEntry != NULL)
+		{
+			auto alignedType = GetSizeAlignedType(typeEntry);
+			if (alignedType != elemType)
+			{
+				return mIRBuilder->CreateBitCast(val, alignedType->getPointerTo());				
+			}
+		}
+	}
+	return NULL;
+}
+
 llvm::Value* BfIRCodeGen::FixGEP(llvm::Value* fromValue, llvm::Value* result)
 {
 	if (auto ptrType = llvm::dyn_cast<llvm::PointerType>(fromValue->getType()))
@@ -1663,6 +1689,7 @@ void BfIRCodeGen::HandleNextCmd()
 			typeEntry.mLLVMType = type;
 			if (typeEntry.mInstLLVMType == NULL)
 				typeEntry.mInstLLVMType = type;
+			mTypeToTypeIdMap[type] = typeId;
 		}
 		break;
 	case BfIRCmd_SetInstType:
@@ -2153,6 +2180,14 @@ void BfIRCodeGen::HandleNextCmd()
 		{
 			CMD_PARAM(llvm::Value*, val);
 			CMD_PARAM(int, idx0);
+
+			if (auto alignedPtr = GetAlignedPtr(val))
+			{
+				auto gepResult = mIRBuilder->CreateConstInBoundsGEP1_32(NULL, alignedPtr, idx0);
+				SetResult(curId, mIRBuilder->CreateBitCast(gepResult, val->getType()));
+				break;
+			}
+
 			SetResult(curId, mIRBuilder->CreateConstInBoundsGEP1_32(NULL, val, idx0));
 		}
 		break;
@@ -2168,6 +2203,14 @@ void BfIRCodeGen::HandleNextCmd()
 		{
 			CMD_PARAM(llvm::Value*, val);
 			CMD_PARAM(llvm::Value*, idx0);
+
+			if (auto alignedPtr = GetAlignedPtr(val))
+			{
+				auto gepResult = mIRBuilder->CreateInBoundsGEP(alignedPtr, idx0);
+				SetResult(curId, mIRBuilder->CreateBitCast(gepResult, val->getType()));
+				break;
+			}
+
 			SetResult(curId, mIRBuilder->CreateInBoundsGEP(val, idx0));
 		}
 		break;

+ 3 - 0
IDEHelper/Compiler/BfIRCodeGen.h

@@ -116,6 +116,7 @@ public:
 	Array<llvm::Constant*> mConfigConsts64;
 	Dictionary<llvm::Type*, llvm::Value*> mReflectDataMap;
 	Dictionary<llvm::Type*, llvm::Type*> mAlignedTypeToNormalType;
+	Dictionary<llvm::Type*, int> mTypeToTypeIdMap;
 
 public:		
 	void InitTarget();
@@ -123,6 +124,7 @@ public:
 	BfTypeCode GetTypeCode(llvm::Type* type, bool isSigned);
 	llvm::Type* GetLLVMType(BfTypeCode typeCode, bool& isSigned);
 	BfIRTypeEntry& GetTypeEntry(int typeId);
+	BfIRTypeEntry* GetTypeEntry(llvm::Type* type);
 	void SetResult(int id, llvm::Value* value);
 	void SetResult(int id, llvm::Type* value);	
 	void SetResult(int id, llvm::BasicBlock* value);
@@ -135,6 +137,7 @@ public:
 	bool TryMemCpy(llvm::Value* ptr, llvm::Value* val);	
 	bool TryVectorCpy(llvm::Value* ptr, llvm::Value* val);
 	llvm::Type* GetSizeAlignedType(BfIRTypeEntry* typeEntry);
+	llvm::Value* GetAlignedPtr(llvm::Value* val);
 	llvm::Value* FixGEP(llvm::Value* fromValue, llvm::Value* result);
 
 public: